Skip to content

Commit

Permalink
Merge pull request #7 from jax-ml:dev-heiko
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 552643782
  • Loading branch information
The coix Authors committed Mar 28, 2024
2 parents ed92054 + 0937aec commit 01471b1
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 40 deletions.
52 changes: 12 additions & 40 deletions coix/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,32 +68,6 @@ def wrapped(*args, **kwargs):
return wrapped


def _get_batch_ndims(log_probs):
if not log_probs:
return 0
min_ndim = min(jnp.ndim(lp) for lp in log_probs)
batch_ndims = 0
for i in range(min_ndim):
if len(set(jnp.shape(lp)[i] for lp in log_probs)) > 1:
break
batch_ndims = batch_ndims + 1
return batch_ndims


def _get_log_weight(trace, batch_ndims):
"""Computes log weight of the trace and keeps its batch dimensions."""
log_weight = jnp.zeros((1,) * batch_ndims)
for site in trace.values():
lp = util.get_site_log_prob(site)
if util.is_observed_site(site):
log_weight = log_weight + jnp.sum(
lp, axis=tuple(range(batch_ndims - jnp.ndim(lp), 0))
)
else:
log_weight = log_weight + jnp.zeros(jnp.shape(lp)[:batch_ndims])
return log_weight


def _split_key(key):
keys = jax.vmap(jax.random.split)(key.reshape(-1, 2)).reshape(
key.shape[:-1] + (2, 2)
Expand Down Expand Up @@ -150,16 +124,14 @@ def wrapped(*args, **kwargs):
name: util.get_site_log_prob(site) for name, site in q_trace.items()
}
log_probs = list(p_log_probs.values()) + list(q_log_probs.values())
batch_ndims = _get_batch_ndims(log_probs)
batch_ndims = util.get_batch_ndims(log_probs)

if "log_weight" in q_metrics:
in_log_weight = q_metrics["log_weight"]
in_log_weight = jnp.sum(
in_log_weight,
axis=tuple(range(batch_ndims - jnp.ndim(in_log_weight), 0)),
)
else:
in_log_weight = _get_log_weight(q_trace, batch_ndims)
assert "log_weight" in q_metrics
in_log_weight = q_metrics["log_weight"]
in_log_weight = jnp.sum(
in_log_weight,
axis=tuple(range(batch_ndims - jnp.ndim(in_log_weight), 0)),
)
p_log_weight = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in p_log_probs.items()
Expand Down Expand Up @@ -269,7 +241,7 @@ def fn(*args, **kwargs):
log_probs = {
name: util.get_site_log_prob(site) for name, site in trace.items()
}
batch_ndims = _get_batch_ndims(log_probs.values())
batch_ndims = util.get_batch_ndims(log_probs.values())
weighted = ("log_weight" in q_metrics) or any(
util.is_observed_site(site) for site in trace.values()
)
Expand All @@ -284,7 +256,7 @@ def fn(*args, **kwargs):
axis=tuple(range(batch_ndims - jnp.ndim(in_log_weight), 0)),
)
else:
in_log_weight = _get_log_weight(trace, batch_ndims)
in_log_weight = util.get_log_weight(trace, batch_ndims)
n = in_log_weight.shape[0]
k = n if num_samples is None else num_samples
log_weight = jax.nn.logsumexp(in_log_weight, 0) - jnp.log(k if k else 1)
Expand Down Expand Up @@ -321,8 +293,8 @@ def _add_missing_metrics(metrics, trace):
name: util.get_site_log_prob(site) for name, site in trace.items()
}
if "log_weight" not in metrics:
batch_ndims = min(_get_batch_ndims(list(log_probs.values())), 1)
log_weight = _get_log_weight(trace, batch_ndims)
batch_ndims = min(util.get_batch_ndims(list(log_probs.values())), 1)
log_weight = util.get_log_weight(trace, batch_ndims)
full_metrics["log_weight"] = log_weight
if batch_ndims: # leftmost dimension is particle dimension
ess = 1 / (jax.nn.softmax(log_weight, axis=0) ** 2).sum(0)
Expand Down Expand Up @@ -430,7 +402,7 @@ def wrapped(*args, **kwargs):
p_log_probs = {
name: util.get_site_log_prob(site) for name, site in p_trace.items()
}
batch_ndims = _get_batch_ndims(p_log_probs.values())
batch_ndims = util.get_batch_ndims(p_log_probs.values())

p_log_weight = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
Expand Down
8 changes: 8 additions & 0 deletions coix/numpyro.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Backend implementation for NumPyro."""

from coix.util import get_batch_ndims
from coix.util import get_log_weight
from coix.util import get_site_log_prob
import jax
import jax.numpy as jnp
import numpyro
Expand Down Expand Up @@ -33,6 +36,11 @@ def wrapped(*args, **kwargs):
for name, site in tr.items()
if site["type"] == "metric"
}
# add log_weight to metrics
if "log_weight" not in metrics:
log_probs = [get_site_log_prob(site) for site in trace.values()]
weight = get_log_weight(trace, get_batch_ndims(log_probs))
metrics = {**metrics, "log_weight": weight}
return out, trace, metrics

return wrapped
Expand Down
7 changes: 7 additions & 0 deletions coix/oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import inspect
import itertools

from coix.util import get_batch_ndims
from coix.util import get_log_weight
from coix.util import get_site_log_prob
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -403,6 +406,10 @@ def wrapped(*args, **kwargs):
if "log_density" not in metrics:
log_density = sum(jnp.sum(site["log_prob"]) for site in trace.values())
metrics["log_density"] = jnp.array(0.0) + log_density
if "log_weight" not in metrics:
log_probs = [get_site_log_prob(site) for site in trace.values()]
weight = get_log_weight(trace, get_batch_ndims(log_probs))
metrics = {**metrics, "log_weight": weight}
return out, trace, metrics

return wrapped
Expand Down
27 changes: 27 additions & 0 deletions coix/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,30 @@ def desuffix(trace):
raw_name = names_to_raw_names[name]
new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name]
return new_trace


def get_batch_ndims(xs):
"""Gets the number of same-size leading dimensions of the elements in xs."""
if not xs:
return 0
min_ndim = min(jnp.ndim(lp) for lp in xs)
batch_ndims = 0
for i in range(min_ndim):
if len(set(jnp.shape(lp)[i] for lp in xs)) > 1:
break
batch_ndims = batch_ndims + 1
return batch_ndims


def get_log_weight(trace, batch_ndims):
"""Computes log weight of the trace and keeps its batch dimensions."""
log_weight = jnp.zeros((1,) * batch_ndims)
for site in trace.values():
lp = get_site_log_prob(site)
if is_observed_site(site):
log_weight = log_weight + jnp.sum(
lp, axis=tuple(range(batch_ndims - jnp.ndim(lp), 0))
)
else:
log_weight = log_weight + jnp.zeros(jnp.shape(lp)[:batch_ndims])
return log_weight

0 comments on commit 01471b1

Please sign in to comment.