Module auton_survival.models.cmhe

Cox Mixtures With Heterogenous Effects

Build Status     codecov     License: MIT     GitHub Repo stars

Cox Mixture with Heterogenous Effects (CMHE) is a flexible approach to recover counterfactual phenotypes of individuals that demonstrate heterogneous effects to an intervention in terms of censored Time-to-Event outcomes. CMHE is not restricted by the strong Cox Proportional Hazards assumption or any parametric assumption on the time to event distributions. CMHE achieves this by describing each individual as belonging to two different latent groups, \mathcal{Z} that mediate the base survival rate and \phi the effect of the treatment. CMHE can also be employed to model individual level counterfactuals or for standard factual survival regression.

For full details on Cox Mixtures with Heterogenous Effects, please refer to our preprint:

Counterfactual Phenotyping with Censored Time-to-Events, arXiv preprint, C. Nagpal, M. Goswami, K. Dufendach, A. Dubrawski

Example Usage

>>> from auton_survival import DeepCoxMixturesHeterogenousEffects
>>> from auton_survival import datasets
>>> # load the SYNTHETIC dataset.
>>> x, t, e, a = datasets.load_dataset('SYNTHETIC')
>>> # instantiate a Cox Mixtures with Heterogenous Effects model.
>>> model = DeepCoxMixturesHeterogenousEffects()
>>> # fit the model to the dataset.
>>>, t, e, a)
>>> # estimate the predicted risks at the time
>>> model.predict_risk(x, 10)
>>> # estimate the treatment effect phenogroups
>>> model.predict_latent_phi(x)




class DeepCoxMixturesHeterogenousEffects (k, g, layers=None, gamma=100, smoothing_factor=0.0001, gate_l2_penalty=0.0001, random_seed=0)

A Deep Cox Mixtures with Heterogenous Effects model.

This is the main interface to a Deep Cox Mixture with Heterogenous Effects. A model is instantiated with approporiate set of hyperparameters and fit on numpy arrays consisting of the features, event/censoring times and the event/censoring indicators.

For full details on Deep Cox Mixture, refer to the paper [1].


[1] Nagpal, C., Goswami M., Dufendach K., and Artur Dubrawski. "Counterfactual phenotyping for censored Time-to-Events" (2022).


k : int
The number of underlying base survival phenotypes.
g : int
The number of underlying treatment effect phenotypes.
layers : list
A list of integers consisting of the number of neurons in each hidden layer.
gate_l2_penalty : float
Strength of the l2 penalty term for the gate layers. Higher means stronger regularization.
random_seed : int
Controls the reproducibility of called functions.


>>> from auton_survival import DeepCoxMixturesHeterogenousEffects
>>> model = DeepCoxMixturesHeterogenousEffects(k=2, g=3)
>>>, t, e, a)


def fit(self, x, t, e, a, vsize=0.15, val_data=None, iters=1, learning_rate=0.001, batch_size=100, patience=2, optimizer='Adam')

This method is used to train an instance of the DSM model.


x : np.ndarray
A numpy array of the input features, x .
t : np.ndarray
A numpy array of the event/censoring times, t .
e : np.ndarray
A numpy array of the event/censoring indicators, \delta . \delta = 1 means the event took place.
a : np.ndarray
A numpy array of the treatment assignment indicators, a . a = 1 means the individual was treated.
vsize : float
Amount of data to set aside as the validation set.
val_data : tuple
A tuple of the validation dataset. If passed vsize is ignored.
iters : int
The maximum number of training iterations on the training dataset.
learning_rate : float
The learning rate for the Adam optimizer.
batch_size : int
learning is performed on mini-batches of input data. this parameter specifies the size of each mini-batch.
optimizer : str
The choice of the gradient based optimization method. One of 'Adam', 'RMSProp' or 'SGD'.
def predict_risk(self, x, a, t=None)
def predict_survival(self, x, a, t=None)

Returns the estimated survival probability at time t , \widehat{\mathbb{P}}(T > t|X) for some input data x .


x : np.ndarray
A numpy array of the input features, x .
a : np.ndarray
A numpy array of the treatmeant assignment, a .
t : list or float
a list or float of the times at which survival probability is to be computed


numpy array of the survival probabilites at each time in t.
def predict_latent_z(self, x)

Returns the estimated latent base survival group z given the confounders x .

def predict_latent_phi(self, x)

Returns the estimated latent treatment effect group \phi given the confounders x .