Module auton_survival.estimators
Utilities to train survival regression models and estimate survival.
Classes
class SurvivalModel (model, random_seed=0, **hyperparams)
-
Universal interface to train multiple different survival models.
Parameters
model
:str
-
A string that determines the choice of the surival analysis model. Survival model choices include:
dsm
: Deep Survival Machines [3] modeldcph
: Deep Cox Proportional Hazards [2] modeldcm
: Deep Cox Mixtures [4] modelrsf
: Random Survival Forests [1] modelcph
: Cox Proportional Hazards [2] model
random_seed
:int
- Controls the reproducibility of called functions.
References
[1] Hemant Ishwaran et al. Random survival forests. The annals of applied statistics, 2(3):841–860, 2008.
[2] Cox, D. R. (1972). Regression models and life-tables. Journal of the Royal Statistical Society: Series B (Methodological).
[3] Chirag Nagpal, Xinyu Li, and Artur Dubrawski. Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks. 2020.
[4] Nagpal, C., Yadlowsky, S., Rostamzadeh, N., and Heller, K. (2021c). Deep cox mixtures for survival regression. In Machine Learning for Healthcare Conference, pages 674–708. PMLR
Methods
def fit(self, features, outcomes, vsize=0.15, val_data=None, weights=None, weights_val=None, resample_size=1.0)
-
This method is used to train an instance of the survival model.
Parameters
features
:pd.DataFrame
- a pandas dataframe with rows corresponding to individual samples and columns as covariates.
outcomes
:pd.DataFrame
- a pandas dataframe with columns 'time' and 'event'.
vsize
:float
, default=0.15
- Amount of data to set aside as the validation set. Not applicable to 'rsf' and 'cph' models.
val_data
:tuple
- A tuple of the validation dataset features and outcomes of 'time' and 'event'. If passed, vsize is ignored. Not applicable to 'rsf' and 'cph' models.
weights_train
:list
ornp.array
- a list or numpy array of importance weights for each sample.
weights_val
:list
ornp.array
- a list or numpy array of importance weights for each validation set sample. Ignored if val_data is None.
resample_size
:float
- a float between 0 and 1 that controls the size of the resampled dataset.
Returns
self
- Trained instance of a survival model.
def predict_survival(self, features, times)
-
Predict survival probabilities at specified time(s).
Parameters
features
:pd.DataFrame
- a pandas dataframe with rows corresponding to individual samples and columns as covariates.
times
:float
orlist
- a float or list of the times at which to compute the survival probability.
Returns
np.array : An array
ofthe survival probabilites at each
time point in times.
def predict_risk(self, features, times)
-
Predict risk of an outcome occurring within the specified time(s).
Parameters
features
:pd.DataFrame
- a pandas dataframe with rows corresponding to individual samples and columns as covariates.
times
:float
orlist
- a float or list of the times at which to compute the risk.
Returns
np.array
- numpy array of the outcome risks at each time point in times.
class CounterfactualSurvivalModel (treated_model, control_model)
-
Universal interface to train multiple different counterfactual survival models.
Methods
def predict_counterfactual_survival(self, features, times)
def predict_counterfactual_risk(self, features, times)