diff --git a/README.md b/README.md index d0d9f8a..2e0b309 100644 --- a/README.md +++ b/README.md @@ -6,4 +6,32 @@ Coix (COmbinators In jaX) is a flexible and backend-agnostic implementation of inference combinators [(Stites and Zimmermann et al., 2021)](https://arxiv.org/abs/2103.00668), a set of program transformations for compositional inference with probabilistic programs. Coix ships with backends for numpyro and oryx, and a set of pre-implemented losses and utility functions that allows to implement and run a wide variety of inference algorithms out-of-the-box. +Coix is a lightweight framework which includes the following main components: + +- **coix.api:** Implementation of the program combinators. +- **coix.core:** Basic program transformations which are used to modify behavior of a stochastic program. +- **coix.loss:** Common objectives for variational inference. +- **coix.algo:** Example inference algorithms. + +Currently, we support [numpyro](https://github.com/pyro-ppl/numpyro) and [oryx](https://github.com/jax-ml/oryx) backends. But other backends can be easily added via the [coix.register_backend](https://coix.readthedocs.io/en/latest/core.html#coix.core.register_backend) utility. + *This is not an officially supported Google product.* + +## Installation + +To install Coix, you can use pip: + +``` +pip install coix +``` + +or you can clone the repository: + +``` +git clone https://github.com/jax-ml/coix.git +cd coix +pip install -e .[dev,doc] +``` + +Many examples would run faster on accelerators. You can follow the [JAX installation](https://jax.readthedocs.io/en/latest/installation.html) instruction for how to install JAX with GPU or TPU support. + diff --git a/coix/core.py b/coix/core.py index f47a393..e34abed 100644 --- a/coix/core.py +++ b/coix/core.py @@ -137,10 +137,12 @@ def wrapped(*args, **kwargs): def empirical(out, trace, metrics): + """Creates an empirical program given a trace.""" return get_backend()["empirical"](out, trace, metrics) def suffix(p): + """Adds suffix `_PREV_` to variable names of `p`.""" fn = get_backend()["suffix"] if fn is not None: return fn(p) @@ -149,6 +151,7 @@ def suffix(p): def detach(p): + """Makes random variables in `p` become non-reparameterized.""" fn = get_backend()["detach"] if fn is not None: return fn(p) @@ -157,6 +160,7 @@ def detach(p): def stick_the_landing(p): + """Stops gradient of distributions' parameters before computing log prob.""" fn = get_backend()["stick_the_landing"] if fn is not None: return fn(p) @@ -165,6 +169,7 @@ def stick_the_landing(p): def prng_key(): + """Generates a random JAX PRNGKey.""" fn = get_backend()["prng_key"] if fn is not None: return fn() diff --git a/coix/numpyro.py b/coix/numpyro.py index ee96ad6..d8fcfe2 100644 --- a/coix/numpyro.py +++ b/coix/numpyro.py @@ -51,7 +51,12 @@ def wrapped(*args, **kwargs): if site["type"] == "sample": value = site["value"] log_prob = site["fn"].log_prob(value) - trace[name] = {"value": value, "log_prob": log_prob} + event_dim_holder = jnp.empty([1] * site["fn"].event_dim) + trace[name] = { + "value": value, + "log_prob": log_prob, + "_event_dim_holder": event_dim_holder, + } if site.get("is_observed", False): trace[name]["is_observed"] = True metrics = { @@ -83,7 +88,7 @@ def wrapped(*args, **kwargs): del args, kwargs for name, site in trace.items(): value, lp = site["value"], site["log_prob"] - event_dim = jnp.ndim(value) - jnp.ndim(lp) + event_dim = jnp.ndim(site["_event_dim_holder"]) obs = value if "is_observed" in site else None numpyro.sample(name, dist.Delta(value, lp, event_dim=event_dim), obs=obs) for name, value in metrics.items(): diff --git a/docs/_static/anneal.png b/docs/_static/anneal.png new file mode 100644 index 0000000..f18f290 Binary files /dev/null and b/docs/_static/anneal.png differ diff --git a/docs/_static/anneal_oryx.png b/docs/_static/anneal_oryx.png new file mode 100644 index 0000000..c37bf31 Binary files /dev/null and b/docs/_static/anneal_oryx.png differ diff --git a/docs/_static/bmnist.gif b/docs/_static/bmnist.gif new file mode 100644 index 0000000..2ea1481 Binary files /dev/null and b/docs/_static/bmnist.gif differ diff --git a/docs/_static/dmm.png b/docs/_static/dmm.png new file mode 100644 index 0000000..5162194 Binary files /dev/null and b/docs/_static/dmm.png differ diff --git a/docs/_static/dmm_oryx.png b/docs/_static/dmm_oryx.png new file mode 100644 index 0000000..3ff0ecc Binary files /dev/null and b/docs/_static/dmm_oryx.png differ diff --git a/docs/_static/gmm.png b/docs/_static/gmm.png new file mode 100644 index 0000000..25eef27 Binary files /dev/null and b/docs/_static/gmm.png differ diff --git a/docs/_static/gmm_oryx.png b/docs/_static/gmm_oryx.png new file mode 100644 index 0000000..fceb69e Binary files /dev/null and b/docs/_static/gmm_oryx.png differ diff --git a/docs/conf.py b/docs/conf.py index ec78c96..13777e2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -207,11 +207,13 @@ ): toctree_path = "notebooks/" if src_file.endswith("ipynb") else "examples/" filename = os.path.splitext(src_file.split("/")[-1])[0] - png_path = "_static/" + filename + ".png" + img_path = "_static/" + filename + ".png" # use Coix logo if not exist png file - if not os.path.exists(png_path): - png_path = "_static/coix_logo.png" - nbsphinx_thumbnails[toctree_path + filename] = png_path + if not os.path.exists(img_path): + img_path = "_static/" + filename + ".gif" + if not os.path.exists(img_path): + img_path = "_static/coix_logo.png" + nbsphinx_thumbnails[toctree_path + filename] = img_path # -- Options for HTML output ------------------------------------------------- diff --git a/docs/index.rst b/docs/index.rst index 96e086f..eed738c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,9 +32,9 @@ Coix Documentation examples/gmm examples/dmm examples/bmnist + examples/anneal_oryx examples/gmm_oryx examples/dmm_oryx - examples/anneal_oryx Indices and tables ================== diff --git a/examples/anneal.py b/examples/anneal.py index 08a3893..91e27a4 100644 --- a/examples/anneal.py +++ b/examples/anneal.py @@ -24,6 +24,9 @@ 1. Zimmermann, Heiko, et al. "Nested variational inference." NeuRIPS 2021. +.. image:: ../_static/anneal.png + :align: center + """ import argparse @@ -199,7 +202,7 @@ def eval_program(seed): plt.figure(figsize=(8, 8)) x = trace["x"]["value"].reshape((-1, 2)) - H, xedges, yedges = np.histogram2d(x[:, 0], x[:, 1], bins=100) + H, _, _ = np.histogram2d(x[:, 0], x[:, 1], bins=100) plt.imshow(H.T) plt.show() diff --git a/examples/anneal_oryx.py b/examples/anneal_oryx.py index 8803067..759b107 100644 --- a/examples/anneal_oryx.py +++ b/examples/anneal_oryx.py @@ -24,6 +24,9 @@ 1. Zimmermann, Heiko, et al. "Nested variational inference." NeuRIPS 2021. +.. image:: ../_static/anneal_oryx.png + :align: center + """ import argparse @@ -119,7 +122,7 @@ def __call__(self, x): def anneal_target(network, key, k=0): key_out, key = random.split(key) x = coryx.rv(dist.Normal(0, 5).expand([2]).mask(False), name="x")(key) - coix.factor(network.anneal_density(x, index=k), name="anneal_density") + coryx.factor(network.anneal_density(x, index=k), name="anneal_density") return key_out, {"x": x} @@ -192,7 +195,7 @@ def main(args): plt.figure(figsize=(8, 8)) x = trace["x"]["value"].reshape((-1, 2)) - H, xedges, yedges = np.histogram2d(x[:, 0], x[:, 1], bins=100) + H, _, _ = np.histogram2d(x[:, 0], x[:, 1], bins=100) plt.imshow(H.T) plt.show() diff --git a/examples/bmnist.py b/examples/bmnist.py index 255b5ad..520721d 100644 --- a/examples/bmnist.py +++ b/examples/bmnist.py @@ -26,25 +26,22 @@ 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural sufficient statistics. ICML 2020. -""" +.. image:: ../_static/bmnist.gif + :align: center -# TODO: Refactor using numpyro backend. The current code is likely not working yet. +""" import argparse from functools import partial -import sys import coix -import flax import flax.linen as nn import jax from jax import random import jax.numpy as jnp -from matplotlib.animation import FuncAnimation -from matplotlib.patches import Ellipse +import matplotlib.animation as animation from matplotlib.patches import Rectangle import matplotlib.pyplot as plt -import numpy as np import numpyro import numpyro.distributions as dist import optax @@ -54,40 +51,27 @@ # %% # First, let's load the moving mnist dataset. -batch_size = 5 -T = 10 # using 10 time steps for training and 20 time steps for testing - def load_dataset(*, is_training, batch_size): - ds, ds_info = tfds.load("moving_mnist:1.0.0", split="test", with_info=True) - ds = ds.cache().repeat() + ds = tfds.load("moving_mnist:1.0.0", split="test") + ds = ds.repeat() if is_training: ds = ds.shuffle(10 * batch_size, seed=0) - map_fn = lambda x: x["image_sequence"][..., :T, :, :, 0] / 255 + map_fn = lambda x: x["image_sequence"][..., :10, :, :, 0] / 255 else: map_fn = lambda x: x["image_sequence"][..., 0] / 255 ds = ds.batch(batch_size) ds = ds.map(map_fn) - return iter(tfds.as_numpy(ds)), ds_info + return iter(tfds.as_numpy(ds)) def get_digit_mean(): ds, ds_info = tfds.load("mnist:3.0.1", split="train", with_info=True) - ds = tfds.as_numpy(ds.batch(ds_info.splits["test"].num_examples)) + ds = tfds.as_numpy(ds.batch(ds_info.splits["train"].num_examples)) digit_mean = next(iter(ds))["image"].squeeze(-1).mean(axis=0) return digit_mean / 255 -train_ds, ds_info = load_dataset(is_training=True, batch_size=batch_size) -test_ds, _ = load_dataset(is_training=False, batch_size=1) -digit_mean = get_digit_mean() -frame_size = ds_info.features["image_sequence"].shape[-2] -frame_length = ds_info.features["image_sequence"]._length -print("Frame length: ", frame_length) -print("Frame size:", frame_size) -print("Digit shape:", digit_mean.shape) - - # %% # Next, we define the neural proposals for the Gibbs kernels and the neural # decoder for the generative model. @@ -110,34 +94,53 @@ def crop_frames(frames, z_where, digit_size=28): # frames: time.frame_size.frame_size # z_where: (digits).time.2 # out: (digits).time.digit_size.digit_size - crop_fn = partial(scale_and_translate, out_size=digit_size) - if z_where.ndim == 2: - return jax.vmap(crop_fn)(frames, z_where) - return jax.vmap(jax.vmap(crop_fn), in_axes=(None, 0))(frames, z_where) + if frames.ndim == 2 and z_where.ndim == 1: + return scale_and_translate(frames, z_where, out_size=digit_size) + elif frames.ndim == 3 and z_where.ndim == 2: + in_axes = (0, 0) + elif frames.ndim == 3 and z_where.ndim == 3: + in_axes = (None, 0) + elif frames.ndim == z_where.ndim: + in_axes = (0, 0) + elif frames.ndim > z_where.ndim: + in_axes = (0, None) + else: + in_axes = (None, 0) + return jax.vmap(partial(crop_frames, digit_size=digit_size), in_axes)( + frames, z_where + ) def embed_digits(digits, z_where, frame_size=64): # digits: (digits). .digit_size.digit_size # z_where: (digits).(time).2 # out: (digits).(time).frame_size.frame_size - embed_fn = partial(scale_and_translate, out_size=frame_size) - if digits.ndim == 2: - if z_where.ndim == 1: - return embed_fn(digits, z_where) - return jax.vmap(embed_fn, in_axes=(None, 0))(digits, z_where) - return jax.vmap(jax.vmap(embed_fn, in_axes=(None, 0)))(digits, z_where) + if digits.ndim == 2 and z_where.ndim == 1: + return scale_and_translate(digits, z_where, out_size=frame_size) + elif digits.ndim == 2 and z_where.ndim == 2: + in_axes = (None, 0) + elif digits.ndim >= z_where.ndim: + in_axes = (0, 0) + else: + in_axes = (None, 0) + return jax.vmap(partial(embed_digits, frame_size=frame_size), in_axes)( + digits, z_where + ) def conv2d(frames, digits): # frames: (time).frame_size.frame_size # digits: (digits). .digit_size.digit_size # out: (digits).(time).conv_size .conv_size - conv2d_fn = partial(jax.scipy.signal.convolve2d, mode="valid") - if frames.ndim == 2: - if digits.ndim == 2: - return conv2d_fn(frames, digits) - return jax.vmap(conv2d_fn, in_axes=(None, 0))(frames, digits) - return jax.vmap(conv2d_fn, in_axes=(0, None))(frames, digits) + if frames.ndim == 2 and digits.ndim == 2: + return jax.scipy.signal.convolve2d(frames, digits, mode="valid") + elif frames.ndim == digits.ndim: + in_axes = (0, 0) + elif frames.ndim > digits.ndim: + in_axes = (0, None) + else: + in_axes = (None, 0) + return jax.vmap(conv2d, in_axes=in_axes)(frames, digits) class EncoderWhat(nn.Module): @@ -150,7 +153,7 @@ def __call__(self, digits): x = nn.Dense(200)(x) x = nn.relu(x) - x = x.sum(-2) # sum across time + x = x.sum(-2) # sum/mean across time loc_raw = nn.Dense(10)(x) scale_raw = 0.5 * nn.Dense(10)(x) return loc_raw, jnp.exp(scale_raw) @@ -210,163 +213,185 @@ def __call__(self, frames): # %% # Then, we define the target and kernels as in Section 6.4. -test_key = random.PRNGKey(0) -test_data = jnp.zeros((frame_length,) + (frame_size, frame_size)) -bmnist_net = BMNISTAutoEncoder(digit_mean=digit_mean, frame_size=frame_size) -init_params = bmnist_net.init(test_key, test_data) -test_network = coix.util.BindModule(bmnist_net, init_params) - -def bmnist_target(network, key, inputs, D=2, T=10): - key_out, key_what, key_where = random.split(key, 3) - - z_what = coix.rv(dist.Normal(0, 1).expand([D, 10]), name="z_what")(key_what) +def bmnist_target(network, inputs, D=2, T=10): + z_what = numpyro.sample( + "z_what", dist.Normal(0, 1).expand([D, 10]).to_event() + ) digits = network.decode_what(z_what) # can cache this z_where = [] + # p = [] for d in range(D): z_where_d = [] z_where_d_t = jnp.zeros(2) for t in range(T): scale = 1 if t == 0 else 0.1 - key_d_t = random.fold_in(key_where, d * T + t) - name = f"z_where_{d}_{t}" - z_where_d_t = coix.rv(dist.Normal(z_where_d_t, scale), name=name)(key_d_t) + z_where_d_t = numpyro.sample( + f"z_where_{d}_{t}", dist.Normal(z_where_d_t, scale).to_event(1) + ) z_where_d.append(z_where_d_t) - z_where.append(jnp.stack(z_where_d, -2)) + z_where_d = jnp.stack(z_where_d, -2) + z_where.append(z_where_d) z_where = jnp.stack(z_where, -3) p = embed_digits(digits, z_where, network.frame_size) p = dist.util.clamp_probs(p.sum(-4)) # sum across digits - frames = coix.rv(dist.Bernoulli(p), obs=inputs, name="frames") + frames = numpyro.sample("frames", dist.Bernoulli(p).to_event(3), obs=inputs) out = { "frames": frames, "frames_recon": p, "z_what": z_what, "digits": jax.lax.stop_gradient(digits), - **{f"z_where_{t}": z_where[:, t, :] for t in range(T)}, + **{f"z_where_{t}": z_where[..., t, :] for t in range(T)}, } - return key_out, out + return (out,) -_, p_out = bmnist_target(test_network, test_key, test_data, T=frame_length) - - -def kernel_where(network, key, inputs, D=2, t=0): +def kernel_where(network, inputs, D=2, t=0): if not isinstance(inputs, dict): inputs = { "frames": inputs, "digits": jnp.repeat(jnp.expand_dims(network.digit_mean, -3), D, -3), } - key_out, key_where = random.split(key) - frame = inputs["frames"][t, :, :] + frame = inputs["frames"][..., t, :, :] z_where_t = [] - key_where = random.split(key_where, D) - for d, key_where_d in enumerate(key_where): - digit = inputs["digits"][d, :, :] + for d in range(D): + digit = inputs["digits"][..., d, :, :] x_conv = conv2d(frame, digit) loc, scale = network.encode_where(x_conv) - name = f"z_where_{d}_{t}" - z_where_d_t = coix.rv(dist.Normal(loc, scale), name=name)(key_where_d) + z_where_d_t = numpyro.sample( + f"z_where_{d}_{t}", dist.Normal(loc, scale).to_event(1) + ) z_where_t.append(z_where_d_t) frame_recon = embed_digits(digit, z_where_d_t, network.frame_size) frame = frame - frame_recon z_where_t = jnp.stack(z_where_t, -2) out = {**inputs, **{f"z_where_{t}": z_where_t}} - return key_out, out + return (out,) -_, k1_initial_out = kernel_where(test_network, test_key, test_data) -_, k1_out = kernel_where(test_network, test_key, p_out) - - -def kernel_what(network, key, inputs, T=10): - key_out, key_what = random.split(key) - +def kernel_what(network, inputs, T=10): z_where = jnp.stack([inputs[f"z_where_{t}"] for t in range(T)], -2) digits = crop_frames(inputs["frames"], z_where, 28) loc, scale = network.encode_what(digits) - z_what = coix.rv(dist.Normal(loc, scale), name="z_what")(key_what) + z_what = numpyro.sample("z_what", dist.Normal(loc, scale).to_event(2)) out = {**inputs, **{"z_what": z_what}} - return key_out, out - - -_, k2_out = kernel_what(test_network, test_key, p_out, T=frame_length) + return (out,) # %% # Finally, we create the dmm inference program, define the loss function, # run the training loop, and plot the results. -num_sweeps = 5 -num_particles = 10 - -def make_bmnist(params, T=10): +def make_bmnist(params, bmnist_net, T=10, num_sweeps=5, num_particles=10): network = coix.util.BindModule(bmnist_net, params) # Add particle dimension and construct a program. - target = jax.vmap(partial(bmnist_target, network, D=2, T=T)) + make_particle_plate = lambda: numpyro.plate("particle", num_particles, dim=-2) + target = make_particle_plate()(partial(bmnist_target, network, D=2, T=T)) kernels = [] for t in range(T): - kernels.append(jax.vmap(partial(kernel_where, network, D=2, t=t))) - kernels.append(jax.vmap(partial(kernel_what, network, T=T))) + kernels.append( + make_particle_plate()(partial(kernel_where, network, D=2, t=t)) + ) + kernels.append(make_particle_plate()(partial(kernel_what, network, T=T))) program = coix.algo.apgs(target, kernels, num_sweeps=num_sweeps) return program -def loss_fn(params, key, batch): +def loss_fn(params, key, batch, bmnist_net, num_sweeps, num_particles): # Prepare data for the program. shuffle_rng, rng_key = random.split(key) batch = random.permutation(shuffle_rng, batch, axis=1) - batch_rng = random.split(rng_key, batch.shape[0]) - batch = jnp.repeat(batch[:, None], num_particles, axis=1) - rng_keys = jax.vmap(partial(random.split, num=num_particles))(batch_rng) + T = batch.shape[-3] # Run the program and get metrics. - program = make_bmnist(params) - _, _, metrics = jax.vmap(coix.traced_evaluate(program))(rng_keys, batch) - metrics = jax.tree_util.tree_map(jnp.mean, metrics) # mean across batch + program = make_bmnist(params, bmnist_net, T, num_sweeps, num_particles) + _, _, metrics = coix.traced_evaluate(program, seed=rng_key)(batch) + for metric_name in ["log_Z", "log_density", "loss"]: + metrics[metric_name] = metrics[metric_name] / batch.shape[0] return metrics["loss"], metrics -lr = 1e-4 -num_steps = 200000 -bmnist_params, _ = coix.util.train( - loss_fn, init_params, optax.adam(lr), num_steps, train_ds -) - - -program = make_bmnist(bmnist_params, T=frame_length) -batch = jnp.repeat(next(test_ds)[:1], num_particles, axis=0) -rng_keys = random.split(random.PRNGKey(1), num_particles) -_, out = program(rng_keys, batch) -batch.shape, out["frames_recon"].shape - - -prop_cycle = plt.rcParams["axes.prop_cycle"] -colors = prop_cycle.by_key()["color"] -fig, axes = plt.subplots(1, 2, figsize=(12, 6)) - +def main(args): + lr = args.learning_rate + num_steps = args.num_steps + batch_size = args.batch_size + num_sweeps = args.num_sweeps + num_particles = args.num_particles + + train_ds = load_dataset(is_training=True, batch_size=batch_size) + test_ds = load_dataset(is_training=False, batch_size=1) + digit_mean = get_digit_mean() + + test_data = next(test_ds) + frame_size = test_data.shape[-1] + bmnist_net = BMNISTAutoEncoder(digit_mean=digit_mean, frame_size=frame_size) + init_params = bmnist_net.init(jax.random.PRNGKey(0), test_data[0]) + bmnist_params, _ = coix.util.train( + partial( + loss_fn, + bmnist_net=bmnist_net, + num_sweeps=num_sweeps, + num_particles=num_particles, + ), + init_params, + optax.adam(lr), + num_steps, + train_ds, + ) -def animate(i): - n = 2 - axes[0].cla() - axes[0].imshow(batch[n][i]) - axes[1].cla() - axes[1].imshow(out["frames_recon"][n, i]) - for d in range(2): - where = 0.5 * (out[f"z_where_{i}"][n, d] + 1) * (frame_size - 28) - 0.5 - color = colors[d] - axes[0].add_patch( - Rectangle(where, 28, 28, edgecolor=color, lw=3, fill=False) - ) + T_test = test_data.shape[-3] + program = make_bmnist( + bmnist_params, bmnist_net, T_test, num_sweeps, num_particles + ) + out, _, _ = coix.traced_evaluate(program, seed=jax.random.PRNGKey(1))( + test_data + ) + out = out[0] + + prop_cycle = plt.rcParams["axes.prop_cycle"] + colors = prop_cycle.by_key()["color"] + fig, axes = plt.subplots(1, 2, figsize=(12, 6)) + + def animate(i): + axes[0].cla() + axes[0].imshow(test_data[0, i]) + axes[1].cla() + axes[1].imshow(out["frames_recon"][0, 0, i]) + for d in range(2): + where = 0.5 * (out[f"z_where_{i}"][0, 0, d] + 1) * (frame_size - 28) - 0.5 + color = colors[d] + axes[0].add_patch( + Rectangle(where, 28, 28, edgecolor=color, lw=3, fill=False) + ) + + plt.rc("animation", html="jshtml") + plt.tight_layout() + ani = animation.FuncAnimation(fig, animate, frames=range(20), interval=300) + writer = animation.PillowWriter(fps=15) + ani.save("bmnist.gif", writer=writer) + plt.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Annealing example") + parser.add_argument("--batch-size", nargs="?", default=5, type=int) + parser.add_argument("--num-sweeps", nargs="?", default=5, type=int) + parser.add_argument("--num_particles", nargs="?", default=10, type=int) + parser.add_argument("--learning-rate", nargs="?", default=1e-4, type=float) + parser.add_argument("--num-steps", nargs="?", default=20000, type=int) + parser.add_argument( + "--device", default="gpu", type=str, help='use "cpu" or "gpu".' + ) + args = parser.parse_args() + tf.config.experimental.set_visible_devices([], "GPU") # Disable GPU for TF. + numpyro.set_platform(args.device) -plt.rc("animation", html="jshtml") -anim = FuncAnimation(fig, animate, frames=range(20), interval=300) -plt.close() -anim + main(args) diff --git a/examples/dmm.py b/examples/dmm.py index e1e3c7d..8300b55 100644 --- a/examples/dmm.py +++ b/examples/dmm.py @@ -25,6 +25,9 @@ 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural sufficient statistics. ICML 2020. +.. image:: ../_static/dmm.png + :align: center + """ import argparse @@ -39,9 +42,9 @@ import numpy as np import numpyro import numpyro.distributions as dist +from numpyro.ops.indexing import Vindex import optax import tensorflow as tf -import tensorflow_datasets as tfds # %% # First, let's simulate a synthetic dataset of 2D ring-shaped mixtures. @@ -63,17 +66,22 @@ def simulate_rings(num_instances=1, N=200, seed=0): return np.take_along_axis(x, shuffle_idx, axis=1) -def load_dataset(split, *, is_training, batch_size): - num_data = 20000 if is_training else batch_size - num_points = 200 if is_training else 600 - seed = 0 if is_training else 1 +def load_dataset(split, *, batch_size): + if split == "train": + num_data = 20000 + num_points = 200 + seed = 0 + else: + num_data = batch_size + num_points = 600 + seed = 1 data = simulate_rings(num_data, num_points, seed=seed) ds = tf.data.Dataset.from_tensor_slices(data) - ds = ds.cache().repeat() - if is_training: + ds = ds.repeat() + if split == "train": ds = ds.shuffle(10 * batch_size, seed=0) ds = ds.batch(batch_size) - return iter(tfds.as_numpy(ds)) + return ds.as_numpy_iterator() # %% @@ -112,7 +120,7 @@ class EncoderC(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(32)(x) - x = nn.tanh(x) + x = nn.relu(x) # nn.tanh(x) logits = nn.Dense(1)(x).squeeze(-1) return logits + jnp.log(jnp.ones(4) / 4) @@ -137,7 +145,8 @@ def __call__(self, x): x = nn.tanh(x) x = nn.Dense(2)(x) angle = x / jnp.linalg.norm(x, axis=-1, keepdims=True) - return angle + radius = 1.0 # self.param("radius", nn.initializers.ones, (1,)) + return radius * angle class DMMAutoEncoder(nn.Module): @@ -153,8 +162,7 @@ def __call__(self, x): # N x D # Heuristic procedure to setup initial parameters. mu, _ = self.encode_initial_mu(x) # M x D - concatenate_fn = lambda x, m: jnp.concatenate([x, m], axis=-1) - xmu = jax.vmap(jax.vmap(concatenate_fn, (None, 0)), (0, None))(x, mu) + xmu = jnp.expand_dims(x, -2) - mu logits = self.encode_c(xmu) # N x M c = jnp.argmax(logits, -1) # N @@ -174,52 +182,48 @@ def __call__(self, x): # N x D # Then, we define the target and kernels as in Section 6.3. -def dmm_target(network, key, inputs): - key_out, key_mu, key_c, key_h = random.split(key, 4) - N = inputs.shape[-2] - - mu = coix.rv(dist.Normal(0, 10).expand([4, 2]), name="mu")(key_mu) - c = coix.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c) - h = coix.rv(dist.Beta(1, 1).expand([N]), name="h")(key_h) - x_recon = mu[c] + network.decode_h(h) - x = coix.rv(dist.Normal(x_recon, 0.1), obs=inputs, name="x") +def dmm_target(network, inputs): + mu = numpyro.sample("mu", dist.Normal(0, 10).expand([4, 2]).to_event()) + with numpyro.plate("N", inputs.shape[-2], dim=-1): + c = numpyro.sample("c", dist.Categorical(probs=jnp.ones(4) / 4)) + h = numpyro.sample("h", dist.Beta(1, 1)) + x_recon = network.decode_h(h) + Vindex(mu)[..., c, :] + x = numpyro.sample("x", dist.Normal(x_recon, 0.1).to_event(1), obs=inputs) out = {"mu": mu, "c": c, "h": h, "x_recon": x_recon, "x": x} - return key_out, out + return (out,) -def dmm_kernel_mu(network, key, inputs): +def dmm_kernel_mu(network, inputs): if not isinstance(inputs, dict): inputs = {"x": inputs} - key_out, key_mu = random.split(key) if "c" in inputs: + x = jnp.broadcast_to(inputs["x"], inputs["h"].shape + (2,)) c = jax.nn.one_hot(inputs["c"], 4) h = jnp.expand_dims(inputs["h"], -1) - xch = jnp.concatenate([inputs["x"], c, h], -1) + xch = jnp.concatenate([x, c, h], -1) loc, scale = network.encode_mu(xch) else: loc, scale = network.encode_initial_mu(inputs["x"]) - mu = coix.rv(dist.Normal(loc, scale), name="mu")(key_mu) + loc, scale = jnp.expand_dims(loc, -3), jnp.expand_dims(scale, -3) + mu = numpyro.sample("mu", dist.Normal(loc, scale).to_event(2)) out = {**inputs, **{"mu": mu}} - return key_out, out + return (out,) -def dmm_kernel_c_h(network, key, inputs): - key_out, key_c, key_h = random.split(key, 3) - - concatenate_fn = lambda x, m: jnp.concatenate([x, m], axis=-1) - xmu = jax.vmap(jax.vmap(concatenate_fn, (None, 0)), (0, None))( - inputs["x"], inputs["mu"] - ) +def dmm_kernel_c_h(network, inputs): + x, mu = inputs["x"], inputs["mu"] + xmu = jnp.expand_dims(x, -2) - mu logits = network.encode_c(xmu) - c = coix.rv(dist.Categorical(logits=logits), name="c")(key_c) - alpha, beta = network.encode_h(inputs["x"] - inputs["mu"][c]) - h = coix.rv(dist.Beta(alpha, beta), name="h")(key_h) + with numpyro.plate("N", logits.shape[-2], dim=-1): + c = numpyro.sample("c", dist.Categorical(logits=logits)) + alpha, beta = network.encode_h(inputs["x"] - Vindex(mu)[..., c, :]) + h = numpyro.sample("h", dist.Beta(alpha, beta)) out = {**inputs, **{"c": c, "h": h}} - return key_out, out + return (out,) # %% @@ -227,14 +231,14 @@ def dmm_kernel_c_h(network, key, inputs): # run the training loop, and plot the results. -def make_dmm(params, num_sweeps): +def make_dmm(params, num_sweeps=5, num_particles=10): network = coix.util.BindModule(DMMAutoEncoder(), params) # Add particle dimension and construct a program. - target = jax.vmap(partial(dmm_target, network)) - kernels = [ - jax.vmap(partial(dmm_kernel_mu, network)), - jax.vmap(partial(dmm_kernel_c_h, network)), - ] + make_particle_plate = lambda: numpyro.plate("particle", num_particles, dim=-3) + target = make_particle_plate()(partial(dmm_target, network)) + kernel_mu = make_particle_plate()(partial(dmm_kernel_mu, network)) + kernel_c_h = make_particle_plate()(partial(dmm_kernel_c_h, network)) + kernels = [kernel_mu, kernel_c_h] program = coix.algo.apgs(target, kernels, num_sweeps=num_sweeps) return program @@ -243,16 +247,12 @@ def loss_fn(params, key, batch, num_sweeps, num_particles): # Prepare data for the program. shuffle_rng, rng_key = random.split(key) batch = random.permutation(shuffle_rng, batch, axis=1) - batch_rng = random.split(rng_key, batch.shape[0]) - batch = jnp.repeat(batch[:, None], num_particles, axis=1) - rng_keys = jax.vmap(partial(random.split, num=num_particles))(batch_rng) # Run the program and get metrics. - program = make_dmm(params, num_sweeps) - _, _, metrics = jax.vmap(coix.traced_evaluate(program))(rng_keys, batch) - metrics = jax.tree_util.tree_map( - partial(jnp.mean, axis=0), metrics - ) # mean across batch + program = make_dmm(params, num_sweeps, num_particles) + _, _, metrics = coix.traced_evaluate(program, seed=rng_key)(batch) + for metric_name in ["log_Z", "log_density", "loss"]: + metrics[metric_name] = metrics[metric_name] / batch.shape[0] return metrics["loss"], metrics @@ -263,8 +263,8 @@ def main(args): num_sweeps = args.num_sweeps num_particles = args.num_particles - train_ds = load_dataset("train", is_training=True, batch_size=batch_size) - test_ds = load_dataset("test", is_training=False, batch_size=batch_size) + train_ds = load_dataset("train", batch_size=batch_size) + test_ds = load_dataset("test", batch_size=batch_size) init_params = DMMAutoEncoder().init( jax.random.PRNGKey(0), jnp.zeros((200, 2)) @@ -277,28 +277,24 @@ def main(args): train_ds, ) - program = make_dmm(dmm_params, num_sweeps) - batch = jnp.repeat(next(test_ds)[:, None], num_particles, axis=1) - rng_keys = jax.vmap(partial(random.split, num=num_particles))( - random.split(jax.random.PRNGKey(1), batch.shape[0]) - ) - _, out = jax.vmap(program)(rng_keys, batch) - batch.shape, out["x_recon"].shape + program = make_dmm(dmm_params, num_sweeps, num_particles) + batch = next(test_ds) + out, _, _ = coix.traced_evaluate(program, seed=jax.random.PRNGKey(1))(batch) + out = out[0] - fig, axes = plt.subplots(2, 3, figsize=(15, 10)) + _, axes = plt.subplots(2, 3, figsize=(15, 10)) for i in range(3): - n = i - axes[0][i].scatter(out["x"][n, 0, :, 0], out["x"][n, 0, :, 1], marker=".") + axes[0][i].scatter(out["x"][i, :, 0], out["x"][i, :, 1], marker=".") axes[1][i].scatter( - out["x_recon"][n, 0, :, 0], - out["x_recon"][n, 0, :, 1], - c=out["c"][n, 0], + out["x_recon"][0, i, :, 0], + out["x_recon"][0, i, :, 1], + c=out["c"][0, i], cmap="Accent", marker=".", ) axes[1][i].scatter( - out["mu"][n, 0, :, 0], - out["mu"][n, 0, :, 1], + out["mu"][0, i, 0, :, 0], + out["mu"][0, i, 0, :, 1], c=range(4), marker="x", cmap="Accent", @@ -311,8 +307,8 @@ def main(args): parser.add_argument("--batch-size", nargs="?", default=20, type=int) parser.add_argument("--num-sweeps", nargs="?", default=5, type=int) parser.add_argument("--num_particles", nargs="?", default=10, type=int) - parser.add_argument("--learning-rate", nargs="?", default=1e-4, type=float) - parser.add_argument("--num-steps", nargs="?", default=300000, type=int) + parser.add_argument("--learning-rate", nargs="?", default=1e-3, type=float) + parser.add_argument("--num-steps", nargs="?", default=30000, type=int) parser.add_argument( "--device", default="gpu", type=str, help='use "cpu" or "gpu".' ) @@ -320,6 +316,5 @@ def main(args): tf.config.experimental.set_visible_devices([], "GPU") # Disable GPU for TF. numpyro.set_platform(args.device) - coix.set_backend("coix.oryx") main(args) diff --git a/examples/dmm_oryx.py b/examples/dmm_oryx.py index 7ba11fa..7e925e3 100644 --- a/examples/dmm_oryx.py +++ b/examples/dmm_oryx.py @@ -25,8 +25,14 @@ 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural sufficient statistics. ICML 2020. +.. image:: ../_static/dmm_oryx.png + :align: center + """ +# %% +# **Note:** The metrics seem to be incorrect in this example. + import argparse from functools import partial @@ -42,7 +48,6 @@ import numpyro.distributions as dist import optax import tensorflow as tf -import tensorflow_datasets as tfds # %% # First, let's simulate a synthetic dataset of 2D ring-shaped mixtures. @@ -65,16 +70,21 @@ def simulate_rings(num_instances=1, N=200, seed=0): def load_dataset(split, *, is_training, batch_size): - num_data = 20000 if is_training else batch_size - num_points = 200 if is_training else 600 - seed = 0 if is_training else 1 + if split == "train": + num_data = 20000 + num_points = 200 + seed = 0 + else: + num_data = batch_size + num_points = 600 + seed = 1 data = simulate_rings(num_data, num_points, seed=seed) ds = tf.data.Dataset.from_tensor_slices(data) - ds = ds.cache().repeat() - if is_training: + ds = ds.repeat() + if split == "train": ds = ds.shuffle(10 * batch_size, seed=0) ds = ds.batch(batch_size) - return iter(tfds.as_numpy(ds)) + return ds.as_numpy_iterator() # %% @@ -113,7 +123,7 @@ class EncoderC(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense(32)(x) - x = nn.tanh(x) + x = nn.relu(x) # nn.tanh(x) logits = nn.Dense(1)(x).squeeze(-1) return logits + jnp.log(jnp.ones(4) / 4) @@ -138,7 +148,8 @@ def __call__(self, x): x = nn.tanh(x) x = nn.Dense(2)(x) angle = x / jnp.linalg.norm(x, axis=-1, keepdims=True) - return angle + radius = 1.0 # self.param("radius", nn.initializers.ones, (1,)) + return radius * angle class DMMAutoEncoder(nn.Module): @@ -154,8 +165,7 @@ def __call__(self, x): # N x D # Heuristic procedure to setup initial parameters. mu, _ = self.encode_initial_mu(x) # M x D - concatenate_fn = lambda x, m: jnp.concatenate([x, m], axis=-1) - xmu = jax.vmap(jax.vmap(concatenate_fn, (None, 0)), (0, None))(x, mu) + xmu = jnp.expand_dims(x, -2) - mu logits = self.encode_c(xmu) # N x M c = jnp.argmax(logits, -1) # N @@ -195,9 +205,10 @@ def dmm_kernel_mu(network, key, inputs): key_out, key_mu = random.split(key) if "c" in inputs: + x = inputs["x"] c = jax.nn.one_hot(inputs["c"], 4) h = jnp.expand_dims(inputs["h"], -1) - xch = jnp.concatenate([inputs["x"], c, h], -1) + xch = jnp.concatenate([x, c, h], -1) loc, scale = network.encode_mu(xch) else: loc, scale = network.encode_initial_mu(inputs["x"]) @@ -210,20 +221,21 @@ def dmm_kernel_mu(network, key, inputs): def dmm_kernel_c_h(network, key, inputs): key_out, key_c, key_h = random.split(key, 3) - concatenate_fn = lambda x, m: jnp.concatenate([x, m], axis=-1) - xmu = jax.vmap(jax.vmap(concatenate_fn, (None, 0)), (0, None))( - inputs["x"], inputs["mu"] - ) + x, mu = inputs["x"], inputs["mu"] + xmu = jnp.expand_dims(x, -2) - mu logits = network.encode_c(xmu) c = coryx.rv(dist.Categorical(logits=logits), name="c")(key_c) - alpha, beta = network.encode_h(inputs["x"] - inputs["mu"][c]) + alpha, beta = network.encode_h(x - mu[c]) h = coryx.rv(dist.Beta(alpha, beta), name="h")(key_h) out = {**inputs, **{"c": c, "h": h}} return key_out, out -### Train +# %% +# Finally, we create the dmm inference program, define the loss function, +# run the training loop, and plot the results. Note that we are using +# 10x less steps than the paper. def make_dmm(params, num_sweeps): @@ -282,7 +294,6 @@ def main(args): random.split(jax.random.PRNGKey(1), batch.shape[0]) ) _, out = jax.vmap(program)(rng_keys, batch) - batch.shape, out["x_recon"].shape fig, axes = plt.subplots(2, 3, figsize=(15, 10)) for i in range(3): @@ -308,10 +319,10 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Annealing example") parser.add_argument("--batch-size", nargs="?", default=20, type=int) - parser.add_argument("--num-sweeps", nargs="?", default=5, type=int) + parser.add_argument("--num-sweeps", nargs="?", default=8, type=int) parser.add_argument("--num_particles", nargs="?", default=10, type=int) - parser.add_argument("--learning-rate", nargs="?", default=1e-4, type=float) - parser.add_argument("--num-steps", nargs="?", default=300000, type=int) + parser.add_argument("--learning-rate", nargs="?", default=1e-3, type=float) + parser.add_argument("--num-steps", nargs="?", default=30000, type=int) parser.add_argument( "--device", default="gpu", type=str, help='use "cpu" or "gpu".' ) diff --git a/examples/gmm.py b/examples/gmm.py index bda6c32..d57a007 100644 --- a/examples/gmm.py +++ b/examples/gmm.py @@ -25,6 +25,9 @@ 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural sufficient statistics. ICML 2020. +.. image:: ../_static/gmm_oryx.png + :align: center + """ import argparse @@ -40,9 +43,9 @@ import numpy as np import numpyro import numpyro.distributions as dist +from numpyro.ops.indexing import Vindex import optax import tensorflow as tf -import tensorflow_datasets as tfds # %% # First, let's simulate a synthetic dataset of 2D Gaussian mixtures. @@ -61,20 +64,25 @@ def simulate_clusters(num_instances=1, N=60, seed=0): return x, c -def load_dataset(split, *, is_training, batch_size): - num_data = 20000 if is_training else batch_size - num_points = 60 if is_training else 100 - seed = 0 if is_training else 1 +def load_dataset(split, *, batch_size): + if split == "train": + num_data = 20000 + num_points = 60 + seed = 0 + else: + num_data = batch_size + num_points = 100 + seed = 1 data, label = simulate_clusters(num_data, num_points, seed=seed) - if is_training: + if split == "train": ds = tf.data.Dataset.from_tensor_slices(data) - ds = ds.cache().repeat() + ds = ds.repeat() ds = ds.shuffle(10 * batch_size, seed=0) else: ds = tf.data.Dataset.from_tensor_slices((data, label)) - ds = ds.cache().repeat() + ds = ds.repeat() ds = ds.batch(batch_size) - return iter(tfds.as_numpy(ds)) + return ds.as_numpy_iterator() # %% @@ -116,6 +124,12 @@ def __call__(self, x): return logits + jnp.log(jnp.ones(3) / 3) +def broadcast_concatenate(*xs): + shape = jnp.broadcast_shapes(*[x.shape[:-1] for x in xs]) + xs = [jnp.broadcast_to(x, shape + x.shape[-1:]) for x in xs] + return jnp.concatenate(xs, -1) + + class GMMEncoder(nn.Module): def setup(self): @@ -128,12 +142,7 @@ def __call__(self, x): # N x D alpha, beta, mean, _ = self.encode_initial_mean_tau(x) # M x D tau = alpha / beta # M x D - concatenate_fn = lambda x, m, t: jnp.concatenate( - [x, m, t], axis=-1 - ) # N x M x 3D - xmt = jax.vmap( - jax.vmap(concatenate_fn, in_axes=(None, 0, 0)), in_axes=(0, None, None) - )(x, mean, tau) + xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau) logits = self.encode_c(xmt) # N x D c = jnp.argmax(logits, -1) # N @@ -145,49 +154,53 @@ def __call__(self, x): # N x D # Then, we define the target and kernels as in Section 6.2. -def gmm_target(network, key, inputs): - key_out, key_mean, key_tau, key_c = random.split(key, 4) - N = inputs.shape[-2] - - tau = coix.rv(dist.Gamma(2, 2).expand([3, 2]), name="tau")(key_tau) - mean = coix.rv(dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)), name="mean")(key_mean) - c = coix.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c) - x = coix.rv(dist.Normal(mean[c], 1 / jnp.sqrt(tau[c])), obs=inputs, name="x") +def gmm_target(inputs): + tau = numpyro.sample("tau", dist.Gamma(2, 2).expand([3, 2]).to_event()) + mean = numpyro.sample( + "mean", dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)).to_event() + ) + with numpyro.plate("N", inputs.shape[-2], dim=-1): + c = numpyro.sample("c", dist.Categorical(probs=jnp.ones(4) / 4)) + loc = Vindex(mean)[..., c, :] + scale = 1 / jnp.sqrt(Vindex(tau)[..., c, :]) + x = numpyro.sample("x", dist.Normal(loc, scale).to_event(1), obs=inputs) out = {"mean": mean, "tau": tau, "c": c, "x": x} - return key_out, out + return (out,) -def gmm_kernel_mean_tau(network, key, inputs): +def gmm_kernel_mean_tau(network, inputs): if not isinstance(inputs, dict): inputs = {"x": inputs} - key_out, key_mean, key_tau = random.split(key, 3) if "c" in inputs: + x = inputs["x"] c = jax.nn.one_hot(inputs["c"], 3) - xc = jnp.concatenate([inputs["x"], c], -1) + xc = broadcast_concatenate(x, c) alpha, beta, mu, nu = network.encode_mean_tau(xc) else: alpha, beta, mu, nu = network.encode_initial_mean_tau(inputs["x"]) - tau = coix.rv(dist.Gamma(alpha, beta), name="tau")(key_tau) - mean = coix.rv(dist.Normal(mu, 1 / jnp.sqrt(tau * nu)), name="mean")(key_mean) + alpha, beta, mu, nu = jax.tree_util.tree_map( + lambda x: jnp.expand_dims(x, -3), (alpha, beta, mu, nu) + ) + tau = numpyro.sample("tau", dist.Gamma(alpha, beta).to_event(2)) + mean = numpyro.sample( + "mean", dist.Normal(mu, 1 / jnp.sqrt(tau * nu)).to_event(2) + ) out = {**inputs, **{"mean": mean, "tau": tau}} - return key_out, out + return (out,) -def gmm_kernel_c(network, key, inputs): - key_out, key_c = random.split(key, 2) - - concatenate_fn = lambda x, m, t: jnp.concatenate([x, m, t], axis=-1) - xmt = jax.vmap( - jax.vmap(concatenate_fn, in_axes=(None, 0, 0)), in_axes=(0, None, None) - )(inputs["x"], inputs["mean"], inputs["tau"]) +def gmm_kernel_c(network, inputs): + x, mean, tau = inputs["x"], inputs["mean"], inputs["tau"] + xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau) logits = network.encode_c(xmt) - c = coix.rv(dist.Categorical(logits=logits), name="c")(key_c) + with numpyro.plate("N", logits.shape[-2], dim=-1): + c = numpyro.sample("c", dist.Categorical(logits=logits)) out = {**inputs, **{"c": c}} - return key_out, out + return (out,) # %% @@ -195,14 +208,14 @@ def gmm_kernel_c(network, key, inputs): # run the training loop, and plot the results. -def make_gmm(params, num_sweeps): +def make_gmm(params, num_sweeps, num_particles): network = coix.util.BindModule(GMMEncoder(), params) # Add particle dimension and construct a program. - target = jax.vmap(partial(gmm_target, network)) - kernels = [ - jax.vmap(partial(gmm_kernel_mean_tau, network)), - jax.vmap(partial(gmm_kernel_c, network)), - ] + make_particle_plate = lambda: numpyro.plate("particle", num_particles, dim=-3) + target = make_particle_plate()(gmm_target) + kernel_mean_tau = make_particle_plate()(partial(gmm_kernel_mean_tau, network)) + kernel_c = make_particle_plate()(partial(gmm_kernel_c, network)) + kernels = [kernel_mean_tau, kernel_c] program = coix.algo.apgs(target, kernels, num_sweeps=num_sweeps) return program @@ -211,16 +224,12 @@ def loss_fn(params, key, batch, num_sweeps, num_particles): # Prepare data for the program. shuffle_rng, rng_key = random.split(key) batch = random.permutation(shuffle_rng, batch, axis=1) - batch_rng = random.split(rng_key, batch.shape[0]) - batch = jnp.repeat(batch[:, None], num_particles, axis=1) - rng_keys = jax.vmap(partial(random.split, num=num_particles))(batch_rng) # Run the program and get metrics. - program = make_gmm(params, num_sweeps) - _, _, metrics = jax.vmap(coix.traced_evaluate(program))(rng_keys, batch) - metrics = jax.tree_util.tree_map( - partial(jnp.mean, axis=0), metrics - ) # mean across batch + program = make_gmm(params, num_sweeps, num_particles) + _, _, metrics = coix.traced_evaluate(program, seed=rng_key)(batch) + for metric_name in ["log_Z", "log_density", "loss"]: + metrics[metric_name] = metrics[metric_name] / batch.shape[0] return metrics["loss"], metrics @@ -231,8 +240,8 @@ def main(args): num_sweeps = args.num_sweeps num_particles = args.num_particles - train_ds = load_dataset("train", is_training=True, batch_size=batch_size) - test_ds = load_dataset("test", is_training=False, batch_size=batch_size) + train_ds = load_dataset("train", batch_size=batch_size) + test_ds = load_dataset("test", batch_size=batch_size) init_params = GMMEncoder().init(jax.random.PRNGKey(0), jnp.zeros((60, 2))) gmm_params, _ = coix.util.train( @@ -243,32 +252,28 @@ def main(args): train_ds, ) - program = make_gmm(gmm_params, num_sweeps) + program = make_gmm(gmm_params, num_sweeps, num_particles) batch, label = next(test_ds) - batch = jnp.repeat(batch[:, None], num_particles, axis=1) - rng_keys = jax.vmap(partial(random.split, num=num_particles))( - random.split(jax.random.PRNGKey(1), batch.shape[0]) - ) - _, out = jax.vmap(program)(rng_keys, batch) - - fig, axes = plt.subplots(1, 3, figsize=(15, 5)) - for i in range(3): - n = i - axes[i].scatter( - batch[n, 0, :, 0], - batch[n, 0, :, 1], + out, _, _ = coix.traced_evaluate(program, seed=jax.random.PRNGKey(1))(batch) + out = out[0] + + _, axes = plt.subplots(2, 3, figsize=(15, 10)) + for i in range(6): + axes[i // 3][i % 3].scatter( + batch[i, :, 0], + batch[i, :, 1], marker=".", - color=np.array(["c", "m", "y"])[label[n]], + color=np.array(["c", "m", "y"])[label[i]], ) for j, c in enumerate(["r", "g", "b"]): ellipse = Ellipse( - xy=(out["mean"][n, 0, j, 0], out["mean"][n, 0, j, 1]), - width=4 / jnp.sqrt(out["tau"][n, 0, j, 0]), - height=4 / jnp.sqrt(out["tau"][n, 0, j, 1]), + xy=(out["mean"][0, i, 0, j, 0], out["mean"][0, i, 0, j, 1]), + width=4 / jnp.sqrt(out["tau"][0, i, 0, j, 0]), + height=4 / jnp.sqrt(out["tau"][0, i, 0, j, 1]), fc=c, alpha=0.3, ) - axes[i].add_patch(ellipse) + axes[i // 3][i % 3].add_patch(ellipse) plt.show() @@ -286,6 +291,5 @@ def main(args): tf.config.experimental.set_visible_devices([], "GPU") # Disable GPU for TF. numpyro.set_platform(args.device) - coix.set_backend("coix.oryx") main(args) diff --git a/examples/gmm_oryx.py b/examples/gmm_oryx.py index 3c0604e..8aa6890 100644 --- a/examples/gmm_oryx.py +++ b/examples/gmm_oryx.py @@ -25,8 +25,14 @@ 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural sufficient statistics. ICML 2020. +.. image:: ../_static/gmm_oryx.png + :align: center + """ +# %% +# **Note:** The metrics seem to be incorrect in this example. + import argparse from functools import partial @@ -43,7 +49,6 @@ import numpyro.distributions as dist import optax import tensorflow as tf -import tensorflow_datasets as tfds # %% # First, let's simulate a synthetic dataset of Gaussian clusters. @@ -62,20 +67,25 @@ def simulate_clusters(num_instances=1, N=60, seed=0): return x, c -def load_dataset(split, *, is_training, batch_size): - num_data = 20000 if is_training else batch_size - num_points = 60 if is_training else 100 - seed = 0 if is_training else 1 +def load_dataset(split, *, batch_size): + if split == "train": + num_data = 20000 + num_points = 60 + seed = 0 + else: + num_data = batch_size + num_points = 100 + seed = 1 data, label = simulate_clusters(num_data, num_points, seed=seed) - if is_training: + if split == "train": ds = tf.data.Dataset.from_tensor_slices(data) - ds = ds.cache().repeat() + ds = ds.repeat() ds = ds.shuffle(10 * batch_size, seed=0) else: ds = tf.data.Dataset.from_tensor_slices((data, label)) - ds = ds.cache().repeat() + ds = ds.repeat() ds = ds.batch(batch_size) - return iter(tfds.as_numpy(ds)) + return ds.as_numpy_iterator() # %% @@ -117,6 +127,12 @@ def __call__(self, x): return logits + jnp.log(jnp.ones(3) / 3) +def broadcast_concatenate(*xs): + shape = jnp.broadcast_shapes(*[x.shape[:-1] for x in xs]) + xs = [jnp.broadcast_to(x, shape + x.shape[-1:]) for x in xs] + return jnp.concatenate(xs, -1) + + class GMMEncoder(nn.Module): def setup(self): @@ -129,12 +145,7 @@ def __call__(self, x): # N x D alpha, beta, mean, _ = self.encode_initial_mean_tau(x) # M x D tau = alpha / beta # M x D - concatenate_fn = lambda x, m, t: jnp.concatenate( - [x, m, t], axis=-1 - ) # N x M x 3D - xmt = jax.vmap( - jax.vmap(concatenate_fn, in_axes=(None, 0, 0)), in_axes=(0, None, None) - )(x, mean, tau) + xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau) logits = self.encode_c(xmt) # N x D c = jnp.argmax(logits, -1) # N @@ -167,8 +178,9 @@ def gmm_kernel_mean_tau(network, key, inputs): key_out, key_mean, key_tau = random.split(key, 3) if "c" in inputs: + x = inputs["x"] c = jax.nn.one_hot(inputs["c"], 3) - xc = jnp.concatenate([inputs["x"], c], -1) + xc = jnp.concatenate([x, c], -1) alpha, beta, mu, nu = network.encode_mean_tau(xc) else: alpha, beta, mu, nu = network.encode_initial_mean_tau(inputs["x"]) @@ -184,10 +196,8 @@ def gmm_kernel_mean_tau(network, key, inputs): def gmm_kernel_c(network, key, inputs): key_out, key_c = random.split(key, 2) - concatenate_fn = lambda x, m, t: jnp.concatenate([x, m, t], axis=-1) - xmt = jax.vmap( - jax.vmap(concatenate_fn, in_axes=(None, 0, 0)), in_axes=(0, None, None) - )(inputs["x"], inputs["mean"], inputs["tau"]) + x, mean, tau = inputs["x"], inputs["mean"], inputs["tau"] + xmt = jax.vmap(broadcast_concatenate, (None, -2, -2), -2)(x, mean, tau) logits = network.encode_c(xmt) c = coryx.rv(dist.Categorical(logits=logits), name="c")(key_c) @@ -236,8 +246,8 @@ def main(args): num_sweeps = args.num_sweeps num_particles = args.num_particles - train_ds = load_dataset("train", is_training=True, batch_size=batch_size) - test_ds = load_dataset("test", is_training=False, batch_size=batch_size) + train_ds = load_dataset("train", batch_size=batch_size) + test_ds = load_dataset("test", batch_size=batch_size) init_params = GMMEncoder().init(jax.random.PRNGKey(0), jnp.zeros((60, 2))) gmm_params, _ = coix.util.train( @@ -256,24 +266,23 @@ def main(args): ) _, out = jax.vmap(program)(rng_keys, batch) - fig, axes = plt.subplots(1, 3, figsize=(15, 5)) - for i in range(3): - n = i - axes[i].scatter( - batch[n, 0, :, 0], - batch[n, 0, :, 1], + _, axes = plt.subplots(2, 3, figsize=(15, 10)) + for i in range(6): + axes[i // 3][i % 3].scatter( + batch[i, 0, :, 0], + batch[i, 0, :, 1], marker=".", - color=np.array(["c", "m", "y"])[label[n]], + color=np.array(["c", "m", "y"])[label[i]], ) for j, c in enumerate(["r", "g", "b"]): ellipse = Ellipse( - xy=(out["mean"][n, 0, j, 0], out["mean"][n, 0, j, 1]), - width=4 / jnp.sqrt(out["tau"][n, 0, j, 0]), - height=4 / jnp.sqrt(out["tau"][n, 0, j, 1]), + xy=(out["mean"][i, 0, j, 0], out["mean"][i, 0, j, 1]), + width=4 / jnp.sqrt(out["tau"][i, 0, j, 0]), + height=4 / jnp.sqrt(out["tau"][i, 0, j, 1]), fc=c, alpha=0.3, ) - axes[i].add_patch(ellipse) + axes[i // 3][i % 3].add_patch(ellipse) plt.show() diff --git a/notebooks/tutorial_part2_api.ipynb b/notebooks/tutorial_part2_api.ipynb index ec2342c..18eecb3 100644 --- a/notebooks/tutorial_part2_api.ipynb +++ b/notebooks/tutorial_part2_api.ipynb @@ -564,14 +564,16 @@ " m_xy[..., 0], m_xy[..., 1], m_p_target, levels=[0.05, 0.3], colors=\"C1\"\n", ")\n", "handles, labels = ax_xy.get_legend_handles_labels()\n", - "handles.extend([\n", - " lines.Line2D(\n", - " [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n", - " ),\n", - " lines.Line2D(\n", - " [0], [0], label=\"target denstity of $extend(f,\\ k)$\", color=\"C1\"\n", - " ),\n", - "])\n", + "handles.extend(\n", + " [\n", + " lines.Line2D(\n", + " [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n", + " ),\n", + " lines.Line2D(\n", + " [0], [0], label=\"target denstity of $extend(f,\\ k)$\", color=\"C1\"\n", + " ),\n", + " ]\n", + ")\n", "ax_xy.legend(handles=handles, loc=\"lower left\");" ] }, @@ -966,20 +968,22 @@ " m_xy[..., 0], m_xy[..., 1], m_p_target, levels=[0.05, 0.3], colors=\"C1\"\n", ")\n", "handles, labels = ax_xy.get_legend_handles_labels()\n", - "handles.extend([\n", - " lines.Line2D(\n", - " [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n", - " ),\n", - " lines.Line2D(\n", - " [0], [0], label=\"proposal denstity $compose(k,\\ q2)$\", color=\"C2\"\n", - " ),\n", - " lines.Line2D(\n", - " [0],\n", - " [0],\n", - " label=\"target denstity $extend(f, k)$ and $compose(k,\\ q2)$\",\n", - " color=\"C1\",\n", - " ),\n", - "])\n", + "handles.extend(\n", + " [\n", + " lines.Line2D(\n", + " [0], [0], label=\"prior density of $extend(f,\\ k)$\", color=\"C0\"\n", + " ),\n", + " lines.Line2D(\n", + " [0], [0], label=\"proposal denstity $compose(k,\\ q2)$\", color=\"C2\"\n", + " ),\n", + " lines.Line2D(\n", + " [0],\n", + " [0],\n", + " label=\"target denstity $extend(f, k)$ and $compose(k,\\ q2)$\",\n", + " color=\"C1\",\n", + " ),\n", + " ]\n", + ")\n", "ax_xy.legend(handles=handles, loc=\"lower left\")\n", "\n", "_, f_ext_trace, f_ext_metrics = traced_evaluate(\n", diff --git a/pyproject.toml b/pyproject.toml index 0429e33..58ef935 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ doc = [ "sphinx-gallery", ] oryx = [ - "oryx", + "oryx@git+https://github.com/jax-ml/oryx", ] [tool.pyink]