Skip to content

Commit

Permalink
Update the exposed api and add docs to examples (#25)
Browse files Browse the repository at this point in the history
* update the exposed api and add docs to examples

* initialize dmm and gmm examples for numpyro backend

* format the examples

* reorganize the sections

* fix failing tests

* nit

* revert README docs due to copybara merge
  • Loading branch information
fehiepsi authored Mar 29, 2024
1 parent 8a9a825 commit 07f0963
Show file tree
Hide file tree
Showing 16 changed files with 798 additions and 110 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
[![Documentation Status](https://readthedocs.org/projects/coix/badge/?version=latest)](https://coix.readthedocs.io/en/latest/?badge=latest)
[![PyPI version](https://badge.fury.io/py/coix.svg)](https://badge.fury.io/py/coix)

Inference Combinators in JAX (Coix) is a machine learning framework used to
develop inference algorithms that are composed of probabilistic programs.
Coix (COmbinators In jaX) is a flexible and backend-agnostic implementation of inference combinators [(Stites and Zimmermann et al., 2021)](https://arxiv.org/abs/2103.00668), a set of program transformations for compositional inference with probabilistic programs. Coix ships with backends for numpyro and oryx, and a set of pre-implemented losses and utility functions that allows to implement and run a wide variety of inference algorithms out-of-the-box.

*This is not an officially supported Google product.*

8 changes: 4 additions & 4 deletions coix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
from coix.api import resample
from coix.core import detach
from coix.core import empirical
from coix.core import factor
from coix.core import prng_key
from coix.core import register_backend
from coix.core import rv
from coix.core import set_backend
from coix.core import stick_the_landing
from coix.core import suffix
from coix.core import traced_evaluate

__all__ = [
Expand All @@ -39,16 +39,16 @@
"compose",
"detach",
"extend",
"factor",
"fori_loop",
"loss",
"memoize",
"prng_key",
"propose",
"register_backend",
"resample",
"rv",
"set_backend",
"stick_the_landing",
"suffix",
"traced_evaluate",
"util",
]
Expand Down
7 changes: 4 additions & 3 deletions coix/algo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import functools

import coix
import coix.oryx as coryx
import jax
from jax import random
import jax.numpy as jnp
Expand All @@ -42,17 +43,17 @@
def model(params, key):
del params
key_z, key_next = random.split(key)
z = coix.rv(dist.Normal(loc_p, scale_p), name="z")(key_z)
z = coryx.rv(dist.Normal(loc_p, scale_p), name="z")(key_z)
z = jnp.broadcast_to(z, (num_data, dim))
x = coix.rv(dist.Normal(z, scale_x), obs=data, name="x")
x = coryx.rv(dist.Normal(z, scale_x), obs=data, name="x")
return key_next, z, x


def guide(params, key, *args):
del args
key, _ = random.split(key) # split here to test tie_in
scale_q = jnp.exp(params["log_scale_q"])
z = coix.rv(dist.Normal(params["loc_q"], scale_q), name="z")(key)
z = coryx.rv(dist.Normal(params["loc_q"], scale_q), name="z")(key)
return z


Expand Down
29 changes: 15 additions & 14 deletions coix/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tests for api.py."""

import coix
import coix.oryx as coryx
import jax
from jax import random
import numpy as np
Expand All @@ -27,11 +28,11 @@
def test_compose():
def p(key):
key, subkey = random.split(key)
x = coix.rv(dist.Normal(0, 1), name="x")(subkey)
x = coryx.rv(dist.Normal(0, 1), name="x")(subkey)
return key, x

def f(key, x):
return coix.rv(dist.Normal(x, 1), name="z")(key)
return coryx.rv(dist.Normal(x, 1), name="z")(key)

_, p_trace, _ = coix.traced_evaluate(coix.compose(f, p))(random.PRNGKey(0))
assert set(p_trace.keys()) == {"x", "z"}
Expand All @@ -40,11 +41,11 @@ def f(key, x):
def test_extend():
def p(key):
key, subkey = random.split(key)
x = coix.rv(dist.Normal(0, 1), name="x")(subkey)
x = coryx.rv(dist.Normal(0, 1), name="x")(subkey)
return key, x

def f(key, x):
return (coix.rv(dist.Normal(x, 1), name="z")(key),)
return (coryx.rv(dist.Normal(x, 1), name="z")(key),)

def g(z):
return z + 1
Expand All @@ -71,14 +72,14 @@ def g(z):
def test_propose():
def p(key):
key, subkey = random.split(key)
x = coix.rv(dist.Normal(0, 1), name="x")(subkey)
x = coryx.rv(dist.Normal(0, 1), name="x")(subkey)
return key, x

def f(key, x):
return coix.rv(dist.Normal(x, 1), name="z")(key)
return coryx.rv(dist.Normal(x, 1), name="z")(key)

def q(key):
return coix.rv(dist.Normal(1, 2), name="x")(key)
return coryx.rv(dist.Normal(1, 2), name="x")(key)

program = coix.propose(coix.extend(p, f), q)
key = random.PRNGKey(0)
Expand All @@ -98,7 +99,7 @@ def q(key):

def test_resample():
def q(key):
return coix.rv(dist.Normal(1, 2), name="x")(key)
return coryx.rv(dist.Normal(1, 2), name="x")(key)

particle_program = jax.vmap(q)
keys = random.split(random.PRNGKey(0), 3)
Expand All @@ -108,8 +109,8 @@ def q(key):

def test_resample_one():
def q(key):
x = coix.rv(dist.Normal(1, 2), name="x")(key)
return coix.rv(dist.Normal(x, 1), name="z", obs=0.0)
x = coryx.rv(dist.Normal(1, 2), name="x")(key)
return coryx.rv(dist.Normal(x, 1), name="z", obs=0.0)

particle_program = jax.vmap(q)
keys = random.split(random.PRNGKey(0), 3)
Expand All @@ -120,7 +121,7 @@ def q(key):
def test_fori_loop():
def drift(key, x):
key_out, key = random.split(key)
x_new = coix.rv(dist.Normal(x, 1.0), name="x")(key)
x_new = coryx.rv(dist.Normal(x, 1.0), name="x")(key)
return key_out, x_new

compile_time = {"value": 0}
Expand All @@ -145,12 +146,12 @@ def body_fun(_, q):
@pytest.mark.skip(reason="Currently, we only support memoised lists.")
def test_memoize():
def model(key):
x = coix.rv(dist.Normal(0, 1), name="x")(key)
y = coix.rv(dist.Normal(x, 1), name="y", obs=0.0)
x = coryx.rv(dist.Normal(0, 1), name="x")(key)
y = coryx.rv(dist.Normal(x, 1), name="y", obs=0.0)
return x, y

def guide(key):
return coix.rv(dist.Normal(1, 2), name="x")(key)
return coryx.rv(dist.Normal(1, 2), name="x")(key)

def vmodel(key):
return jax.vmap(model)(random.split(key, 5))
Expand Down
30 changes: 3 additions & 27 deletions coix/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
__all__ = [
"detach",
"empirical",
"factor",
"prng_key",
"rv",
"register_backend",
"set_backend",
"stick_the_landing",
Expand All @@ -39,22 +37,18 @@ def register_backend(
traced_evaluate=None,
empirical=None,
suffix=None,
prng_key=None,
detach=None,
stick_the_landing=None,
rv=None,
factor=None,
prng_key=None,
):
"""Register backend."""
fn_map = {
"traced_evaluate": traced_evaluate,
"empirical": empirical,
"suffix": suffix,
"prng_key": prng_key,
"detach": detach,
"stick_the_landing": stick_the_landing,
"rv": rv,
"factor": factor,
"prng_key": prng_key,
}
_BACKENDS[backend] = fn_map

Expand All @@ -73,11 +67,9 @@ def set_backend(backend):
"traced_evaluate",
"empirical",
"suffix",
"prng_key",
"detach",
"stick_the_landing",
"rv",
"factor",
"prng_key",
]:
fn_map[fn] = getattr(module, fn, None)
register_backend(backend, **fn_map)
Expand Down Expand Up @@ -172,22 +164,6 @@ def stick_the_landing(p):
return p


def rv(*args, **kwargs):
fn = get_backend()["rv"]
if fn is not None:
return fn(*args, **kwargs)
else:
raise NotImplementedError


def factor(*args, **kwargs):
fn = get_backend()["factor"]
if fn is not None:
return fn(*args, **kwargs)
else:
raise NotImplementedError


def prng_key():
fn = get_backend()["prng_key"]
if fn is not None:
Expand Down
9 changes: 9 additions & 0 deletions coix/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
from numpyro import handlers
import numpyro.distributions as dist

__all__ = [
"detach",
"empirical",
"prng_key",
"stick_the_landing",
"suffix",
"traced_evaluate",
]

prng_key = numpyro.prng_key


Expand Down
5 changes: 5 additions & 0 deletions coix/oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"detach",
"empirical",
"factor",
"prng_key",
"rv",
"stick_the_landing",
"suffix",
Expand Down Expand Up @@ -449,3 +450,7 @@ def wrapped(*args, **kwargs):
return out

return wrapped


def prng_key():
raise ValueError("Cannot genenerate random key under the oryx backend.")
24 changes: 12 additions & 12 deletions coix/oryx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import coix
import coix.core
import coix.oryx
import coix.oryx as coryx
import jax
from jax import random
import jax.numpy as jnp
Expand All @@ -28,7 +28,7 @@

def test_call_and_reap_tags():
def model(key):
return coix.rv(dist.Normal(0, 1), name="x")(key)
return coryx.rv(dist.Normal(0, 1), name="x")(key)

_, trace, _ = coix.traced_evaluate(model)(random.PRNGKey(0))
assert set(trace.keys()) == {"x"}
Expand All @@ -38,23 +38,23 @@ def model(key):
def test_delta_distribution():
def model(key):
x = random.normal(key)
return coix.rv(dist.Delta(x, 5.0), name="x")(key)
return coryx.rv(dist.Delta(x, 5.0), name="x")(key)

_, trace, _ = coix.traced_evaluate(model)(random.PRNGKey(0))
assert set(trace.keys()) == {"x"}


def test_detach():
def model(x):
return coix.rv(dist.Delta(x, 0.0), name="x")(None) * x
return coryx.rv(dist.Delta(x, 0.0), name="x")(None) * x

x = 2.0
np.testing.assert_allclose(jax.grad(coix.detach(model))(x), x)


def test_detach_vmap():
def model(x):
return coix.rv(dist.Normal(x, 1.0), name="x")(random.PRNGKey(0))
return coryx.rv(dist.Normal(x, 1.0), name="x")(random.PRNGKey(0))

outs = coix.detach(jax.vmap(model))(jnp.ones(2))
np.testing.assert_allclose(outs[0], outs[1])
Expand All @@ -63,7 +63,7 @@ def model(x):
def test_distribution():
def model(key):
x = random.normal(key)
return coix.rv(dist.Delta(x, 5.0), name="x")(key)
return coryx.rv(dist.Delta(x, 5.0), name="x")(key)

f = coix.oryx.call_and_reap_tags(
coix.oryx.tag_distribution(model), coix.oryx.DISTRIBUTION
Expand All @@ -90,7 +90,7 @@ def model(x):

def test_factor():
def model(x):
return coix.factor(x, name="x")
return coryx.factor(x, name="x")

_, trace, _ = coix.traced_evaluate(model)(10.0)
assert "x" in trace
Expand All @@ -99,7 +99,7 @@ def model(x):

def test_log_prob_detach():
def model(loc):
x = coix.rv(dist.Normal(loc, 1), name="x")(random.PRNGKey(0))
x = coryx.rv(dist.Normal(loc, 1), name="x")(random.PRNGKey(0))
return x

def actual_fn(x):
Expand All @@ -115,7 +115,7 @@ def expected_fn(x):

def test_observed():
def model(a):
return coix.rv(dist.Delta(2.0, 3.0), obs=1.0, name="x") + a
return coryx.rv(dist.Delta(2.0, 3.0), obs=1.0, name="x") + a

_, trace, _ = coix.traced_evaluate(model)(2.0)
assert "x" in trace
Expand All @@ -125,7 +125,7 @@ def model(a):

def test_stick_the_landing():
def model(lp):
return coix.rv(dist.Delta(0.0, lp), name="x")(None)
return coryx.rv(dist.Delta(0.0, lp), name="x")(None)

def p(x):
return coix.traced_evaluate(coix.detach(model))(x)[1]["x"]["log_prob"]
Expand All @@ -140,7 +140,7 @@ def q(x):

def test_substitute():
def model(key):
return coix.rv(dist.Delta(1.0, 5.0), name="x")(key)
return coryx.rv(dist.Delta(1.0, 5.0), name="x")(key)

expected = {"x": 9.0}
_, trace, _ = coix.traced_evaluate(model, expected)(random.PRNGKey(0))
Expand All @@ -150,7 +150,7 @@ def model(key):

def test_suffix():
def model(x):
return coix.rv(dist.Delta(x, 5.0), name="x")(None)
return coryx.rv(dist.Delta(x, 5.0), name="x")(None)

f = coix.oryx.call_and_reap_tags(
coix.core.suffix(model), coix.oryx.RANDOM_VARIABLE
Expand Down
6 changes: 4 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ Coix Documentation
notebooks/tutorial_part2_api
notebooks/tutorial_part3_smcs
examples/anneal
examples/anneal_oryx
examples/gmm
examples/dmm
examples/bmnist
examples/gmm_oryx
examples/dmm_oryx
examples/bmnist
examples/anneal_oryx

Indices and tables
==================
Expand Down
Loading

0 comments on commit 07f0963

Please sign in to comment.