Module auton_survival.estimators
Utilities to train survival regression models and estimate survival.
Classes
class SurvivalModel (model, random_seed=0, **hyperparams)-
Universal interface to train multiple different survival models.
Parameters
model:str-
A string that determines the choice of the surival analysis model. Survival model choices include:
dsm: Deep Survival Machines [3] modeldcph: Deep Cox Proportional Hazards [2] modeldcm: Deep Cox Mixtures [4] modelrsf: Random Survival Forests [1] modelcph: Cox Proportional Hazards [2] model
random_seed:int- Controls the reproducibility of called functions.
References
[1] Hemant Ishwaran et al. Random survival forests. The annals of applied statistics, 2(3):841–860, 2008.
[2] Cox, D. R. (1972). Regression models and life-tables. Journal of the Royal Statistical Society: Series B (Methodological).
[3] Chirag Nagpal, Xinyu Li, and Artur Dubrawski. Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks. 2020.
[4] Nagpal, C., Yadlowsky, S., Rostamzadeh, N., and Heller, K. (2021c). Deep cox mixtures for survival regression. In Machine Learning for Healthcare Conference, pages 674–708. PMLR
Methods
def fit(self, features, outcomes, vsize=0.15, val_data=None, weights=None, weights_val=None, resample_size=1.0)-
This method is used to train an instance of the survival model.
Parameters
features:pd.DataFrame- a pandas dataframe with rows corresponding to individual samples and columns as covariates.
outcomes:pd.DataFrame- a pandas dataframe with columns 'time' and 'event'.
vsize:float, default=0.15- Amount of data to set aside as the validation set. Not applicable to 'rsf' and 'cph' models.
val_data:tuple- A tuple of the validation dataset features and outcomes of 'time' and 'event'. If passed, vsize is ignored. Not applicable to 'rsf' and 'cph' models.
weights_train:listornp.array- a list or numpy array of importance weights for each sample.
weights_val:listornp.array- a list or numpy array of importance weights for each validation set sample. Ignored if val_data is None.
resample_size:float- a float between 0 and 1 that controls the size of the resampled dataset.
Returns
self- Trained instance of a survival model.
def predict_survival(self, features, times)-
Predict survival probabilities at specified time(s).
Parameters
features:pd.DataFrame- a pandas dataframe with rows corresponding to individual samples and columns as covariates.
times:floatorlist- a float or list of the times at which to compute the survival probability.
Returns
np.array : An arrayofthe survival probabilites at each
time point in times.
def predict_risk(self, features, times)-
Predict risk of an outcome occurring within the specified time(s).
Parameters
features:pd.DataFrame- a pandas dataframe with rows corresponding to individual samples and columns as covariates.
times:floatorlist- a float or list of the times at which to compute the risk.
Returns
np.array- numpy array of the outcome risks at each time point in times.
class CounterfactualSurvivalModel (treated_model, control_model)-
Universal interface to train multiple different counterfactual survival models.
Methods
def predict_counterfactual_survival(self, features, times)def predict_counterfactual_risk(self, features, times)