Skip to content

Commit

Permalink
format the examples
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Mar 29, 2024
1 parent 22ff47f commit bf645d2
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 27 deletions.
2 changes: 1 addition & 1 deletion coix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from coix.core import prng_key
from coix.core import register_backend
from coix.core import set_backend
from coix.core import suffix
from coix.core import stick_the_landing
from coix.core import suffix
from coix.core import traced_evaluate

__all__ = [
Expand Down
1 change: 0 additions & 1 deletion coix/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from numpyro import handlers
import numpyro.distributions as dist


__all__ = [
"detach",
"empirical",
Expand Down
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ Coix Documentation
notebooks/tutorial_part3_smcs
examples/anneal
examples/anneal_oryx
examples/gmm
examples/gmm_oryx
examples/dmm
examples/dmm_oryx
examples/bmnist

Expand Down
10 changes: 7 additions & 3 deletions examples/anneal.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
import numpyro.distributions as dist
import optax

### Networks
# %%
# First, we define the neural networks for the targets and kernels.


class AnnealKernel(nn.Module):
Expand Down Expand Up @@ -110,7 +111,8 @@ def __call__(self, x):
return self.forward_kernels(x)


### Model and kernels
# %%
# Then, we define the targets and kernels as in Section E.1.


def anneal_target(network, k=0):
Expand All @@ -130,7 +132,9 @@ def anneal_reverse(network, inputs, k=0):
return numpyro.sample("x", dist.Normal(mu, sigma).to_event(1))


### Train
# %%
# Finally, we create the anneal inference program, define the loss function,
# run the training loop, and plot the results.


def make_anneal(params, unroll=False, num_particles=10):
Expand Down
10 changes: 7 additions & 3 deletions examples/anneal_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
import numpyro.distributions as dist
import optax

### Networks
# %%
# First, we define the neural networks for the targets and kernels.


class AnnealKernel(nn.Module):
Expand Down Expand Up @@ -111,7 +112,8 @@ def __call__(self, x):
return self.forward_kernels(x)


### Model and kernels
# %%
# Then, we define the targets and kernels as in Section E.1.


def anneal_target(network, key, k=0):
Expand All @@ -131,7 +133,9 @@ def anneal_reverse(network, key, inputs, k=0):
return coryx.rv(dist.Normal(mu, sigma), name="x")(key)


### Train
# %%
# Finally, we create the anneal inference program, define the loss function,
# run the training loop, and plot the results.


def make_anneal(params, unroll=False):
Expand Down
32 changes: 25 additions & 7 deletions examples/bmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,19 @@
# limitations under the License.

"""
BMNIST example in NumPyro
=========================
Example: Time Series Model - Bouncing MNIST in NumPyro
======================================================
This example illustrates how to construct an inference program based on the APGS
sampler [1] for BMNIST. The details of BMNIST can be found in the sections
6.4 and F.3 of the reference. We will use the NumPyro (default) backend for this
example.
**References**
1. Wu, Hao, et al. Amortized population Gibbs samplers with neural
sufficient statistics. ICML 2020.
"""

# TODO: Refactor using numpyro backend. The current code is likely not working yet.
Expand All @@ -40,6 +51,9 @@
import tensorflow as tf
import tensorflow_datasets as tfds

# %%
# First, let's load the moving mnist dataset.

batch_size = 5
T = 10 # using 10 time steps for training and 20 time steps for testing

Expand Down Expand Up @@ -74,7 +88,9 @@ def get_digit_mean():
print("Digit shape:", digit_mean.shape)


### Autoencoder
# %%
# Next, we define the neural proposals for the Gibbs kernels and the neural
# decoder for the generative model.


def scale_and_translate(image, where, out_size):
Expand Down Expand Up @@ -191,7 +207,8 @@ def __call__(self, frames):
return frames_recon


### Model and kernels
# %%
# Then, we define the target and kernels as in Section 6.4.

test_key = random.PRNGKey(0)
test_data = jnp.zeros((frame_length,) + (frame_size, frame_size))
Expand Down Expand Up @@ -281,6 +298,10 @@ def kernel_what(network, key, inputs, T=10):
_, k2_out = kernel_what(test_network, test_key, p_out, T=frame_length)


# %%
# Finally, we create the dmm inference program, define the loss function,
# run the training loop, and plot the results.

num_sweeps = 5
num_particles = 10

Expand All @@ -297,9 +318,6 @@ def make_bmnist(params, T=10):
return program


a


def loss_fn(params, key, batch):
# Prepare data for the program.
shuffle_rng, rng_key = random.split(key)
Expand Down
14 changes: 10 additions & 4 deletions examples/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
import tensorflow as tf
import tensorflow_datasets as tfds

# Data
# %%
# First, let's simulate a synthetic dataset of 2D ring-shaped mixtures.


def simulate_rings(num_instances=1, N=200, seed=0):
Expand Down Expand Up @@ -76,7 +77,9 @@ def load_dataset(split, *, is_training, batch_size):
return iter(tfds.as_numpy(ds))


### Autoencoder
# %%
# Next, we define the neural proposals for the Gibbs kernels and the neural
# decoder for the generative model.


class EncoderMu(nn.Module):
Expand Down Expand Up @@ -168,7 +171,8 @@ def __call__(self, x): # N x D
return x_recon


### Model and kernels
# %%
# Then, we define the target and kernels as in Section 6.3.


def dmm_target(network, key, inputs):
Expand Down Expand Up @@ -219,7 +223,9 @@ def dmm_kernel_c_h(network, key, inputs):
return key_out, out


### Train
# %%
# Finally, we create the dmm inference program, define the loss function,
# run the training loop, and plot the results.


def make_dmm(params, num_sweeps):
Expand Down
10 changes: 7 additions & 3 deletions examples/dmm_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
import tensorflow as tf
import tensorflow_datasets as tfds

# Data
# %%
# First, let's simulate a synthetic dataset of 2D ring-shaped mixtures.


def simulate_rings(num_instances=1, N=200, seed=0):
Expand Down Expand Up @@ -76,7 +77,9 @@ def load_dataset(split, *, is_training, batch_size):
return iter(tfds.as_numpy(ds))


### Autoencoder
# %%
# Next, we define the neural proposals for the Gibbs kernels and the neural
# decoder for the generative model.


class EncoderMu(nn.Module):
Expand Down Expand Up @@ -168,7 +171,8 @@ def __call__(self, x): # N x D
return x_recon


### Model and kernels
# %%
# Then, we define the target and kernels as in Section 6.3.


def dmm_target(network, key, inputs):
Expand Down
7 changes: 5 additions & 2 deletions examples/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@
import tensorflow as tf
import tensorflow_datasets as tfds


# %%
# First, let's simulate a synthetic dataset of Gaussian clusters.
# First, let's simulate a synthetic dataset of 2D Gaussian mixtures.


def simulate_clusters(num_instances=1, N=60, seed=0):
np.random.seed(seed)
Expand Down Expand Up @@ -81,6 +81,7 @@ def load_dataset(split, *, is_training, batch_size):
# %%
# Next, we define the neural proposals for the Gibbs kernels.


class GMMEncoderMeanTau(nn.Module):

@nn.compact
Expand Down Expand Up @@ -144,6 +145,7 @@ def __call__(self, x): # N x D
# %%
# Then, we define the target and kernels as in Section 6.2.


def gmm_target(network, key, inputs):
key_out, key_mean, key_tau, key_c = random.split(key, 4)
N = inputs.shape[-2]
Expand Down Expand Up @@ -193,6 +195,7 @@ def gmm_kernel_c(network, key, inputs):
# Finally, we create the gmm inference program, define the loss function,
# run the training loop, and plot the results.


def make_gmm(params, num_sweeps):
network = coix.util.BindModule(GMMEncoder(), params)
# Add particle dimension and construct a program.
Expand Down
13 changes: 10 additions & 3 deletions examples/gmm_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@
import tensorflow as tf
import tensorflow_datasets as tfds


# %%
# First, let's simulate a synthetic dataset of Gaussian clusters.


def simulate_clusters(num_instances=1, N=60, seed=0):
np.random.seed(seed)
tau = np.random.gamma(2, 0.5, (num_instances, 4, 2))
Expand Down Expand Up @@ -81,6 +81,7 @@ def load_dataset(split, *, is_training, batch_size):
# %%
# Next, we define the neural proposals for the Gibbs kernels.


class GMMEncoderMeanTau(nn.Module):

@nn.compact
Expand Down Expand Up @@ -144,12 +145,15 @@ def __call__(self, x): # N x D
# %%
# Then, we define the target and kernels as in Section 6.2.


def gmm_target(network, key, inputs):
key_out, key_mean, key_tau, key_c = random.split(key, 4)
N = inputs.shape[-2]

tau = coryx.rv(dist.Gamma(2, 2).expand([3, 2]), name="tau")(key_tau)
mean = coryx.rv(dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)), name="mean")(key_mean)
mean = coryx.rv(dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)), name="mean")(
key_mean
)
c = coryx.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c)
x = coryx.rv(dist.Normal(mean[c], 1 / jnp.sqrt(tau[c])), obs=inputs, name="x")

Expand All @@ -169,7 +173,9 @@ def gmm_kernel_mean_tau(network, key, inputs):
else:
alpha, beta, mu, nu = network.encode_initial_mean_tau(inputs["x"])
tau = coryx.rv(dist.Gamma(alpha, beta), name="tau")(key_tau)
mean = coryx.rv(dist.Normal(mu, 1 / jnp.sqrt(tau * nu)), name="mean")(key_mean)
mean = coryx.rv(dist.Normal(mu, 1 / jnp.sqrt(tau * nu)), name="mean")(
key_mean
)

out = {**inputs, **{"mean": mean, "tau": tau}}
return key_out, out
Expand All @@ -193,6 +199,7 @@ def gmm_kernel_c(network, key, inputs):
# Finally, we create the gmm inference program, define the loss function,
# run the training loop, and plot the results.


def make_gmm(params, num_sweeps):
network = coix.util.BindModule(GMMEncoder(), params)
# Add particle dimension and construct a program.
Expand Down

0 comments on commit bf645d2

Please sign in to comment.