From 01dc6c13c6a5c08a4680a67886ae6236ffe44da5 Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Tue, 27 Feb 2024 11:35:01 -0500 Subject: [PATCH] Update documentation, make available algorithms public. --- CHANGELOG.md | 4 + README.md | 210 ++------------------------------------- bayeux/__init__.py | 2 +- bayeux/_src/bayeux.py | 20 ++-- docs/debug_mode.md | 166 +++++++++++++++++++++++++++++++ docs/inference.md | 57 +++++++++++ docs/inspecting.md | 225 ++++++++++++++++++++++++++++++++++++++++++ mkdocs.yml | 3 + 8 files changed, 477 insertions(+), 210 deletions(-) create mode 100644 docs/debug_mode.md create mode 100644 docs/inference.md create mode 100644 docs/inspecting.md diff --git a/CHANGELOG.md b/CHANGELOG.md index bb786fe..cb07ed3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ## [Unreleased] +## [0.1.9] - 2024-02-27 + +### Add programmatic access to algorithms + ## [0.1.8] - 2024-02-14 ### Add HMC and NUTS from TFP diff --git a/README.md b/README.md index 61778df..7189dd2 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ pip install bayeux-ml ``` ## Quickstart -We define a model by providing a log density in JAX. This could be defined using a probabilistic programming language (PPL) like numpyro, PyMC, TFP, distrax, oryx, coix, or directly in JAX. +We define a model by providing a log density in JAX. This could be defined using a probabilistic programming language (PPL) like [numpyro](/examples/numpyro_and_bayeux), [PyMC](/examples/pymc_and_bayeux), [TFP](/examples/tfp_and_bayeux), distrax, oryx, coix, or directly in JAX. ```python import bayeux as bx @@ -24,13 +24,8 @@ normal_density = bx.Model( log_density=lambda x: -x*x, test_point=1.) -seed = jax.random.PRNGKey(0) -``` - -## Simple -Every inference algorithm in `bayeux` will (try to) run with just a seed as an argument: +seed = jax.random.key(0) -```python opt_results = normal_density.optimize.optax_adam(seed=seed) # OR! idata = normal_density.mcmc.numpyro_nuts(seed=seed) @@ -38,201 +33,12 @@ idata = normal_density.mcmc.numpyro_nuts(seed=seed) surrogate_posterior, loss = normal_density.vi.tfp_factored_surrogate_posterior(seed=seed) ``` -An (only rarely) optional third argument to `bx.Model` is `transform_fn`, which maps a real number to the support of the distribution. The [oryx](https://github.com/jax-ml/oryx) library is used to automatically compute the inverse and Jacobian determinants for changes of variables, but the user can supply these if known. - -```python -half_normal_density = bx.Model( - lambda x: -x*x, - test_point=1., - transform_fn=jax.nn.softplus) -``` - -## Self descriptive - -Since `bayeux` is built on top of other fantastic libraries, it tries not to get in the way of them. Each algorithm has a `.get_kwargs()` method that tells you how it will be called, and what functions are being called: - -```python -normal_density.optimize.jaxopt_bfgs.get_kwargs() - -{jaxopt._src.bfgs.BFGS: {'value_and_grad': False, - 'has_aux': False, - 'maxiter': 500, - 'tol': 0.001, - 'stepsize': 0.0, - 'linesearch': 'zoom', - 'linesearch_init': 'increase', - 'condition': None, - 'maxls': 30, - 'decrease_factor': None, - 'increase_factor': 1.5, - 'max_stepsize': 1.0, - 'min_stepsize': 1e-06, - 'implicit_diff': True, - 'implicit_diff_solve': None, - 'jit': True, - 'unroll': 'auto', - 'verbose': False}, - 'extra_parameters': {'chain_method': 'vectorized', - 'num_particles': 8, - 'num_iters': 1000, - 'apply_transform': True}} -``` - -If you pass an argument into `.get_kwargs()`, this will also tell you what will be passed on to the actual algorithms. - -``` -normal_density.mcmc.blackjax_nuts.get_kwargs( - num_chains=5, - target_acceptance_rate=0.99) - -{: {'algorithm': , - 'initial_step_size': 1.0, - 'is_mass_matrix_diagonal': True, - 'progress_bar': False, - 'target_acceptance_rate': 0.8}, - 'extra_parameters': {'chain_method': 'vectorized', - 'num_adapt_draws': 500, - 'num_chains': 17, - 'num_draws': 500, - 'return_pytree': False}, - : {'divergence_threshold': 1000, - 'integrator': , - 'max_num_doublings': 10, - 'step_size': 0.01}} -✓✓✓✓✓✓✓✓✓✓ - -Checking it is possible to compute an initial state ✓ -Initial state has shape -(17,) -✓✓✓✓✓✓✓✓✓✓ - -Checking initial state is has no NaN ✓ -No nans detected! -✓✓✓✓✓✓✓✓✓✓ - -Computing initial state log density × -Initial state has log density -Array([1.2212421 , nan, nan, 1.4113309 , nan, - nan, nan, nan, nan, nan, - 0.5912253 , nan, nan, nan, 0.65457666, - nan, nan], dtype=float32) -×××××××××× - -Transforming model to R^n ✓ -Transformed state has shape -(17,) -✓✓✓✓✓✓✓✓✓✓ - -Computing transformed state log density shape ✓ -Transformed state log density has shape -(17,) -✓✓✓✓✓✓✓✓✓✓ - -Computing gradients of transformed log density × -The gradient contains NaNs! Initial gradients has shape -(17,) -×××××××××× - -False -``` +* [Defining models](/inference) +* [Inspecting models](/inspecting) +* [Testing and debugging](/debug_mode) +* Also see `bayeux` integration with [numpyro](/examples/numpyro_and_bayeux), [PyMC](/examples/pymc_and_bayeux), and [TFP](/examples/tfp_and_bayeux)! -*This is not an officially supported Google product.* \ No newline at end of file +*This is not an officially supported Google product.* diff --git a/bayeux/__init__.py b/bayeux/__init__.py index 64c2d69..a7222ae 100644 --- a/bayeux/__init__.py +++ b/bayeux/__init__.py @@ -16,7 +16,7 @@ # A new PyPI release will be pushed everytime `__version__` is increased # When changing this, also update the CHANGELOG.md -__version__ = '0.1.8' +__version__ = '0.1.9' # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/google/jax/issues/7570 diff --git a/bayeux/_src/bayeux.py b/bayeux/_src/bayeux.py index 1eb7588..116621d 100644 --- a/bayeux/_src/bayeux.py +++ b/bayeux/_src/bayeux.py @@ -36,15 +36,15 @@ class _Namespace: def __init__(self): - self._fns = [] + self.methods = [] def __repr__(self): - return "\n".join(self._fns) + return "\n".join(self.methods) def __setclass__(self, clas, parent): kwargs = {k: getattr(parent, k) for k in _REQUIRED_KWARGS} setattr(self, clas.name, clas(**kwargs)) - self._fns.append(clas.name) + self.methods.append(clas.name) def is_tfp_bijector(bij): @@ -100,12 +100,18 @@ def __post_init__(self): def __repr__(self): methods = [] - for name in self._namespaces: - methods.append(name) - k = getattr(self, name) - methods.append("\t." + "\n\t.".join(str(k).split())) + for key, values in self.methods.items(): + methods.append(key) + methods.append("\t." + "\n\t.".join(values)) return "\n".join(methods) + @property + def methods(self): + methods = {} + for name in self._namespaces: + methods[name] = getattr(self, name).methods + return methods + @classmethod def from_tfp(cls, pinned_joint_distribution, initial_state=None): log_density = pinned_joint_distribution.log_prob diff --git a/docs/debug_mode.md b/docs/debug_mode.md new file mode 100644 index 0000000..1dbe604 --- /dev/null +++ b/docs/debug_mode.md @@ -0,0 +1,166 @@ +# Debug Mode + +Algorithms come with a built-in `debug` mode that attempts to fail quickly and in a manner that might help debug problems quickly. The signature for `debug` accepts `verbosity` and `catch_exceptions` arguments, as well as a `kwargs` dictionary that the user plans to pass to the algorithm itself. + +## Default behavior + +By default, debug mode will print a little description of what is happening, and whether the test passed. This can also be useful when unit testing your models, since the return value is whether all the tests passed! + +```python +import bayeux as bx +import jax +import jax.numpy as jnp + +normal_density = bx.Model( + log_density=lambda x: -x*x, + test_point=1.) + +seed = jax.random.key(0) + +normal_density.mcmc.numpyro_nuts.debug(seed=seed) + +Checking test_point shape ✓ +Computing test point log density ✓ +Loading keyword arguments... ✓ +Checking it is possible to compute an initial state ✓ +Checking initial state is has no NaN ✓ +Computing initial state log density ✓ +Transforming model to R^n ✓ +Computing transformed state log density shape ✓ +Comparing transformed log density to untransformed ✓ +Computing gradients of transformed log density ✓ +True +``` + +## Do not catch exceptions + +Often our models are bad because they don't even run. Debug mode aggresively catches exceptions, but you can disable that to make sure it is possible to use the model. + +See if you can spot what is wrong with this model: + +```python +bad_model = bx.Model( + log_density=lambda x: jnp.sqrt(x['mean']), + test_point=-1.) + +bad_model.mcmc.numpyro_nuts.debug(seed=seed, catch_exceptions=False) + +Checking test_point shape ✓ +Computing test point log density × + ... + 1 bad_model = bx.Model( +----> 2 log_density=lambda x: jnp.sqrt(x['mean']), + 3 test_point=-1.) + +TypeError: 'float' object is not subscriptable +``` + +## Changing verbosity + +Debug mode also accepts a `verbosity` argument. The default is 2. We have a new subtly poorly specified `bad_model` with no outputs: + +```python + +bad_model = bx.Model( + log_density=jnp.sqrt, + test_point=-1.) + +bad_model.mcmc.blackjax_nuts.debug(seed=seed, verbosity=0, kwargs={"num_chains": 17}) + +False +``` + +With `verbosity=1` there is a minimal output: + +```python +bad_model.mcmc.blackjax_nuts.debug(seed=seed, verbosity=0, kwargs={"num_chains": 17}) + +✓ × ✓ ✓ ✓ × ✓ ✓ × +False +``` + +With higher verbosity, we can see the actual outputs and perhaps diagnose the problem after seeing that the log density of the initial point is `nan`. We should have passed in a `transform=jnp.exp` or similar!: + +```python +bad_model.mcmc.blackjax_nuts.debug(seed=seed, verbosity=3, kwargs={"num_chains": 17}) + +Checking test_point shape ✓ +Test point has shape +() +✓✓✓✓✓✓✓✓✓✓ + +Computing test point log density × +Test point has log density +Array(nan, dtype=float32, weak_type=True) +×××××××××× + +Loading keyword arguments... ✓ +Keyword arguments are +{: {'algorithm': , + 'initial_step_size': 1.0, + 'is_mass_matrix_diagonal': True, + 'logdensity_fn': .wrap_log_density..wrapped at 0x15fb97880>, + 'progress_bar': False, + 'target_acceptance_rate': 0.8}, + 'adapt.run': {'num_steps': 500}, + 'extra_parameters': {'chain_method': 'vectorized', + 'num_adapt_draws': 500, + 'num_chains': 17, + 'num_draws': 500, + 'return_pytree': False}, + : {'divergence_threshold': 1000, + 'integrator': .euclidean_integrator at 0x14bad0e50>, + 'logdensity_fn': .wrap_log_density..wrapped at 0x15fb97880>, + 'max_num_doublings': 10, + 'step_size': 0.5}} +✓✓✓✓✓✓✓✓✓✓ + +Checking it is possible to compute an initial state ✓ +Initial state has shape +(17,) +✓✓✓✓✓✓✓✓✓✓ + +Checking initial state is has no NaN ✓ +No nans detected! +✓✓✓✓✓✓✓✓✓✓ + +Computing initial state log density × +Initial state has log density +Array([1.2212421 , nan, nan, 1.4113309 , nan, + nan, nan, nan, nan, nan, + 0.5912253 , nan, nan, nan, 0.65457666, + nan, nan], dtype=float32) +×××××××××× + +Transforming model to R^n ✓ +Transformed state has shape +(17,) +✓✓✓✓✓✓✓✓✓✓ + +Computing transformed state log density shape ✓ +Transformed state log density has shape +(17,) +✓✓✓✓✓✓✓✓✓✓ + +Computing gradients of transformed log density × +The gradient contains NaNs! Initial gradients has shape +(17,) +×××××××××× + +False +``` + +Even bigger numbers will give even more details. + +## Fun mode + +I mean, you're reading about debugging statistical models. + +```python +bx.debug.FunMode.engaged = True + +bad_model.mcmc.blackjax_nuts.debug(seed=seed, verbosity=1, kwargs={"num_chains": 17}) + +🌈 👎 💪 🙌 🚀 💀 🌈 ✓ ❌ +False +``` diff --git a/docs/inference.md b/docs/inference.md new file mode 100644 index 0000000..f5d5f8d --- /dev/null +++ b/docs/inference.md @@ -0,0 +1,57 @@ +# Building models + +The two main contracts `bayeux` has are that +1. You can specify a model using a log density, a test point, and a transformation (the transformation defaults to an identity, but that is rarely what you want) +2. Every inference algorithm in `bayeux` will (try to) run with just a seed as an argument. + +## Specifying a model + +In case you have a scalar model, there is no need to normalize the density. + +```python +import bayeux as bx +import jax +import numpy as np + +normal_density = bx.Model( + log_density=lambda x: -x*x, + test_point=1.) +``` + +Suppose we have a bunch of observations of a normal distribution, and we want to infer the mean and scale. Maybe we write this down by hand, putting a prior of N(0, 10) on the mean and half normal with scale 10 on the scale: + +```python +points = 3 * np.random.randn(100) - 10 + +def log_density(pt): + log_prior = -(pt['loc'] ** 2 + pt['scale']**2) / 200. + log_likelihood = jnp.sum(jst.norm.logpdf(points, loc=pt['loc'], scale=pt['scale'])) + return log_prior + log_likelihood +``` + +We additionally need to restrict the scale to be positive. A [softplus](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)#Softplus) is useful for this: + +```python +def transform_fn(pt): + return {'loc': pt['loc'], 'scale': jax.nn.softplus(pt['scale'])} +``` + +The [oryx](https://github.com/jax-ml/oryx) library is used to automatically compute the inverse and Jacobian determinants for changes of variables, but the user can supply these if known. + +Then we can get the model: +```python +model = bx.Model( + log_density=log_density, + test_point={'loc': 0., 'scale': 1.}, + transform_fn=transform_fn) + +opt = model.optimize.optax_adam(seed=seed, num_iters=10000) +opt.params + +{'loc': Array([-9.428163, -9.428162, -9.428163, -9.428162, -9.428165, -9.428163, + -9.428163, -9.428164], dtype=float32), + 'scale': Array([2.9746027, 2.9746041, 2.9746022, 2.9746022, 2.9745977, 2.9746022, + 2.9746027, 2.9746022], dtype=float32)} +``` + +By default, we ran 8 particles for optimization, which is helpful to see that all of them found approximately the same maximum likelihood estimate. diff --git a/docs/inspecting.md b/docs/inspecting.md new file mode 100644 index 0000000..6641134 --- /dev/null +++ b/docs/inspecting.md @@ -0,0 +1,225 @@ +# Inspecting models + +## Seeing keyword arguments + +Since `bayeux` is built on top of other fantastic libraries, it tries not to get in the way of them. Each algorithm has a `.get_kwargs()` method that tells you how it will be called, and what functions are being called: + +```python +normal_density.optimize.jaxopt_bfgs.get_kwargs() + +{jaxopt._src.bfgs.BFGS: {'value_and_grad': False, + 'has_aux': False, + 'maxiter': 500, + 'tol': 0.001, + 'stepsize': 0.0, + 'linesearch': 'zoom', + 'linesearch_init': 'increase', + 'condition': None, + 'maxls': 30, + 'decrease_factor': None, + 'increase_factor': 1.5, + 'max_stepsize': 1.0, + 'min_stepsize': 1e-06, + 'implicit_diff': True, + 'implicit_diff_solve': None, + 'jit': True, + 'unroll': 'auto', + 'verbose': False}, + 'extra_parameters': {'chain_method': 'vectorized', + 'num_particles': 8, + 'num_iters': 1000, + 'apply_transform': True}} +``` + +If you pass an argument into `.get_kwargs()`, this will also tell you what will be passed on to the actual algorithms. + +```python +normal_density.mcmc.blackjax_nuts.get_kwargs( + num_chains=5, + target_acceptance_rate=0.99) + +{