diff --git a/coix/api.py b/coix/api.py index c4d6101..b528005 100644 --- a/coix/api.py +++ b/coix/api.py @@ -68,32 +68,6 @@ def wrapped(*args, **kwargs): return wrapped -def _get_batch_ndims(log_probs): - if not log_probs: - return 0 - min_ndim = min(jnp.ndim(lp) for lp in log_probs) - batch_ndims = 0 - for i in range(min_ndim): - if len(set(jnp.shape(lp)[i] for lp in log_probs)) > 1: - break - batch_ndims = batch_ndims + 1 - return batch_ndims - - -def _get_log_weight(trace, batch_ndims): - """Computes log weight of the trace and keeps its batch dimensions.""" - log_weight = jnp.zeros((1,) * batch_ndims) - for site in trace.values(): - lp = util.get_site_log_prob(site) - if util.is_observed_site(site): - log_weight = log_weight + jnp.sum( - lp, axis=tuple(range(batch_ndims - jnp.ndim(lp), 0)) - ) - else: - log_weight = log_weight + jnp.zeros(jnp.shape(lp)[:batch_ndims]) - return log_weight - - def _split_key(key): keys = jax.vmap(jax.random.split)(key.reshape(-1, 2)).reshape( key.shape[:-1] + (2, 2) @@ -150,16 +124,14 @@ def wrapped(*args, **kwargs): name: util.get_site_log_prob(site) for name, site in q_trace.items() } log_probs = list(p_log_probs.values()) + list(q_log_probs.values()) - batch_ndims = _get_batch_ndims(log_probs) + batch_ndims = util.get_batch_ndims(log_probs) - if "log_weight" in q_metrics: - in_log_weight = q_metrics["log_weight"] - in_log_weight = jnp.sum( - in_log_weight, - axis=tuple(range(batch_ndims - jnp.ndim(in_log_weight), 0)), - ) - else: - in_log_weight = _get_log_weight(q_trace, batch_ndims) + assert "log_weight" in q_metrics + in_log_weight = q_metrics["log_weight"] + in_log_weight = jnp.sum( + in_log_weight, + axis=tuple(range(batch_ndims - jnp.ndim(in_log_weight), 0)), + ) p_log_weight = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) for name, lp in p_log_probs.items() @@ -269,7 +241,7 @@ def fn(*args, **kwargs): log_probs = { name: util.get_site_log_prob(site) for name, site in trace.items() } - batch_ndims = _get_batch_ndims(log_probs.values()) + batch_ndims = util.get_batch_ndims(log_probs.values()) weighted = ("log_weight" in q_metrics) or any( util.is_observed_site(site) for site in trace.values() ) @@ -284,7 +256,7 @@ def fn(*args, **kwargs): axis=tuple(range(batch_ndims - jnp.ndim(in_log_weight), 0)), ) else: - in_log_weight = _get_log_weight(trace, batch_ndims) + in_log_weight = util.get_log_weight(trace, batch_ndims) n = in_log_weight.shape[0] k = n if num_samples is None else num_samples log_weight = jax.nn.logsumexp(in_log_weight, 0) - jnp.log(k if k else 1) @@ -321,8 +293,8 @@ def _add_missing_metrics(metrics, trace): name: util.get_site_log_prob(site) for name, site in trace.items() } if "log_weight" not in metrics: - batch_ndims = min(_get_batch_ndims(list(log_probs.values())), 1) - log_weight = _get_log_weight(trace, batch_ndims) + batch_ndims = min(util.get_batch_ndims(list(log_probs.values())), 1) + log_weight = util.get_log_weight(trace, batch_ndims) full_metrics["log_weight"] = log_weight if batch_ndims: # leftmost dimension is particle dimension ess = 1 / (jax.nn.softmax(log_weight, axis=0) ** 2).sum(0) @@ -430,7 +402,7 @@ def wrapped(*args, **kwargs): p_log_probs = { name: util.get_site_log_prob(site) for name, site in p_trace.items() } - batch_ndims = _get_batch_ndims(p_log_probs.values()) + batch_ndims = util.get_batch_ndims(p_log_probs.values()) p_log_weight = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) diff --git a/coix/numpyro.py b/coix/numpyro.py index a34e845..2672ed3 100644 --- a/coix/numpyro.py +++ b/coix/numpyro.py @@ -1,5 +1,8 @@ """Backend implementation for NumPyro.""" +from coix.util import get_batch_ndims +from coix.util import get_log_weight +from coix.util import get_site_log_prob import jax import jax.numpy as jnp import numpyro @@ -33,6 +36,11 @@ def wrapped(*args, **kwargs): for name, site in tr.items() if site["type"] == "metric" } + # add log_weight to metrics + if "log_weight" not in metrics: + log_probs = [get_site_log_prob(site) for site in trace.values()] + weight = get_log_weight(trace, get_batch_ndims(log_probs)) + metrics = {**metrics, "log_weight": weight} return out, trace, metrics return wrapped diff --git a/coix/oryx.py b/coix/oryx.py index 68bd481..077ac0c 100644 --- a/coix/oryx.py +++ b/coix/oryx.py @@ -4,6 +4,9 @@ import inspect import itertools +from coix.util import get_batch_ndims +from coix.util import get_log_weight +from coix.util import get_site_log_prob import jax import jax.numpy as jnp @@ -403,6 +406,10 @@ def wrapped(*args, **kwargs): if "log_density" not in metrics: log_density = sum(jnp.sum(site["log_prob"]) for site in trace.values()) metrics["log_density"] = jnp.array(0.0) + log_density + if "log_weight" not in metrics: + log_probs = [get_site_log_prob(site) for site in trace.values()] + weight = get_log_weight(trace, get_batch_ndims(log_probs)) + metrics = {**metrics, "log_weight": weight} return out, trace, metrics return wrapped diff --git a/coix/util.py b/coix/util.py index 9a8ce02..1d15fce 100644 --- a/coix/util.py +++ b/coix/util.py @@ -216,3 +216,30 @@ def desuffix(trace): raw_name = names_to_raw_names[name] new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name] return new_trace + + +def get_batch_ndims(xs): + """Gets the number of same-size leading dimensions of the elements in xs.""" + if not xs: + return 0 + min_ndim = min(jnp.ndim(lp) for lp in xs) + batch_ndims = 0 + for i in range(min_ndim): + if len(set(jnp.shape(lp)[i] for lp in xs)) > 1: + break + batch_ndims = batch_ndims + 1 + return batch_ndims + + +def get_log_weight(trace, batch_ndims): + """Computes log weight of the trace and keeps its batch dimensions.""" + log_weight = jnp.zeros((1,) * batch_ndims) + for site in trace.values(): + lp = get_site_log_prob(site) + if is_observed_site(site): + log_weight = log_weight + jnp.sum( + lp, axis=tuple(range(batch_ndims - jnp.ndim(lp), 0)) + ) + else: + log_weight = log_weight + jnp.zeros(jnp.shape(lp)[:batch_ndims]) + return log_weight