Module auton_survival.models.dcm
Deep Cox Mixtures
The Cox Mixture involves the assumption that the survival function of the individual to be a mixture of K Cox Models. Conditioned on each subgroup Z=k ; the PH assumptions are assumed to hold and the baseline hazard rates is determined non-parametrically using an spline-interpolated Breslow's estimator.
For full details on Deep Cox Mixture, refer to the paper [1].
References
[1] Deep Cox Mixtures for Survival Regression. Machine Learning in Health Conference (2021)
@article{nagpal2021dcm,
title={Deep Cox mixtures for survival regression},
author={Nagpal, Chirag and Yadlowsky, Steve and Rostamzadeh, Negar and Heller, Katherine},
journal={arXiv preprint arXiv:2101.06536},
year={2021}
}
Sub-modules
auton_survival.models.dcm.dcm_torchauton_survival.models.dcm.dcm_utilities
Classes
class DeepCoxMixtures (k=3, layers=None, gamma=10, smoothing_factor=0.0001, use_activation=False, random_seed=0)-
A Deep Cox Mixture model.
This is the main interface to a Deep Cox Mixture model. A model is instantiated with approporiate set of hyperparameters and fit on numpy arrays consisting of the features, event/censoring times and the event/censoring indicators.
For full details on Deep Cox Mixture, refer to the paper [1].
References
[1] Deep Cox Mixtures for Survival Regression. Machine Learning in Health Conference (2021)
Parameters
k:int- The number of underlying Cox distributions.
layers:list- A list of integers consisting of the number of neurons in each hidden layer.
random_seed:int- Controls the reproducibility of called functions.
Example
>>> from auton_survival.models.dcm import DeepCoxMixtures >>> model = DeepCoxMixtures() >>> model.fit(x, t, e)Methods
def fit(self, x, t, e, vsize=0.15, val_data=None, iters=1, learning_rate=0.001, batch_size=100, optimizer='Adam')-
This method is used to train an instance of the DSM model.
Parameters
x:np.ndarray- A numpy array of the input features, x .
t:np.ndarray- A numpy array of the event/censoring times, t .
e:np.ndarray- A numpy array of the event/censoring indicators, \delta . \delta = 1 means the event took place.
vsize:float- Amount of data to set aside as the validation set.
val_data:tuple- A tuple of the validation dataset. If passed vsize is ignored.
iters:int- The maximum number of training iterations on the training dataset.
learning_rate:float- The learning rate for the
Adamoptimizer. batch_size:int- learning is performed on mini-batches of input data. this parameter specifies the size of each mini-batch.
optimizer:str- The choice of the gradient based optimization method. One of 'Adam', 'RMSProp' or 'SGD'.
def predict_survival(self, x, t)-
Returns the estimated survival probability at time t , \widehat{\mathbb{P}}(T > t|X) for some input data x .
Parameters
x:np.ndarray- A numpy array of the input features, x .
t:listorfloat- a list or float of the times at which survival probability is to be computed
Returns
np.array- numpy array of the survival probabilites at each time in t.
def predict_latent_z(self, x)