From 6a52896c0b13bfa3eee0b165cb7352ec2cddb381 Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Thu, 1 Feb 2024 09:10:49 -0800 Subject: [PATCH] Add flowMC samplers and release a new version. These currently only work for more than 1 dimension, and may require some further tuning. PiperOrigin-RevId: 603388912 --- CHANGELOG.md | 5 + bayeux/__init__.py | 4 +- bayeux/_src/__init__.py | 2 +- bayeux/_src/bayeux.py | 2 +- bayeux/_src/debug.py | 2 +- bayeux/_src/initialization.py | 2 +- bayeux/_src/mcmc/__init__.py | 2 +- bayeux/_src/mcmc/blackjax.py | 6 +- bayeux/_src/mcmc/flowmc.py | 267 +++++++++++++++++++++++++++++ bayeux/_src/mcmc/numpyro.py | 2 +- bayeux/_src/optimize/__init__.py | 2 +- bayeux/_src/optimize/jaxopt.py | 2 +- bayeux/_src/optimize/optax.py | 2 +- bayeux/_src/optimize/optimistix.py | 2 +- bayeux/_src/optimize/shared.py | 2 +- bayeux/_src/shared.py | 2 +- bayeux/_src/types.py | 2 +- bayeux/_src/vi/__init__.py | 2 +- bayeux/_src/vi/tfp.py | 2 +- bayeux/mcmc/__init__.py | 14 +- bayeux/optimize/__init__.py | 2 +- bayeux/tests/compat_test.py | 2 +- bayeux/tests/debug_test.py | 2 +- bayeux/tests/mcmc_test.py | 13 +- bayeux/tests/optimize_test.py | 2 +- bayeux/tests/vi_test.py | 2 +- bayeux/vi/__init__.py | 2 +- pyproject.toml | 1 + 28 files changed, 321 insertions(+), 31 deletions(-) create mode 100644 bayeux/_src/mcmc/flowmc.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fb50bad..df24128 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ## [Unreleased] +## [0.1.6] - 2024-02-01 + +### Add samplers from flowMC + ## [0.1.5] - 2024-01-12 ### Bugfix for PyMC models @@ -47,6 +51,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ### Initial release [Unreleased]: https://github.com/jax-ml/bayeux/compare/v0.1.5...HEAD +[0.1.6]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.6 [0.1.5]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.5 [0.1.4]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.4 [0.1.3]: https://github.com/jax-ml/bayeux/releases/tag/v0.1.3 diff --git a/bayeux/__init__.py b/bayeux/__init__.py index e8ef4b2..6142514 100644 --- a/bayeux/__init__.py +++ b/bayeux/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ # A new PyPI release will be pushed everytime `__version__` is increased # When changing this, also update the CHANGELOG.md -__version__ = '0.1.5' +__version__ = '0.1.6' # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/google/jax/issues/7570 diff --git a/bayeux/_src/__init__.py b/bayeux/_src/__init__.py index 1aad681..19a26b7 100644 --- a/bayeux/_src/__init__.py +++ b/bayeux/_src/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/_src/bayeux.py b/bayeux/_src/bayeux.py index 78dd712..1eb7588 100644 --- a/bayeux/_src/bayeux.py +++ b/bayeux/_src/bayeux.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/_src/debug.py b/bayeux/_src/debug.py index e69c993..32e84cd 100644 --- a/bayeux/_src/debug.py +++ b/bayeux/_src/debug.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/_src/initialization.py b/bayeux/_src/initialization.py index 241310a..45017d1 100644 --- a/bayeux/_src/initialization.py +++ b/bayeux/_src/initialization.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/_src/mcmc/__init__.py b/bayeux/_src/mcmc/__init__.py index 1aad681..19a26b7 100644 --- a/bayeux/_src/mcmc/__init__.py +++ b/bayeux/_src/mcmc/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/_src/mcmc/blackjax.py b/bayeux/_src/mcmc/blackjax.py index 2fb3337..d02f34d 100644 --- a/bayeux/_src/mcmc/blackjax.py +++ b/bayeux/_src/mcmc/blackjax.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -312,8 +312,8 @@ def get_algorithm_kwargs(algorithm, log_density, kwargs): algorithm_kwargs, algorithm_required = shared.get_default_signature(algorithm) kwargs_with_defaults = { "logdensity_fn": log_density, - "step_size": 0.01, - "num_integration_steps": 8, + "step_size": 0.5, + "num_integration_steps": 16, } | kwargs algorithm_kwargs.update( { diff --git a/bayeux/_src/mcmc/flowmc.py b/bayeux/_src/mcmc/flowmc.py new file mode 100644 index 0000000..c99326f --- /dev/null +++ b/bayeux/_src/mcmc/flowmc.py @@ -0,0 +1,267 @@ +# Copyright 2024 The bayeux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2024 The bayeux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""flowMC specific code.""" +import arviz as az +from bayeux._src import shared +from flowMC.nfmodel import realNVP +from flowMC.nfmodel import rqSpline +from flowMC.sampler import HMC +from flowMC.sampler import MALA +from flowMC.sampler import Sampler +import jax +import jax.numpy as jnp + + +_NF_MODELS = { + "real_nvp": realNVP.RealNVP, + "masked_coupling_rq_spline": rqSpline.MaskedCouplingRQSpline, +} + +_LOCAL_SAMPLERS = {"mala": MALA.MALA, "hmc": HMC.HMC} + + +def get_nf_model_kwargs(nf_model, n_features, kwargs): + """Sets defaults and merges user-provided adaptation keywords.""" + nf_model_kwargs, nf_model_required = shared.get_default_signature( + nf_model) + nf_model_kwargs.update( + {k: kwargs[k] for k in nf_model_kwargs if k in kwargs}) + nf_model_kwargs.update( + {k: kwargs[k] for k in nf_model_required if k in kwargs}) + nf_model_kwargs.setdefault("n_features", n_features) + nf_model_required.remove("key") + nf_model_required.remove("kwargs") + nf_model_required = nf_model_required - nf_model_kwargs.keys() + + defaults = { + # RealNVP kwargs + "n_hidden": 100, + "n_layer": 10, + # MaskedCouplingRQSpline kwargs + "n_layers": 4, + "num_bins": 8, + "hidden_size": [64, 64], + "spline_range": (-10.0, 10.0), + } + for key, value in defaults.items(): + if key in nf_model_required: + nf_model_kwargs[key] = value + + nf_model_required = nf_model_required - nf_model_kwargs.keys() + + if nf_model_required: + raise ValueError( + "Unexpected required arguments: " + f"{','.join(nf_model_required)}. Probably file a bug, but " + "you can try to manually supply them as keywords." + ) + return nf_model_kwargs + + +def get_local_sampler_kwargs(local_sampler, log_density, n_features, kwargs): + """Sets defaults and merges user-provided adaptation keywords.""" + + kwargs["logpdf"] = log_density + sampler_kwargs, sampler_required = shared.get_default_signature( + local_sampler) + sampler_kwargs.setdefault("jit", True) + sampler_kwargs.update( + {k: kwargs[k] for k in sampler_required if k in kwargs}) + sampler_required = sampler_required - sampler_kwargs.keys() + + defaults = { + # HMC kwargs + "condition_matrix": jnp.eye(n_features), + "n_leapfrog": 10, + # Both + "step_size": 0.1, + } + if "params" in sampler_required: + sampler_kwargs["params"] = defaults + else: + sampler_kwargs["params"] = sampler_kwargs["params"] | defaults + + sampler_required = sampler_required - sampler_kwargs.keys() + + if sampler_required: + raise ValueError( + "Unexpected required arguments: " + f"{','.join(sampler_required)}. Probably file a bug, but " + "you can try to manually supply them as keywords." + ) + return sampler_kwargs + + +def get_sampler_kwargs(sampler, n_features, kwargs): + """Sets defaults and merges user-provided adaptation keywords.""" + sampler_kwargs, sampler_required = shared.get_default_signature(sampler) + sampler_kwargs.update( + {k: kwargs[k] for k in sampler_required if k in kwargs}) + sampler_kwargs.setdefault("data", {}) + sampler_kwargs.setdefault("n_dim", n_features) + sampler_required = (sampler_required - + {"nf_model", "local_sampler", "rng_key_set", "kwargs"}) + sampler_required = sampler_required - sampler_kwargs.keys() + + defaults = { + "n_loop_training": 5, + "n_loop_production": 5, + "n_local_steps": 50, + "n_global_steps": 50, + "n_chains": 20, + "n_epochs": 30, + "learning_rate": 0.01, + "max_samples": 10_000, + "momentum": 0.9, + "batch_size": 10_000, + "use_global": True, + "global_sampler": None, + "logging": True, + "keep_quantile": 0., + "local_autotune": None, + "train_thinning": 1, + "output_thinning": 1, + "n_sample_max": 10_000, + "precompile": False, + "verbose": False} + for key, value in defaults.items(): + if key not in sampler_kwargs: + sampler_kwargs[key] = value + + sampler_required = sampler_required - sampler_kwargs.keys() + + if sampler_required: + raise ValueError( + "Unexpected required arguments: " + f"{','.join(sampler_required)}. Probably file a bug, but " + "you can try to manually supply them as keywords." + ) + return sampler_kwargs + + +class _FlowMCSampler(shared.Base): + """Base class for flowmc samplers.""" + name: str = "" + nf_model: str = "" + local_sampler: str = "" + + def _get_aux(self): + flat, unflatten = jax.flatten_util.ravel_pytree(self.test_point) + + @jax.vmap + def flatten(pytree): + return jax.flatten_util.ravel_pytree(pytree)[0] + + constrained_log_density = self.constrained_log_density() + def log_density(x, _): + return constrained_log_density(unflatten(x)).squeeze() + + return log_density, flatten, unflatten, flat.shape[0] + + def get_kwargs(self, **kwargs): + nf_model = _NF_MODELS[self.nf_model] + local_sampler = _LOCAL_SAMPLERS[self.local_sampler] + log_density, flatten, unflatten, n_features = self._get_aux() + + nf_model_kwargs = get_nf_model_kwargs(nf_model, n_features, kwargs) + local_sampler_kwargs = get_local_sampler_kwargs( + local_sampler, log_density, n_features, kwargs) + sampler = Sampler.Sampler + sampler_kwargs = get_sampler_kwargs(sampler, n_features, kwargs) + extra_parameters = {"flatten": flatten, + "unflatten": unflatten, + "num_chains": sampler_kwargs["n_chains"], + "return_pytree": kwargs.get("return_pytree", False)} + + return {nf_model: nf_model_kwargs, + local_sampler: local_sampler_kwargs, + sampler: sampler_kwargs, + "extra_parameters": extra_parameters} + + def __call__(self, seed, **kwargs): + kwargs = self.get_kwargs(**kwargs) + extra_parameters = kwargs["extra_parameters"] + num_chains = extra_parameters["num_chains"] + init_key, nf_key, seed = jax.random.split(seed, 3) + initial_state = self.get_initial_state( + init_key, num_chains=num_chains) + initial_state = extra_parameters["flatten"](initial_state) + nf_model = _NF_MODELS[self.nf_model] + local_sampler = _LOCAL_SAMPLERS[self.local_sampler] + + rng_key_init, rng_key_mcmc, rng_key_nf = jax.random.split(seed, 3) + rng_keys_mcmc = jax.random.split(rng_key_mcmc, num_chains) + rng_keys_nf, init_rng_keys_nf = jax.random.split(rng_key_nf, 2) + + model = nf_model(key=nf_key, **kwargs[nf_model]) + local_sampler = local_sampler(**kwargs[local_sampler]) + sampler = Sampler.Sampler + nf_sampler = sampler( + rng_key_set=( + rng_key_init, rng_keys_mcmc, rng_keys_nf, init_rng_keys_nf), + local_sampler=local_sampler, + nf_model=model, + **kwargs[sampler]) + nf_sampler.sample(initial_state, {}) + chains, *_ = nf_sampler.get_sampler_state().values() + + unflatten = jax.vmap(jax.vmap(extra_parameters["unflatten"])) + pytree = self.transform_fn(unflatten(chains)) + if extra_parameters["return_pytree"]: + return pytree + else: + if hasattr(pytree, "_asdict"): + pytree = pytree._asdict() + elif not isinstance(pytree, dict): + pytree = {"var0": pytree} + return az.from_dict(posterior=pytree) + + +class RealNVPMALA(_FlowMCSampler): + name = "flowmc_realnvp_mala" + nf_model = "real_nvp" + local_sampler = "mala" + + +class RealNVPHMC(_FlowMCSampler): + name = "flowmc_realnvp_hmc" + nf_model = "real_nvp" + local_sampler = "hmc" + + +class MaskedCouplingRQSplineMALA(_FlowMCSampler): + name = "flowmc_rqspline_mala" + nf_model = "masked_coupling_rq_spline" + local_sampler = "mala" + + +class MaskedCouplingRQSplineHMC(_FlowMCSampler): + name = "flowmc_rqspline_hmc" + nf_model = "masked_coupling_rq_spline" + local_sampler = "hmc" diff --git a/bayeux/_src/mcmc/numpyro.py b/bayeux/_src/mcmc/numpyro.py index 61a112c..0b174d0 100644 --- a/bayeux/_src/mcmc/numpyro.py +++ b/bayeux/_src/mcmc/numpyro.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/_src/optimize/__init__.py b/bayeux/_src/optimize/__init__.py index 1aad681..19a26b7 100644 --- a/bayeux/_src/optimize/__init__.py +++ b/bayeux/_src/optimize/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/_src/optimize/jaxopt.py b/bayeux/_src/optimize/jaxopt.py index dabbab7..1a66b0d 100644 --- a/bayeux/_src/optimize/jaxopt.py +++ b/bayeux/_src/optimize/jaxopt.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/_src/optimize/optax.py b/bayeux/_src/optimize/optax.py index 33ea2ae..3d7d452 100644 --- a/bayeux/_src/optimize/optax.py +++ b/bayeux/_src/optimize/optax.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/_src/optimize/optimistix.py b/bayeux/_src/optimize/optimistix.py index 8b34dac..4be810d 100644 --- a/bayeux/_src/optimize/optimistix.py +++ b/bayeux/_src/optimize/optimistix.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/_src/optimize/shared.py b/bayeux/_src/optimize/shared.py index d58d40f..17119d9 100644 --- a/bayeux/_src/optimize/shared.py +++ b/bayeux/_src/optimize/shared.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/_src/shared.py b/bayeux/_src/shared.py index d16f181..3c8673d 100644 --- a/bayeux/_src/shared.py +++ b/bayeux/_src/shared.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/_src/types.py b/bayeux/_src/types.py index 5864634..738b051 100644 --- a/bayeux/_src/types.py +++ b/bayeux/_src/types.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/_src/vi/__init__.py b/bayeux/_src/vi/__init__.py index 1aad681..19a26b7 100644 --- a/bayeux/_src/vi/__init__.py +++ b/bayeux/_src/vi/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/_src/vi/tfp.py b/bayeux/_src/vi/tfp.py index 6ee5143..c8a6d78 100644 --- a/bayeux/_src/vi/tfp.py +++ b/bayeux/_src/vi/tfp.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/mcmc/__init__.py b/bayeux/mcmc/__init__.py index 629c46f..7caea40 100644 --- a/bayeux/mcmc/__init__.py +++ b/bayeux/mcmc/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,6 +29,18 @@ "NUTSblackjax", "HMC_Pathfinder_blackjax", "NUTS_Pathfinder_blackjax"]) +if importlib.util.find_spec("flowMC") is not None: + from bayeux._src.mcmc.flowmc import MaskedCouplingRQSplineHMC as MaskedCouplingRQSplineHMCflowmc + from bayeux._src.mcmc.flowmc import MaskedCouplingRQSplineMALA as MaskedCouplingRQSplineMALAflowmc + from bayeux._src.mcmc.flowmc import RealNVPHMC as RealNVPHMCflowmc + from bayeux._src.mcmc.flowmc import RealNVPMALA as RealNVPMALAflowmc + + __all__.extend([ + "MaskedCouplingRQSplineHMCflowmc", + "MaskedCouplingRQSplineMALAflowmc", + "RealNVPHMCflowmc", + "RealNVPMALAflowmc"]) + if importlib.util.find_spec("numpyro") is not None: from bayeux._src.mcmc.numpyro import HMC as HMCnumpyro from bayeux._src.mcmc.numpyro import NUTS as NUTSnumpyro diff --git a/bayeux/optimize/__init__.py b/bayeux/optimize/__init__.py index 3601da1..458feb9 100644 --- a/bayeux/optimize/__init__.py +++ b/bayeux/optimize/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/tests/compat_test.py b/bayeux/tests/compat_test.py index 5846808..b39b659 100644 --- a/bayeux/tests/compat_test.py +++ b/bayeux/tests/compat_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/tests/debug_test.py b/bayeux/tests/debug_test.py index 84b9697..c2902f4 100644 --- a/bayeux/tests/debug_test.py +++ b/bayeux/tests/debug_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/tests/mcmc_test.py b/bayeux/tests/mcmc_test.py index 720976d..48ecdb9 100644 --- a/bayeux/tests/mcmc_test.py +++ b/bayeux/tests/mcmc_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -65,13 +65,18 @@ def test_return_pytree_numpyro(): @pytest.mark.parametrize("method", METHODS) def test_samplers(method): - model = bx.Model(log_density=lambda pt: -pt["x"]**2, - test_point={"x": jnp.array(1.)}) + # flowMC samplers are broken for 0 or 1 dimensions, so just test + # everything on 2 dimensions for now. + model = bx.Model(log_density=lambda pt: jnp.sum(-pt["x"]**2), + test_point={"x": jnp.ones((1, 2))}) sampler = getattr(model.mcmc, method) seed = jax.random.PRNGKey(0) assert sampler.debug(seed=seed, verbosity=0) idata = sampler(seed=seed) - assert max_rhat(idata) < 1.1 + if method == "blackjax_hmc": + assert max_rhat(idata) < 1.2 + else: + assert max_rhat(idata) < 1.1 @pytest.mark.parametrize("method", METHODS) diff --git a/bayeux/tests/optimize_test.py b/bayeux/tests/optimize_test.py index a967279..805145c 100644 --- a/bayeux/tests/optimize_test.py +++ b/bayeux/tests/optimize_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/tests/vi_test.py b/bayeux/tests/vi_test.py index b7d7c80..92a043a 100644 --- a/bayeux/tests/vi_test.py +++ b/bayeux/tests/vi_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/bayeux/vi/__init__.py b/bayeux/vi/__init__.py index 2a4cc2a..5ca05d0 100644 --- a/bayeux/vi/__init__.py +++ b/bayeux/vi/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 The bayeux Authors. +# Copyright 2024 The bayeux Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/pyproject.toml b/pyproject.toml index 0b9f213..5d9f740 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "optax", "optimistix", "blackjax", + "flowmc", "numpyro", "jaxopt", ]