Skip to content

Commit

Permalink
Support for superfluous variables which appear in proposal but not in…
Browse files Browse the repository at this point in the history
… generative model (#37)

* propose with superfluous variables

* add desuffix to trace to extract forward/reverse variables

* fix lint

* fix failing test
  • Loading branch information
fehiepsi authored May 3, 2024
1 parent a764c0f commit 199103f
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 56 deletions.
6 changes: 4 additions & 2 deletions coix/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,11 @@ def vsmc(targets, proposals, *, num_targets=None):
if _use_fori_loop(targets, num_targets, proposals):

def body_fun(i, q):
return compose(proposals(i + 1), resample(propose(targets(i), q)))
return propose(targets(i), compose(proposals(i), resample(q)))

q = fori_loop(0, num_targets - 1, body_fun, proposals(0))
q = propose(targets(0), proposals(0))
q = fori_loop(1, num_targets - 1, body_fun, q)
q = compose(proposals(num_targets - 1), resample(q))
return propose(targets(num_targets - 1), q, loss_fn=iwae_loss)

q = propose(targets[0], proposals[0])
Expand Down
7 changes: 2 additions & 5 deletions coix/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,6 @@ def propose(p, q, *, loss_fn=None, detach=False, chain=False):
additional batch dimensions to the whole program by using `vmap`, e.g.
`vmap(propose(p, q))`.
Note: We assume superfluous variables, which appear in `q` but not in `p`,
implicitly follow Delta distribution in `p`.
Args:
p: a target program
q: a proposal program
Expand Down Expand Up @@ -179,10 +176,10 @@ def wrapped(*args, **kwargs):
for name, lp in p_log_probs.items()
if util.is_observed_site(p_trace[name]) or (name in q_trace)
)
# Note: We include superfluous variables, whose `name in p_trace`.
q_log_weight = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for lp in q_log_probs.values()
for name, lp in q_log_probs.items()
if util.is_observed_site(q_trace[name]) or (name in p_trace)
)
incremental_log_weight = p_log_weight - q_log_weight
log_weight = in_log_weight + incremental_log_weight
Expand Down
2 changes: 2 additions & 0 deletions coix/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def desuffix(trace):
new_trace = {}
for name in trace:
raw_name = names_to_raw_names[name]
if raw_name != name and isinstance(trace[name], dict):
trace[name]["suffix"] = True
new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name]
return new_trace

Expand Down
44 changes: 22 additions & 22 deletions coix/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,29 +115,19 @@ def elbo_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
return loss


def _proposal_and_target_sites(q_trace, p_trace):
"""Gets current proposal sites and current target sites."""
proposal_sites = []
target_sites = []
for name in p_trace:
if not name.endswith("_PREV_"):
target_sites.append(name)
if name in q_trace:
while name + "_PREV_" in q_trace:
name += "_PREV_"
proposal_sites.append(name)
if not any(name.endswith("_PREV_") for name in proposal_sites):
proposal_sites = []
return proposal_sites, target_sites


def fkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
"""Forward KL objective."""
"""Forward KL objective. Here we do not optimize p."""
del p_trace
batch_ndims = incoming_log_weight.ndim
q_log_probs = {
name: util.get_site_log_prob(site) for name, site in q_trace.items()
}
proposal_sites, _ = _proposal_and_target_sites(q_trace, p_trace)
proposal_sites = [
name
for name, site in q_trace.items()
if name.endswith("_PREV_")
or (isinstance(site, dict) and "suffix" in site)
]

proposal_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
Expand Down Expand Up @@ -180,7 +170,12 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
q_log_probs = {
name: util.get_site_log_prob(site) for name, site in q_trace.items()
}
proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace)
proposal_sites = [
name
for name, site in q_trace.items()
if name.endswith("_PREV_")
or (isinstance(site, dict) and "suffix" in site)
]

proposal_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
Expand All @@ -190,7 +185,7 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
target_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in p_log_probs.items()
if name in target_sites
if not name.endswith("_PREV_")
)

w1 = jax.lax.stop_gradient(jax.nn.softmax(incoming_log_weight, axis=0))
Expand All @@ -213,7 +208,12 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
q_log_probs = {
name: util.get_site_log_prob(site) for name, site in q_trace.items()
}
proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace)
proposal_sites = [
name
for name, site in q_trace.items()
if name.endswith("_PREV_")
or (isinstance(site, dict) and "suffix" in site)
]

proposal_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
Expand All @@ -228,7 +228,7 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
target_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in p_log_probs.items()
if name in target_sites
if not name.endswith("_PREV_")
)

surrogate_loss = (target_lp - proposal_lp) + forward_lp
Expand Down
2 changes: 1 addition & 1 deletion coix/loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
}
q_trace = {
"x": {"log_prob": np.ones((3, 2))},
"y": {"log_prob": np.array([1.0, 1.0, 0.0])},
"y": {"log_prob": np.array([1.0, 1.0, 0.0]), "suffix": True},
"x_PREV_": {"log_prob": np.full((3, 2), 3.0)},
}
incoming_weight = np.zeros(3)
Expand Down
26 changes: 0 additions & 26 deletions coix/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,32 +223,6 @@ def step_fn(params, opt_state, *args, **kwargs):
return params, metrics


def _remove_suffix(name):
i = 0
while name.endswith("_PREV_"):
i += len("_PREV_")
name = name[: -len("_PREV_")]
return name, i


def desuffix(trace):
"""Remove unnecessary suffix terms added to the trace."""
names_to_raw_names = {}
num_suffix_min = {}
for name in trace:
raw_name, num_suffix = _remove_suffix(name)
names_to_raw_names[name] = raw_name
if raw_name in num_suffix_min:
num_suffix_min[raw_name] = min(num_suffix_min[raw_name], num_suffix)
else:
num_suffix_min[raw_name] = num_suffix
new_trace = {}
for name in trace:
raw_name = names_to_raw_names[name]
new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name]
return new_trace


def get_batch_ndims(xs):
"""Gets the number of same-size leading dimensions of the elements in xs."""
if not xs:
Expand Down
2 changes: 2 additions & 0 deletions examples/anneal.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def __call__(self, x):
def anneal_target(network, k=0):
x = numpyro.sample("x", dist.Normal(0, 5).expand([2]).mask(False).to_event())
anneal_density = network.anneal_density(x, index=k)
# We make "anneal_density" a latent site so that it does not contribute
# to the likelihood weighting of the first proposal.
numpyro.sample("anneal_density", dist.Unit(anneal_density))
return ({"x": x},)

Expand Down

0 comments on commit 199103f

Please sign in to comment.