Skip to content

Commit

Permalink
revert changes at loss
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Apr 2, 2024
1 parent 05981aa commit 21a64b3
Showing 1 changed file with 1 addition and 9 deletions.
10 changes: 1 addition & 9 deletions coix/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,7 @@

def apg_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
"""RWS objective that exploits conditional dependency."""
# del incoming_log_weight, incremental_log_weight
print("p_trace", jax.tree.map(jnp.shape, p_trace))
print("q_trace", jax.tree.map(jnp.shape, q_trace))
jax.debug.print("incoming={ilw}", ilw=incoming_log_weight)
jax.debug.print("incremental={ilw}", ilw=incremental_log_weight)
del incoming_log_weight, incremental_log_weight
p_log_probs = {
name: util.get_site_log_prob(site) for name, site in p_trace.items()
}
Expand Down Expand Up @@ -210,10 +206,6 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):

def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
"""Reweighted Wake-Sleep objective."""
print("p_trace", jax.tree.map(jnp.shape, p_trace))
print("q_trace", jax.tree.map(jnp.shape, q_trace))
jax.debug.print("incoming={ilw}", ilw=incoming_log_weight)
jax.debug.print("incremental={ilw}", ilw=incremental_log_weight)
batch_ndims = incoming_log_weight.ndim
p_log_probs = {
name: util.get_site_log_prob(site) for name, site in p_trace.items()
Expand Down

0 comments on commit 21a64b3

Please sign in to comment.