diff --git a/coix/algo.py b/coix/algo.py index 27453bd..2da2eb2 100644 --- a/coix/algo.py +++ b/coix/algo.py @@ -408,9 +408,11 @@ def vsmc(targets, proposals, *, num_targets=None): if _use_fori_loop(targets, num_targets, proposals): def body_fun(i, q): - return compose(proposals(i + 1), resample(propose(targets(i), q))) + return propose(targets(i), compose(proposals(i), resample(q))) - q = fori_loop(0, num_targets - 1, body_fun, proposals(0)) + q = propose(targets(0), proposals(0)) + q = fori_loop(1, num_targets - 1, body_fun, q) + q = compose(proposals(num_targets - 1), resample(q)) return propose(targets(num_targets - 1), q, loss_fn=iwae_loss) q = propose(targets[0], proposals[0]) diff --git a/coix/api.py b/coix/api.py index b1018c3..861b9a4 100644 --- a/coix/api.py +++ b/coix/api.py @@ -124,9 +124,6 @@ def propose(p, q, *, loss_fn=None, detach=False, chain=False): additional batch dimensions to the whole program by using `vmap`, e.g. `vmap(propose(p, q))`. - Note: We assume superfluous variables, which appear in `q` but not in `p`, - implicitly follow Delta distribution in `p`. - Args: p: a target program q: a proposal program @@ -179,10 +176,10 @@ def wrapped(*args, **kwargs): for name, lp in p_log_probs.items() if util.is_observed_site(p_trace[name]) or (name in q_trace) ) - # Note: We include superfluous variables, whose `name in p_trace`. q_log_weight = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) - for lp in q_log_probs.values() + for name, lp in q_log_probs.items() + if util.is_observed_site(q_trace[name]) or (name in p_trace) ) incremental_log_weight = p_log_weight - q_log_weight log_weight = in_log_weight + incremental_log_weight diff --git a/coix/core.py b/coix/core.py index e34abed..9121552 100644 --- a/coix/core.py +++ b/coix/core.py @@ -117,6 +117,8 @@ def desuffix(trace): new_trace = {} for name in trace: raw_name = names_to_raw_names[name] + if raw_name != name and isinstance(trace[name], dict): + trace[name]["suffix"] = True new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name] return new_trace diff --git a/coix/loss.py b/coix/loss.py index bc7fda4..2989971 100644 --- a/coix/loss.py +++ b/coix/loss.py @@ -115,29 +115,19 @@ def elbo_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): return loss -def _proposal_and_target_sites(q_trace, p_trace): - """Gets current proposal sites and current target sites.""" - proposal_sites = [] - target_sites = [] - for name in p_trace: - if not name.endswith("_PREV_"): - target_sites.append(name) - if name in q_trace: - while name + "_PREV_" in q_trace: - name += "_PREV_" - proposal_sites.append(name) - if not any(name.endswith("_PREV_") for name in proposal_sites): - proposal_sites = [] - return proposal_sites, target_sites - - def fkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): - """Forward KL objective.""" + """Forward KL objective. Here we do not optimize p.""" + del p_trace batch_ndims = incoming_log_weight.ndim q_log_probs = { name: util.get_site_log_prob(site) for name, site in q_trace.items() } - proposal_sites, _ = _proposal_and_target_sites(q_trace, p_trace) + proposal_sites = [ + name + for name, site in q_trace.items() + if name.endswith("_PREV_") + or (isinstance(site, dict) and "suffix" in site) + ] proposal_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) @@ -180,7 +170,12 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): q_log_probs = { name: util.get_site_log_prob(site) for name, site in q_trace.items() } - proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace) + proposal_sites = [ + name + for name, site in q_trace.items() + if name.endswith("_PREV_") + or (isinstance(site, dict) and "suffix" in site) + ] proposal_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) @@ -190,7 +185,7 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): target_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) for name, lp in p_log_probs.items() - if name in target_sites + if not name.endswith("_PREV_") ) w1 = jax.lax.stop_gradient(jax.nn.softmax(incoming_log_weight, axis=0)) @@ -213,7 +208,12 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): q_log_probs = { name: util.get_site_log_prob(site) for name, site in q_trace.items() } - proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace) + proposal_sites = [ + name + for name, site in q_trace.items() + if name.endswith("_PREV_") + or (isinstance(site, dict) and "suffix" in site) + ] proposal_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) @@ -228,7 +228,7 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): target_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) for name, lp in p_log_probs.items() - if name in target_sites + if not name.endswith("_PREV_") ) surrogate_loss = (target_lp - proposal_lp) + forward_lp diff --git a/coix/loss_test.py b/coix/loss_test.py index 46fde42..d822860 100644 --- a/coix/loss_test.py +++ b/coix/loss_test.py @@ -25,7 +25,7 @@ } q_trace = { "x": {"log_prob": np.ones((3, 2))}, - "y": {"log_prob": np.array([1.0, 1.0, 0.0])}, + "y": {"log_prob": np.array([1.0, 1.0, 0.0]), "suffix": True}, "x_PREV_": {"log_prob": np.full((3, 2), 3.0)}, } incoming_weight = np.zeros(3) diff --git a/coix/util.py b/coix/util.py index 2ebf2ac..f356513 100644 --- a/coix/util.py +++ b/coix/util.py @@ -223,32 +223,6 @@ def step_fn(params, opt_state, *args, **kwargs): return params, metrics -def _remove_suffix(name): - i = 0 - while name.endswith("_PREV_"): - i += len("_PREV_") - name = name[: -len("_PREV_")] - return name, i - - -def desuffix(trace): - """Remove unnecessary suffix terms added to the trace.""" - names_to_raw_names = {} - num_suffix_min = {} - for name in trace: - raw_name, num_suffix = _remove_suffix(name) - names_to_raw_names[name] = raw_name - if raw_name in num_suffix_min: - num_suffix_min[raw_name] = min(num_suffix_min[raw_name], num_suffix) - else: - num_suffix_min[raw_name] = num_suffix - new_trace = {} - for name in trace: - raw_name = names_to_raw_names[name] - new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name] - return new_trace - - def get_batch_ndims(xs): """Gets the number of same-size leading dimensions of the elements in xs.""" if not xs: diff --git a/examples/anneal.py b/examples/anneal.py index 91e27a4..a0f761d 100644 --- a/examples/anneal.py +++ b/examples/anneal.py @@ -121,6 +121,8 @@ def __call__(self, x): def anneal_target(network, k=0): x = numpyro.sample("x", dist.Normal(0, 5).expand([2]).mask(False).to_event()) anneal_density = network.anneal_density(x, index=k) + # We make "anneal_density" a latent site so that it does not contribute + # to the likelihood weighting of the first proposal. numpyro.sample("anneal_density", dist.Unit(anneal_density)) return ({"x": x},)