diff --git a/docs/_static/gmm.png b/docs/_static/gmm.png new file mode 100644 index 0000000..25eef27 Binary files /dev/null and b/docs/_static/gmm.png differ diff --git a/docs/_static/gmm_oryx.png b/docs/_static/gmm_oryx.png index 3119066..fceb69e 100644 Binary files a/docs/_static/gmm_oryx.png and b/docs/_static/gmm_oryx.png differ diff --git a/examples/bmnist.py b/examples/bmnist.py index 255b5ad..e37046e 100644 --- a/examples/bmnist.py +++ b/examples/bmnist.py @@ -28,8 +28,6 @@ """ -# TODO: Refactor using numpyro backend. The current code is likely not working yet. - import argparse from functools import partial import sys diff --git a/examples/dmm.py b/examples/dmm.py index a20fc52..d453b83 100644 --- a/examples/dmm.py +++ b/examples/dmm.py @@ -25,6 +25,9 @@ 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural sufficient statistics. ICML 2020. +.. image:: ../_static/dmm.png + :align: center + """ import argparse @@ -142,7 +145,7 @@ def __call__(self, x): x = nn.tanh(x) x = nn.Dense(2)(x) angle = x / jnp.linalg.norm(x, axis=-1, keepdims=True) - radius = 1. # self.param("radius", nn.initializers.ones, (1,)) + radius = 1.0 # self.param("radius", nn.initializers.ones, (1,)) return radius * angle diff --git a/examples/dmm_oryx.py b/examples/dmm_oryx.py index 14ec38b..28f2851 100644 --- a/examples/dmm_oryx.py +++ b/examples/dmm_oryx.py @@ -25,6 +25,9 @@ 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural sufficient statistics. ICML 2020. +.. image:: ../_static/dmm_oryx.png + :align: center + """ import argparse diff --git a/examples/gmm.py b/examples/gmm.py index 48fbdd2..d57a007 100644 --- a/examples/gmm.py +++ b/examples/gmm.py @@ -25,6 +25,9 @@ 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural sufficient statistics. ICML 2020. +.. image:: ../_static/gmm_oryx.png + :align: center + """ import argparse @@ -121,6 +124,12 @@ def __call__(self, x): return logits + jnp.log(jnp.ones(3) / 3) +def broadcast_concatenate(*xs): + shape = jnp.broadcast_shapes(*[x.shape[:-1] for x in xs]) + xs = [jnp.broadcast_to(x, shape + x.shape[-1:]) for x in xs] + return jnp.concatenate(xs, -1) + + class GMMEncoder(nn.Module): def setup(self): @@ -133,12 +142,7 @@ def __call__(self, x): # N x D alpha, beta, mean, _ = self.encode_initial_mean_tau(x) # M x D tau = alpha / beta # M x D - concatenate_fn = lambda x, m, t: jnp.concatenate( - [x, m, t], axis=-1 - ) # N x M x 3D - xmt = jax.vmap( - jax.vmap(concatenate_fn, in_axes=(None, 0, 0)), in_axes=(0, None, None) - )(x, mean, tau) + xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau) logits = self.encode_c(xmt) # N x D c = jnp.argmax(logits, -1) # N @@ -150,61 +154,53 @@ def __call__(self, x): # N x D # Then, we define the target and kernels as in Section 6.2. -def gmm_target(network, inputs): +def gmm_target(inputs): + tau = numpyro.sample("tau", dist.Gamma(2, 2).expand([3, 2]).to_event()) + mean = numpyro.sample( + "mean", dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)).to_event() + ) with numpyro.plate("N", inputs.shape[-2], dim=-1): - tau = numpyro.sample("tau", dist.Gamma(2, 2).expand([3, 2]).to_event()) - mean = numpyro.sample("mean", dist.Normal(0, 1 / jnp.sqrt(tau * 0.1))) c = numpyro.sample("c", dist.Categorical(probs=jnp.ones(4) / 4)) loc = Vindex(mean)[..., c, :] scale = 1 / jnp.sqrt(Vindex(tau)[..., c, :]) x = numpyro.sample("x", dist.Normal(loc, scale).to_event(1), obs=inputs) out = {"mean": mean, "tau": tau, "c": c, "x": x} - return out, - - -def dmm_target(network, inputs): - mu = numpyro.sample("mu", dist.Normal(0, 10).expand([4, 2]).to_event()) - with numpyro.plate("N", inputs.shape[-2], dim=-1): - c = numpyro.sample("c", dist.Categorical(probs=jnp.ones(4) / 4)) - h = numpyro.sample("h", dist.Beta(1, 1)) - x_recon = network.decode_h(h) + Vindex(mu)[..., c, :] - x = numpyro.sample("x", dist.Normal(x_recon, 0.1).to_event(1), obs=inputs) - - out = {"mu": mu, "c": c, "h": h, "x_recon": x_recon, "x": x} return (out,) -def gmm_kernel_mean_tau(network, key, inputs): +def gmm_kernel_mean_tau(network, inputs): if not isinstance(inputs, dict): inputs = {"x": inputs} - key_out, key_mean, key_tau = random.split(key, 3) if "c" in inputs: + x = inputs["x"] c = jax.nn.one_hot(inputs["c"], 3) - xc = jnp.concatenate([inputs["x"], c], -1) + xc = broadcast_concatenate(x, c) alpha, beta, mu, nu = network.encode_mean_tau(xc) else: alpha, beta, mu, nu = network.encode_initial_mean_tau(inputs["x"]) - tau = coix.rv(dist.Gamma(alpha, beta), name="tau")(key_tau) - mean = coix.rv(dist.Normal(mu, 1 / jnp.sqrt(tau * nu)), name="mean")(key_mean) + alpha, beta, mu, nu = jax.tree_util.tree_map( + lambda x: jnp.expand_dims(x, -3), (alpha, beta, mu, nu) + ) + tau = numpyro.sample("tau", dist.Gamma(alpha, beta).to_event(2)) + mean = numpyro.sample( + "mean", dist.Normal(mu, 1 / jnp.sqrt(tau * nu)).to_event(2) + ) out = {**inputs, **{"mean": mean, "tau": tau}} - return key_out, out - + return (out,) -def gmm_kernel_c(network, key, inputs): - key_out, key_c = random.split(key, 2) - concatenate_fn = lambda x, m, t: jnp.concatenate([x, m, t], axis=-1) - xmt = jax.vmap( - jax.vmap(concatenate_fn, in_axes=(None, 0, 0)), in_axes=(0, None, None) - )(inputs["x"], inputs["mean"], inputs["tau"]) +def gmm_kernel_c(network, inputs): + x, mean, tau = inputs["x"], inputs["mean"], inputs["tau"] + xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau) logits = network.encode_c(xmt) - c = coix.rv(dist.Categorical(logits=logits), name="c")(key_c) + with numpyro.plate("N", logits.shape[-2], dim=-1): + c = numpyro.sample("c", dist.Categorical(logits=logits)) out = {**inputs, **{"c": c}} - return key_out, out + return (out,) # %% @@ -212,14 +208,14 @@ def gmm_kernel_c(network, key, inputs): # run the training loop, and plot the results. -def make_gmm(params, num_sweeps): +def make_gmm(params, num_sweeps, num_particles): network = coix.util.BindModule(GMMEncoder(), params) # Add particle dimension and construct a program. - target = jax.vmap(partial(gmm_target, network)) - kernels = [ - jax.vmap(partial(gmm_kernel_mean_tau, network)), - jax.vmap(partial(gmm_kernel_c, network)), - ] + make_particle_plate = lambda: numpyro.plate("particle", num_particles, dim=-3) + target = make_particle_plate()(gmm_target) + kernel_mean_tau = make_particle_plate()(partial(gmm_kernel_mean_tau, network)) + kernel_c = make_particle_plate()(partial(gmm_kernel_c, network)) + kernels = [kernel_mean_tau, kernel_c] program = coix.algo.apgs(target, kernels, num_sweeps=num_sweeps) return program @@ -228,16 +224,12 @@ def loss_fn(params, key, batch, num_sweeps, num_particles): # Prepare data for the program. shuffle_rng, rng_key = random.split(key) batch = random.permutation(shuffle_rng, batch, axis=1) - batch_rng = random.split(rng_key, batch.shape[0]) - batch = jnp.repeat(batch[:, None], num_particles, axis=1) - rng_keys = jax.vmap(partial(random.split, num=num_particles))(batch_rng) # Run the program and get metrics. - program = make_gmm(params, num_sweeps) - _, _, metrics = jax.vmap(coix.traced_evaluate(program))(rng_keys, batch) - metrics = jax.tree_util.tree_map( - partial(jnp.mean, axis=0), metrics - ) # mean across batch + program = make_gmm(params, num_sweeps, num_particles) + _, _, metrics = coix.traced_evaluate(program, seed=rng_key)(batch) + for metric_name in ["log_Z", "log_density", "loss"]: + metrics[metric_name] = metrics[metric_name] / batch.shape[0] return metrics["loss"], metrics @@ -260,32 +252,28 @@ def main(args): train_ds, ) - program = make_gmm(gmm_params, num_sweeps) + program = make_gmm(gmm_params, num_sweeps, num_particles) batch, label = next(test_ds) - batch = jnp.repeat(batch[:, None], num_particles, axis=1) - rng_keys = jax.vmap(partial(random.split, num=num_particles))( - random.split(jax.random.PRNGKey(1), batch.shape[0]) - ) - _, out = jax.vmap(program)(rng_keys, batch) - - fig, axes = plt.subplots(1, 3, figsize=(15, 5)) - for i in range(3): - n = i - axes[i].scatter( - batch[n, 0, :, 0], - batch[n, 0, :, 1], + out, _, _ = coix.traced_evaluate(program, seed=jax.random.PRNGKey(1))(batch) + out = out[0] + + _, axes = plt.subplots(2, 3, figsize=(15, 10)) + for i in range(6): + axes[i // 3][i % 3].scatter( + batch[i, :, 0], + batch[i, :, 1], marker=".", - color=np.array(["c", "m", "y"])[label[n]], + color=np.array(["c", "m", "y"])[label[i]], ) for j, c in enumerate(["r", "g", "b"]): ellipse = Ellipse( - xy=(out["mean"][n, 0, j, 0], out["mean"][n, 0, j, 1]), - width=4 / jnp.sqrt(out["tau"][n, 0, j, 0]), - height=4 / jnp.sqrt(out["tau"][n, 0, j, 1]), + xy=(out["mean"][0, i, 0, j, 0], out["mean"][0, i, 0, j, 1]), + width=4 / jnp.sqrt(out["tau"][0, i, 0, j, 0]), + height=4 / jnp.sqrt(out["tau"][0, i, 0, j, 1]), fc=c, alpha=0.3, ) - axes[i].add_patch(ellipse) + axes[i // 3][i % 3].add_patch(ellipse) plt.show() @@ -303,6 +291,5 @@ def main(args): tf.config.experimental.set_visible_devices([], "GPU") # Disable GPU for TF. numpyro.set_platform(args.device) - coix.set_backend("coix.oryx") main(args) diff --git a/examples/gmm_oryx.py b/examples/gmm_oryx.py index 5e5e4f5..3c7c643 100644 --- a/examples/gmm_oryx.py +++ b/examples/gmm_oryx.py @@ -124,6 +124,12 @@ def __call__(self, x): return logits + jnp.log(jnp.ones(3) / 3) +def broadcast_concatenate(*xs): + shape = jnp.broadcast_shapes(*[x.shape[:-1] for x in xs]) + xs = [jnp.broadcast_to(x, shape + x.shape[-1:]) for x in xs] + return jnp.concatenate(xs, -1) + + class GMMEncoder(nn.Module): def setup(self): @@ -136,12 +142,7 @@ def __call__(self, x): # N x D alpha, beta, mean, _ = self.encode_initial_mean_tau(x) # M x D tau = alpha / beta # M x D - concatenate_fn = lambda x, m, t: jnp.concatenate( - [x, m, t], axis=-1 - ) # N x M x 3D - xmt = jax.vmap( - jax.vmap(concatenate_fn, in_axes=(None, 0, 0)), in_axes=(0, None, None) - )(x, mean, tau) + xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau) logits = self.encode_c(xmt) # N x D c = jnp.argmax(logits, -1) # N @@ -174,8 +175,9 @@ def gmm_kernel_mean_tau(network, key, inputs): key_out, key_mean, key_tau = random.split(key, 3) if "c" in inputs: + x = inputs["x"] c = jax.nn.one_hot(inputs["c"], 3) - xc = jnp.concatenate([inputs["x"], c], -1) + xc = jnp.concatenate([x, c], -1) alpha, beta, mu, nu = network.encode_mean_tau(xc) else: alpha, beta, mu, nu = network.encode_initial_mean_tau(inputs["x"]) @@ -191,10 +193,8 @@ def gmm_kernel_mean_tau(network, key, inputs): def gmm_kernel_c(network, key, inputs): key_out, key_c = random.split(key, 2) - concatenate_fn = lambda x, m, t: jnp.concatenate([x, m, t], axis=-1) - xmt = jax.vmap( - jax.vmap(concatenate_fn, in_axes=(None, 0, 0)), in_axes=(0, None, None) - )(inputs["x"], inputs["mean"], inputs["tau"]) + x, mean, tau = inputs["x"], inputs["mean"], inputs["tau"] + xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau) logits = network.encode_c(xmt) c = coryx.rv(dist.Categorical(logits=logits), name="c")(key_c) @@ -263,24 +263,23 @@ def main(args): ) _, out = jax.vmap(program)(rng_keys, batch) - fig, axes = plt.subplots(1, 3, figsize=(15, 5)) - for i in range(3): - n = i - axes[i].scatter( - batch[n, 0, :, 0], - batch[n, 0, :, 1], + _, axes = plt.subplots(2, 3, figsize=(15, 10)) + for i in range(6): + axes[i // 3][i % 3].scatter( + batch[i, 0, :, 0], + batch[i, 0, :, 1], marker=".", - color=np.array(["c", "m", "y"])[label[n]], + color=np.array(["c", "m", "y"])[label[i]], ) for j, c in enumerate(["r", "g", "b"]): ellipse = Ellipse( - xy=(out["mean"][n, 0, j, 0], out["mean"][n, 0, j, 1]), - width=4 / jnp.sqrt(out["tau"][n, 0, j, 0]), - height=4 / jnp.sqrt(out["tau"][n, 0, j, 1]), + xy=(out["mean"][i, 0, j, 0], out["mean"][i, 0, j, 1]), + width=4 / jnp.sqrt(out["tau"][i, 0, j, 0]), + height=4 / jnp.sqrt(out["tau"][i, 0, j, 1]), fc=c, alpha=0.3, ) - axes[i].add_patch(ellipse) + axes[i // 3][i % 3].add_patch(ellipse) plt.show()