Module auton_survival.models.dcm.dcm_utilities
Functions
def randargmax(b, **kw)
-
a random tie-breaking argmax
def partial_ll_loss(lrisks, tb, eb, eps=0.01)
def fit_spline(t, surv, s=0.0001)
def smooth_bl_survival(breslow, smoothing_factor)
def get_probability_(lrisks, ts, spl)
def get_survival_(lrisks, ts, spl)
def get_probability(lrisks, breslow_splines, t)
def get_survival(lrisks, breslow_splines, t)
def get_posteriors(probs)
def get_hard_z(gates_prob)
def sample_hard_z(gates_prob)
def repair_probs(probs)
def get_likelihood(model, breslow_splines, x, t, e)
def q_function(model, x, t, e, posteriors, typ='soft')
def e_step(model, breslow_splines, x, t, e)
def m_step(model, optimizer, x, t, e, posteriors, typ='soft')
def fit_breslow(model, x, t, e, posteriors=None, smoothing_factor=0.0001, typ='soft')
def train_step(model, x, t, e, breslow_splines, optimizer, bs=256, seed=100, typ='soft', use_posteriors=False, update_splines_after=10, smoothing_factor=0.0001)
def test_step(model, x, t, e, breslow_splines, loss='q', typ='soft')
def train_dcm(model, train_data, val_data, epochs=50, patience=3, vloss='q', bs=256, typ='soft', lr=0.001, use_posteriors=True, debug=False, random_seed=0, return_losses=False, update_splines_after=10, smoothing_factor=0.01)
def predict_survival(model, x, t)
def predict_latent_z(model, x)