Skip to content

Commit

Permalink
port dmm example to numpyro
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Apr 2, 2024
1 parent 07f0963 commit 11425c5
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 60 deletions.
5 changes: 3 additions & 2 deletions coix/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def wrapped(*args, **kwargs):
if site["type"] == "sample":
value = site["value"]
log_prob = site["fn"].log_prob(value)
trace[name] = {"value": value, "log_prob": log_prob}
event_dim_holder = jnp.empty([1] * site["fn"].event_dim)
trace[name] = {"value": value, "log_prob": log_prob, "_event_dim_holder": event_dim_holder}
if site.get("is_observed", False):
trace[name]["is_observed"] = True
metrics = {
Expand Down Expand Up @@ -83,7 +84,7 @@ def wrapped(*args, **kwargs):
del args, kwargs
for name, site in trace.items():
value, lp = site["value"], site["log_prob"]
event_dim = jnp.ndim(value) - jnp.ndim(lp)
event_dim = jnp.ndim(site["_event_dim_holder"])
obs = value if "is_observed" in site else None
numpyro.sample(name, dist.Delta(value, lp, event_dim=event_dim), obs=obs)
for name, value in metrics.items():
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ Coix Documentation
examples/gmm
examples/dmm
examples/bmnist
examples/anneal_oryx
examples/gmm_oryx
examples/dmm_oryx
examples/anneal_oryx

Indices and tables
==================
Expand Down
3 changes: 3 additions & 0 deletions examples/anneal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
1. Zimmermann, Heiko, et al. "Nested variational inference." NeuRIPS 2021.
.. image:: ../_static/anneal.png
:align: center
"""

import argparse
Expand Down
5 changes: 4 additions & 1 deletion examples/anneal_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
1. Zimmermann, Heiko, et al. "Nested variational inference." NeuRIPS 2021.
.. image:: ../_static/anneal_oryx.png
:align: center
"""

import argparse
Expand Down Expand Up @@ -119,7 +122,7 @@ def __call__(self, x):
def anneal_target(network, key, k=0):
key_out, key = random.split(key)
x = coryx.rv(dist.Normal(0, 5).expand([2]).mask(False), name="x")(key)
coix.factor(network.anneal_density(x, index=k), name="anneal_density")
coryx.factor(network.anneal_density(x, index=k), name="anneal_density")
return key_out, {"x": x}


Expand Down
94 changes: 42 additions & 52 deletions examples/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.ops.indexing import Vindex
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
Expand Down Expand Up @@ -69,7 +70,7 @@ def load_dataset(split, *, is_training, batch_size):
seed = 0 if is_training else 1
data = simulate_rings(num_data, num_points, seed=seed)
ds = tf.data.Dataset.from_tensor_slices(data)
ds = ds.cache().repeat()
ds = ds.repeat()
if is_training:
ds = ds.shuffle(10 * batch_size, seed=0)
ds = ds.batch(batch_size)
Expand Down Expand Up @@ -137,7 +138,8 @@ def __call__(self, x):
x = nn.tanh(x)
x = nn.Dense(2)(x)
angle = x / jnp.linalg.norm(x, axis=-1, keepdims=True)
return angle
radius = self.param("radius", nn.initializers.ones, (1,))
return radius * angle


class DMMAutoEncoder(nn.Module):
Expand All @@ -153,8 +155,9 @@ def __call__(self, x): # N x D
# Heuristic procedure to setup initial parameters.
mu, _ = self.encode_initial_mu(x) # M x D

concatenate_fn = lambda x, m: jnp.concatenate([x, m], axis=-1)
xmu = jax.vmap(jax.vmap(concatenate_fn, (None, 0)), (0, None))(x, mu)
# concatenate_fn = lambda x, m: jnp.concatenate([x, m], axis=-1)
# xmu = jax.vmap(jax.vmap(concatenate_fn, (None, 0)), (0, None))(x, mu)
xmu = jnp.expand_dims(x, -2) - mu
logits = self.encode_c(xmu) # N x M
c = jnp.argmax(logits, -1) # N

Expand All @@ -174,67 +177,63 @@ def __call__(self, x): # N x D
# Then, we define the target and kernels as in Section 6.3.


def dmm_target(network, key, inputs):
key_out, key_mu, key_c, key_h = random.split(key, 4)
N = inputs.shape[-2]

mu = coix.rv(dist.Normal(0, 10).expand([4, 2]), name="mu")(key_mu)
c = coix.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c)
h = coix.rv(dist.Beta(1, 1).expand([N]), name="h")(key_h)
x_recon = mu[c] + network.decode_h(h)
x = coix.rv(dist.Normal(x_recon, 0.1), obs=inputs, name="x")
def dmm_target(network, inputs):
mu = numpyro.sample("mu", dist.Normal(0, 10).expand([4, 2 ]).to_event())
with numpyro.plate("N", inputs.shape[-2], dim=-1):
c = numpyro.sample("c", dist.Categorical(probs=jnp.ones(4) / 4))
h = numpyro.sample("h", dist.Beta(1, 1))
x_recon = network.decode_h(h) + Vindex(mu)[..., c, :]
x = numpyro.sample("x", dist.Normal(x_recon, 0.1).to_event(1), obs=inputs)

out = {"mu": mu, "c": c, "h": h, "x_recon": x_recon, "x": x}
return key_out, out
return out,


def dmm_kernel_mu(network, key, inputs):
def dmm_kernel_mu(network, inputs):
if not isinstance(inputs, dict):
inputs = {"x": inputs}
key_out, key_mu = random.split(key)

if "c" in inputs:
x = jnp.broadcast_to(inputs["x"], inputs["h"].shape + (2,))
c = jax.nn.one_hot(inputs["c"], 4)
h = jnp.expand_dims(inputs["h"], -1)
xch = jnp.concatenate([inputs["x"], c, h], -1)
xch = jnp.concatenate([x, c, h], -1)
loc, scale = network.encode_mu(xch)
else:
loc, scale = network.encode_initial_mu(inputs["x"])
mu = coix.rv(dist.Normal(loc, scale), name="mu")(key_mu)
loc, scale = jnp.expand_dims(loc, -3), jnp.expand_dims(scale, -3)
mu = numpyro.sample("mu", dist.Normal(loc, scale).to_event(2))

out = {**inputs, **{"mu": mu}}
return key_out, out

return out,

def dmm_kernel_c_h(network, key, inputs):
key_out, key_c, key_h = random.split(key, 3)

concatenate_fn = lambda x, m: jnp.concatenate([x, m], axis=-1)
xmu = jax.vmap(jax.vmap(concatenate_fn, (None, 0)), (0, None))(
inputs["x"], inputs["mu"]
)
def dmm_kernel_c_h(network, inputs):
x, mu = inputs["x"], inputs["mu"]
xmu = jnp.expand_dims(x, -2) - mu
logits = network.encode_c(xmu)
c = coix.rv(dist.Categorical(logits=logits), name="c")(key_c)
alpha, beta = network.encode_h(inputs["x"] - inputs["mu"][c])
h = coix.rv(dist.Beta(alpha, beta), name="h")(key_h)
with numpyro.plate("N", logits.shape[-2], dim=-1):
c = numpyro.sample("c", dist.Categorical(logits=logits))
alpha, beta = network.encode_h(inputs["x"] - Vindex(mu)[..., c, :])
h = numpyro.sample("h", dist.Beta(alpha, beta))

out = {**inputs, **{"c": c, "h": h}}
return key_out, out
return out,


# %%
# Finally, we create the dmm inference program, define the loss function,
# run the training loop, and plot the results.


def make_dmm(params, num_sweeps):
def make_dmm(params, num_sweeps=5, num_particles=10):
network = coix.util.BindModule(DMMAutoEncoder(), params)
# Add particle dimension and construct a program.
target = jax.vmap(partial(dmm_target, network))
kernels = [
jax.vmap(partial(dmm_kernel_mu, network)),
jax.vmap(partial(dmm_kernel_c_h, network)),
]
make_particle_plate = lambda: numpyro.plate("particle", num_particles, dim=-3)
target = make_particle_plate()(partial(dmm_target, network))
kernel_mu = make_particle_plate()(partial(dmm_kernel_mu, network))
kernel_c_h = make_particle_plate()(partial(dmm_kernel_c_h, network))
kernels = [kernel_mu, kernel_c_h]
program = coix.algo.apgs(target, kernels, num_sweeps=num_sweeps)
return program

Expand All @@ -243,16 +242,12 @@ def loss_fn(params, key, batch, num_sweeps, num_particles):
# Prepare data for the program.
shuffle_rng, rng_key = random.split(key)
batch = random.permutation(shuffle_rng, batch, axis=1)
batch_rng = random.split(rng_key, batch.shape[0])
batch = jnp.repeat(batch[:, None], num_particles, axis=1)
rng_keys = jax.vmap(partial(random.split, num=num_particles))(batch_rng)

# Run the program and get metrics.
program = make_dmm(params, num_sweeps)
_, _, metrics = jax.vmap(coix.traced_evaluate(program))(rng_keys, batch)
metrics = jax.tree_util.tree_map(
partial(jnp.mean, axis=0), metrics
) # mean across batch
program = make_dmm(params, num_sweeps, num_particles)
_, _, metrics = coix.traced_evaluate(program, seed=rng_key)(batch)
for metric_name in ["log_Z", "log_density", "loss"]:
metrics[metric_name] = metrics[metric_name] / batch.shape[0]
return metrics["loss"], metrics


Expand All @@ -278,12 +273,8 @@ def main(args):
)

program = make_dmm(dmm_params, num_sweeps)
batch = jnp.repeat(next(test_ds)[:, None], num_particles, axis=1)
rng_keys = jax.vmap(partial(random.split, num=num_particles))(
random.split(jax.random.PRNGKey(1), batch.shape[0])
)
_, out = jax.vmap(program)(rng_keys, batch)
batch.shape, out["x_recon"].shape
batch = next(test_ds)
out, _, _ = coix.traced_evaluate(program, seed=jax.random.PRNGKey(1))(batch)

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for i in range(3):
Expand Down Expand Up @@ -312,14 +303,13 @@ def main(args):
parser.add_argument("--num-sweeps", nargs="?", default=5, type=int)
parser.add_argument("--num_particles", nargs="?", default=10, type=int)
parser.add_argument("--learning-rate", nargs="?", default=1e-4, type=float)
parser.add_argument("--num-steps", nargs="?", default=300000, type=int)
parser.add_argument("--num-steps", nargs="?", default=30000, type=int)
parser.add_argument(
"--device", default="gpu", type=str, help='use "cpu" or "gpu".'
)
args = parser.parse_args()

tf.config.experimental.set_visible_devices([], "GPU") # Disable GPU for TF.
numpyro.set_platform(args.device)
coix.set_backend("coix.oryx")

main(args)
10 changes: 6 additions & 4 deletions examples/dmm_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,10 @@ def dmm_kernel_c_h(network, key, inputs):
return key_out, out


### Train
# %%
# Finally, we create the dmm inference program, define the loss function,
# run the training loop, and plot the results. Note that we are using
# 10x less steps than the paper.


def make_dmm(params, num_sweeps):
Expand Down Expand Up @@ -282,7 +285,6 @@ def main(args):
random.split(jax.random.PRNGKey(1), batch.shape[0])
)
_, out = jax.vmap(program)(rng_keys, batch)
batch.shape, out["x_recon"].shape

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for i in range(3):
Expand All @@ -308,10 +310,10 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Annealing example")
parser.add_argument("--batch-size", nargs="?", default=20, type=int)
parser.add_argument("--num-sweeps", nargs="?", default=5, type=int)
parser.add_argument("--num-sweeps", nargs="?", default=8, type=int)
parser.add_argument("--num_particles", nargs="?", default=10, type=int)
parser.add_argument("--learning-rate", nargs="?", default=1e-4, type=float)
parser.add_argument("--num-steps", nargs="?", default=300000, type=int)
parser.add_argument("--num-steps", nargs="?", default=30000, type=int)
parser.add_argument(
"--device", default="gpu", type=str, help='use "cpu" or "gpu".'
)
Expand Down
3 changes: 3 additions & 0 deletions examples/gmm_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
1. Wu, Hao, et al. Amortized population Gibbs samplers with neural
sufficient statistics. ICML 2020.
.. image:: ../_static/gmm_oryx.png
:align: center
"""

import argparse
Expand Down

0 comments on commit 11425c5

Please sign in to comment.