Module auton_survival.experiments
Utilities to perform cross-validation.
Classes
- class SurvivalRegressionCV (model='dcph', folds=None, num_folds=5, random_seed=0, hyperparam_grid={})
- 
Universal interface to train Survival Analysis models in a cross- validation fashion. The model is trained in a CV fashion over the user-specified hyperparameter grid. Model hyperparameters are selected based on the user-specified metric. Parameters- model:- str
- A string that determines the choice of the surival regression model. Survival model choices include: - 'dsm' : Deep Survival Machines [3] model - 'dcph' : Deep Cox Proportional Hazards [2] model - 'dcm' : Deep Cox Mixtures [4] model - 'rsf' : Random Survival Forests [1] model - 'cph' : Cox Proportional Hazards [2] model
- model:- str, default=- 'dcph'
- Survival regression model name.
- folds:- list, default=- None
- A list of fold assignment values for each sample. For regular (unnested) cross-validation, folds correspond to train and validation set. For nested cross-validation, folds correspond to train and test set.
- num_folds:- int, default=- 5
- The number of folds. Ignored if folds is specified.
- random_seed:- int, default=- 0
- Controls reproducibility of results.
- hyperparam_grid:- dict
- A dictionary that contains the hyperparameters for grid search. The keys of the dictionary are the hyperparameter names and the values are lists of hyperparameter values.
 References[1] Hemant Ishwaran et al. Random survival forests. The annals of applied statistics, 2(3):841–860, 2008. [2] Cox, D. R. (1972). Regression models and life-tables. Journal of the Royal Statistical Society: Series B (Methodological). [3] Chirag Nagpal, Xinyu Li, and Artur Dubrawski. Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks. 2020. [4] Nagpal, C., Yadlowsky, S., Rostamzadeh, N., and Heller, K. (2021c). Deep cox mixtures for survival regression. In Machine Learning for Healthcare Conference, pages 674–708. PMLR Methods- def fit(self, features, outcomes, horizons, metric='ibs')
- 
Fits the survival regression model to the data in a cross- validation or nested cross-validation fashion. Parameters- features:- pd.DataFrame
- A pandas dataframe with rows corresponding to individual samples and columns as covariates.
- outcomes:- pd.DataFrame
- A pandas dataframe with columns 'time' and 'event' that contain the survival time and censoring status \delta_i = 1 , respectively.
- horizons:- intor- floator- list
- Event-horizons at which to evaluate model performance.
- metric:- str, default=- 'ibs'
- Metric used to evaluate model performance and tune hyperparameters. Options include: - 'auc': Dynamic area under the ROC curve - 'brs' : Brier Score - 'ibs' : Integrated Brier Score - 'ctd' : Concordance Index
 ReturnsTrained survival regression model(s). 
 
- class CounterfactualSurvivalRegressionCV (model, cv_folds=5, random_seed=0, hyperparam_grid={})
- 
Universal interface to train Counterfactual Survival Analysis models in a Cross Validation fashion. Each of the model is trained in a CV fashion over the user specified hyperparameter grid. The best model (in terms of integrated brier score) is then selected. Parameters- model:- str
- A string that determines the choice of the surival analysis model. Survival model choices include: - 'dsm' : Deep Survival Machines [3] model - 'dcph' : Deep Cox Proportional Hazards [2] model - 'dcm' : Deep Cox Mixtures [4] model - 'rsf' : Random Survival Forests [1] model - 'cph' : Cox Proportional Hazards [2] model
- cv_folds:- int
- Number of folds in the cross validation.
- random_seed:- int
- Random seed for reproducibility.
- hyperparam_grid:- dict
- A dictionary that contains the hyperparameters for grid search. The keys of the dictionary are the hyperparameter names and the values are lists of hyperparameter values.
 References[1] Hemant Ishwaran et al. Random survival forests. The annals of applied statistics, 2(3):841–860, 2008. [2] Cox, D. R. (1972). Regression models and life-tables. Journal of the Royal Statistical Society: Series B (Methodological). [3] Chirag Nagpal, Xinyu Li, and Artur Dubrawski. Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks. 2020. [4] Nagpal, C., Yadlowsky, S., Rostamzadeh, N., and Heller, K. (2021c). Deep cox mixtures for survival regression. In Machine Learning for Healthcare Conference, pages 674–708. PMLR Methods- def fit(self, features, outcomes, interventions, horizons, metric)
- 
Fits the Survival Regression Model to the data in a cross- validation fashion. Parameters- features:- pandas.DataFrame
- a pandas dataframe containing the features to use as covariates.
- outcomes:- pandas.DataFrame
- a pandas dataframe containing the survival outcomes. The index of the dataframe should be the same as the index of the features dataframe. Should contain a column named 'time' that contains the survival time and a column named 'event' that contains the censoring status. \delta_i = 1 if the event is observed.
- interventions:- pandas.Series
- A pandas series containing the treatment status of each subject.
 a_i = 1  if the subject is treated, else is considered control.
- horizons:- intor- floator- list
- Event-horizons at which to evaluate model performance.
- metric:- str, default=- 'ibs'
- Metric used to evaluate model performance and tune hyperparameters. Options include: - 'auc': Dynamic area under the ROC curve - 'brs' : Brier Score - 'ibs' : Integrated Brier Score - 'ctd' : Concordance Index
 Returns- auton_survival.estimators.CounterfactualSurvivalModel:
- The trained counterfactual survival model.