Module auton_survival.models.dsm
Deep Survival Machines
Deep Survival Machines (DSM) is a fully parametric approach to model Time-to-Event outcomes in the presence of Censoring first introduced in [1]. In the context of Healthcare ML and Biostatistics, this is known as 'Survival Analysis'. The key idea behind Deep Survival Machines is to model the underlying event outcome distribution as a mixure of some fixed k parametric distributions. The parameters of these mixture distributions as well as the mixing weights are modelled using Neural Networks.
Example Usage
>>> from dsm import DeepSurvivalMachines
>>> from dsm import datasets
>>> # load the SUPPORT dataset.
>>> x, t, e = datasets.load_dataset('SUPPORT')
>>> # instantiate a DeepSurvivalMachines model.
>>> model = DeepSurvivalMachines()
>>> # fit the model to the dataset.
>>> model.fit(x, t, e)
>>> # estimate the predicted risks at the time
>>> model.predict_risk(x, 10)
Deep Recurrent Survival Machines
Deep Recurrent Survival Machines (DRSM) builds on the original DSM
model and allows for learning of representations of the input covariates using
Recurrent Neural Networks like LSTMs, GRUs. Deep Recurrent Survival
Machines is a natural fit to model problems where there are time dependendent
covariates. Examples include situations where we are working with streaming
data like vital signs, degradation monitoring signals in predictive
maintainance. DRSM allows the learnt representations at each time step to
involve historical context from previous time steps. DRSM implementation in
dsm
is carried out through an easy to use API,
DeepRecurrentSurvivalMachines
that accepts lists of data streams and
corresponding failure times. The module automatically takes care of appropriate
batching and padding of variable length sequences.
Deep Convolutional Survival Machines
Predictive maintenance and medical imaging sometimes requires to work with image streams. Deep Convolutional Survival Machines extends DSM and DRSM to learn representations of the input image data using convolutional layers. If working with streaming data, the learnt representations are then passed through an LSTM to model temporal dependencies before determining the underlying survival distributions.
Warning: Not Implemented Yet!
References
Please cite the following papers if you are using the auton_survival
package.
@article{nagpal2021dsm,
title={Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks},
author={Nagpal, Chirag and Li, Xinyu and Dubrawski, Artur},
journal={IEEE Journal of Biomedical and Health Informatics},
volume={25},
number={8},
pages={3163--3175},
year={2021},
publisher={IEEE}
}
[2] Deep Parametric Time-to-Event Regression with Time-Varying Covariates. AAAI Spring Symposium (2021)
@InProceedings{pmlr-v146-nagpal21a,
title={Deep Parametric Time-to-Event Regression with Time-Varying Covariates},
author={Nagpal, Chirag and Jeanselme, Vincent and Dubrawski, Artur},
booktitle={Proceedings of AAAI Spring Symposium on Survival Prediction - Algorithms, Challenges, and Applications 2021},
series={Proceedings of Machine Learning Research},
publisher={PMLR},
}
[3] Deep Cox Mixtures for Survival Regression. Conference on Machine Learning for Healthcare (2021)
@inproceedings{nagpal2021dcm,
title={Deep Cox mixtures for survival regression},
author={Nagpal, Chirag and Yadlowsky, Steve and Rostamzadeh, Negar and Heller, Katherine},
booktitle={Machine Learning for Healthcare Conference},
pages={674--708},
year={2021},
organization={PMLR}
}
[4] Counterfactual Phenotyping with Censored Time-to-Events (2022)
@article{nagpal2022counterfactual,
title={Counterfactual Phenotyping with Censored Time-to-Events},
author={Nagpal, Chirag and Goswami, Mononito and Dufendach, Keith and Dubrawski, Artur},
journal={arXiv preprint arXiv:2202.11089},
year={2022}
}
Sub-modules
auton_survival.models.dsm.datasets
-
Utility functions to load standard datasets to train and evaluate the Deep Survival Machines models.
auton_survival.models.dsm.dsm_torch
-
Torch model definitons for the Deep Survival Machines model
This includes definitons for the Torch Deep Survival Machines module. The main interface is the DeepSurvivalMachines class which inherits from torch.nn.Module.
Note: NOT DESIGNED TO BE CALLED DIRECTLY!!!
auton_survival.models.dsm.losses
-
Loss function definitions for the Deep Survival Machines model
In this module we define the various losses for the censored and uncensored instances of data corresponding to Weibull and LogNormal distributions. These losses are optimized when training DSM.
TODO
Use torch.distributions
Warning
NOT DESIGNED TO BE CALLED DIRECTLY!!!
auton_survival.models.dsm.utilities
-
Utility functions to train the Deep Survival Machines models
Classes
class DeepSurvivalMachines (k=3, layers=None, distribution='Weibull', temp=1000.0, discount=1.0, random_seed=0)
-
A Deep Survival Machines model.
This is the main interface to a Deep Survival Machines model. A model is instantiated with approporiate set of hyperparameters and fit on numpy arrays consisting of the features, event/censoring times and the event/censoring indicators.
For full details on Deep Survival Machines, refer to our paper [1].
References
Parameters
k
:int
- The number of underlying parametric distributions.
layers
:list
- A list of integers consisting of the number of neurons in each hidden layer.
distribution
:str
- Choice of the underlying survival distributions. One of 'Weibull', 'LogNormal'. Default is 'Weibull'.
temp
:float
- The logits for the gate are rescaled with this value. Default is 1000.
discount
:float
- a float in [0,1] that determines how to discount the tail bias from the uncensored instances. Default is 1.
Example
>>> from dsm import DeepSurvivalMachines >>> model = DeepSurvivalMachines() >>> model.fit(x, t, e)
class DeepRecurrentSurvivalMachines (k=3, layers=None, hidden=None, distribution='Weibull', temp=1000.0, discount=1.0, typ='LSTM', random_seed=0)
-
The Deep Recurrent Survival Machines model to handle data with time-dependent covariates.
For full details on Deep Recurrent Survival Machines, refer to our paper [1].
References
class DeepConvolutionalSurvivalMachines (k=3, layers=None, hidden=None, distribution='Weibull', temp=1000.0, discount=1.0, typ='ConvNet')
-
The Deep Convolutional Survival Machines model to handle data with image-based covariates.
class DeepCNNRNNSurvivalMachines (k=3, layers=None, hidden=None, distribution='Weibull', temp=1000.0, discount=1.0, typ='LSTM')
-
The Deep CNN-RNN Survival Machines model to handle data with moving image streams.