Module auton_survival.experiments
Utilities to perform cross-validation.
Classes
class SurvivalRegressionCV (model='dcph', folds=None, num_folds=5, random_seed=0, hyperparam_grid={})
-
Universal interface to train Survival Analysis models in a cross- validation fashion.
The model is trained in a CV fashion over the user-specified hyperparameter grid. Model hyperparameters are selected based on the user-specified metric.
Parameters
model
:str
- A string that determines the choice of the surival regression model. Survival model choices include: - 'dsm' : Deep Survival Machines [3] model - 'dcph' : Deep Cox Proportional Hazards [2] model - 'dcm' : Deep Cox Mixtures [4] model - 'rsf' : Random Survival Forests [1] model - 'cph' : Cox Proportional Hazards [2] model
model
:str
, default='dcph'
- Survival regression model name.
folds
:list
, default=None
- A list of fold assignment values for each sample. For regular (unnested) cross-validation, folds correspond to train and validation set. For nested cross-validation, folds correspond to train and test set.
num_folds
:int
, default=5
- The number of folds. Ignored if folds is specified.
random_seed
:int
, default=0
- Controls reproducibility of results.
hyperparam_grid
:dict
- A dictionary that contains the hyperparameters for grid search. The keys of the dictionary are the hyperparameter names and the values are lists of hyperparameter values.
References
[1] Hemant Ishwaran et al. Random survival forests. The annals of applied statistics, 2(3):841–860, 2008. [2] Cox, D. R. (1972). Regression models and life-tables. Journal of the Royal Statistical Society: Series B (Methodological). [3] Chirag Nagpal, Xinyu Li, and Artur Dubrawski. Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks. 2020. [4] Nagpal, C., Yadlowsky, S., Rostamzadeh, N., and Heller, K. (2021c). Deep cox mixtures for survival regression. In Machine Learning for Healthcare Conference, pages 674–708. PMLR
Methods
def fit(self, features, outcomes, horizons, metric='ibs')
-
Fits the survival regression model to the data in a cross- validation or nested cross-validation fashion.
Parameters
features
:pd.DataFrame
- A pandas dataframe with rows corresponding to individual samples and columns as covariates.
outcomes
:pd.DataFrame
- A pandas dataframe with columns 'time' and 'event' that contain the survival time and censoring status \delta_i = 1 , respectively.
horizons
:int
orfloat
orlist
- Event-horizons at which to evaluate model performance.
metric
:str
, default='ibs'
- Metric used to evaluate model performance and tune hyperparameters. Options include: - 'auc': Dynamic area under the ROC curve - 'brs' : Brier Score - 'ibs' : Integrated Brier Score - 'ctd' : Concordance Index
Returns
Trained survival regression model(s).
class CounterfactualSurvivalRegressionCV (model, cv_folds=5, random_seed=0, hyperparam_grid={})
-
Universal interface to train Counterfactual Survival Analysis models in a Cross Validation fashion.
Each of the model is trained in a CV fashion over the user specified hyperparameter grid. The best model (in terms of integrated brier score) is then selected.
Parameters
model
:str
- A string that determines the choice of the surival analysis model. Survival model choices include: - 'dsm' : Deep Survival Machines [3] model - 'dcph' : Deep Cox Proportional Hazards [2] model - 'dcm' : Deep Cox Mixtures [4] model - 'rsf' : Random Survival Forests [1] model - 'cph' : Cox Proportional Hazards [2] model
cv_folds
:int
- Number of folds in the cross validation.
random_seed
:int
- Random seed for reproducibility.
hyperparam_grid
:dict
- A dictionary that contains the hyperparameters for grid search. The keys of the dictionary are the hyperparameter names and the values are lists of hyperparameter values.
References
[1] Hemant Ishwaran et al. Random survival forests. The annals of applied statistics, 2(3):841–860, 2008.
[2] Cox, D. R. (1972). Regression models and life-tables. Journal of the Royal Statistical Society: Series B (Methodological).
[3] Chirag Nagpal, Xinyu Li, and Artur Dubrawski. Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks. 2020.
[4] Nagpal, C., Yadlowsky, S., Rostamzadeh, N., and Heller, K. (2021c). Deep cox mixtures for survival regression. In Machine Learning for Healthcare Conference, pages 674–708. PMLR
Methods
def fit(self, features, outcomes, interventions, horizons, metric)
-
Fits the Survival Regression Model to the data in a cross- validation fashion.
Parameters
features
:pandas.DataFrame
- a pandas dataframe containing the features to use as covariates.
outcomes
:pandas.DataFrame
- a pandas dataframe containing the survival outcomes. The index of the dataframe should be the same as the index of the features dataframe. Should contain a column named 'time' that contains the survival time and a column named 'event' that contains the censoring status. \delta_i = 1 if the event is observed.
interventions
:pandas.Series
- A pandas series containing the treatment status of each subject.
a_i = 1 if the subject is
treated
, else is considered control. horizons
:int
orfloat
orlist
- Event-horizons at which to evaluate model performance.
metric
:str
, default='ibs'
- Metric used to evaluate model performance and tune hyperparameters. Options include: - 'auc': Dynamic area under the ROC curve - 'brs' : Brier Score - 'ibs' : Integrated Brier Score - 'ctd' : Concordance Index
Returns
auton_survival.estimators.CounterfactualSurvivalModel:
- The trained counterfactual survival model.