Module auton_survival.models.cmhe.cmhe_utilities

Functions

def randargmax(b, **kw)

a random tie-breaking argmax

def partial_ll_loss(lrisks, tb, eb, eps=0.01)
def fit_spline(t, surv, smoothing_factor=0.0001)
def smooth_bl_survival(breslow, smoothing_factor)
def get_probability_(lrisks, ts, spl)
def get_survival_(lrisks, ts, spl)
def get_probability(lrisks, breslow_splines, t)
def get_survival(lrisks, breslow_splines, t)
def get_posteriors(probs)
def get_hard_z(gates_prob)
def sample_hard_z(gates_prob)
def repair_probs(probs)
def get_likelihood(model, breslow_splines, x, t, e, a)
def q_function(model, x, t, e, a, log_likelihoods, typ='soft')
def e_step(model, breslow_splines, x, t, e, a)
def m_step(model, optimizer, x, t, e, a, log_likelihoods, typ='soft')
def fit_breslow(model, x, t, e, a, log_likelihoods=None, smoothing_factor=0.0001, typ='soft')
def train_step(model, x, t, e, a, breslow_splines, optimizer, bs=256, seed=100, typ='soft', use_posteriors=False, update_splines_after=10, smoothing_factor=0.0001)
def test_step(model, x, t, e, a, breslow_splines, loss='q', typ='soft')
def train_cmhe(model, train_data, val_data, epochs=50, patience=2, vloss='q', bs=256, typ='soft', lr=0.001, use_posteriors=False, debug=False, return_losses=False, update_splines_after=1, smoothing_factor=0.0001, random_seed=0)
def predict_survival(model, x, a, t)
def predict_latent_z(model, x)
def predict_latent_phi(model, x)