From bf645d25fd8e2fce537c4c6bcfa4161364a8a877 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 29 Mar 2024 08:33:10 -0400 Subject: [PATCH] format the examples --- coix/__init__.py | 2 +- coix/numpyro.py | 1 - docs/index.rst | 2 ++ examples/anneal.py | 10 +++++++--- examples/anneal_oryx.py | 10 +++++++--- examples/bmnist.py | 32 +++++++++++++++++++++++++------- examples/dmm.py | 14 ++++++++++---- examples/dmm_oryx.py | 10 +++++++--- examples/gmm.py | 7 +++++-- examples/gmm_oryx.py | 13 ++++++++++--- 10 files changed, 74 insertions(+), 27 deletions(-) diff --git a/coix/__init__.py b/coix/__init__.py index d01472d..7188797 100644 --- a/coix/__init__.py +++ b/coix/__init__.py @@ -28,8 +28,8 @@ from coix.core import prng_key from coix.core import register_backend from coix.core import set_backend -from coix.core import suffix from coix.core import stick_the_landing +from coix.core import suffix from coix.core import traced_evaluate __all__ = [ diff --git a/coix/numpyro.py b/coix/numpyro.py index 6c52e95..ee96ad6 100644 --- a/coix/numpyro.py +++ b/coix/numpyro.py @@ -23,7 +23,6 @@ from numpyro import handlers import numpyro.distributions as dist - __all__ = [ "detach", "empirical", diff --git a/docs/index.rst b/docs/index.rst index 6fa2a7e..368868a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -30,7 +30,9 @@ Coix Documentation notebooks/tutorial_part3_smcs examples/anneal examples/anneal_oryx + examples/gmm examples/gmm_oryx + examples/dmm examples/dmm_oryx examples/bmnist diff --git a/examples/anneal.py b/examples/anneal.py index 2e62d98..08a3893 100644 --- a/examples/anneal.py +++ b/examples/anneal.py @@ -41,7 +41,8 @@ import numpyro.distributions as dist import optax -### Networks +# %% +# First, we define the neural networks for the targets and kernels. class AnnealKernel(nn.Module): @@ -110,7 +111,8 @@ def __call__(self, x): return self.forward_kernels(x) -### Model and kernels +# %% +# Then, we define the targets and kernels as in Section E.1. def anneal_target(network, k=0): @@ -130,7 +132,9 @@ def anneal_reverse(network, inputs, k=0): return numpyro.sample("x", dist.Normal(mu, sigma).to_event(1)) -### Train +# %% +# Finally, we create the anneal inference program, define the loss function, +# run the training loop, and plot the results. def make_anneal(params, unroll=False, num_particles=10): diff --git a/examples/anneal_oryx.py b/examples/anneal_oryx.py index 3e13b66..8803067 100644 --- a/examples/anneal_oryx.py +++ b/examples/anneal_oryx.py @@ -42,7 +42,8 @@ import numpyro.distributions as dist import optax -### Networks +# %% +# First, we define the neural networks for the targets and kernels. class AnnealKernel(nn.Module): @@ -111,7 +112,8 @@ def __call__(self, x): return self.forward_kernels(x) -### Model and kernels +# %% +# Then, we define the targets and kernels as in Section E.1. def anneal_target(network, key, k=0): @@ -131,7 +133,9 @@ def anneal_reverse(network, key, inputs, k=0): return coryx.rv(dist.Normal(mu, sigma), name="x")(key) -### Train +# %% +# Finally, we create the anneal inference program, define the loss function, +# run the training loop, and plot the results. def make_anneal(params, unroll=False): diff --git a/examples/bmnist.py b/examples/bmnist.py index 0fa67d7..255b5ad 100644 --- a/examples/bmnist.py +++ b/examples/bmnist.py @@ -13,8 +13,19 @@ # limitations under the License. """ -BMNIST example in NumPyro -========================= +Example: Time Series Model - Bouncing MNIST in NumPyro +====================================================== + +This example illustrates how to construct an inference program based on the APGS +sampler [1] for BMNIST. The details of BMNIST can be found in the sections +6.4 and F.3 of the reference. We will use the NumPyro (default) backend for this +example. + +**References** + + 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural + sufficient statistics. ICML 2020. + """ # TODO: Refactor using numpyro backend. The current code is likely not working yet. @@ -40,6 +51,9 @@ import tensorflow as tf import tensorflow_datasets as tfds +# %% +# First, let's load the moving mnist dataset. + batch_size = 5 T = 10 # using 10 time steps for training and 20 time steps for testing @@ -74,7 +88,9 @@ def get_digit_mean(): print("Digit shape:", digit_mean.shape) -### Autoencoder +# %% +# Next, we define the neural proposals for the Gibbs kernels and the neural +# decoder for the generative model. def scale_and_translate(image, where, out_size): @@ -191,7 +207,8 @@ def __call__(self, frames): return frames_recon -### Model and kernels +# %% +# Then, we define the target and kernels as in Section 6.4. test_key = random.PRNGKey(0) test_data = jnp.zeros((frame_length,) + (frame_size, frame_size)) @@ -281,6 +298,10 @@ def kernel_what(network, key, inputs, T=10): _, k2_out = kernel_what(test_network, test_key, p_out, T=frame_length) +# %% +# Finally, we create the dmm inference program, define the loss function, +# run the training loop, and plot the results. + num_sweeps = 5 num_particles = 10 @@ -297,9 +318,6 @@ def make_bmnist(params, T=10): return program -a - - def loss_fn(params, key, batch): # Prepare data for the program. shuffle_rng, rng_key = random.split(key) diff --git a/examples/dmm.py b/examples/dmm.py index 1703cc0..11942b9 100644 --- a/examples/dmm.py +++ b/examples/dmm.py @@ -44,7 +44,8 @@ import tensorflow as tf import tensorflow_datasets as tfds -# Data +# %% +# First, let's simulate a synthetic dataset of 2D ring-shaped mixtures. def simulate_rings(num_instances=1, N=200, seed=0): @@ -76,7 +77,9 @@ def load_dataset(split, *, is_training, batch_size): return iter(tfds.as_numpy(ds)) -### Autoencoder +# %% +# Next, we define the neural proposals for the Gibbs kernels and the neural +# decoder for the generative model. class EncoderMu(nn.Module): @@ -168,7 +171,8 @@ def __call__(self, x): # N x D return x_recon -### Model and kernels +# %% +# Then, we define the target and kernels as in Section 6.3. def dmm_target(network, key, inputs): @@ -219,7 +223,9 @@ def dmm_kernel_c_h(network, key, inputs): return key_out, out -### Train +# %% +# Finally, we create the dmm inference program, define the loss function, +# run the training loop, and plot the results. def make_dmm(params, num_sweeps): diff --git a/examples/dmm_oryx.py b/examples/dmm_oryx.py index a3fb043..7ba11fa 100644 --- a/examples/dmm_oryx.py +++ b/examples/dmm_oryx.py @@ -44,7 +44,8 @@ import tensorflow as tf import tensorflow_datasets as tfds -# Data +# %% +# First, let's simulate a synthetic dataset of 2D ring-shaped mixtures. def simulate_rings(num_instances=1, N=200, seed=0): @@ -76,7 +77,9 @@ def load_dataset(split, *, is_training, batch_size): return iter(tfds.as_numpy(ds)) -### Autoencoder +# %% +# Next, we define the neural proposals for the Gibbs kernels and the neural +# decoder for the generative model. class EncoderMu(nn.Module): @@ -168,7 +171,8 @@ def __call__(self, x): # N x D return x_recon -### Model and kernels +# %% +# Then, we define the target and kernels as in Section 6.3. def dmm_target(network, key, inputs): diff --git a/examples/gmm.py b/examples/gmm.py index 72af3f0..552a96b 100644 --- a/examples/gmm.py +++ b/examples/gmm.py @@ -45,9 +45,9 @@ import tensorflow as tf import tensorflow_datasets as tfds - # %% -# First, let's simulate a synthetic dataset of Gaussian clusters. +# First, let's simulate a synthetic dataset of 2D Gaussian mixtures. + def simulate_clusters(num_instances=1, N=60, seed=0): np.random.seed(seed) @@ -81,6 +81,7 @@ def load_dataset(split, *, is_training, batch_size): # %% # Next, we define the neural proposals for the Gibbs kernels. + class GMMEncoderMeanTau(nn.Module): @nn.compact @@ -144,6 +145,7 @@ def __call__(self, x): # N x D # %% # Then, we define the target and kernels as in Section 6.2. + def gmm_target(network, key, inputs): key_out, key_mean, key_tau, key_c = random.split(key, 4) N = inputs.shape[-2] @@ -193,6 +195,7 @@ def gmm_kernel_c(network, key, inputs): # Finally, we create the gmm inference program, define the loss function, # run the training loop, and plot the results. + def make_gmm(params, num_sweeps): network = coix.util.BindModule(GMMEncoder(), params) # Add particle dimension and construct a program. diff --git a/examples/gmm_oryx.py b/examples/gmm_oryx.py index 7270f23..f11082a 100644 --- a/examples/gmm_oryx.py +++ b/examples/gmm_oryx.py @@ -45,10 +45,10 @@ import tensorflow as tf import tensorflow_datasets as tfds - # %% # First, let's simulate a synthetic dataset of Gaussian clusters. + def simulate_clusters(num_instances=1, N=60, seed=0): np.random.seed(seed) tau = np.random.gamma(2, 0.5, (num_instances, 4, 2)) @@ -81,6 +81,7 @@ def load_dataset(split, *, is_training, batch_size): # %% # Next, we define the neural proposals for the Gibbs kernels. + class GMMEncoderMeanTau(nn.Module): @nn.compact @@ -144,12 +145,15 @@ def __call__(self, x): # N x D # %% # Then, we define the target and kernels as in Section 6.2. + def gmm_target(network, key, inputs): key_out, key_mean, key_tau, key_c = random.split(key, 4) N = inputs.shape[-2] tau = coryx.rv(dist.Gamma(2, 2).expand([3, 2]), name="tau")(key_tau) - mean = coryx.rv(dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)), name="mean")(key_mean) + mean = coryx.rv(dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)), name="mean")( + key_mean + ) c = coryx.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c) x = coryx.rv(dist.Normal(mean[c], 1 / jnp.sqrt(tau[c])), obs=inputs, name="x") @@ -169,7 +173,9 @@ def gmm_kernel_mean_tau(network, key, inputs): else: alpha, beta, mu, nu = network.encode_initial_mean_tau(inputs["x"]) tau = coryx.rv(dist.Gamma(alpha, beta), name="tau")(key_tau) - mean = coryx.rv(dist.Normal(mu, 1 / jnp.sqrt(tau * nu)), name="mean")(key_mean) + mean = coryx.rv(dist.Normal(mu, 1 / jnp.sqrt(tau * nu)), name="mean")( + key_mean + ) out = {**inputs, **{"mean": mean, "tau": tau}} return key_out, out @@ -193,6 +199,7 @@ def gmm_kernel_c(network, key, inputs): # Finally, we create the gmm inference program, define the loss function, # run the training loop, and plot the results. + def make_gmm(params, num_sweeps): network = coix.util.BindModule(GMMEncoder(), params) # Add particle dimension and construct a program.