diff --git a/coix/api.py b/coix/api.py index b1018c3..861b9a4 100644 --- a/coix/api.py +++ b/coix/api.py @@ -124,9 +124,6 @@ def propose(p, q, *, loss_fn=None, detach=False, chain=False): additional batch dimensions to the whole program by using `vmap`, e.g. `vmap(propose(p, q))`. - Note: We assume superfluous variables, which appear in `q` but not in `p`, - implicitly follow Delta distribution in `p`. - Args: p: a target program q: a proposal program @@ -179,10 +176,10 @@ def wrapped(*args, **kwargs): for name, lp in p_log_probs.items() if util.is_observed_site(p_trace[name]) or (name in q_trace) ) - # Note: We include superfluous variables, whose `name in p_trace`. q_log_weight = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) - for lp in q_log_probs.values() + for name, lp in q_log_probs.items() + if util.is_observed_site(q_trace[name]) or (name in p_trace) ) incremental_log_weight = p_log_weight - q_log_weight log_weight = in_log_weight + incremental_log_weight diff --git a/coix/loss.py b/coix/loss.py index bc7fda4..6d4dc46 100644 --- a/coix/loss.py +++ b/coix/loss.py @@ -115,29 +115,17 @@ def elbo_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): return loss -def _proposal_and_target_sites(q_trace, p_trace): - """Gets current proposal sites and current target sites.""" - proposal_sites = [] - target_sites = [] - for name in p_trace: - if not name.endswith("_PREV_"): - target_sites.append(name) - if name in q_trace: - while name + "_PREV_" in q_trace: - name += "_PREV_" - proposal_sites.append(name) - if not any(name.endswith("_PREV_") for name in proposal_sites): - proposal_sites = [] - return proposal_sites, target_sites - - def fkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): - """Forward KL objective.""" + """Forward KL objective. Here we do not optimize p.""" batch_ndims = incoming_log_weight.ndim q_log_probs = { name: util.get_site_log_prob(site) for name, site in q_trace.items() } - proposal_sites, _ = _proposal_and_target_sites(q_trace, p_trace) + proposal_sites = [ + name + for name, site in q_trace.items() + if util.is_observed_site(site) or name.endswith("_PREV_") + ] proposal_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) @@ -180,7 +168,11 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): q_log_probs = { name: util.get_site_log_prob(site) for name, site in q_trace.items() } - proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace) + proposal_sites = [ + name + for name, site in q_trace.items() + if util.is_observed_site(site) or name.endswith("_PREV_") + ] proposal_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) @@ -190,7 +182,7 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): target_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) for name, lp in p_log_probs.items() - if name in target_sites + if not name.endswith("_PREV_") ) w1 = jax.lax.stop_gradient(jax.nn.softmax(incoming_log_weight, axis=0)) @@ -213,7 +205,11 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): q_log_probs = { name: util.get_site_log_prob(site) for name, site in q_trace.items() } - proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace) + proposal_sites = [ + name + for name, site in q_trace.items() + if util.is_observed_site(site) or name.endswith("_PREV_") + ] proposal_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) @@ -228,7 +224,7 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): target_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) for name, lp in p_log_probs.items() - if name in target_sites + if not name.endswith("_PREV_") ) surrogate_loss = (target_lp - proposal_lp) + forward_lp diff --git a/examples/anneal.py b/examples/anneal.py index 91e27a4..a0f761d 100644 --- a/examples/anneal.py +++ b/examples/anneal.py @@ -121,6 +121,8 @@ def __call__(self, x): def anneal_target(network, k=0): x = numpyro.sample("x", dist.Normal(0, 5).expand([2]).mask(False).to_event()) anneal_density = network.anneal_density(x, index=k) + # We make "anneal_density" a latent site so that it does not contribute + # to the likelihood weighting of the first proposal. numpyro.sample("anneal_density", dist.Unit(anneal_density)) return ({"x": x},)