Skip to content

Commit

Permalink
propose with superfluous variables
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed May 3, 2024
1 parent a764c0f commit 9f8a493
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 27 deletions.
7 changes: 2 additions & 5 deletions coix/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,6 @@ def propose(p, q, *, loss_fn=None, detach=False, chain=False):
additional batch dimensions to the whole program by using `vmap`, e.g.
`vmap(propose(p, q))`.
Note: We assume superfluous variables, which appear in `q` but not in `p`,
implicitly follow Delta distribution in `p`.
Args:
p: a target program
q: a proposal program
Expand Down Expand Up @@ -179,10 +176,10 @@ def wrapped(*args, **kwargs):
for name, lp in p_log_probs.items()
if util.is_observed_site(p_trace[name]) or (name in q_trace)
)
# Note: We include superfluous variables, whose `name in p_trace`.
q_log_weight = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for lp in q_log_probs.values()
for name, lp in q_log_probs.items()
if util.is_observed_site(q_trace[name]) or (name in p_trace)
)
incremental_log_weight = p_log_weight - q_log_weight
log_weight = in_log_weight + incremental_log_weight
Expand Down
40 changes: 18 additions & 22 deletions coix/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,29 +115,17 @@ def elbo_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
return loss


def _proposal_and_target_sites(q_trace, p_trace):
"""Gets current proposal sites and current target sites."""
proposal_sites = []
target_sites = []
for name in p_trace:
if not name.endswith("_PREV_"):
target_sites.append(name)
if name in q_trace:
while name + "_PREV_" in q_trace:
name += "_PREV_"
proposal_sites.append(name)
if not any(name.endswith("_PREV_") for name in proposal_sites):
proposal_sites = []
return proposal_sites, target_sites


def fkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
"""Forward KL objective."""
"""Forward KL objective. Here we do not optimize p."""
batch_ndims = incoming_log_weight.ndim
q_log_probs = {
name: util.get_site_log_prob(site) for name, site in q_trace.items()
}
proposal_sites, _ = _proposal_and_target_sites(q_trace, p_trace)
proposal_sites = [
name
for name, site in q_trace.items()
if util.is_observed_site(site) or name.endswith("_PREV_")
]

proposal_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
Expand Down Expand Up @@ -180,7 +168,11 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
q_log_probs = {
name: util.get_site_log_prob(site) for name, site in q_trace.items()
}
proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace)
proposal_sites = [
name
for name, site in q_trace.items()
if util.is_observed_site(site) or name.endswith("_PREV_")
]

proposal_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
Expand All @@ -190,7 +182,7 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
target_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in p_log_probs.items()
if name in target_sites
if not name.endswith("_PREV_")
)

w1 = jax.lax.stop_gradient(jax.nn.softmax(incoming_log_weight, axis=0))
Expand All @@ -213,7 +205,11 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
q_log_probs = {
name: util.get_site_log_prob(site) for name, site in q_trace.items()
}
proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace)
proposal_sites = [
name
for name, site in q_trace.items()
if util.is_observed_site(site) or name.endswith("_PREV_")
]

proposal_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
Expand All @@ -228,7 +224,7 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
target_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in p_log_probs.items()
if name in target_sites
if not name.endswith("_PREV_")
)

surrogate_loss = (target_lp - proposal_lp) + forward_lp
Expand Down
2 changes: 2 additions & 0 deletions examples/anneal.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def __call__(self, x):
def anneal_target(network, k=0):
x = numpyro.sample("x", dist.Normal(0, 5).expand([2]).mask(False).to_event())
anneal_density = network.anneal_density(x, index=k)
# We make "anneal_density" a latent site so that it does not contribute
# to the likelihood weighting of the first proposal.
numpyro.sample("anneal_density", dist.Unit(anneal_density))
return ({"x": x},)

Expand Down

0 comments on commit 9f8a493

Please sign in to comment.