diff --git a/coix/loss.py b/coix/loss.py index 2989971..995ec98 100644 --- a/coix/loss.py +++ b/coix/loss.py @@ -29,7 +29,13 @@ ] -def apg_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): +def apg_loss( + q_trace, + p_trace, + incoming_log_weight, + incremental_log_weight, + aggregate=True, +): """RWS objective that exploits conditional dependency.""" del incoming_log_weight, incremental_log_weight p_log_probs = { @@ -87,11 +93,19 @@ def apg_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): surrogate_loss = target_lp + forward_lp log_weight = target_lp + reverse_lp - (forward_lp + proposal_lp) w = jax.lax.stop_gradient(jax.nn.softmax(log_weight, axis=0)) - loss = -(w * surrogate_loss).sum() + loss = -(w * surrogate_loss) + if aggregate: + loss = loss.sum() return loss -def avo_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): +def avo_loss( + q_trace, + p_trace, + incoming_log_weight, + incremental_log_weight, + aggregate=True, +): """Annealed Variational Objective.""" del q_trace, p_trace surrogate_loss = incremental_log_weight @@ -99,11 +113,19 @@ def avo_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): w1 = 1.0 / incoming_log_weight.shape[0] else: w1 = 1.0 - loss = -(w1 * surrogate_loss).sum() + loss = -(w1 * surrogate_loss) + if aggregate: + loss = loss.sum() return loss -def elbo_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): +def elbo_loss( + q_trace, + p_trace, + incoming_log_weight, + incremental_log_weight, + aggregate=True, +): """Evidence Lower Bound objective.""" del q_trace, p_trace surrogate_loss = incremental_log_weight @@ -111,11 +133,19 @@ def elbo_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): w1 = jax.lax.stop_gradient(jax.nn.softmax(incoming_log_weight, axis=0)) else: w1 = 1.0 - loss = -(w1 * surrogate_loss).sum() + loss = -(w1 * surrogate_loss) + if aggregate: + loss = loss.sum() return loss -def fkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): +def fkl_loss( + q_trace, + p_trace, + incoming_log_weight, + incremental_log_weight, + aggregate=True, +): """Forward KL objective. Here we do not optimize p.""" del p_trace batch_ndims = incoming_log_weight.ndim @@ -144,11 +174,19 @@ def fkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): w1 = jax.lax.stop_gradient(jax.nn.softmax(incoming_log_weight, axis=0)) log_weight = incoming_log_weight + incremental_log_weight w = jax.lax.stop_gradient(jax.nn.softmax(log_weight, axis=0)) - loss = -(w * surrogate_loss - w1 * proposal_lp).sum() + loss = -(w * surrogate_loss - w1 * proposal_lp) + if aggregate: + loss = loss.sum() return loss -def iwae_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): +def iwae_loss( + q_trace, + p_trace, + incoming_log_weight, + incremental_log_weight, + aggregate=True, +): """Importance Weighted Autoencoder objective.""" del q_trace, p_trace log_weight = incoming_log_weight + incremental_log_weight @@ -157,11 +195,19 @@ def iwae_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): w = jax.lax.stop_gradient(jax.nn.softmax(log_weight, axis=0)) else: w = 1.0 - loss = -(w * surrogate_loss).sum() + loss = -(w * surrogate_loss) + if aggregate: + loss = loss.sum() return loss -def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): +def rkl_loss( + q_trace, + p_trace, + incoming_log_weight, + incremental_log_weight, + aggregate=True, +): """Reverse KL objective.""" batch_ndims = incoming_log_weight.ndim p_log_probs = { @@ -195,11 +241,19 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): ) log_weight = incoming_log_weight + incremental_log_weight w = jax.lax.stop_gradient(jax.nn.softmax(log_weight, axis=0)) - loss = -(w1 * surrogate_loss - w * target_lp).sum() + loss = -(w1 * surrogate_loss - w * target_lp) + if aggregate: + loss = loss.sum() return loss -def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): +def rws_loss( + q_trace, + p_trace, + incoming_log_weight, + incremental_log_weight, + aggregate=True, +): """Reweighted Wake-Sleep objective.""" batch_ndims = incoming_log_weight.ndim p_log_probs = { @@ -234,5 +288,7 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): surrogate_loss = (target_lp - proposal_lp) + forward_lp log_weight = incoming_log_weight + incremental_log_weight w = jax.lax.stop_gradient(jax.nn.softmax(log_weight, axis=0)) - loss = -(w * surrogate_loss).sum() + loss = -(w * surrogate_loss) + if aggregate: + loss = loss.sum() return loss