Skip to content

Commit

Permalink
Update gmm and dmm examples (#26)
Browse files Browse the repository at this point in the history
* port dmm example to numpyro

* make dmm numpyro work

* debug oryx

* revert changes at loss

* update gmm example

* add bmnist example

* use dev oryx version

* Improve README and core docs
  • Loading branch information
fehiepsi authored Apr 6, 2024
1 parent 07f0963 commit f00cbfd
Show file tree
Hide file tree
Showing 21 changed files with 459 additions and 365 deletions.
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,32 @@

Coix (COmbinators In jaX) is a flexible and backend-agnostic implementation of inference combinators [(Stites and Zimmermann et al., 2021)](https://arxiv.org/abs/2103.00668), a set of program transformations for compositional inference with probabilistic programs. Coix ships with backends for numpyro and oryx, and a set of pre-implemented losses and utility functions that allows to implement and run a wide variety of inference algorithms out-of-the-box.

Coix is a lightweight framework which includes the following main components:

- **coix.api:** Implementation of the program combinators.
- **coix.core:** Basic program transformations which are used to modify behavior of a stochastic program.
- **coix.loss:** Common objectives for variational inference.
- **coix.algo:** Example inference algorithms.

Currently, we support [numpyro](https://github.com/pyro-ppl/numpyro) and [oryx](https://github.com/jax-ml/oryx) backends. But other backends can be easily added via the [coix.register_backend](https://coix.readthedocs.io/en/latest/core.html#coix.core.register_backend) utility.

*This is not an officially supported Google product.*

## Installation

To install Coix, you can use pip:

```
pip install coix
```

or you can clone the repository:

```
git clone https://github.com/jax-ml/coix.git
cd coix
pip install -e .[dev,doc]
```

Many examples would run faster on accelerators. You can follow the [JAX installation](https://jax.readthedocs.io/en/latest/installation.html) instruction for how to install JAX with GPU or TPU support.

5 changes: 5 additions & 0 deletions coix/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,12 @@ def wrapped(*args, **kwargs):


def empirical(out, trace, metrics):
"""Creates an empirical program given a trace."""
return get_backend()["empirical"](out, trace, metrics)


def suffix(p):
"""Adds suffix `_PREV_` to variable names of `p`."""
fn = get_backend()["suffix"]
if fn is not None:
return fn(p)
Expand All @@ -149,6 +151,7 @@ def suffix(p):


def detach(p):
"""Makes random variables in `p` become non-reparameterized."""
fn = get_backend()["detach"]
if fn is not None:
return fn(p)
Expand All @@ -157,6 +160,7 @@ def detach(p):


def stick_the_landing(p):
"""Stops gradient of distributions' parameters before computing log prob."""
fn = get_backend()["stick_the_landing"]
if fn is not None:
return fn(p)
Expand All @@ -165,6 +169,7 @@ def stick_the_landing(p):


def prng_key():
"""Generates a random JAX PRNGKey."""
fn = get_backend()["prng_key"]
if fn is not None:
return fn()
Expand Down
9 changes: 7 additions & 2 deletions coix/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ def wrapped(*args, **kwargs):
if site["type"] == "sample":
value = site["value"]
log_prob = site["fn"].log_prob(value)
trace[name] = {"value": value, "log_prob": log_prob}
event_dim_holder = jnp.empty([1] * site["fn"].event_dim)
trace[name] = {
"value": value,
"log_prob": log_prob,
"_event_dim_holder": event_dim_holder,
}
if site.get("is_observed", False):
trace[name]["is_observed"] = True
metrics = {
Expand Down Expand Up @@ -83,7 +88,7 @@ def wrapped(*args, **kwargs):
del args, kwargs
for name, site in trace.items():
value, lp = site["value"], site["log_prob"]
event_dim = jnp.ndim(value) - jnp.ndim(lp)
event_dim = jnp.ndim(site["_event_dim_holder"])
obs = value if "is_observed" in site else None
numpyro.sample(name, dist.Delta(value, lp, event_dim=event_dim), obs=obs)
for name, value in metrics.items():
Expand Down
Binary file added docs/_static/anneal.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/anneal_oryx.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/bmnist.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/dmm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/dmm_oryx.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/gmm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/gmm_oryx.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 6 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,13 @@
):
toctree_path = "notebooks/" if src_file.endswith("ipynb") else "examples/"
filename = os.path.splitext(src_file.split("/")[-1])[0]
png_path = "_static/" + filename + ".png"
img_path = "_static/" + filename + ".png"
# use Coix logo if not exist png file
if not os.path.exists(png_path):
png_path = "_static/coix_logo.png"
nbsphinx_thumbnails[toctree_path + filename] = png_path
if not os.path.exists(img_path):
img_path = "_static/" + filename + ".gif"
if not os.path.exists(img_path):
img_path = "_static/coix_logo.png"
nbsphinx_thumbnails[toctree_path + filename] = img_path


# -- Options for HTML output -------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ Coix Documentation
examples/gmm
examples/dmm
examples/bmnist
examples/anneal_oryx
examples/gmm_oryx
examples/dmm_oryx
examples/anneal_oryx

Indices and tables
==================
Expand Down
5 changes: 4 additions & 1 deletion examples/anneal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
1. Zimmermann, Heiko, et al. "Nested variational inference." NeuRIPS 2021.
.. image:: ../_static/anneal.png
:align: center
"""

import argparse
Expand Down Expand Up @@ -199,7 +202,7 @@ def eval_program(seed):

plt.figure(figsize=(8, 8))
x = trace["x"]["value"].reshape((-1, 2))
H, xedges, yedges = np.histogram2d(x[:, 0], x[:, 1], bins=100)
H, _, _ = np.histogram2d(x[:, 0], x[:, 1], bins=100)
plt.imshow(H.T)
plt.show()

Expand Down
7 changes: 5 additions & 2 deletions examples/anneal_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
1. Zimmermann, Heiko, et al. "Nested variational inference." NeuRIPS 2021.
.. image:: ../_static/anneal_oryx.png
:align: center
"""

import argparse
Expand Down Expand Up @@ -119,7 +122,7 @@ def __call__(self, x):
def anneal_target(network, key, k=0):
key_out, key = random.split(key)
x = coryx.rv(dist.Normal(0, 5).expand([2]).mask(False), name="x")(key)
coix.factor(network.anneal_density(x, index=k), name="anneal_density")
coryx.factor(network.anneal_density(x, index=k), name="anneal_density")
return key_out, {"x": x}


Expand Down Expand Up @@ -192,7 +195,7 @@ def main(args):

plt.figure(figsize=(8, 8))
x = trace["x"]["value"].reshape((-1, 2))
H, xedges, yedges = np.histogram2d(x[:, 0], x[:, 1], bins=100)
H, _, _ = np.histogram2d(x[:, 0], x[:, 1], bins=100)
plt.imshow(H.T)
plt.show()

Expand Down
Loading

0 comments on commit f00cbfd

Please sign in to comment.