diff --git a/coix/loss.py b/coix/loss.py index b925d56..bc7fda4 100644 --- a/coix/loss.py +++ b/coix/loss.py @@ -31,11 +31,7 @@ def apg_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): """RWS objective that exploits conditional dependency.""" - # del incoming_log_weight, incremental_log_weight - print("p_trace", jax.tree.map(jnp.shape, p_trace)) - print("q_trace", jax.tree.map(jnp.shape, q_trace)) - jax.debug.print("incoming={ilw}", ilw=incoming_log_weight) - jax.debug.print("incremental={ilw}", ilw=incremental_log_weight) + del incoming_log_weight, incremental_log_weight p_log_probs = { name: util.get_site_log_prob(site) for name, site in p_trace.items() } @@ -210,10 +206,6 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): """Reweighted Wake-Sleep objective.""" - print("p_trace", jax.tree.map(jnp.shape, p_trace)) - print("q_trace", jax.tree.map(jnp.shape, q_trace)) - jax.debug.print("incoming={ilw}", ilw=incoming_log_weight) - jax.debug.print("incremental={ilw}", ilw=incremental_log_weight) batch_ndims = incoming_log_weight.ndim p_log_probs = { name: util.get_site_log_prob(site) for name, site in p_trace.items()