Skip to content

Commit

Permalink
Merge pull request #15 from jax-ml:prng
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568218621
  • Loading branch information
The coix Authors committed Mar 28, 2024
2 parents a114e61 + 222781a commit 8167fe2
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 18 deletions.
42 changes: 35 additions & 7 deletions coix/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@
import jax.numpy as jnp
import numpy as np

# pytype: disable=module-attr
try:
wrap_key_data = jax.random.wrap_key_data
except AttributeError:
try:
wrap_key_data = jax.extend.random.wrap_key_data
except AttributeError:

def _identity(k):
return k

wrap_key_data = _identity


# pytype: enable=module-attr


__all__ = [
"compose",
"extend",
Expand Down Expand Up @@ -69,15 +86,20 @@ def wrapped(*args, **kwargs):
return wrapped


def _reshape_key(key, shape):
if jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key):
return jnp.reshape(key, shape)
else:
return jnp.reshape(key, shape + (2,))


def _split_key(key):
keys = jax.vmap(jax.random.split)(key.reshape(-1, 2)).reshape(
key.shape[:-1] + (2, 2)
)
return keys[..., 0, :], keys[..., 1, :]
keys = jax.vmap(jax.random.split, out_axes=1)(_reshape_key(key, (-1,)))
return keys[0].reshape(key.shape), keys[1].reshape(key.shape)


def _fold_in_key(key, i):
key_new = jax.vmap(jax.random.fold_in, (0, None))(key.reshape(-1, 2), i)
key_new = jax.vmap(jax.random.fold_in, (0, None))(_reshape_key(key, (-1,)), i)
return key_new.reshape(key.shape)


Expand Down Expand Up @@ -219,6 +241,12 @@ def _maybe_get_along_first_axis(x, idx, n, squeeze=False):
idx = idx.reshape(idx.shape + (1,) * (x.ndim - idx.ndim))
if isinstance(x, np.ndarray):
y = np.take_along_axis(x, idx, axis=0)
elif jax.dtypes.issubdtype(x.dtype, jax.dtypes.prng_key):
x_data = jax.random.key_data(x)
idx = idx.reshape(idx.shape + (1,) * (x_data.ndim - idx.ndim))
y_data = jnp.take_along_axis(x_data, idx, axis=0)
y_data = y_data[0] if (idx.shape[0] == 1 and squeeze) else y_data
y = wrap_key_data(y_data)
else:
y = jnp.take_along_axis(x, idx, axis=0)
y = y.tolist() if is_list else y
Expand All @@ -244,7 +272,7 @@ def fn(*args, **kwargs):
if util.can_extract_key(args):
key_r, key_q = _split_key(args[0])
# We just need a single key for resampling.
key_r = key_r.reshape((-1, 2))[0]
key_r = _reshape_key(key_r, (-1,))[0]
args = (key_q,) + args[1:]
else:
key_r = core.prng_key()
Expand Down Expand Up @@ -401,7 +429,7 @@ def memoize(p, q, memory=None, memory_size=None):
def wrapped(*args, **kwargs):
if util.can_extract_key(args):
key = args[0]
p_key, q_key = key + jnp.asarray([1, 0], dtype=key.dtype), key + 1
p_key, q_key = _split_key(key)
p_args = (p_key,) + args[1:]
q_args = (q_key,) + args[1:]
else:
Expand Down
11 changes: 7 additions & 4 deletions coix/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,17 @@ def g(z):
assert set(trace.keys()) == {"x", "z"}

expected_key, expected_x = p(key)
np.testing.assert_allclose(out[0], np.asarray(expected_key))
expected_key = random.key_data(expected_key)
actual_key = random.key_data(out[0])
np.testing.assert_allclose(actual_key, expected_key)
np.testing.assert_allclose(out[1], expected_x)

marginal_pfg = coix.traced_evaluate(coix.extend(p, coix.compose(g, f)))(key)[
0
]
actual_key2, actual_x2 = marginal_pfg
np.testing.assert_allclose(actual_key2, np.asarray(expected_key))
actual_key2 = random.key_data(actual_key2)
np.testing.assert_allclose(actual_key2, expected_key)
np.testing.assert_allclose(actual_x2, expected_x)


Expand All @@ -68,15 +71,15 @@ def q(key):
out, trace, metrics = coix.traced_evaluate(program)(key)
assert set(trace.keys()) == {"x", "z"}
assert isinstance(out, tuple) and len(out) == 2
assert out[0].shape == (2,)
assert out[0].shape == key.shape
with np.testing.assert_raises(AssertionError):
np.testing.assert_allclose(metrics["log_density"], 0.0)

particle_program = coix.propose(jax.vmap(coix.extend(p, f)), jax.vmap(q))
keys = random.split(key, 3)
particle_out = particle_program(keys)
assert isinstance(particle_out, tuple) and len(particle_out) == 2
assert particle_out[0].shape == (3, 2)
assert particle_out[0].shape == keys.shape


def test_resample():
Expand Down
16 changes: 9 additions & 7 deletions coix/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ def is_observed_site(site):


def can_extract_key(args):
return (
args
and isinstance(args[0], jnp.ndarray)
and (args[0].dtype == jnp.uint32)
and (jnp.ndim(args[0]) >= 1)
and (args[0].shape[-1] == 2)
return args and (
jax.dtypes.issubdtype(args[0].dtype, jax.dtypes.prng_key)
or (
isinstance(args[0], jnp.ndarray)
and (args[0].dtype == jnp.uint32)
and (jnp.ndim(args[0]) >= 1)
and (args[0].shape[-1] == 2)
)
)


Expand Down Expand Up @@ -180,7 +182,7 @@ def step_fn(params, opt_state, *args, **kwargs):
params, opt_state, metrics = maybe_jitted_step_fn(
params, opt_state, *args, **kwargs
)
for name, value in kwargs.items():
for name in kwargs:
if name in metrics:
kwargs[name] = metrics[name]
if step == 1:
Expand Down

0 comments on commit 8167fe2

Please sign in to comment.