Skip to content

Commit

Permalink
Merge pull request #9 from jax-ml:add-test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 552927184
  • Loading branch information
The coix Authors committed Mar 28, 2024
2 parents 07fb6f9 + be14a87 commit 0d264c6
Show file tree
Hide file tree
Showing 9 changed files with 527 additions and 39 deletions.
5 changes: 3 additions & 2 deletions coix/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def dais(targets, momentum, leapfrog, refreshment, *, num_targets=None):
if _use_fori_loop(targets, num_targets):

def body_fun(i, q):
assert callable(targets)
p = extend(compose(momentum, targets(i), suffix=False), refreshment)
return propose(p, compose(refreshment, compose(leapfrog, q)))

Expand All @@ -141,7 +142,7 @@ def body_fun(i, q):

targets = [compose(momentum, p, suffix=False) for p in targets]
q = targets[0]
loss_fns = [None] * (len(targets) - 2) + [iwae_loss]
loss_fns = (None,) * (len(targets) - 2) + (iwae_loss,)
for p, loss_fn in zip(targets[1:], loss_fns):
q = compose(refreshment, compose(leapfrog, q))
q = propose(extend(p, refreshment), q, loss_fn=loss_fn)
Expand Down Expand Up @@ -399,7 +400,7 @@ def body_fun(i, q):
return propose(targets(num_targets - 1), q, loss_fn=iwae_loss)

q = propose(targets[0], proposals[0])
loss_fns = [None] * (len(proposals) - 2) + [iwae_loss]
loss_fns = (None,) * (len(proposals) - 2) + (iwae_loss,)
for p, fwd, loss_fn in zip(targets[1:], proposals[1:], loss_fns):
q = propose(p, compose(fwd, resample(q)), loss_fn=loss_fn)
return q
102 changes: 102 additions & 0 deletions coix/algo_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Tests for algo.py."""

import functools

import coix
import jax
from jax import random
import jax.numpy as jnp
import numpy as np
import numpyro.distributions as dist
import optax

coix.set_backend("coix.oryx")

np.random.seed(0)
num_data, dim = 4, 2
data = np.random.randn(num_data, dim).astype(np.float32)
loc_p = np.random.randn(dim).astype(np.float32)
precision_p = np.random.rand(dim).astype(np.float32)
scale_p = np.sqrt(1 / precision_p)
precision_x = np.random.rand(dim).astype(np.float32)
scale_x = np.sqrt(1 / precision_x)
precision_q = precision_p + num_data * precision_x
loc_q = (data.sum(0) * precision_x + loc_p * precision_p) / precision_q
log_scale_q = -0.5 * np.log(precision_q)


def model(params, key):
del params
key_z, key_next = random.split(key)
z = coix.rv(dist.Normal(loc_p, scale_p), name="z")(key_z)
z = jnp.broadcast_to(z, (num_data, dim))
x = coix.rv(dist.Normal(z, scale_x), obs=data, name="x")
return key_next, z, x


def guide(params, key, *args):
del args
key, _ = random.split(key) # split here to test tie_in
scale_q = jnp.exp(params["log_scale_q"])
z = coix.rv(dist.Normal(params["loc_q"], scale_q), name="z")(key)
return z


def check_ess(make_program):
params = {"loc_q": loc_q, "log_scale_q": log_scale_q}
p = jax.vmap(functools.partial(model, params))
q = jax.vmap(functools.partial(guide, params))
program = make_program(p, q)

keys = random.split(random.PRNGKey(0), 5)
ess = coix.traced_evaluate(program)(keys)[2]["ess"]
np.testing.assert_allclose(ess, 5.0)


def run_inference(make_program, num_steps=1000):
"""Performs inference given an algorithm `make_program`."""

def loss_fn(params, key):
p = jax.vmap(functools.partial(model, params))
q = jax.vmap(functools.partial(guide, params))
program = make_program(p, q)

keys = random.split(key, 5)
metrics = coix.traced_evaluate(program)(keys)[2]
return metrics["loss"], metrics

init_params = {
"loc_q": jnp.zeros_like(loc_q),
"log_scale_q": jnp.zeros_like(log_scale_q),
}
params, _ = coix.util.train(
loss_fn, init_params, optax.adam(0.01), num_steps=num_steps
)

np.testing.assert_allclose(params["loc_q"], loc_q, atol=0.2)
np.testing.assert_allclose(params["log_scale_q"], log_scale_q, atol=0.2)


def test_apgs():
check_ess(lambda p, q: coix.algo.apgs(p, [q]))
run_inference(lambda p, q: coix.algo.apgs(p, [q]))


def test_rws():
check_ess(coix.algo.rws)
run_inference(coix.algo.rws)


def test_svi_elbo():
check_ess(coix.algo.svi)
run_inference(coix.algo.svi)


def test_svi_iwae():
check_ess(coix.algo.svi_iwae)
run_inference(coix.algo.svi_iwae)


def test_svi_stl():
check_ess(coix.algo.svi_stl)
run_inference(coix.algo.svi_stl)
69 changes: 43 additions & 26 deletions coix/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,14 @@ def wrapped(*args, **kwargs):
log_probs = list(p_log_probs.values()) + list(q_log_probs.values())
batch_ndims = util.get_batch_ndims(log_probs)

assert "log_weight" in q_metrics
in_log_weight = q_metrics["log_weight"]
in_log_weight = jnp.sum(
in_log_weight,
axis=tuple(range(batch_ndims - jnp.ndim(in_log_weight), 0)),
)
if "log_weight" in q_metrics:
in_log_weight = q_metrics["log_weight"]
in_log_weight = jnp.sum(
in_log_weight,
axis=tuple(range(batch_ndims - jnp.ndim(in_log_weight), 0)),
)
else:
in_log_weight = util.get_log_weight(q_trace, batch_ndims)
p_log_weight = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in p_log_probs.items()
Expand All @@ -140,7 +142,7 @@ def wrapped(*args, **kwargs):
# Note: We include superfluous variables, whose `name in p_trace`.
q_log_weight = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in q_log_probs.items()
for lp in q_log_probs.values()
)
incremental_log_weight = p_log_weight - q_log_weight
log_weight = in_log_weight + incremental_log_weight
Expand Down Expand Up @@ -193,12 +195,20 @@ def _maybe_get_along_first_axis(x, idx, n, squeeze=False):
x = np.array(x)
# Special treatment for cascades.
if hasattr(x, "value"):
x.value = _maybe_get_along_first_axis(
util.get_site_value(x), idx, n, squeeze=squeeze
setattr(
x,
"value",
_maybe_get_along_first_axis(
util.get_site_value(x), idx, n, squeeze=squeeze
),
)
if hasattr(x, "log_density"):
x.log_density = _maybe_get_along_first_axis(
util.get_site_log_prob(x), idx, n, squeeze=squeeze
setattr(
x,
"log_density",
_maybe_get_along_first_axis(
util.get_site_log_prob(x), idx, n, squeeze=squeeze
),
)
if (
isinstance(x, (np.ndarray, jnp.ndarray))
Expand Down Expand Up @@ -233,7 +243,7 @@ def fn(*args, **kwargs):
if util.can_extract_key(args):
key_r, key_q = _split_key(args[0])
# We just need a single key for resampling.
key_r = key_r.reshape((-1, 2)).sum(0)
key_r = key_r.reshape((-1, 2))[0]
args = (key_q,) + args[1:]
else:
key_r = core.prng_key()
Expand Down Expand Up @@ -296,12 +306,17 @@ def _add_missing_metrics(metrics, trace):
batch_ndims = min(util.get_batch_ndims(list(log_probs.values())), 1)
log_weight = util.get_log_weight(trace, batch_ndims)
full_metrics["log_weight"] = log_weight
if batch_ndims: # leftmost dimension is particle dimension
ess = 1 / (jax.nn.softmax(log_weight, axis=0) ** 2).sum(0)
full_metrics["ess"] = ess.mean()
n = log_weight.shape[0]
log_z = jax.scipy.special.logsumexp(log_weight, 0) - jnp.log(n)
full_metrics["log_Z"] = log_z.mean()
else:
batch_ndims = metrics["log_weight"].ndim
log_weight = metrics["log_weight"]
# leftmost dimension is particle dimension
if batch_ndims and "ess" not in metrics:
assert "log_Z" not in metrics
ess = 1 / (jax.nn.softmax(log_weight, axis=0) ** 2).sum(0)
full_metrics["ess"] = ess.mean()
n = log_weight.shape[0]
log_z = jax.scipy.special.logsumexp(log_weight, 0) - jnp.log(n)
full_metrics["log_Z"] = log_z.mean()
if "loss" not in metrics:
full_metrics["loss"] = jnp.array(0.0)
if "log_density" not in metrics:
Expand All @@ -325,17 +340,18 @@ def fori_loop(lower, upper, body_fun, init_program):
"""

def fn(*args, **kwargs):
if util.can_extract_key(args):
key = args[0]
def trace_arg_key(fn, key):
return core.traced_evaluate(fn)(key, *args[1:], **kwargs)

def trace_fn(fn, key):
return core.traced_evaluate(fn)(key, *args[1:], **kwargs)
def trace_with_seed(fn, key):
return core.traced_evaluate(fn, seed=key)(*args, **kwargs)

if util.can_extract_key(args):
key = args[0]
trace_fn = trace_arg_key
else:
key = core.prng_key()

def trace_fn(fn, key):
return core.traced_evaluate(fn, seed=key)(*args, **kwargs)
trace_fn = trace_with_seed

key_body, key_init = _split_key(key)

Expand Down Expand Up @@ -406,7 +422,7 @@ def wrapped(*args, **kwargs):

p_log_weight = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in p_log_probs.items()
for lp in p_log_probs.values()
)

marginal_trace = {
Expand All @@ -417,6 +433,7 @@ def wrapped(*args, **kwargs):
new_memory = {
name: util.get_site_value(site) for name, site in marginal_trace.items()
}
assert not isinstance(p_log_weight, int)
num_particles = p_log_weight.shape[0]
batch_dim = p_log_weight.ndim
flat_memory = {
Expand Down
Loading

0 comments on commit 0d264c6

Please sign in to comment.