From 52140cebc9965ffb72edde712d8d2180330e90fa Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Tue, 28 May 2024 09:28:33 -0700 Subject: [PATCH] Prepare for blackjax API change. PiperOrigin-RevId: 637925845 --- bayeux/_src/mcmc/blackjax.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/bayeux/_src/mcmc/blackjax.py b/bayeux/_src/mcmc/blackjax.py index 62dd4bf..757cc74 100644 --- a/bayeux/_src/mcmc/blackjax.py +++ b/bayeux/_src/mcmc/blackjax.py @@ -23,6 +23,24 @@ import optax +def _get_run_inference_algorithm_kwarg_name(): + """This is a hack while blackjax changes API for run_inference_algorithm. + + This should be deleted and default to "initial_state" once the API + stabilizes and we can depend on some version > 1.2.1. + + We do this out here so that it just runs once. + + Returns: + keyword argument name that `blackjax.util.run_inference_algorithm` expects. + """ + _, req = shared.get_default_signature(blackjax.util.run_inference_algorithm) + if "initial_state_or_position" in req: + return "initial_state_or_position" + return "initial_state" +_INFERENCE_KWARG = _get_run_inference_algorithm_kwarg_name() + + _ADAPT_FNS = { "window": blackjax.window_adaptation, "pathfinder": blackjax.pathfinder_adaptation, @@ -182,10 +200,10 @@ def _blackjax_inference( inference_algorithm = algorithm(**algorithm_kwargs) _, states, infos = blackjax.util.run_inference_algorithm( rng_key=seed, - initial_state_or_position=adapt_state, inference_algorithm=inference_algorithm, num_steps=num_draws, - progress_bar=False) + progress_bar=False, + **{_INFERENCE_KWARG: adapt_state}) return states, infos