From ae6ec6eaf892efcf4744255137655a8a96685573 Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Sat, 17 Feb 2024 12:44:34 -0800 Subject: [PATCH] Use `blackjax.util.run_inference`. We prefer library-provided functions where possible. PiperOrigin-RevId: 607999912 --- bayeux/_src/mcmc/blackjax.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/bayeux/_src/mcmc/blackjax.py b/bayeux/_src/mcmc/blackjax.py index 23c3c58..4b22e4d 100644 --- a/bayeux/_src/mcmc/blackjax.py +++ b/bayeux/_src/mcmc/blackjax.py @@ -149,7 +149,6 @@ def _blackjax_adapt( return last_state, parameters -# TODO(colcarroll): Use blackjax.util.run_inference_algorithm here. def _blackjax_inference( seed, adapt_state, @@ -173,24 +172,14 @@ def _blackjax_inference( """ algorithm_kwargs = kwargs[algorithm] | adapt_parameters - kernel = algorithm(**algorithm_kwargs).step - - @jax.jit - def inference_loop(rng_key): - - def one_step(state, rng_key): - state, info = kernel(rng_key, state) - return state, (state, info) - - keys = jax.random.split(rng_key, num_draws) - _, (states, infos) = jax.lax.scan(one_step, adapt_state, keys) - - return states, infos - - # Functions returned by chees adaptation. - adapt_parameters.pop("next_random_arg_fn", None) - adapt_parameters.pop("integration_steps_fn", None) - return inference_loop(seed), adapt_parameters + inference_algorithm = algorithm(**algorithm_kwargs) + _, states, infos = blackjax.util.run_inference_algorithm( + rng_key=seed, + initial_state_or_position=adapt_state, + inference_algorithm=inference_algorithm, + num_steps=num_draws, + progress_bar=False) + return states, infos def _blackjax_inference_loop( @@ -210,7 +199,7 @@ def _blackjax_inference_loop( adapt_parameters, algorithm, num_draws, - kwargs) + kwargs), adapt_parameters def _blackjax_stats_to_dict(sample_stats, potential_energy, adapt_parameters): @@ -362,7 +351,7 @@ def _sample_blackjax_dynamic( map_seed = jax.random.split(seed, num_chains) mapped_sampler = shared.map_fn(chain_method, sampler) - (states, stats), adapt_parameters = mapped_sampler(map_seed, adapt_state) + states, stats = mapped_sampler(map_seed, adapt_state) draws = transform_fn(states.position) if extra_parameters["return_pytree"]: return draws