Skip to content

Commit

Permalink
Add meads and chees samplers from blackjax.
Browse files Browse the repository at this point in the history
These samplers pass the current MCMC tests (which just require sampling from a 1d gaussian), but do poorly enough on actual problems that there is probably a bug somewhere. I *think* this is a blackjax problem, but exposing these methods may help track down where the problem is.

PiperOrigin-RevId: 595794770
  • Loading branch information
ColCarroll authored and The bayeux Authors committed Jan 4, 2024
1 parent f5fc144 commit eb6be2c
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 58 deletions.
3 changes: 1 addition & 2 deletions bayeux/_src/bayeux.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
"transform_fn",
"inverse_transform_fn",
"inverse_log_det_jacobian",
"initial_state",
)
"initial_state",)


class _Namespace:
Expand Down
242 changes: 197 additions & 45 deletions bayeux/_src/mcmc/blackjax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,21 @@
from bayeux._src import shared
import blackjax
import jax
import jax.numpy as jnp
import optax


_ADAPT_FNS = {
"window": blackjax.window_adaptation,
"pathfinder": blackjax.pathfinder_adaptation,
"chees": blackjax.chees_adaptation,
"meads": blackjax.meads_adaptation,
}

_ALGORITHMS = {
"hmc": blackjax.hmc,
"ghmc": blackjax.ghmc,
"dynamic_hmc": blackjax.dynamic_hmc,
"nuts": blackjax.nuts,
}

Expand All @@ -51,9 +57,15 @@ class _BlackjaxSampler(shared.Base):
def get_kwargs(self, **kwargs):
adapt_fn = _ADAPT_FNS[self.adapt_fn]
algorithm = _ALGORITHMS[self.algorithm]
return {adapt_fn: get_adaptation_kwargs(adapt_fn, algorithm, kwargs),
algorithm: get_algorithm_kwargs(algorithm, kwargs),
"extra_parameters": get_extra_kwargs(kwargs)}
extra_parameters = get_extra_kwargs(kwargs)
constrained_log_density = self.constrained_log_density()
adaptation_kwargs, run_kwargs = get_adaptation_kwargs(
adapt_fn, algorithm, constrained_log_density, extra_parameters | kwargs)
return {adapt_fn: adaptation_kwargs,
"adapt.run": run_kwargs,
algorithm: get_algorithm_kwargs(
algorithm, constrained_log_density, kwargs),
"extra_parameters": extra_parameters}

def __call__(self, seed, **kwargs):
init_key, sample_key = jax.random.split(seed)
Expand All @@ -62,7 +74,24 @@ def __call__(self, seed, **kwargs):
init_key, num_chains=kwargs["extra_parameters"]["num_chains"])

return _sample_blackjax(
log_density=self.constrained_log_density(),
initial_state=self.inverse_transform_fn(initial_state),
algorithm=_ALGORITHMS[self.algorithm],
transform_fn=self.transform_fn,
adapt_fn=_ADAPT_FNS[self.adapt_fn],
seed=sample_key,
kwargs=kwargs)


class _BlackjaxDynamicSampler(_BlackjaxSampler):
"""Base class for blackjax samplers."""

def __call__(self, seed, **kwargs):
init_key, sample_key = jax.random.split(seed)
kwargs = self.get_kwargs(**kwargs)
initial_state = self.get_initial_state(
init_key, num_chains=kwargs["extra_parameters"]["num_chains"])

return _sample_blackjax_dynamic(
initial_state=self.inverse_transform_fn(initial_state),
algorithm=_ALGORITHMS[self.algorithm],
transform_fn=self.transform_fn,
Expand All @@ -77,6 +106,18 @@ class HMC(_BlackjaxSampler):
algorithm = "hmc"


class CheesHMC(_BlackjaxDynamicSampler):
name = "blackjax_chees_hmc"
adapt_fn = "chees"
algorithm = "dynamic_hmc"


class MeadsHMC(_BlackjaxDynamicSampler):
name = "blackjax_meads_hmc"
adapt_fn = "meads"
algorithm = "ghmc"


class HMCPathfinder(_BlackjaxSampler):
name = "blackjax_hmc_pathfinder"
adapt_fn = "pathfinder"
Expand All @@ -95,23 +136,43 @@ class NUTSPathfinder(_BlackjaxSampler):
algorithm = "nuts"


def _blackjax_inference_loop(
def _blackjax_adapt(
seed,
init_position,
adapt_fn,
kwarg_dict,
**kwargs):
adapt = adapt_fn(**kwarg_dict[adapt_fn])
(last_state, parameters), _ = adapt.run(
rng_key=seed, **kwargs,
**kwarg_dict["adapt.run"])
return last_state, parameters


# TODO(colcarroll): Use blackjax.util.run_inference_algorithm here.
def _blackjax_inference(
seed,
adapt_state,
adapt_parameters,
algorithm,
log_density,
num_draws,
num_adapt_draws,
kwargs):
"""Constructs and runs inference loop."""
adapt_seed, inference_seed = jax.random.split(seed)
adapt = adapt_fn(logdensity_fn=log_density, **kwargs[adapt_fn])
(last_state, parameters), _ = adapt.run(
rng_key=adapt_seed, position=init_position, num_steps=num_adapt_draws)
"""Run blackjax inference loop in a vmappable way.
Args:
seed: jax PRNGKey
adapt_state: return value from a blackjax adaptation algorithm.
adapt_parameters: return value from blackjax adaptation algorithm
algorithm: name of algorithm to run.
num_draws: number of iterations to run for.
kwargs: Extra keyword arguments for algorithm.
algorithm_kwargs = kwargs[algorithm] | parameters
kernel = algorithm(log_density, **algorithm_kwargs).step
Returns:
The results of the inference and the adaptation:
(states, infos), adaptation_parameters
"""

algorithm_kwargs = kwargs[algorithm] | adapt_parameters
kernel = algorithm(**algorithm_kwargs).step

@jax.jit
def inference_loop(rng_key):
Expand All @@ -121,21 +182,45 @@ def one_step(state, rng_key):
return state, (state, info)

keys = jax.random.split(rng_key, num_draws)
_, (states, infos) = jax.lax.scan(one_step, last_state, keys)
_, (states, infos) = jax.lax.scan(one_step, adapt_state, keys)

return states, infos

return inference_loop(inference_seed)
# Functions returned by chees adaptation.
adapt_parameters.pop("next_random_arg_fn", None)
adapt_parameters.pop("integration_steps_fn", None)
return inference_loop(seed), adapt_parameters


def _blackjax_stats_to_dict(sample_stats, potential_energy):
def _blackjax_inference_loop(
seed,
init_position,
adapt_fn,
algorithm,
num_draws,
kwargs):
"""Constructs and runs inference loop."""
adapt_seed, inference_seed = jax.random.split(seed)
adapt_state, adapt_parameters = _blackjax_adapt(
adapt_seed, adapt_fn, kwarg_dict=kwargs, position=init_position)
return _blackjax_inference(
inference_seed,
adapt_state,
adapt_parameters,
algorithm,
num_draws,
kwargs)


def _blackjax_stats_to_dict(sample_stats, potential_energy, adapt_parameters):
"""Extract ArviZ compatible stats from blackjax sampler.
Adapted from https://github.com/pymc-devs/pymc
Args:
sample_stats: Blackjax NUTSInfo object containing sampler statistics.
potential_energy: Potential energy values of sampled positions.
adapt_parameters: Parameters from adaptation.
Returns:
Dictionary of sampler statistics.
Expand All @@ -148,27 +233,42 @@ def _blackjax_stats_to_dict(sample_stats, potential_energy):
"acceptance_rate": "acceptance_rate", # naming here depends
"acceptance_probability": "acceptance_rate", # on blackjax version
}
converted_stats = {}
converted_stats["lp"] = potential_energy
converted_stats = {"lp": potential_energy}
step_size = adapt_parameters.get("step_size", None)
if step_size is not None:
if jnp.ndim(step_size) == 0:
converted_stats["step_size"] = jnp.full_like(potential_energy, step_size)
else:
converted_stats["step_size"] = jnp.repeat(
step_size[..., None], repeats=jnp.shape(potential_energy)[-1], axis=-1
)
for old_name, new_name in rename_key.items():
value = getattr(sample_stats, old_name, None)
if value is not None:
converted_stats[new_name] = value
return converted_stats


def get_adaptation_kwargs(adaptation_algorithm, algorithm, kwargs):
def get_adaptation_kwargs(adaptation_algorithm, algorithm, log_density, kwargs):
"""Sets defaults and merges user-provided adaptation keywords."""
adaptation_kwargs, adaptation_required = shared.get_default_signature(
adaptation_algorithm)
adaptation_kwargs.update(
{k: kwargs[k] for k in adaptation_required if k in kwargs})
adaptation_required.remove("logdensity_fn")
adaptation_required.remove("extra_parameters")
adaptation_required.remove("algorithm")
adaptation_kwargs["algorithm"] = algorithm
adaptation_kwargs = (
get_algorithm_kwargs(algorithm, kwargs) | adaptation_kwargs)
if "logdensity_fn" in adaptation_required:
adaptation_kwargs["logdensity_fn"] = log_density
adaptation_required.remove("logdensity_fn")
elif "logprob_fn" in adaptation_required:
adaptation_kwargs["logprob_fn"] = log_density
adaptation_required.remove("logprob_fn")

adaptation_required.discard("extra_parameters")
if "algorithm" in adaptation_required:
adaptation_required.remove("algorithm")
adaptation_kwargs["algorithm"] = algorithm
adaptation_kwargs = (
get_algorithm_kwargs(algorithm, log_density, kwargs) | adaptation_kwargs
)

adaptation_required = adaptation_required - adaptation_kwargs.keys()

Expand All @@ -183,21 +283,35 @@ def get_adaptation_kwargs(adaptation_algorithm, algorithm, kwargs):
)
# step_size will get adapted -- maybe warn if this is set manually, and
# suggest setting init_step_size instead?
adaptation_kwargs.pop("step_size")
adaptation_kwargs.pop("step_size", None)
# blackjax doesn't have a pleasant way to accept this argument --
# window_adaptation calls `algorithm.build_kernel()` with no arguments, but
# it should probably take the below arguments:
adaptation_kwargs.pop("divergence_threshold", None)
adaptation_kwargs.pop("integrator", None)
adaptation_kwargs.pop("max_num_doublings", None)

return adaptation_kwargs
adapt = adaptation_algorithm(**adaptation_kwargs)
run_kwargs, run_required = shared.get_default_signature(adapt.run)
run_required.remove("rng_key")
run_kwargs.update({k: kwargs[k] for k in run_required if k in kwargs})
if "optim" in run_required:
run_kwargs["optim"] = optax.adam(learning_rate=0.01)
run_required.remove("optim")
if "step_size" in run_required:
run_kwargs["step_size"] = 0.001
run_required.remove("step_size")
run_kwargs["num_steps"] = kwargs.get("num_adapt_draws",
run_kwargs["num_steps"])

return adaptation_kwargs, run_kwargs


def get_algorithm_kwargs(algorithm, kwargs):
def get_algorithm_kwargs(algorithm, log_density, kwargs):
"""Sets defaults and merges user-provided keywords for sampling."""
algorithm_kwargs, algorithm_required = shared.get_default_signature(algorithm)
kwargs_with_defaults = {
"logdensity_fn": log_density,
"step_size": 0.01,
"num_integration_steps": 8,
} | kwargs
Expand All @@ -208,7 +322,10 @@ def get_algorithm_kwargs(algorithm, kwargs):
if k in kwargs_with_defaults
})
algorithm_required.remove("logdensity_fn")
algorithm_required.remove("inverse_mass_matrix")
algorithm_required.discard("inverse_mass_matrix")
algorithm_required.discard("alpha")
algorithm_required.discard("delta")
algorithm_required.discard("momentum_inverse_scale")

algorithm_required = algorithm_required - algorithm_kwargs.keys()
if algorithm_required:
Expand All @@ -223,9 +340,53 @@ def get_algorithm_kwargs(algorithm, kwargs):
return algorithm_kwargs


def _sample_blackjax_dynamic(
*,
initial_state,
algorithm,
seed,
transform_fn,
adapt_fn,
kwargs):
"""Constructs and runs blackjax sampler."""
extra_parameters = kwargs.pop("extra_parameters")
num_draws = extra_parameters["num_draws"]
num_chains = extra_parameters["num_chains"]
chain_method = extra_parameters["chain_method"]

adapt_seed, seed = jax.random.split(seed)
adapt_state, adapt_parameters = _blackjax_adapt(
seed=adapt_seed,
adapt_fn=adapt_fn,
kwarg_dict=kwargs,
positions=initial_state,
)
sampler = functools.partial(
_blackjax_inference,
adapt_parameters=adapt_parameters,
algorithm=algorithm,
num_draws=num_draws,
kwargs=kwargs)
map_seed = jax.random.split(seed, num_chains)
mapped_sampler = shared.map_fn(chain_method, sampler)

(states, stats), adapt_parameters = mapped_sampler(map_seed, adapt_state)
draws = transform_fn(states.position)
if extra_parameters["return_pytree"]:
return draws
else:
potential_energy = states.logdensity
sample_stats = _blackjax_stats_to_dict(
stats, potential_energy, adapt_parameters)
if hasattr(draws, "_asdict"):
draws = draws._asdict()
elif not isinstance(draws, dict):
draws = {"var0": draws}
return az.from_dict(posterior=draws, sample_stats=sample_stats)


def _sample_blackjax(
*,
log_density,
initial_state,
algorithm,
seed,
Expand All @@ -237,32 +398,23 @@ def _sample_blackjax(
num_draws = extra_parameters["num_draws"]
num_chains = extra_parameters["num_chains"]
chain_method = extra_parameters["chain_method"]
num_adapt_draws = extra_parameters["num_adapt_draws"]
sampler = functools.partial(
_blackjax_inference_loop,
log_density=log_density,
algorithm=algorithm,
adapt_fn=adapt_fn,
num_draws=num_draws,
num_adapt_draws=num_adapt_draws,
kwargs=kwargs)
map_seed = jax.random.split(seed, num_chains)
if chain_method == "parallel":
mapped_sampler = jax.pmap(sampler)
elif chain_method == "vectorized":
mapped_sampler = jax.vmap(sampler)
elif chain_method == "sequential":
mapped_sampler = functools.partial(jax.tree_map, sampler)
else:
raise ValueError(f"Chain method {chain_method} not supported.")
mapped_sampler = shared.map_fn(chain_method, sampler)

states, stats = mapped_sampler(map_seed, initial_state)
(states, stats), adapt_parameters = mapped_sampler(map_seed, initial_state)
draws = transform_fn(states.position)
if extra_parameters["return_pytree"]:
return draws
else:
potential_energy = states.logdensity
sample_stats = _blackjax_stats_to_dict(stats, potential_energy)
sample_stats = _blackjax_stats_to_dict(
stats, potential_energy, adapt_parameters)
if hasattr(draws, "_asdict"):
draws = draws._asdict()
elif not isinstance(draws, dict):
Expand Down
Loading

0 comments on commit eb6be2c

Please sign in to comment.