Skip to content

Commit

Permalink
Improve README and core docs
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Apr 6, 2024
1 parent d2c5363 commit 751802b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,32 @@

Coix (COmbinators In jaX) is a flexible and backend-agnostic implementation of inference combinators [(Stites and Zimmermann et al., 2021)](https://arxiv.org/abs/2103.00668), a set of program transformations for compositional inference with probabilistic programs. Coix ships with backends for numpyro and oryx, and a set of pre-implemented losses and utility functions that allows to implement and run a wide variety of inference algorithms out-of-the-box.

Coix is a lightweight framework which includes the following main components:

- **coix.api:** Implementation of the program combinators.
- **coix.core:** Basic program transformations which are used to modify behavior of a stochastic program.
- **coix.loss:** Common objectives for variational inference.
- **coix.algo:** Example inference algorithms.

Currently, we support [numpyro](https://github.com/pyro-ppl/numpyro) and [oryx](https://github.com/jax-ml/oryx) backends. But other backends can be easily added via the [coix.register_backend](https://coix.readthedocs.io/en/latest/core.html#coix.core.register_backend) utility.

*This is not an officially supported Google product.*

## Installation

To install Coix, you can use pip:

```
pip install coix
```

or you can clone the repository:

```
git clone https://github.com/jax-ml/coix.git
cd coix
pip install -e .[dev,doc]
```

Many examples would run faster on accelerators. You can follow the [JAX installation](https://jax.readthedocs.io/en/latest/installation.html) instruction for how to install JAX with GPU or TPU support.

5 changes: 5 additions & 0 deletions coix/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,12 @@ def wrapped(*args, **kwargs):


def empirical(out, trace, metrics):
"""Creates an empirical program given a trace."""
return get_backend()["empirical"](out, trace, metrics)


def suffix(p):
"""Adds suffix `_PREV_` to variable names of `p`."""
fn = get_backend()["suffix"]
if fn is not None:
return fn(p)
Expand All @@ -149,6 +151,7 @@ def suffix(p):


def detach(p):
"""Makes random variables in `p` become non-reparameterized."""
fn = get_backend()["detach"]
if fn is not None:
return fn(p)
Expand All @@ -157,6 +160,7 @@ def detach(p):


def stick_the_landing(p):
"""Stops gradient of distributions' parameters before computing log prob."""
fn = get_backend()["stick_the_landing"]
if fn is not None:
return fn(p)
Expand All @@ -165,6 +169,7 @@ def stick_the_landing(p):


def prng_key():
"""Generates a random JAX PRNGKey."""
fn = get_backend()["prng_key"]
if fn is not None:
return fn()
Expand Down

0 comments on commit 751802b

Please sign in to comment.