Skip to content

Commit

Permalink
Merge branch 'main' of github.com:jax-ml/coix into tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
zmheiko committed Dec 8, 2023
2 parents c4acd07 + f2ddc60 commit a0cd4da
Show file tree
Hide file tree
Showing 31 changed files with 1,265 additions and 139 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/pytest_and_autopublish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ jobs:
with:
python-version: '3.10'

- run: sudo apt install -y pandoc gsfonts
- run: pip --version
- run: pip install -e .[dev,oryx]
- run: pip install -e .[dev,doc,oryx]
- run: pip freeze

- name: Lint with pylint
Expand All @@ -26,6 +27,8 @@ jobs:
run: pyink --check .
- name: Lint with isort
run: isort --check .
- name: Build documentation
run: make docs

pytest-job:
needs: lint
Expand Down Expand Up @@ -53,6 +56,10 @@ jobs:
- name: Run core tests
run: pytest -vv -n auto

# Run custom prng tests
- name: Run custom prng tests
run: JAX_ENABLE_CUSTOM_PRNG=1 pytest -vv -n auto

# Auto-publish when version is increased
publish-job:
# Only try to publish if:
Expand Down
17 changes: 17 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
version: 2

build:
os: ubuntu-22.04
tools:
python: "3.9"

sphinx:
configuration: docs/conf.py

formats:
- pdf
- epub

python:
install:
- requirements: docs/requirements.txt
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@ lint: FORCE
test: lint FORCE
pytest -vv -n auto

docs: FORCE
$(MAKE) -C docs html

FORCE:
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coix

[![Unittests](https://github.com/jax-ml/coix/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/jax-ml/coix/actions/workflows/pytest_and_autopublish.yml)
[![Documentation Status](https://readthedocs.org/projects/coix/badge/?version=latest)](https://coix.readthedocs.io/en/latest/?badge=latest)
[![PyPI version](https://badge.fury.io/py/coix.svg)](https://badge.fury.io/py/coix)

Inference Combinators in JAX (Coix) is a machine learning framework used to
Expand Down
5 changes: 3 additions & 2 deletions coix/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def dais(targets, momentum, leapfrog, refreshment, *, num_targets=None):
if _use_fori_loop(targets, num_targets):

def body_fun(i, q):
assert callable(targets)
p = extend(compose(momentum, targets(i), suffix=False), refreshment)
return propose(p, compose(refreshment, compose(leapfrog, q)))

Expand All @@ -155,7 +156,7 @@ def body_fun(i, q):

targets = [compose(momentum, p, suffix=False) for p in targets]
q = targets[0]
loss_fns = [None] * (len(targets) - 2) + [iwae_loss]
loss_fns = (None,) * (len(targets) - 2) + (iwae_loss,)
for p, loss_fn in zip(targets[1:], loss_fns):
q = compose(refreshment, compose(leapfrog, q))
q = propose(extend(p, refreshment), q, loss_fn=loss_fn)
Expand Down Expand Up @@ -413,7 +414,7 @@ def body_fun(i, q):
return propose(targets(num_targets - 1), q, loss_fn=iwae_loss)

q = propose(targets[0], proposals[0])
loss_fns = [None] * (len(proposals) - 2) + [iwae_loss]
loss_fns = (None,) * (len(proposals) - 2) + (iwae_loss,)
for p, fwd, loss_fn in zip(targets[1:], proposals[1:], loss_fns):
q = propose(p, compose(fwd, resample(q)), loss_fn=loss_fn)
return q
116 changes: 116 additions & 0 deletions coix/algo_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2023 The coix Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for algo.py."""

import functools

import coix
import jax
from jax import random
import jax.numpy as jnp
import numpy as np
import numpyro.distributions as dist
import optax

coix.set_backend("coix.oryx")

np.random.seed(0)
num_data, dim = 4, 2
data = np.random.randn(num_data, dim).astype(np.float32)
loc_p = np.random.randn(dim).astype(np.float32)
precision_p = np.random.rand(dim).astype(np.float32)
scale_p = np.sqrt(1 / precision_p)
precision_x = np.random.rand(dim).astype(np.float32)
scale_x = np.sqrt(1 / precision_x)
precision_q = precision_p + num_data * precision_x
loc_q = (data.sum(0) * precision_x + loc_p * precision_p) / precision_q
log_scale_q = -0.5 * np.log(precision_q)


def model(params, key):
del params
key_z, key_next = random.split(key)
z = coix.rv(dist.Normal(loc_p, scale_p), name="z")(key_z)
z = jnp.broadcast_to(z, (num_data, dim))
x = coix.rv(dist.Normal(z, scale_x), obs=data, name="x")
return key_next, z, x


def guide(params, key, *args):
del args
key, _ = random.split(key) # split here to test tie_in
scale_q = jnp.exp(params["log_scale_q"])
z = coix.rv(dist.Normal(params["loc_q"], scale_q), name="z")(key)
return z


def check_ess(make_program):
params = {"loc_q": loc_q, "log_scale_q": log_scale_q}
p = jax.vmap(functools.partial(model, params))
q = jax.vmap(functools.partial(guide, params))
program = make_program(p, q)

keys = random.split(random.PRNGKey(0), 5)
ess = coix.traced_evaluate(program)(keys)[2]["ess"]
np.testing.assert_allclose(ess, 5.0)


def run_inference(make_program, num_steps=1000):
"""Performs inference given an algorithm `make_program`."""

def loss_fn(params, key):
p = jax.vmap(functools.partial(model, params))
q = jax.vmap(functools.partial(guide, params))
program = make_program(p, q)

keys = random.split(key, 5)
metrics = coix.traced_evaluate(program)(keys)[2]
return metrics["loss"], metrics

init_params = {
"loc_q": jnp.zeros_like(loc_q),
"log_scale_q": jnp.zeros_like(log_scale_q),
}
params, _ = coix.util.train(
loss_fn, init_params, optax.adam(0.01), num_steps=num_steps
)

np.testing.assert_allclose(params["loc_q"], loc_q, atol=0.2)
np.testing.assert_allclose(params["log_scale_q"], log_scale_q, atol=0.2)


def test_apgs():
check_ess(lambda p, q: coix.algo.apgs(p, [q]))
run_inference(lambda p, q: coix.algo.apgs(p, [q]))


def test_rws():
check_ess(coix.algo.rws)
run_inference(coix.algo.rws)


def test_svi_elbo():
check_ess(coix.algo.svi)
run_inference(coix.algo.svi)


def test_svi_iwae():
check_ess(coix.algo.svi_iwae)
run_inference(coix.algo.svi_iwae)


def test_svi_stl():
check_ess(coix.algo.svi_stl)
run_inference(coix.algo.svi_stl)
Loading

0 comments on commit a0cd4da

Please sign in to comment.