Skip to content

Commit

Permalink
update gmm example
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Apr 4, 2024
1 parent 21a64b3 commit 3ce8c3d
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 94 deletions.
Binary file added docs/_static/gmm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/gmm_oryx.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 0 additions & 2 deletions examples/bmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
"""

# TODO: Refactor using numpyro backend. The current code is likely not working yet.

import argparse
from functools import partial
import sys
Expand Down
5 changes: 4 additions & 1 deletion examples/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
1. Wu, Hao, et al. Amortized population Gibbs samplers with neural
sufficient statistics. ICML 2020.
.. image:: ../_static/dmm.png
:align: center
"""

import argparse
Expand Down Expand Up @@ -142,7 +145,7 @@ def __call__(self, x):
x = nn.tanh(x)
x = nn.Dense(2)(x)
angle = x / jnp.linalg.norm(x, axis=-1, keepdims=True)
radius = 1. # self.param("radius", nn.initializers.ones, (1,))
radius = 1.0 # self.param("radius", nn.initializers.ones, (1,))
return radius * angle


Expand Down
3 changes: 3 additions & 0 deletions examples/dmm_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
1. Wu, Hao, et al. Amortized population Gibbs samplers with neural
sufficient statistics. ICML 2020.
.. image:: ../_static/dmm_oryx.png
:align: center
"""

import argparse
Expand Down
125 changes: 56 additions & 69 deletions examples/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
1. Wu, Hao, et al. Amortized population Gibbs samplers with neural
sufficient statistics. ICML 2020.
.. image:: ../_static/gmm_oryx.png
:align: center
"""

import argparse
Expand Down Expand Up @@ -121,6 +124,12 @@ def __call__(self, x):
return logits + jnp.log(jnp.ones(3) / 3)


def broadcast_concatenate(*xs):
shape = jnp.broadcast_shapes(*[x.shape[:-1] for x in xs])
xs = [jnp.broadcast_to(x, shape + x.shape[-1:]) for x in xs]
return jnp.concatenate(xs, -1)


class GMMEncoder(nn.Module):

def setup(self):
Expand All @@ -133,12 +142,7 @@ def __call__(self, x): # N x D
alpha, beta, mean, _ = self.encode_initial_mean_tau(x) # M x D
tau = alpha / beta # M x D

concatenate_fn = lambda x, m, t: jnp.concatenate(
[x, m, t], axis=-1
) # N x M x 3D
xmt = jax.vmap(
jax.vmap(concatenate_fn, in_axes=(None, 0, 0)), in_axes=(0, None, None)
)(x, mean, tau)
xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau)
logits = self.encode_c(xmt) # N x D
c = jnp.argmax(logits, -1) # N

Expand All @@ -150,76 +154,68 @@ def __call__(self, x): # N x D
# Then, we define the target and kernels as in Section 6.2.


def gmm_target(network, inputs):
def gmm_target(inputs):
tau = numpyro.sample("tau", dist.Gamma(2, 2).expand([3, 2]).to_event())
mean = numpyro.sample(
"mean", dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)).to_event()
)
with numpyro.plate("N", inputs.shape[-2], dim=-1):
tau = numpyro.sample("tau", dist.Gamma(2, 2).expand([3, 2]).to_event())
mean = numpyro.sample("mean", dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)))
c = numpyro.sample("c", dist.Categorical(probs=jnp.ones(4) / 4))
loc = Vindex(mean)[..., c, :]
scale = 1 / jnp.sqrt(Vindex(tau)[..., c, :])
x = numpyro.sample("x", dist.Normal(loc, scale).to_event(1), obs=inputs)

out = {"mean": mean, "tau": tau, "c": c, "x": x}
return out,


def dmm_target(network, inputs):
mu = numpyro.sample("mu", dist.Normal(0, 10).expand([4, 2]).to_event())
with numpyro.plate("N", inputs.shape[-2], dim=-1):
c = numpyro.sample("c", dist.Categorical(probs=jnp.ones(4) / 4))
h = numpyro.sample("h", dist.Beta(1, 1))
x_recon = network.decode_h(h) + Vindex(mu)[..., c, :]
x = numpyro.sample("x", dist.Normal(x_recon, 0.1).to_event(1), obs=inputs)

out = {"mu": mu, "c": c, "h": h, "x_recon": x_recon, "x": x}
return (out,)


def gmm_kernel_mean_tau(network, key, inputs):
def gmm_kernel_mean_tau(network, inputs):
if not isinstance(inputs, dict):
inputs = {"x": inputs}
key_out, key_mean, key_tau = random.split(key, 3)

if "c" in inputs:
x = inputs["x"]
c = jax.nn.one_hot(inputs["c"], 3)
xc = jnp.concatenate([inputs["x"], c], -1)
xc = broadcast_concatenate(x, c)
alpha, beta, mu, nu = network.encode_mean_tau(xc)
else:
alpha, beta, mu, nu = network.encode_initial_mean_tau(inputs["x"])
tau = coix.rv(dist.Gamma(alpha, beta), name="tau")(key_tau)
mean = coix.rv(dist.Normal(mu, 1 / jnp.sqrt(tau * nu)), name="mean")(key_mean)
alpha, beta, mu, nu = jax.tree_util.tree_map(
lambda x: jnp.expand_dims(x, -3), (alpha, beta, mu, nu)
)
tau = numpyro.sample("tau", dist.Gamma(alpha, beta).to_event(2))
mean = numpyro.sample(
"mean", dist.Normal(mu, 1 / jnp.sqrt(tau * nu)).to_event(2)
)

out = {**inputs, **{"mean": mean, "tau": tau}}
return key_out, out

return (out,)

def gmm_kernel_c(network, key, inputs):
key_out, key_c = random.split(key, 2)

concatenate_fn = lambda x, m, t: jnp.concatenate([x, m, t], axis=-1)
xmt = jax.vmap(
jax.vmap(concatenate_fn, in_axes=(None, 0, 0)), in_axes=(0, None, None)
)(inputs["x"], inputs["mean"], inputs["tau"])
def gmm_kernel_c(network, inputs):
x, mean, tau = inputs["x"], inputs["mean"], inputs["tau"]
xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau)
logits = network.encode_c(xmt)
c = coix.rv(dist.Categorical(logits=logits), name="c")(key_c)
with numpyro.plate("N", logits.shape[-2], dim=-1):
c = numpyro.sample("c", dist.Categorical(logits=logits))

out = {**inputs, **{"c": c}}
return key_out, out
return (out,)


# %%
# Finally, we create the gmm inference program, define the loss function,
# run the training loop, and plot the results.


def make_gmm(params, num_sweeps):
def make_gmm(params, num_sweeps, num_particles):
network = coix.util.BindModule(GMMEncoder(), params)
# Add particle dimension and construct a program.
target = jax.vmap(partial(gmm_target, network))
kernels = [
jax.vmap(partial(gmm_kernel_mean_tau, network)),
jax.vmap(partial(gmm_kernel_c, network)),
]
make_particle_plate = lambda: numpyro.plate("particle", num_particles, dim=-3)
target = make_particle_plate()(gmm_target)
kernel_mean_tau = make_particle_plate()(partial(gmm_kernel_mean_tau, network))
kernel_c = make_particle_plate()(partial(gmm_kernel_c, network))
kernels = [kernel_mean_tau, kernel_c]
program = coix.algo.apgs(target, kernels, num_sweeps=num_sweeps)
return program

Expand All @@ -228,16 +224,12 @@ def loss_fn(params, key, batch, num_sweeps, num_particles):
# Prepare data for the program.
shuffle_rng, rng_key = random.split(key)
batch = random.permutation(shuffle_rng, batch, axis=1)
batch_rng = random.split(rng_key, batch.shape[0])
batch = jnp.repeat(batch[:, None], num_particles, axis=1)
rng_keys = jax.vmap(partial(random.split, num=num_particles))(batch_rng)

# Run the program and get metrics.
program = make_gmm(params, num_sweeps)
_, _, metrics = jax.vmap(coix.traced_evaluate(program))(rng_keys, batch)
metrics = jax.tree_util.tree_map(
partial(jnp.mean, axis=0), metrics
) # mean across batch
program = make_gmm(params, num_sweeps, num_particles)
_, _, metrics = coix.traced_evaluate(program, seed=rng_key)(batch)
for metric_name in ["log_Z", "log_density", "loss"]:
metrics[metric_name] = metrics[metric_name] / batch.shape[0]
return metrics["loss"], metrics


Expand All @@ -260,32 +252,28 @@ def main(args):
train_ds,
)

program = make_gmm(gmm_params, num_sweeps)
program = make_gmm(gmm_params, num_sweeps, num_particles)
batch, label = next(test_ds)
batch = jnp.repeat(batch[:, None], num_particles, axis=1)
rng_keys = jax.vmap(partial(random.split, num=num_particles))(
random.split(jax.random.PRNGKey(1), batch.shape[0])
)
_, out = jax.vmap(program)(rng_keys, batch)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for i in range(3):
n = i
axes[i].scatter(
batch[n, 0, :, 0],
batch[n, 0, :, 1],
out, _, _ = coix.traced_evaluate(program, seed=jax.random.PRNGKey(1))(batch)
out = out[0]

_, axes = plt.subplots(2, 3, figsize=(15, 10))
for i in range(6):
axes[i // 3][i % 3].scatter(
batch[i, :, 0],
batch[i, :, 1],
marker=".",
color=np.array(["c", "m", "y"])[label[n]],
color=np.array(["c", "m", "y"])[label[i]],
)
for j, c in enumerate(["r", "g", "b"]):
ellipse = Ellipse(
xy=(out["mean"][n, 0, j, 0], out["mean"][n, 0, j, 1]),
width=4 / jnp.sqrt(out["tau"][n, 0, j, 0]),
height=4 / jnp.sqrt(out["tau"][n, 0, j, 1]),
xy=(out["mean"][0, i, 0, j, 0], out["mean"][0, i, 0, j, 1]),
width=4 / jnp.sqrt(out["tau"][0, i, 0, j, 0]),
height=4 / jnp.sqrt(out["tau"][0, i, 0, j, 1]),
fc=c,
alpha=0.3,
)
axes[i].add_patch(ellipse)
axes[i // 3][i % 3].add_patch(ellipse)
plt.show()


Expand All @@ -303,6 +291,5 @@ def main(args):

tf.config.experimental.set_visible_devices([], "GPU") # Disable GPU for TF.
numpyro.set_platform(args.device)
coix.set_backend("coix.oryx")

main(args)
43 changes: 21 additions & 22 deletions examples/gmm_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ def __call__(self, x):
return logits + jnp.log(jnp.ones(3) / 3)


def broadcast_concatenate(*xs):
shape = jnp.broadcast_shapes(*[x.shape[:-1] for x in xs])
xs = [jnp.broadcast_to(x, shape + x.shape[-1:]) for x in xs]
return jnp.concatenate(xs, -1)


class GMMEncoder(nn.Module):

def setup(self):
Expand All @@ -136,12 +142,7 @@ def __call__(self, x): # N x D
alpha, beta, mean, _ = self.encode_initial_mean_tau(x) # M x D
tau = alpha / beta # M x D

concatenate_fn = lambda x, m, t: jnp.concatenate(
[x, m, t], axis=-1
) # N x M x 3D
xmt = jax.vmap(
jax.vmap(concatenate_fn, in_axes=(None, 0, 0)), in_axes=(0, None, None)
)(x, mean, tau)
xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau)
logits = self.encode_c(xmt) # N x D
c = jnp.argmax(logits, -1) # N

Expand Down Expand Up @@ -174,8 +175,9 @@ def gmm_kernel_mean_tau(network, key, inputs):
key_out, key_mean, key_tau = random.split(key, 3)

if "c" in inputs:
x = inputs["x"]
c = jax.nn.one_hot(inputs["c"], 3)
xc = jnp.concatenate([inputs["x"], c], -1)
xc = jnp.concatenate([x, c], -1)
alpha, beta, mu, nu = network.encode_mean_tau(xc)
else:
alpha, beta, mu, nu = network.encode_initial_mean_tau(inputs["x"])
Expand All @@ -191,10 +193,8 @@ def gmm_kernel_mean_tau(network, key, inputs):
def gmm_kernel_c(network, key, inputs):
key_out, key_c = random.split(key, 2)

concatenate_fn = lambda x, m, t: jnp.concatenate([x, m, t], axis=-1)
xmt = jax.vmap(
jax.vmap(concatenate_fn, in_axes=(None, 0, 0)), in_axes=(0, None, None)
)(inputs["x"], inputs["mean"], inputs["tau"])
x, mean, tau = inputs["x"], inputs["mean"], inputs["tau"]
xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau)
logits = network.encode_c(xmt)
c = coryx.rv(dist.Categorical(logits=logits), name="c")(key_c)

Expand Down Expand Up @@ -263,24 +263,23 @@ def main(args):
)
_, out = jax.vmap(program)(rng_keys, batch)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for i in range(3):
n = i
axes[i].scatter(
batch[n, 0, :, 0],
batch[n, 0, :, 1],
_, axes = plt.subplots(2, 3, figsize=(15, 10))
for i in range(6):
axes[i // 3][i % 3].scatter(
batch[i, 0, :, 0],
batch[i, 0, :, 1],
marker=".",
color=np.array(["c", "m", "y"])[label[n]],
color=np.array(["c", "m", "y"])[label[i]],
)
for j, c in enumerate(["r", "g", "b"]):
ellipse = Ellipse(
xy=(out["mean"][n, 0, j, 0], out["mean"][n, 0, j, 1]),
width=4 / jnp.sqrt(out["tau"][n, 0, j, 0]),
height=4 / jnp.sqrt(out["tau"][n, 0, j, 1]),
xy=(out["mean"][i, 0, j, 0], out["mean"][i, 0, j, 1]),
width=4 / jnp.sqrt(out["tau"][i, 0, j, 0]),
height=4 / jnp.sqrt(out["tau"][i, 0, j, 1]),
fc=c,
alpha=0.3,
)
axes[i].add_patch(ellipse)
axes[i // 3][i % 3].add_patch(ellipse)
plt.show()


Expand Down

0 comments on commit 3ce8c3d

Please sign in to comment.