Skip to content

Commit

Permalink
Add converters from numpyro and TFP.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597289022
  • Loading branch information
ColCarroll authored and The bayeux Authors committed Jan 10, 2024
1 parent 1cf2a7c commit de90bd7
Show file tree
Hide file tree
Showing 4 changed files with 1,425 additions and 879 deletions.
47 changes: 47 additions & 0 deletions bayeux/_src/bayeux.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from bayeux import optimize
from bayeux import vi
from bayeux._src import shared
import jax
import oryx

_MODULES = (mcmc, optimize, vi)
Expand Down Expand Up @@ -104,3 +105,49 @@ def __repr__(self):
k = getattr(self, name)
methods.append("\t." + "\n\t.".join(str(k).split()))
return "\n".join(methods)

@classmethod
def from_tfp(cls, pinned_joint_distribution, initial_state=None):
log_density = pinned_joint_distribution.log_prob
test_point = pinned_joint_distribution.sample_unpinned(
seed=jax.random.PRNGKey(0))
transform_fn = (
pinned_joint_distribution.experimental_default_event_space_bijector()
)
inverse_transform_fn = transform_fn.inverse
inverse_log_det_jacobian = transform_fn.inverse_log_det_jacobian
return cls(
log_density=log_density,
test_point=test_point,
transform_fn=transform_fn,
initial_state=initial_state,
inverse_transform_fn=inverse_transform_fn,
inverse_log_det_jacobian=inverse_log_det_jacobian)

@classmethod
def from_numpyro(cls, numpyro_fn, initial_state=None):
import numpyro # pylint: disable=g-import-not-at-top

def log_density(*args, **kwargs):
# This clause is only required because the tfp vi routine tries to
# pass dictionaries as keyword arguments, so this allows either
# log_density(params) or log_density(**params)
if args:
x = args[0]
else:
x = kwargs
return numpyro.infer.util.log_density(numpyro_fn, (), {}, x)[0]

test_point = numpyro.infer.Predictive(
numpyro_fn, num_samples=1)(jax.random.PRNGKey(0))
test_point = {k: v[0] for k, v in test_point.items() if k != "observed"}

def transform_fn(x):
return numpyro.infer.util.constrain_fn(numpyro_fn, (), {}, x)

return cls(
log_density=log_density,
test_point=test_point,
transform_fn=transform_fn,
initial_state=initial_state)

67 changes: 67 additions & 0 deletions bayeux/tests/compat_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2023 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for bayeux.Model working with other libraries."""
import bayeux as bx
import jax
import numpy as np
import numpyro
import tensorflow_probability.substrates.jax as tfp

dist = numpyro.distributions
tfd = tfp.distributions


def test_from_numpyro():
treatment_effects = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype=np.float32)
treatment_stddevs = np.array(
[15, 10, 16, 11, 9, 11, 10, 18], dtype=np.float32)

def numpyro_model():
avg_effect = numpyro.sample('avg_effect', dist.Normal(0.0, 10.0))
avg_stddev = numpyro.sample('avg_stddev', dist.HalfNormal(10.0))
with numpyro.plate('J', 8):
school_effects = numpyro.sample('school_effects', dist.Normal(0.0, 1.0))
numpyro.sample(
'observed',
dist.Normal(
avg_effect[..., None] + avg_stddev[..., None] * school_effects,
treatment_stddevs),
obs=treatment_effects)

bx_model = bx.Model.from_numpyro(numpyro_model)
ret = bx_model.optimize.optax_adam(seed=jax.random.key(0))
assert ret is not None


def test_from_tfp():
treatment_effects = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype=np.float32)
treatment_stddevs = np.array(
[15, 10, 16, 11, 9, 11, 10, 18], dtype=np.float32)

@tfd.JointDistributionCoroutineAutoBatched
def tfp_model():
avg_effect = yield tfd.Normal(0., 10., name='avg_effect')
avg_stddev = yield tfd.HalfNormal(10., name='avg_stddev')

school_effects = yield tfd.Sample(
tfd.Normal(0., 1.), sample_shape=8, name='school_effects')
yield tfd.Normal(
avg_effect[..., None] + avg_stddev[..., None] * school_effects,
treatment_stddevs, name='observed')

pinned_model = tfp_model.experimental_pin(observed=treatment_effects)
bx_model = bx.Model.from_tfp(pinned_model)
ret = bx_model.optimize.optax_adam(seed=jax.random.key(0))
assert ret is not None
Loading

0 comments on commit de90bd7

Please sign in to comment.