From 4d8fec2103d80c16531ba2c762cad1168fac7079 Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Tue, 13 Feb 2024 09:09:52 -0800 Subject: [PATCH] Fix keyword handling for flowMC. Also makes a new release. PiperOrigin-RevId: 606642479 --- CHANGELOG.md | 8 ++- bayeux/__init__.py | 2 +- bayeux/_src/mcmc/flowmc.py | 71 ++++++++++------------- bayeux/_src/mcmc/tfp.py | 116 +++++++++++++++++++++++++++++++++++++ bayeux/mcmc/__init__.py | 6 +- bayeux/tests/mcmc_test.py | 35 +++++++++++ 6 files changed, 196 insertions(+), 42 deletions(-) create mode 100644 bayeux/_src/mcmc/tfp.py diff --git a/CHANGELOG.md b/CHANGELOG.md index df24128..36f8fed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,11 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ## [Unreleased] +## [0.1.7] - 2024-02-13 + +### Add SNAPER HMC from TFP +### Fix flowMC keyword handling + ## [0.1.6] - 2024-02-01 ### Add samplers from flowMC @@ -50,7 +55,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ### Initial release -[Unreleased]: https://github.com/jax-ml/bayeux/compare/v0.1.5...HEAD +[Unreleased]: https://github.com/jax-ml/bayeux/compare/v0.1.7...HEAD +[0.1.7]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.7 [0.1.6]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.6 [0.1.5]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.5 [0.1.4]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.4 diff --git a/bayeux/__init__.py b/bayeux/__init__.py index 6142514..026e7db 100644 --- a/bayeux/__init__.py +++ b/bayeux/__init__.py @@ -16,7 +16,7 @@ # A new PyPI release will be pushed everytime `__version__` is increased # When changing this, also update the CHANGELOG.md -__version__ = '0.1.6' +__version__ = '0.1.7' # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/google/jax/issues/7570 diff --git a/bayeux/_src/mcmc/flowmc.py b/bayeux/_src/mcmc/flowmc.py index c99326f..b5376d5 100644 --- a/bayeux/_src/mcmc/flowmc.py +++ b/bayeux/_src/mcmc/flowmc.py @@ -48,17 +48,6 @@ def get_nf_model_kwargs(nf_model, n_features, kwargs): """Sets defaults and merges user-provided adaptation keywords.""" - nf_model_kwargs, nf_model_required = shared.get_default_signature( - nf_model) - nf_model_kwargs.update( - {k: kwargs[k] for k in nf_model_kwargs if k in kwargs}) - nf_model_kwargs.update( - {k: kwargs[k] for k in nf_model_required if k in kwargs}) - nf_model_kwargs.setdefault("n_features", n_features) - nf_model_required.remove("key") - nf_model_required.remove("kwargs") - nf_model_required = nf_model_required - nf_model_kwargs.keys() - defaults = { # RealNVP kwargs "n_hidden": 100, @@ -68,11 +57,15 @@ def get_nf_model_kwargs(nf_model, n_features, kwargs): "num_bins": 8, "hidden_size": [64, 64], "spline_range": (-10.0, 10.0), - } - for key, value in defaults.items(): - if key in nf_model_required: - nf_model_kwargs[key] = value + "n_features": n_features, + } | kwargs + nf_model_kwargs, nf_model_required = shared.get_default_signature( + nf_model) + nf_model_kwargs.update( + {k: defaults[k] for k in nf_model_required if k in defaults}) + nf_model_required.remove("key") + nf_model_required.remove("kwargs") nf_model_required = nf_model_required - nf_model_kwargs.keys() if nf_model_required: @@ -81,27 +74,30 @@ def get_nf_model_kwargs(nf_model, n_features, kwargs): f"{','.join(nf_model_required)}. Probably file a bug, but " "you can try to manually supply them as keywords." ) + nf_model_kwargs.update( + {k: defaults[k] for k in nf_model_kwargs if k in defaults}) + return nf_model_kwargs def get_local_sampler_kwargs(local_sampler, log_density, n_features, kwargs): """Sets defaults and merges user-provided adaptation keywords.""" - kwargs["logpdf"] = log_density - sampler_kwargs, sampler_required = shared.get_default_signature( - local_sampler) - sampler_kwargs.setdefault("jit", True) - sampler_kwargs.update( - {k: kwargs[k] for k in sampler_required if k in kwargs}) - sampler_required = sampler_required - sampler_kwargs.keys() - defaults = { # HMC kwargs "condition_matrix": jnp.eye(n_features), "n_leapfrog": 10, # Both "step_size": 0.1, - } + "logpdf": log_density + } | kwargs + + sampler_kwargs, sampler_required = shared.get_default_signature( + local_sampler) + sampler_kwargs.setdefault("jit", True) + sampler_kwargs.update( + {k: defaults[k] for k in sampler_required if k in defaults}) + sampler_required = sampler_required - sampler_kwargs.keys() if "params" in sampler_required: sampler_kwargs["params"] = defaults else: @@ -120,15 +116,9 @@ def get_local_sampler_kwargs(local_sampler, log_density, n_features, kwargs): def get_sampler_kwargs(sampler, n_features, kwargs): """Sets defaults and merges user-provided adaptation keywords.""" - sampler_kwargs, sampler_required = shared.get_default_signature(sampler) - sampler_kwargs.update( - {k: kwargs[k] for k in sampler_required if k in kwargs}) - sampler_kwargs.setdefault("data", {}) - sampler_kwargs.setdefault("n_dim", n_features) - sampler_required = (sampler_required - - {"nf_model", "local_sampler", "rng_key_set", "kwargs"}) - sampler_required = sampler_required - sampler_kwargs.keys() - + # We support `num_chains` everywhere else, so support it here. + if "num_chains" in kwargs: + kwargs["n_chains"] = kwargs["num_chains"] defaults = { "n_loop_training": 5, "n_loop_production": 5, @@ -149,11 +139,14 @@ def get_sampler_kwargs(sampler, n_features, kwargs): "output_thinning": 1, "n_sample_max": 10_000, "precompile": False, - "verbose": False} - for key, value in defaults.items(): - if key not in sampler_kwargs: - sampler_kwargs[key] = value - + "verbose": False, + "n_dim": n_features, + "data": {}} | kwargs + sampler_kwargs, sampler_required = shared.get_default_signature(sampler) + sampler_kwargs.update( + {k: defaults[k] for k in sampler_required if k in defaults}) + sampler_required = (sampler_required - + {"nf_model", "local_sampler", "rng_key_set", "kwargs"}) sampler_required = sampler_required - sampler_kwargs.keys() if sampler_required: @@ -162,7 +155,7 @@ def get_sampler_kwargs(sampler, n_features, kwargs): f"{','.join(sampler_required)}. Probably file a bug, but " "you can try to manually supply them as keywords." ) - return sampler_kwargs + return defaults | sampler_kwargs class _FlowMCSampler(shared.Base): diff --git a/bayeux/_src/mcmc/tfp.py b/bayeux/_src/mcmc/tfp.py new file mode 100644 index 0000000..f682a91 --- /dev/null +++ b/bayeux/_src/mcmc/tfp.py @@ -0,0 +1,116 @@ +# Copyright 2024 The bayeux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2024 The bayeux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NumPyro specific code.""" + +import arviz as az +from bayeux._src import shared +import jax +import numpy as np +import tensorflow_probability.substrates.jax as tfp + + +class SnaperHMC(shared.Base): + """Implements SNAPER HMC [1] with step size adaptation. + + [1]: Sountsov, P. & Hoffman, M. (2021). Focusing on Difficult Directions for + Learning HMC Trajectory Lengths. + """ + name = "tfp_snaper_hmc" + + def get_kwargs(self, **kwargs): + kwargs_with_defaults = { + "num_results": 1_000, + "num_chains": 8, + } | kwargs + snaper = tfp.experimental.mcmc.sample_snaper_hmc + snaper_kwargs, snaper_required = shared.get_default_signature(snaper) + snaper_kwargs.update({k: kwargs_with_defaults[k] for k in snaper_required + if k in kwargs_with_defaults}) + snaper_required.remove("model") + # Initial state is handled internally + snaper_kwargs.pop("init_state") + # Seed set later + snaper_kwargs.pop("seed") + + snaper_required = snaper_required - snaper_kwargs.keys() + + if snaper_required: + raise ValueError(f"Unexpected required arguments: " + f"{','.join(snaper_required)}. Probably file a bug, but " + "you can try to manually supply them as keywords.") + snaper_kwargs.update({k: kwargs_with_defaults[k] for k in snaper_kwargs + if k in kwargs_with_defaults}) + return { + snaper: snaper_kwargs, + "extra_parameters": { + "return_pytree": kwargs.get("return_pytree", False) + }, + } + + def __call__(self, seed, **kwargs): + snaper = tfp.experimental.mcmc.sample_snaper_hmc + init_key, sample_key = jax.random.split(seed) + kwargs = self.get_kwargs(**kwargs) + initial_state = self.get_initial_state( + init_key, num_chains=kwargs[snaper]["num_chains"]) + + vmapped_constrained_log_prob = jax.vmap(self.constrained_log_density()) + + def tlp(*args, **kwargs): + if args: + return vmapped_constrained_log_prob(args) + else: + return vmapped_constrained_log_prob(kwargs) + + (draws, trace), *_ = snaper( + model=tlp, init_state=initial_state, seed=sample_key, **kwargs[snaper] + ) + draws = self.transform_fn(draws) + if kwargs["extra_parameters"]["return_pytree"]: + return draws + + if hasattr(draws, "_asdict"): + draws = draws._asdict() + elif not isinstance(draws, dict): + draws = {"var0": draws} + + draws = {x: np.swapaxes(v, 0, 1) for x, v in draws.items()} + return az.from_dict(posterior=draws, sample_stats=_tfp_stats_to_dict(trace)) + + +def _tfp_stats_to_dict(stats): + new_stats = {} + for k, v in stats.items(): + if k == "variance_scaling": + continue + if np.ndim(v) > 1: + new_stats[k] = np.swapaxes(v, 0, 1) + else: + new_stats[k] = v + return new_stats diff --git a/bayeux/mcmc/__init__.py b/bayeux/mcmc/__init__.py index 7caea40..a6978f6 100644 --- a/bayeux/mcmc/__init__.py +++ b/bayeux/mcmc/__init__.py @@ -15,9 +15,13 @@ """Imports from submodules.""" # pylint: disable=g-importing-member # pylint: disable=g-import-not-at-top +# pylint: disable=g-bad-import-order import importlib -__all__ = [] +# TFP-on-JAX always installed +from bayeux._src.mcmc.tfp import SnaperHMC as SNAPER_HMC_TFP +__all__ = ["SNAPER_HMC_TFP"] + if importlib.util.find_spec("blackjax") is not None: from bayeux._src.mcmc.blackjax import CheesHMC as CheesHMCblackjax from bayeux._src.mcmc.blackjax import HMC as HMCblackjax diff --git a/bayeux/tests/mcmc_test.py b/bayeux/tests/mcmc_test.py index 48ecdb9..48e46d5 100644 --- a/bayeux/tests/mcmc_test.py +++ b/bayeux/tests/mcmc_test.py @@ -13,6 +13,8 @@ # limitations under the License. """Tests for the API.""" +import importlib + import arviz as az import bayeux as bx import jax @@ -63,6 +65,39 @@ def test_return_pytree_numpyro(): assert pytree["x"]["y"].shape == (4, 10) +def test_return_pytree_tfp(): + model = bx.Model(log_density=lambda pt: -pt["x"]["y"]**2, + test_point={"x": {"y": jnp.array(1.)}}) + seed = jax.random.PRNGKey(0) + pytree = model.mcmc.tfp_snaper_hmc( + seed=seed, + return_pytree=True, + num_chains=4, + num_results=10, + num_burnin_steps=10, + ) + assert pytree["x"]["y"].shape == (10, 4) + + +@pytest.mark.skipif(importlib.util.find_spec("flowMC") is None, + reason="Test requires flowMC which is not installed") +def test_return_pytree_flowmc(): + model = bx.Model(log_density=lambda pt: -jnp.sum(pt["x"]["y"]**2), + test_point={"x": {"y": jnp.array([1., 1.])}}) + seed = jax.random.PRNGKey(0) + pytree = model.mcmc.flowmc_realnvp_mala( + seed=seed, + return_pytree=True, + n_chains=4, + n_local_steps=1, + n_global_steps=1, + n_loop_training=1, + n_loop_production=5, + ) + # 10 draws = (1 local + 1 global) * 5 loops + assert pytree["x"]["y"].shape == (4, 10, 2) + + @pytest.mark.parametrize("method", METHODS) def test_samplers(method): # flowMC samplers are broken for 0 or 1 dimensions, so just test