diff --git a/bayeux/_src/shared.py b/bayeux/_src/shared.py index 514a80a..97d3355 100644 --- a/bayeux/_src/shared.py +++ b/bayeux/_src/shared.py @@ -33,7 +33,7 @@ def map_fn(chain_method, fn): elif chain_method == "vectorized": return jax.vmap(fn) elif chain_method == "sequential": - return functools.partial(jax.tree_map, fn) + return functools.partial(jax.tree.map, fn) raise ValueError(f"Chain method {chain_method} not supported.") diff --git a/bayeux/_src/vi/tfp.py b/bayeux/_src/vi/tfp.py index c8a6d78..2bb5156 100644 --- a/bayeux/_src/vi/tfp.py +++ b/bayeux/_src/vi/tfp.py @@ -30,7 +30,7 @@ class Custom(tfb.Bijector): def __init__(self, bx_model): super().__init__( - forward_min_event_ndims=jax.tree_map(jnp.ndim, bx_model.test_point)) + forward_min_event_ndims=jax.tree.map(jnp.ndim, bx_model.test_point)) self.bx_model = bx_model def _forward(self, x): @@ -46,12 +46,12 @@ def _forward_log_det_jacobian(self, x): return -self.inverse_log_det_jacobian(self.forward(x)) def _forward_event_shape_tensor(self, input_shape): - return jax.tree_map(jnp.shape, - self._forward(jax.tree_map(jnp.ones, input_shape))) + return jax.tree.map(jnp.shape, + self._forward(jax.tree.map(jnp.ones, input_shape))) def _inverse_event_shape_tensor(self, output_shape): - return jax.tree_map(jnp.shape, - self._inverse(jax.tree_map(jnp.ones, output_shape))) + return jax.tree.map(jnp.shape, + self._inverse(jax.tree.map(jnp.ones, output_shape))) def get_fit_kwargs(log_density, kwargs): @@ -104,7 +104,7 @@ def get_kwargs(self, **kwargs): return { tfp.experimental.vi.build_factored_surrogate_posterior_stateless: ( get_build_kwargs( - jax.tree_map(jnp.shape, self.test_point), + jax.tree.map(jnp.shape, self.test_point), self.constraining_bijector(), kwargs)), tfp.vi.fit_surrogate_posterior_stateless: get_fit_kwargs( @@ -140,7 +140,7 @@ def __call__(self, seed, **kwargs): elif chain_method == "parallel": mapped_fit = jax.pmap(fit_fn) elif chain_method == "sequential": - mapped_fit = functools.partial(jax.tree_map, fit_fn) + mapped_fit = functools.partial(jax.tree.map, fit_fn) else: raise ValueError(f"Chain method {chain_method} not supported.") diff --git a/docs/examples/numpyro_and_bayeux.ipynb b/docs/examples/numpyro_and_bayeux.ipynb index 0dd66a9..a793b2e 100644 --- a/docs/examples/numpyro_and_bayeux.ipynb +++ b/docs/examples/numpyro_and_bayeux.ipynb @@ -637,7 +637,7 @@ "ax.plot(losses.T)\n", "\n", "draws = surrogate_posterior.sample(100, seed=jax.random.PRNGKey(1))\n", - "jax.tree_map(lambda x: np.mean(x, axis=(0, 1)), draws)" + "jax.tree.map(lambda x: np.mean(x, axis=(0, 1)), draws)" ] } ], diff --git a/docs/examples/pymc_and_bayeux.ipynb b/docs/examples/pymc_and_bayeux.ipynb index cd4dcd8..a83071c 100644 --- a/docs/examples/pymc_and_bayeux.ipynb +++ b/docs/examples/pymc_and_bayeux.ipynb @@ -425,7 +425,7 @@ "ax.plot(losses.T)\n", "\n", "draws = surrogate_posterior.sample(100, seed=jax.random.PRNGKey(1))\n", - "jax.tree_map(lambda x: np.mean(x, axis=(0, 1)), draws)" + "jax.tree.map(lambda x: np.mean(x, axis=(0, 1)), draws)" ] } ], diff --git a/docs/examples/tfp_and_bayeux.ipynb b/docs/examples/tfp_and_bayeux.ipynb index d62abbf..f58c738 100644 --- a/docs/examples/tfp_and_bayeux.ipynb +++ b/docs/examples/tfp_and_bayeux.ipynb @@ -631,7 +631,7 @@ "ax.plot(losses.T)\n", "\n", "draws = surrogate_posterior.sample(100, seed=draw_key)\n", - "jax.tree_map(lambda x: np.mean(x, axis=(0, 1)), draws)" + "jax.tree.map(lambda x: np.mean(x, axis=(0, 1)), draws)" ] } ],