Skip to content

Commit

Permalink
Prepare for blackjax API change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 637925845
  • Loading branch information
ColCarroll authored and The bayeux Authors committed May 28, 2024
1 parent f5e9cf6 commit 52140ce
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions bayeux/_src/mcmc/blackjax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@
import optax


def _get_run_inference_algorithm_kwarg_name():
"""This is a hack while blackjax changes API for run_inference_algorithm.
This should be deleted and default to "initial_state" once the API
stabilizes and we can depend on some version > 1.2.1.
We do this out here so that it just runs once.
Returns:
keyword argument name that `blackjax.util.run_inference_algorithm` expects.
"""
_, req = shared.get_default_signature(blackjax.util.run_inference_algorithm)
if "initial_state_or_position" in req:
return "initial_state_or_position"
return "initial_state"
_INFERENCE_KWARG = _get_run_inference_algorithm_kwarg_name()


_ADAPT_FNS = {
"window": blackjax.window_adaptation,
"pathfinder": blackjax.pathfinder_adaptation,
Expand Down Expand Up @@ -182,10 +200,10 @@ def _blackjax_inference(
inference_algorithm = algorithm(**algorithm_kwargs)
_, states, infos = blackjax.util.run_inference_algorithm(
rng_key=seed,
initial_state_or_position=adapt_state,
inference_algorithm=inference_algorithm,
num_steps=num_draws,
progress_bar=False)
progress_bar=False,
**{_INFERENCE_KWARG: adapt_state})
return states, infos


Expand Down

0 comments on commit 52140ce

Please sign in to comment.