Module auton_survival.models.dsm.losses
Loss function definitions for the Deep Survival Machines model
In this module we define the various losses for the censored and uncensored instances of data corresponding to Weibull and LogNormal distributions. These losses are optimized when training DSM.
TODO
Use torch.distributions
Warning
NOT DESIGNED TO BE CALLED DIRECTLY!!!
Functions
def unconditional_loss(model, t, e, risk='1')
def conditional_loss(model, x, t, e, elbo=True, risk='1')
def predict_mean(model, x, risk='1')
def predict_pdf(model, x, t_horizon, risk='1')
def predict_cdf(model, x, t_horizon, risk='1')