diff --git a/README.md b/README.md index d0d9f8a..2e0b309 100644 --- a/README.md +++ b/README.md @@ -6,4 +6,32 @@ Coix (COmbinators In jaX) is a flexible and backend-agnostic implementation of inference combinators [(Stites and Zimmermann et al., 2021)](https://arxiv.org/abs/2103.00668), a set of program transformations for compositional inference with probabilistic programs. Coix ships with backends for numpyro and oryx, and a set of pre-implemented losses and utility functions that allows to implement and run a wide variety of inference algorithms out-of-the-box. +Coix is a lightweight framework which includes the following main components: + +- **coix.api:** Implementation of the program combinators. +- **coix.core:** Basic program transformations which are used to modify behavior of a stochastic program. +- **coix.loss:** Common objectives for variational inference. +- **coix.algo:** Example inference algorithms. + +Currently, we support [numpyro](https://github.com/pyro-ppl/numpyro) and [oryx](https://github.com/jax-ml/oryx) backends. But other backends can be easily added via the [coix.register_backend](https://coix.readthedocs.io/en/latest/core.html#coix.core.register_backend) utility. + *This is not an officially supported Google product.* + +## Installation + +To install Coix, you can use pip: + +``` +pip install coix +``` + +or you can clone the repository: + +``` +git clone https://github.com/jax-ml/coix.git +cd coix +pip install -e .[dev,doc] +``` + +Many examples would run faster on accelerators. You can follow the [JAX installation](https://jax.readthedocs.io/en/latest/installation.html) instruction for how to install JAX with GPU or TPU support. + diff --git a/coix/core.py b/coix/core.py index f47a393..e34abed 100644 --- a/coix/core.py +++ b/coix/core.py @@ -137,10 +137,12 @@ def wrapped(*args, **kwargs): def empirical(out, trace, metrics): + """Creates an empirical program given a trace.""" return get_backend()["empirical"](out, trace, metrics) def suffix(p): + """Adds suffix `_PREV_` to variable names of `p`.""" fn = get_backend()["suffix"] if fn is not None: return fn(p) @@ -149,6 +151,7 @@ def suffix(p): def detach(p): + """Makes random variables in `p` become non-reparameterized.""" fn = get_backend()["detach"] if fn is not None: return fn(p) @@ -157,6 +160,7 @@ def detach(p): def stick_the_landing(p): + """Stops gradient of distributions' parameters before computing log prob.""" fn = get_backend()["stick_the_landing"] if fn is not None: return fn(p) @@ -165,6 +169,7 @@ def stick_the_landing(p): def prng_key(): + """Generates a random JAX PRNGKey.""" fn = get_backend()["prng_key"] if fn is not None: return fn()