Skip to content

Commit

Permalink
update the exposed api and add docs to examples
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Mar 29, 2024
1 parent 8a9a825 commit 493c965
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 62 deletions.
8 changes: 4 additions & 4 deletions coix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from coix.api import resample
from coix.core import detach
from coix.core import empirical
from coix.core import factor
from coix.core import prng_key
from coix.core import register_backend
from coix.core import rv
from coix.core import set_backend
from coix.core import suffix
from coix.core import stick_the_landing
from coix.core import traced_evaluate

Expand All @@ -39,16 +39,16 @@
"compose",
"detach",
"extend",
"factor",
"fori_loop",
"loss",
"memoize",
"prng_key",
"propose",
"register_backend",
"resample",
"rv",
"set_backend",
"stick_the_landing",
"suffix",
"traced_evaluate",
"util",
]
Expand Down
30 changes: 3 additions & 27 deletions coix/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
__all__ = [
"detach",
"empirical",
"factor",
"prng_key",
"rv",
"register_backend",
"set_backend",
"stick_the_landing",
Expand All @@ -39,22 +37,18 @@ def register_backend(
traced_evaluate=None,
empirical=None,
suffix=None,
prng_key=None,
detach=None,
stick_the_landing=None,
rv=None,
factor=None,
prng_key=None,
):
"""Register backend."""
fn_map = {
"traced_evaluate": traced_evaluate,
"empirical": empirical,
"suffix": suffix,
"prng_key": prng_key,
"detach": detach,
"stick_the_landing": stick_the_landing,
"rv": rv,
"factor": factor,
"prng_key": prng_key,
}
_BACKENDS[backend] = fn_map

Expand All @@ -73,11 +67,9 @@ def set_backend(backend):
"traced_evaluate",
"empirical",
"suffix",
"prng_key",
"detach",
"stick_the_landing",
"rv",
"factor",
"prng_key",
]:
fn_map[fn] = getattr(module, fn, None)
register_backend(backend, **fn_map)
Expand Down Expand Up @@ -172,22 +164,6 @@ def stick_the_landing(p):
return p


def rv(*args, **kwargs):
fn = get_backend()["rv"]
if fn is not None:
return fn(*args, **kwargs)
else:
raise NotImplementedError


def factor(*args, **kwargs):
fn = get_backend()["factor"]
if fn is not None:
return fn(*args, **kwargs)
else:
raise NotImplementedError


def prng_key():
fn = get_backend()["prng_key"]
if fn is not None:
Expand Down
10 changes: 10 additions & 0 deletions coix/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@
from numpyro import handlers
import numpyro.distributions as dist


__all__ = [
"detach",
"empirical",
"prng_key",
"stick_the_landing",
"suffix",
"traced_evaluate",
]

prng_key = numpyro.prng_key


Expand Down
5 changes: 5 additions & 0 deletions coix/oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"detach",
"empirical",
"factor",
"prng_key",
"rv",
"stick_the_landing",
"suffix",
Expand Down Expand Up @@ -449,3 +450,7 @@ def wrapped(*args, **kwargs):
return out

return wrapped


def prng_key():
raise ValueError("Cannot genenerate random key under the oryx backend.")
13 changes: 11 additions & 2 deletions examples/anneal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,17 @@
# limitations under the License.

"""
Example: Anneal example in NumPyro
==================================
Example: Annealed Variational Inference in NumPyro
==================================================
This example illustrates how to construct an inference program based on the NVI
algorithm [1] for AVI. The details of AVI can be found in the sections E.1 of
the reference. We will use the NumPyro (default) backend for this example.
**References**
1. Zimmermann, Heiko, et al. "Nested variational inference." NeuRIPS 2021.
"""

import argparse
Expand Down
20 changes: 15 additions & 5 deletions examples/anneal_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,24 @@
# limitations under the License.

"""
Example: Anneal example in Oryx
===============================
Example: Annealed Variational Inference in Oryx
===============================================
This example illustrates how to construct an inference program based on the NVI
algorithm [1] for AVI. The details of AVI can be found in the sections E.1 of
the reference. We will use the Oryx backend for this example.
**References**
1. Zimmermann, Heiko, et al. "Nested variational inference." NeuRIPS 2021.
"""

import argparse
from functools import partial

import coix
import coix.oryx as coryx
import flax
import flax.linen as nn
import jax
Expand Down Expand Up @@ -106,19 +116,19 @@ def __call__(self, x):

def anneal_target(network, key, k=0):
key_out, key = random.split(key)
x = coix.rv(dist.Normal(0, 5).expand([2]).mask(False), name="x")(key)
x = coryx.rv(dist.Normal(0, 5).expand([2]).mask(False), name="x")(key)
coix.factor(network.anneal_density(x, index=k), name="anneal_density")
return key_out, {"x": x}


def anneal_forward(network, key, inputs, k=0):
mu, sigma = network.forward_kernels(inputs["x"], index=k)
return coix.rv(dist.Normal(mu, sigma), name="x")(key)
return coryx.rv(dist.Normal(mu, sigma), name="x")(key)


def anneal_reverse(network, key, inputs, k=0):
mu, sigma = network.reverse_kernels(inputs["x"], index=k)
return coix.rv(dist.Normal(mu, sigma), name="x")(key)
return coryx.rv(dist.Normal(mu, sigma), name="x")(key)


### Train
Expand Down
29 changes: 20 additions & 9 deletions examples/dmm_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,25 @@
# limitations under the License.

"""
Example: DMM example in Oryx
============================
Example: Deep Generative Mixture Model in Oryx
==============================================
This example illustrates how to construct an inference program based on the APGS
sampler [1] for DMM. The details of DMM can be found in the sections 6.3 and
F.2 of the reference. We will use the Oryx backend for this example.
**References**
1. Wu, Hao, et al. Amortized population Gibbs samplers with neural
sufficient statistics. ICML 2020.
"""

import argparse
from functools import partial

import coix
import coix.oryx as coryx
import flax.linen as nn
import jax
from jax import random
Expand Down Expand Up @@ -164,11 +175,11 @@ def dmm_target(network, key, inputs):
key_out, key_mu, key_c, key_h = random.split(key, 4)
N = inputs.shape[-2]

mu = coix.rv(dist.Normal(0, 10).expand([4, 2]), name="mu")(key_mu)
c = coix.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c)
h = coix.rv(dist.Beta(1, 1).expand([N]), name="h")(key_h)
mu = coryx.rv(dist.Normal(0, 10).expand([4, 2]), name="mu")(key_mu)
c = coryx.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c)
h = coryx.rv(dist.Beta(1, 1).expand([N]), name="h")(key_h)
x_recon = mu[c] + network.decode_h(h)
x = coix.rv(dist.Normal(x_recon, 0.1), obs=inputs, name="x")
x = coryx.rv(dist.Normal(x_recon, 0.1), obs=inputs, name="x")

out = {"mu": mu, "c": c, "h": h, "x_recon": x_recon, "x": x}
return key_out, out
Expand All @@ -186,7 +197,7 @@ def dmm_kernel_mu(network, key, inputs):
loc, scale = network.encode_mu(xch)
else:
loc, scale = network.encode_initial_mu(inputs["x"])
mu = coix.rv(dist.Normal(loc, scale), name="mu")(key_mu)
mu = coryx.rv(dist.Normal(loc, scale), name="mu")(key_mu)

out = {**inputs, **{"mu": mu}}
return key_out, out
Expand All @@ -200,9 +211,9 @@ def dmm_kernel_c_h(network, key, inputs):
inputs["x"], inputs["mu"]
)
logits = network.encode_c(xmu)
c = coix.rv(dist.Categorical(logits=logits), name="c")(key_c)
c = coryx.rv(dist.Categorical(logits=logits), name="c")(key_c)
alpha, beta = network.encode_h(inputs["x"] - inputs["mu"][c])
h = coix.rv(dist.Beta(alpha, beta), name="h")(key_h)
h = coryx.rv(dist.Beta(alpha, beta), name="h")(key_h)

out = {**inputs, **{"c": c, "h": h}}
return key_out, out
Expand Down
43 changes: 28 additions & 15 deletions examples/gmm_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,25 @@
# limitations under the License.

"""
Example: GMM example in Oryx
============================
Example: Gaussian Mixture Model in Oryx
=======================================
This example illustrates how to construct an inference program for GMM, based on
the APGS sampler [1]. The details of GMM can be found in the sections 6.2 and
F.1 of the reference. We will use the Oryx backend for this example.
**References**
1. Wu, Hao, et al. Amortized population Gibbs samplers with neural
sufficient statistics. ICML 2020.
"""

import argparse
from functools import partial

import coix
import coix.oryx as coryx
import flax.linen as nn
import jax
from jax import random
Expand All @@ -34,8 +45,9 @@
import tensorflow as tf
import tensorflow_datasets as tfds

### Data

# %%
# First, let's simulate a synthetic dataset of Gaussian clusters.

def simulate_clusters(num_instances=1, N=60, seed=0):
np.random.seed(seed)
Expand Down Expand Up @@ -66,8 +78,8 @@ def load_dataset(split, *, is_training, batch_size):
return iter(tfds.as_numpy(ds))


### Encoder

# %%
# Next, we define the neural proposals for the Gibbs kernels.

class GMMEncoderMeanTau(nn.Module):

Expand Down Expand Up @@ -129,17 +141,17 @@ def __call__(self, x): # N x D
return self.encode_mean_tau(xc)


### Model and kernels

# %%
# Then, we define the target and kernels as in Section 6.2.

def gmm_target(network, key, inputs):
key_out, key_mean, key_tau, key_c = random.split(key, 4)
N = inputs.shape[-2]

tau = coix.rv(dist.Gamma(2, 2).expand([3, 2]), name="tau")(key_tau)
mean = coix.rv(dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)), name="mean")(key_mean)
c = coix.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c)
x = coix.rv(dist.Normal(mean[c], 1 / jnp.sqrt(tau[c])), obs=inputs, name="x")
tau = coryx.rv(dist.Gamma(2, 2).expand([3, 2]), name="tau")(key_tau)
mean = coryx.rv(dist.Normal(0, 1 / jnp.sqrt(tau * 0.1)), name="mean")(key_mean)
c = coryx.rv(dist.DiscreteUniform(0, 3).expand([N]), name="c")(key_c)
x = coryx.rv(dist.Normal(mean[c], 1 / jnp.sqrt(tau[c])), obs=inputs, name="x")

out = {"mean": mean, "tau": tau, "c": c, "x": x}
return key_out, out
Expand All @@ -156,8 +168,8 @@ def gmm_kernel_mean_tau(network, key, inputs):
alpha, beta, mu, nu = network.encode_mean_tau(xc)
else:
alpha, beta, mu, nu = network.encode_initial_mean_tau(inputs["x"])
tau = coix.rv(dist.Gamma(alpha, beta), name="tau")(key_tau)
mean = coix.rv(dist.Normal(mu, 1 / jnp.sqrt(tau * nu)), name="mean")(key_mean)
tau = coryx.rv(dist.Gamma(alpha, beta), name="tau")(key_tau)
mean = coryx.rv(dist.Normal(mu, 1 / jnp.sqrt(tau * nu)), name="mean")(key_mean)

out = {**inputs, **{"mean": mean, "tau": tau}}
return key_out, out
Expand All @@ -177,8 +189,9 @@ def gmm_kernel_c(network, key, inputs):
return key_out, out


### Train

# %%
# Finally, we create the gmm inference program, define the loss function,
# run the training loop, and plot the results.

def make_gmm(params, num_sweeps):
network = coix.util.BindModule(GMMEncoder(), params)
Expand Down

0 comments on commit 493c965

Please sign in to comment.