From 4ee0eb9385a2fc0dfdf7324c26bf75c7362209ae Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Tue, 13 Feb 2024 15:26:00 -0800 Subject: [PATCH] Fix keyword handling for flowMC. Also makes a new release. PiperOrigin-RevId: 606765865 --- CHANGELOG.md | 8 ++++- bayeux/__init__.py | 2 +- bayeux/_src/mcmc/flowmc.py | 71 +++++++++++++++++--------------------- bayeux/tests/mcmc_test.py | 21 +++++++++++ 4 files changed, 61 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index df24128..36f8fed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,11 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ## [Unreleased] +## [0.1.7] - 2024-02-13 + +### Add SNAPER HMC from TFP +### Fix flowMC keyword handling + ## [0.1.6] - 2024-02-01 ### Add samplers from flowMC @@ -50,7 +55,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ### Initial release -[Unreleased]: https://github.com/jax-ml/bayeux/compare/v0.1.5...HEAD +[Unreleased]: https://github.com/jax-ml/bayeux/compare/v0.1.7...HEAD +[0.1.7]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.7 [0.1.6]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.6 [0.1.5]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.5 [0.1.4]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.4 diff --git a/bayeux/__init__.py b/bayeux/__init__.py index 6142514..026e7db 100644 --- a/bayeux/__init__.py +++ b/bayeux/__init__.py @@ -16,7 +16,7 @@ # A new PyPI release will be pushed everytime `__version__` is increased # When changing this, also update the CHANGELOG.md -__version__ = '0.1.6' +__version__ = '0.1.7' # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/google/jax/issues/7570 diff --git a/bayeux/_src/mcmc/flowmc.py b/bayeux/_src/mcmc/flowmc.py index c99326f..b5376d5 100644 --- a/bayeux/_src/mcmc/flowmc.py +++ b/bayeux/_src/mcmc/flowmc.py @@ -48,17 +48,6 @@ def get_nf_model_kwargs(nf_model, n_features, kwargs): """Sets defaults and merges user-provided adaptation keywords.""" - nf_model_kwargs, nf_model_required = shared.get_default_signature( - nf_model) - nf_model_kwargs.update( - {k: kwargs[k] for k in nf_model_kwargs if k in kwargs}) - nf_model_kwargs.update( - {k: kwargs[k] for k in nf_model_required if k in kwargs}) - nf_model_kwargs.setdefault("n_features", n_features) - nf_model_required.remove("key") - nf_model_required.remove("kwargs") - nf_model_required = nf_model_required - nf_model_kwargs.keys() - defaults = { # RealNVP kwargs "n_hidden": 100, @@ -68,11 +57,15 @@ def get_nf_model_kwargs(nf_model, n_features, kwargs): "num_bins": 8, "hidden_size": [64, 64], "spline_range": (-10.0, 10.0), - } - for key, value in defaults.items(): - if key in nf_model_required: - nf_model_kwargs[key] = value + "n_features": n_features, + } | kwargs + nf_model_kwargs, nf_model_required = shared.get_default_signature( + nf_model) + nf_model_kwargs.update( + {k: defaults[k] for k in nf_model_required if k in defaults}) + nf_model_required.remove("key") + nf_model_required.remove("kwargs") nf_model_required = nf_model_required - nf_model_kwargs.keys() if nf_model_required: @@ -81,27 +74,30 @@ def get_nf_model_kwargs(nf_model, n_features, kwargs): f"{','.join(nf_model_required)}. Probably file a bug, but " "you can try to manually supply them as keywords." ) + nf_model_kwargs.update( + {k: defaults[k] for k in nf_model_kwargs if k in defaults}) + return nf_model_kwargs def get_local_sampler_kwargs(local_sampler, log_density, n_features, kwargs): """Sets defaults and merges user-provided adaptation keywords.""" - kwargs["logpdf"] = log_density - sampler_kwargs, sampler_required = shared.get_default_signature( - local_sampler) - sampler_kwargs.setdefault("jit", True) - sampler_kwargs.update( - {k: kwargs[k] for k in sampler_required if k in kwargs}) - sampler_required = sampler_required - sampler_kwargs.keys() - defaults = { # HMC kwargs "condition_matrix": jnp.eye(n_features), "n_leapfrog": 10, # Both "step_size": 0.1, - } + "logpdf": log_density + } | kwargs + + sampler_kwargs, sampler_required = shared.get_default_signature( + local_sampler) + sampler_kwargs.setdefault("jit", True) + sampler_kwargs.update( + {k: defaults[k] for k in sampler_required if k in defaults}) + sampler_required = sampler_required - sampler_kwargs.keys() if "params" in sampler_required: sampler_kwargs["params"] = defaults else: @@ -120,15 +116,9 @@ def get_local_sampler_kwargs(local_sampler, log_density, n_features, kwargs): def get_sampler_kwargs(sampler, n_features, kwargs): """Sets defaults and merges user-provided adaptation keywords.""" - sampler_kwargs, sampler_required = shared.get_default_signature(sampler) - sampler_kwargs.update( - {k: kwargs[k] for k in sampler_required if k in kwargs}) - sampler_kwargs.setdefault("data", {}) - sampler_kwargs.setdefault("n_dim", n_features) - sampler_required = (sampler_required - - {"nf_model", "local_sampler", "rng_key_set", "kwargs"}) - sampler_required = sampler_required - sampler_kwargs.keys() - + # We support `num_chains` everywhere else, so support it here. + if "num_chains" in kwargs: + kwargs["n_chains"] = kwargs["num_chains"] defaults = { "n_loop_training": 5, "n_loop_production": 5, @@ -149,11 +139,14 @@ def get_sampler_kwargs(sampler, n_features, kwargs): "output_thinning": 1, "n_sample_max": 10_000, "precompile": False, - "verbose": False} - for key, value in defaults.items(): - if key not in sampler_kwargs: - sampler_kwargs[key] = value - + "verbose": False, + "n_dim": n_features, + "data": {}} | kwargs + sampler_kwargs, sampler_required = shared.get_default_signature(sampler) + sampler_kwargs.update( + {k: defaults[k] for k in sampler_required if k in defaults}) + sampler_required = (sampler_required - + {"nf_model", "local_sampler", "rng_key_set", "kwargs"}) sampler_required = sampler_required - sampler_kwargs.keys() if sampler_required: @@ -162,7 +155,7 @@ def get_sampler_kwargs(sampler, n_features, kwargs): f"{','.join(sampler_required)}. Probably file a bug, but " "you can try to manually supply them as keywords." ) - return sampler_kwargs + return defaults | sampler_kwargs class _FlowMCSampler(shared.Base): diff --git a/bayeux/tests/mcmc_test.py b/bayeux/tests/mcmc_test.py index 24ef3ea..48e46d5 100644 --- a/bayeux/tests/mcmc_test.py +++ b/bayeux/tests/mcmc_test.py @@ -13,6 +13,8 @@ # limitations under the License. """Tests for the API.""" +import importlib + import arviz as az import bayeux as bx import jax @@ -77,6 +79,25 @@ def test_return_pytree_tfp(): assert pytree["x"]["y"].shape == (10, 4) +@pytest.mark.skipif(importlib.util.find_spec("flowMC") is None, + reason="Test requires flowMC which is not installed") +def test_return_pytree_flowmc(): + model = bx.Model(log_density=lambda pt: -jnp.sum(pt["x"]["y"]**2), + test_point={"x": {"y": jnp.array([1., 1.])}}) + seed = jax.random.PRNGKey(0) + pytree = model.mcmc.flowmc_realnvp_mala( + seed=seed, + return_pytree=True, + n_chains=4, + n_local_steps=1, + n_global_steps=1, + n_loop_training=1, + n_loop_production=5, + ) + # 10 draws = (1 local + 1 global) * 5 loops + assert pytree["x"]["y"].shape == (4, 10, 2) + + @pytest.mark.parametrize("method", METHODS) def test_samplers(method): # flowMC samplers are broken for 0 or 1 dimensions, so just test