From 493c965a944ba526ecc71d890bfd36f53f47efba Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 29 Mar 2024 07:44:07 -0400 Subject: [PATCH] update the exposed api and add docs to examples --- coix/__init__.py | 8 ++++---- coix/core.py | 30 +++------------------------- coix/numpyro.py | 10 ++++++++++ coix/oryx.py | 5 +++++ examples/anneal.py | 13 +++++++++++-- examples/anneal_oryx.py | 20 ++++++++++++++----- examples/dmm_oryx.py | 29 ++++++++++++++++++--------- examples/gmm_oryx.py | 43 +++++++++++++++++++++++++++-------------- 8 files changed, 96 insertions(+), 62 deletions(-) diff --git a/coix/__init__.py b/coix/__init__.py index d46c565..d01472d 100644 --- a/coix/__init__.py +++ b/coix/__init__.py @@ -25,10 +25,10 @@ from coix.api import resample from coix.core import detach from coix.core import empirical -from coix.core import factor +from coix.core import prng_key from coix.core import register_backend -from coix.core import rv from coix.core import set_backend +from coix.core import suffix from coix.core import stick_the_landing from coix.core import traced_evaluate @@ -39,16 +39,16 @@ "compose", "detach", "extend", - "factor", "fori_loop", "loss", "memoize", + "prng_key", "propose", "register_backend", "resample", - "rv", "set_backend", "stick_the_landing", + "suffix", "traced_evaluate", "util", ] diff --git a/coix/core.py b/coix/core.py index 3d4e2b0..f47a393 100644 --- a/coix/core.py +++ b/coix/core.py @@ -19,9 +19,7 @@ __all__ = [ "detach", "empirical", - "factor", "prng_key", - "rv", "register_backend", "set_backend", "stick_the_landing", @@ -39,22 +37,18 @@ def register_backend( traced_evaluate=None, empirical=None, suffix=None, + prng_key=None, detach=None, stick_the_landing=None, - rv=None, - factor=None, - prng_key=None, ): """Register backend.""" fn_map = { "traced_evaluate": traced_evaluate, "empirical": empirical, "suffix": suffix, + "prng_key": prng_key, "detach": detach, "stick_the_landing": stick_the_landing, - "rv": rv, - "factor": factor, - "prng_key": prng_key, } _BACKENDS[backend] = fn_map @@ -73,11 +67,9 @@ def set_backend(backend): "traced_evaluate", "empirical", "suffix", + "prng_key", "detach", "stick_the_landing", - "rv", - "factor", - "prng_key", ]: fn_map[fn] = getattr(module, fn, None) register_backend(backend, **fn_map) @@ -172,22 +164,6 @@ def stick_the_landing(p): return p -def rv(*args, **kwargs): - fn = get_backend()["rv"] - if fn is not None: - return fn(*args, **kwargs) - else: - raise NotImplementedError - - -def factor(*args, **kwargs): - fn = get_backend()["factor"] - if fn is not None: - return fn(*args, **kwargs) - else: - raise NotImplementedError - - def prng_key(): fn = get_backend()["prng_key"] if fn is not None: diff --git a/coix/numpyro.py b/coix/numpyro.py index e624272..6c52e95 100644 --- a/coix/numpyro.py +++ b/coix/numpyro.py @@ -23,6 +23,16 @@ from numpyro import handlers import numpyro.distributions as dist + +__all__ = [ + "detach", + "empirical", + "prng_key", + "stick_the_landing", + "suffix", + "traced_evaluate", +] + prng_key = numpyro.prng_key diff --git a/coix/oryx.py b/coix/oryx.py index b03e15f..fa687da 100644 --- a/coix/oryx.py +++ b/coix/oryx.py @@ -37,6 +37,7 @@ "detach", "empirical", "factor", + "prng_key", "rv", "stick_the_landing", "suffix", @@ -449,3 +450,7 @@ def wrapped(*args, **kwargs): return out return wrapped + + +def prng_key(): + raise ValueError("Cannot genenerate random key under the oryx backend.") diff --git a/examples/anneal.py b/examples/anneal.py index a66edb2..2e62d98 100644 --- a/examples/anneal.py +++ b/examples/anneal.py @@ -13,8 +13,17 @@ # limitations under the License. """ -Example: Anneal example in NumPyro -================================== +Example: Annealed Variational Inference in NumPyro +================================================== + +This example illustrates how to construct an inference program based on the NVI +algorithm [1] for AVI. The details of AVI can be found in the sections E.1 of +the reference. We will use the NumPyro (default) backend for this example. + +**References** + + 1. Zimmermann, Heiko, et al. "Nested variational inference." NeuRIPS 2021. + """ import argparse diff --git a/examples/anneal_oryx.py b/examples/anneal_oryx.py index b220124..3e13b66 100644 --- a/examples/anneal_oryx.py +++ b/examples/anneal_oryx.py @@ -13,14 +13,24 @@ # limitations under the License. """ -Example: Anneal example in Oryx -=============================== +Example: Annealed Variational Inference in Oryx +=============================================== + +This example illustrates how to construct an inference program based on the NVI +algorithm [1] for AVI. The details of AVI can be found in the sections E.1 of +the reference. We will use the Oryx backend for this example. + +**References** + + 1. Zimmermann, Heiko, et al. "Nested variational inference." NeuRIPS 2021. + """ import argparse from functools import partial import coix +import coix.oryx as coryx import flax import flax.linen as nn import jax @@ -106,19 +116,19 @@ def __call__(self, x): def anneal_target(network, key, k=0): key_out, key = random.split(key) - x = coix.rv(dist.Normal(0, 5).expand([2]).mask(False), name="x")(key) + x = coryx.rv(dist.Normal(0, 5).expand([2]).mask(False), name="x")(key) coix.factor(network.anneal_density(x, index=k), name="anneal_density") return key_out, {"x": x} def anneal_forward(network, key, inputs, k=0): mu, sigma = network.forward_kernels(inputs["x"], index=k) - return coix.rv(dist.Normal(mu, sigma), name="x")(key) + return coryx.rv(dist.Normal(mu, sigma), name="x")(key) def anneal_reverse(network, key, inputs, k=0): mu, sigma = network.reverse_kernels(inputs["x"], index=k) - return coix.rv(dist.Normal(mu, sigma), name="x")(key) + return coryx.rv(dist.Normal(mu, sigma), name="x")(key) ### Train diff --git a/examples/dmm_oryx.py b/examples/dmm_oryx.py index cb39155..a3fb043 100644 --- a/examples/dmm_oryx.py +++ b/examples/dmm_oryx.py @@ -13,14 +13,25 @@ # limitations under the License. """ -Example: DMM example in Oryx -============================ +Example: Deep Generative Mixture Model in Oryx +============================================== + +This example illustrates how to construct an inference program based on the APGS +sampler [1] for DMM. The details of DMM can be found in the sections 6.3 and +F.2 of the reference. We will use the Oryx backend for this example. + +**References** + + 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural + sufficient statistics. ICML 2020. + """ import argparse from functools import partial import coix +import coix.oryx as coryx import flax.linen as nn import jax from jax import random @@ -164,11 +175,11 @@ def dmm_target(network, key, inputs): key_out, key_mu, key_c, key_h = random.split(key, 4) N = inputs.shape[-2] - mu = coix.rv(dist.Normal(0, 10).expand([4, 2]), name="mu")(key_mu) - c = coix.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c) - h = coix.rv(dist.Beta(1, 1).expand([N]), name="h")(key_h) + mu = coryx.rv(dist.Normal(0, 10).expand([4, 2]), name="mu")(key_mu) + c = coryx.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c) + h = coryx.rv(dist.Beta(1, 1).expand([N]), name="h")(key_h) x_recon = mu[c] + network.decode_h(h) - x = coix.rv(dist.Normal(x_recon, 0.1), obs=inputs, name="x") + x = coryx.rv(dist.Normal(x_recon, 0.1), obs=inputs, name="x") out = {"mu": mu, "c": c, "h": h, "x_recon": x_recon, "x": x} return key_out, out @@ -186,7 +197,7 @@ def dmm_kernel_mu(network, key, inputs): loc, scale = network.encode_mu(xch) else: loc, scale = network.encode_initial_mu(inputs["x"]) - mu = coix.rv(dist.Normal(loc, scale), name="mu")(key_mu) + mu = coryx.rv(dist.Normal(loc, scale), name="mu")(key_mu) out = {**inputs, **{"mu": mu}} return key_out, out @@ -200,9 +211,9 @@ def dmm_kernel_c_h(network, key, inputs): inputs["x"], inputs["mu"] ) logits = network.encode_c(xmu) - c = coix.rv(dist.Categorical(logits=logits), name="c")(key_c) + c = coryx.rv(dist.Categorical(logits=logits), name="c")(key_c) alpha, beta = network.encode_h(inputs["x"] - inputs["mu"][c]) - h = coix.rv(dist.Beta(alpha, beta), name="h")(key_h) + h = coryx.rv(dist.Beta(alpha, beta), name="h")(key_h) out = {**inputs, **{"c": c, "h": h}} return key_out, out diff --git a/examples/gmm_oryx.py b/examples/gmm_oryx.py index 26dac93..7270f23 100644 --- a/examples/gmm_oryx.py +++ b/examples/gmm_oryx.py @@ -13,14 +13,25 @@ # limitations under the License. """ -Example: GMM example in Oryx -============================ +Example: Gaussian Mixture Model in Oryx +======================================= + +This example illustrates how to construct an inference program for GMM, based on +the APGS sampler [1]. The details of GMM can be found in the sections 6.2 and +F.1 of the reference. We will use the Oryx backend for this example. + +**References** + + 1. Wu, Hao, et al. Amortized population Gibbs samplers with neural + sufficient statistics. ICML 2020. + """ import argparse from functools import partial import coix +import coix.oryx as coryx import flax.linen as nn import jax from jax import random @@ -34,8 +45,9 @@ import tensorflow as tf import tensorflow_datasets as tfds -### Data +# %% +# First, let's simulate a synthetic dataset of Gaussian clusters. def simulate_clusters(num_instances=1, N=60, seed=0): np.random.seed(seed) @@ -66,8 +78,8 @@ def load_dataset(split, *, is_training, batch_size): return iter(tfds.as_numpy(ds)) -### Encoder - +# %% +# Next, we define the neural proposals for the Gibbs kernels. class GMMEncoderMeanTau(nn.Module): @@ -129,17 +141,17 @@ def __call__(self, x): # N x D return self.encode_mean_tau(xc) -### Model and kernels - +# %% +# Then, we define the target and kernels as in Section 6.2. def gmm_target(network, key, inputs): key_out, key_mean, key_tau, key_c = random.split(key, 4) N = inputs.shape[-2] - tau = coix.rv(dist.Gamma(2, 2).expand([3, 2]), name="tau")(key_tau) - mean = coix.rv(dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)), name="mean")(key_mean) - c = coix.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c) - x = coix.rv(dist.Normal(mean[c], 1 / jnp.sqrt(tau[c])), obs=inputs, name="x") + tau = coryx.rv(dist.Gamma(2, 2).expand([3, 2]), name="tau")(key_tau) + mean = coryx.rv(dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)), name="mean")(key_mean) + c = coryx.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c) + x = coryx.rv(dist.Normal(mean[c], 1 / jnp.sqrt(tau[c])), obs=inputs, name="x") out = {"mean": mean, "tau": tau, "c": c, "x": x} return key_out, out @@ -156,8 +168,8 @@ def gmm_kernel_mean_tau(network, key, inputs): alpha, beta, mu, nu = network.encode_mean_tau(xc) else: alpha, beta, mu, nu = network.encode_initial_mean_tau(inputs["x"]) - tau = coix.rv(dist.Gamma(alpha, beta), name="tau")(key_tau) - mean = coix.rv(dist.Normal(mu, 1 / jnp.sqrt(tau * nu)), name="mean")(key_mean) + tau = coryx.rv(dist.Gamma(alpha, beta), name="tau")(key_tau) + mean = coryx.rv(dist.Normal(mu, 1 / jnp.sqrt(tau * nu)), name="mean")(key_mean) out = {**inputs, **{"mean": mean, "tau": tau}} return key_out, out @@ -177,8 +189,9 @@ def gmm_kernel_c(network, key, inputs): return key_out, out -### Train - +# %% +# Finally, we create the gmm inference program, define the loss function, +# run the training loop, and plot the results. def make_gmm(params, num_sweeps): network = coix.util.BindModule(GMMEncoder(), params)