Module auton_survival.experiments
Utilities to perform cross-validation.
Classes
class SurvivalRegressionCV (model='dcph', folds=None, num_folds=5, random_seed=0, hyperparam_grid={})-
Universal interface to train Survival Analysis models in a cross- validation fashion.
The model is trained in a CV fashion over the user-specified hyperparameter grid. Model hyperparameters are selected based on the user-specified metric.
Parameters
model:str- A string that determines the choice of the surival regression model. Survival model choices include: - 'dsm' : Deep Survival Machines [3] model - 'dcph' : Deep Cox Proportional Hazards [2] model - 'dcm' : Deep Cox Mixtures [4] model - 'rsf' : Random Survival Forests [1] model - 'cph' : Cox Proportional Hazards [2] model
model:str, default='dcph'- Survival regression model name.
folds:list, default=None- A list of fold assignment values for each sample. For regular (unnested) cross-validation, folds correspond to train and validation set. For nested cross-validation, folds correspond to train and test set.
num_folds:int, default=5- The number of folds. Ignored if folds is specified.
random_seed:int, default=0- Controls reproducibility of results.
hyperparam_grid:dict- A dictionary that contains the hyperparameters for grid search. The keys of the dictionary are the hyperparameter names and the values are lists of hyperparameter values.
References
[1] Hemant Ishwaran et al. Random survival forests. The annals of applied statistics, 2(3):841–860, 2008. [2] Cox, D. R. (1972). Regression models and life-tables. Journal of the Royal Statistical Society: Series B (Methodological). [3] Chirag Nagpal, Xinyu Li, and Artur Dubrawski. Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks. 2020. [4] Nagpal, C., Yadlowsky, S., Rostamzadeh, N., and Heller, K. (2021c). Deep cox mixtures for survival regression. In Machine Learning for Healthcare Conference, pages 674–708. PMLR
Methods
def fit(self, features, outcomes, horizons, metric='ibs')-
Fits the survival regression model to the data in a cross- validation or nested cross-validation fashion.
Parameters
features:pd.DataFrame- A pandas dataframe with rows corresponding to individual samples and columns as covariates.
outcomes:pd.DataFrame- A pandas dataframe with columns 'time' and 'event' that contain the survival time and censoring status \delta_i = 1 , respectively.
horizons:intorfloatorlist- Event-horizons at which to evaluate model performance.
metric:str, default='ibs'- Metric used to evaluate model performance and tune hyperparameters. Options include: - 'auc': Dynamic area under the ROC curve - 'brs' : Brier Score - 'ibs' : Integrated Brier Score - 'ctd' : Concordance Index
Returns
Trained survival regression model(s).
class CounterfactualSurvivalRegressionCV (model, cv_folds=5, random_seed=0, hyperparam_grid={})-
Universal interface to train Counterfactual Survival Analysis models in a Cross Validation fashion.
Each of the model is trained in a CV fashion over the user specified hyperparameter grid. The best model (in terms of integrated brier score) is then selected.
Parameters
model:str- A string that determines the choice of the surival analysis model. Survival model choices include: - 'dsm' : Deep Survival Machines [3] model - 'dcph' : Deep Cox Proportional Hazards [2] model - 'dcm' : Deep Cox Mixtures [4] model - 'rsf' : Random Survival Forests [1] model - 'cph' : Cox Proportional Hazards [2] model
cv_folds:int- Number of folds in the cross validation.
random_seed:int- Random seed for reproducibility.
hyperparam_grid:dict- A dictionary that contains the hyperparameters for grid search. The keys of the dictionary are the hyperparameter names and the values are lists of hyperparameter values.
References
[1] Hemant Ishwaran et al. Random survival forests. The annals of applied statistics, 2(3):841–860, 2008.
[2] Cox, D. R. (1972). Regression models and life-tables. Journal of the Royal Statistical Society: Series B (Methodological).
[3] Chirag Nagpal, Xinyu Li, and Artur Dubrawski. Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks. 2020.
[4] Nagpal, C., Yadlowsky, S., Rostamzadeh, N., and Heller, K. (2021c). Deep cox mixtures for survival regression. In Machine Learning for Healthcare Conference, pages 674–708. PMLR
Methods
def fit(self, features, outcomes, interventions, horizons, metric)-
Fits the Survival Regression Model to the data in a cross- validation fashion.
Parameters
features:pandas.DataFrame- a pandas dataframe containing the features to use as covariates.
outcomes:pandas.DataFrame- a pandas dataframe containing the survival outcomes. The index of the dataframe should be the same as the index of the features dataframe. Should contain a column named 'time' that contains the survival time and a column named 'event' that contains the censoring status. \delta_i = 1 if the event is observed.
interventions:pandas.Series- A pandas series containing the treatment status of each subject.
a_i = 1 if the subject is
treated, else is considered control. horizons:intorfloatorlist- Event-horizons at which to evaluate model performance.
metric:str, default='ibs'- Metric used to evaluate model performance and tune hyperparameters. Options include: - 'auc': Dynamic area under the ROC curve - 'brs' : Brier Score - 'ibs' : Integrated Brier Score - 'ctd' : Concordance Index
Returns
auton_survival.estimators.CounterfactualSurvivalModel:- The trained counterfactual survival model.