Module auton_survival.models.cmhe
Cox Mixtures With Heterogenous Effects
Cox Mixture with Heterogenous Effects (CMHE) is a flexible approach to recover counterfactual phenotypes of individuals that demonstrate heterogneous effects to an intervention in terms of censored Time-to-Event outcomes. CMHE is not restricted by the strong Cox Proportional Hazards assumption or any parametric assumption on the time to event distributions. CMHE achieves this by describing each individual as belonging to two different latent groups, \mathcal{Z} that mediate the base survival rate and \phi the effect of the treatment. CMHE can also be employed to model individual level counterfactuals or for standard factual survival regression.
For full details on Cox Mixtures with Heterogenous Effects, please refer to our preprint:
Example Usage
>>> from auton_survival import DeepCoxMixturesHeterogenousEffects
>>> from auton_survival import datasets
>>> # load the SYNTHETIC dataset.
>>> x, t, e, a = datasets.load_dataset('SYNTHETIC')
>>> # instantiate a Cox Mixtures with Heterogenous Effects model.
>>> model = DeepCoxMixturesHeterogenousEffects()
>>> # fit the model to the dataset.
>>> model.fit(x, t, e, a)
>>> # estimate the predicted risks at the time
>>> model.predict_risk(x, 10)
>>> # estimate the treatment effect phenogroups
>>> model.predict_latent_phi(x)
Sub-modules
auton_survival.models.cmhe.cmhe_torch
auton_survival.models.cmhe.cmhe_utilities
Classes
class DeepCoxMixturesHeterogenousEffects (k, g, layers=None, gamma=100, smoothing_factor=0.0001, gate_l2_penalty=0.0001, random_seed=0)
-
A Deep Cox Mixtures with Heterogenous Effects model.
This is the main interface to a Deep Cox Mixture with Heterogenous Effects. A model is instantiated with approporiate set of hyperparameters and fit on numpy arrays consisting of the features, event/censoring times and the event/censoring indicators.
For full details on Deep Cox Mixture, refer to the paper [1].
References
[1] Nagpal, C., Goswami M., Dufendach K., and Artur Dubrawski. "Counterfactual phenotyping for censored Time-to-Events" (2022).
Parameters
k
:int
- The number of underlying base survival phenotypes.
g
:int
- The number of underlying treatment effect phenotypes.
layers
:list
- A list of integers consisting of the number of neurons in each hidden layer.
gate_l2_penalty
:float
- Strength of the l2 penalty term for the gate layers. Higher means stronger regularization.
random_seed
:int
- Controls the reproducibility of called functions.
Example
>>> from auton_survival import DeepCoxMixturesHeterogenousEffects >>> model = DeepCoxMixturesHeterogenousEffects(k=2, g=3) >>> model.fit(x, t, e, a)
Methods
def fit(self, x, t, e, a, vsize=0.15, val_data=None, iters=1, learning_rate=0.001, batch_size=100, patience=2, optimizer='Adam')
-
This method is used to train an instance of the DSM model.
Parameters
x
:np.ndarray
- A numpy array of the input features, x .
t
:np.ndarray
- A numpy array of the event/censoring times, t .
e
:np.ndarray
- A numpy array of the event/censoring indicators, \delta . \delta = 1 means the event took place.
a
:np.ndarray
- A numpy array of the treatment assignment indicators, a . a = 1 means the individual was treated.
vsize
:float
- Amount of data to set aside as the validation set.
val_data
:tuple
- A tuple of the validation dataset. If passed vsize is ignored.
iters
:int
- The maximum number of training iterations on the training dataset.
learning_rate
:float
- The learning rate for the
Adam
optimizer. batch_size
:int
- learning is performed on mini-batches of input data. this parameter specifies the size of each mini-batch.
optimizer
:str
- The choice of the gradient based optimization method. One of 'Adam', 'RMSProp' or 'SGD'.
def predict_risk(self, x, a, t=None)
def predict_survival(self, x, a, t=None)
-
Returns the estimated survival probability at time t , \widehat{\mathbb{P}}(T > t|X) for some input data x .
Parameters
x
:np.ndarray
- A numpy array of the input features, x .
a
:np.ndarray
- A numpy array of the treatmeant assignment, a .
t
:list
orfloat
- a list or float of the times at which survival probability is to be computed
Returns
np.array
- numpy array of the survival probabilites at each time in t.
def predict_latent_z(self, x)
-
Returns the estimated latent base survival group z given the confounders x .
def predict_latent_phi(self, x)
-
Returns the estimated latent treatment effect group \phi given the confounders x .