Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for superfluous variables which appear in proposal but not in generative model #37

Merged
merged 4 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions coix/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,11 @@ def vsmc(targets, proposals, *, num_targets=None):
if _use_fori_loop(targets, num_targets, proposals):

def body_fun(i, q):
return compose(proposals(i + 1), resample(propose(targets(i), q)))
return propose(targets(i), compose(proposals(i), resample(q)))

q = fori_loop(0, num_targets - 1, body_fun, proposals(0))
q = propose(targets(0), proposals(0))
q = fori_loop(1, num_targets - 1, body_fun, q)
q = compose(proposals(num_targets - 1), resample(q))
return propose(targets(num_targets - 1), q, loss_fn=iwae_loss)

q = propose(targets[0], proposals[0])
Expand Down
7 changes: 2 additions & 5 deletions coix/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,6 @@ def propose(p, q, *, loss_fn=None, detach=False, chain=False):
additional batch dimensions to the whole program by using `vmap`, e.g.
`vmap(propose(p, q))`.

Note: We assume superfluous variables, which appear in `q` but not in `p`,
implicitly follow Delta distribution in `p`.

Args:
p: a target program
q: a proposal program
Expand Down Expand Up @@ -179,10 +176,10 @@ def wrapped(*args, **kwargs):
for name, lp in p_log_probs.items()
if util.is_observed_site(p_trace[name]) or (name in q_trace)
)
# Note: We include superfluous variables, whose `name in p_trace`.
q_log_weight = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for lp in q_log_probs.values()
for name, lp in q_log_probs.items()
if util.is_observed_site(q_trace[name]) or (name in p_trace)
)
incremental_log_weight = p_log_weight - q_log_weight
log_weight = in_log_weight + incremental_log_weight
Expand Down
2 changes: 2 additions & 0 deletions coix/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def desuffix(trace):
new_trace = {}
for name in trace:
raw_name = names_to_raw_names[name]
if raw_name != name and isinstance(trace[name], dict):
trace[name]["suffix"] = True
new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name]
return new_trace

Expand Down
44 changes: 22 additions & 22 deletions coix/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,29 +115,19 @@ def elbo_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
return loss


def _proposal_and_target_sites(q_trace, p_trace):
"""Gets current proposal sites and current target sites."""
proposal_sites = []
target_sites = []
for name in p_trace:
if not name.endswith("_PREV_"):
target_sites.append(name)
if name in q_trace:
while name + "_PREV_" in q_trace:
name += "_PREV_"
proposal_sites.append(name)
if not any(name.endswith("_PREV_") for name in proposal_sites):
proposal_sites = []
return proposal_sites, target_sites


def fkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
"""Forward KL objective."""
"""Forward KL objective. Here we do not optimize p."""
del p_trace
batch_ndims = incoming_log_weight.ndim
q_log_probs = {
name: util.get_site_log_prob(site) for name, site in q_trace.items()
}
proposal_sites, _ = _proposal_and_target_sites(q_trace, p_trace)
proposal_sites = [
name
for name, site in q_trace.items()
if name.endswith("_PREV_")
or (isinstance(site, dict) and "suffix" in site)
]

proposal_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
Expand Down Expand Up @@ -180,7 +170,12 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
q_log_probs = {
name: util.get_site_log_prob(site) for name, site in q_trace.items()
}
proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace)
proposal_sites = [
name
for name, site in q_trace.items()
if name.endswith("_PREV_")
or (isinstance(site, dict) and "suffix" in site)
]

proposal_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
Expand All @@ -190,7 +185,7 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
target_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in p_log_probs.items()
if name in target_sites
if not name.endswith("_PREV_")
)

w1 = jax.lax.stop_gradient(jax.nn.softmax(incoming_log_weight, axis=0))
Expand All @@ -213,7 +208,12 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
q_log_probs = {
name: util.get_site_log_prob(site) for name, site in q_trace.items()
}
proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace)
proposal_sites = [
name
for name, site in q_trace.items()
if name.endswith("_PREV_")
or (isinstance(site, dict) and "suffix" in site)
]

proposal_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
Expand All @@ -228,7 +228,7 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
target_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in p_log_probs.items()
if name in target_sites
if not name.endswith("_PREV_")
)

surrogate_loss = (target_lp - proposal_lp) + forward_lp
Expand Down
2 changes: 1 addition & 1 deletion coix/loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
}
q_trace = {
"x": {"log_prob": np.ones((3, 2))},
"y": {"log_prob": np.array([1.0, 1.0, 0.0])},
"y": {"log_prob": np.array([1.0, 1.0, 0.0]), "suffix": True},
"x_PREV_": {"log_prob": np.full((3, 2), 3.0)},
}
incoming_weight = np.zeros(3)
Expand Down
26 changes: 0 additions & 26 deletions coix/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,32 +223,6 @@ def step_fn(params, opt_state, *args, **kwargs):
return params, metrics


def _remove_suffix(name):
i = 0
while name.endswith("_PREV_"):
i += len("_PREV_")
name = name[: -len("_PREV_")]
return name, i


def desuffix(trace):
"""Remove unnecessary suffix terms added to the trace."""
names_to_raw_names = {}
num_suffix_min = {}
for name in trace:
raw_name, num_suffix = _remove_suffix(name)
names_to_raw_names[name] = raw_name
if raw_name in num_suffix_min:
num_suffix_min[raw_name] = min(num_suffix_min[raw_name], num_suffix)
else:
num_suffix_min[raw_name] = num_suffix
new_trace = {}
for name in trace:
raw_name = names_to_raw_names[name]
new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name]
return new_trace


def get_batch_ndims(xs):
"""Gets the number of same-size leading dimensions of the elements in xs."""
if not xs:
Expand Down
2 changes: 2 additions & 0 deletions examples/anneal.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def __call__(self, x):
def anneal_target(network, k=0):
x = numpyro.sample("x", dist.Normal(0, 5).expand([2]).mask(False).to_event())
anneal_density = network.anneal_density(x, index=k)
# We make "anneal_density" a latent site so that it does not contribute
# to the likelihood weighting of the first proposal.
numpyro.sample("anneal_density", dist.Unit(anneal_density))
return ({"x": x},)

Expand Down
Loading