diff --git a/coix/api.py b/coix/api.py index 861b9a4..4553de6 100644 --- a/coix/api.py +++ b/coix/api.py @@ -328,10 +328,10 @@ def fn(*args, **kwargs): maybe_get_along_first_axis = functools.partial( _maybe_get_along_first_axis, idx=idx, n=n, squeeze=not k ) - out = jax.tree_util.tree_map( + out = jax.tree.map( maybe_get_along_first_axis, out, is_leaf=lambda x: isinstance(x, list) ) - resample_trace = jax.tree_util.tree_map( + resample_trace = jax.tree.map( maybe_get_along_first_axis, trace, is_leaf=lambda x: isinstance(x, list) ) return core.empirical(out, resample_trace, metrics)(*args, **kwargs) @@ -504,15 +504,15 @@ def wrapped(*args, **kwargs): _maybe_get_along_first_axis, idx=idxs, n=num_particles ) metrics["log_weight"] = maybe_get_along_first_axis(p_log_weight) - out = jax.tree_util.tree_map( + out = jax.tree.map( maybe_get_along_first_axis, out, is_leaf=lambda x: isinstance(x, list) ) - marginal_trace = jax.tree_util.tree_map( + marginal_trace = jax.tree.map( maybe_get_along_first_axis, marginal_trace, is_leaf=lambda x: isinstance(x, list), ) - metrics["memory"] = jax.tree_util.tree_map( + metrics["memory"] = jax.tree.map( maybe_get_along_first_axis, new_memory, is_leaf=lambda x: isinstance(x, list), diff --git a/coix/numpyro.py b/coix/numpyro.py index d8fcfe2..961c4a4 100644 --- a/coix/numpyro.py +++ b/coix/numpyro.py @@ -126,13 +126,13 @@ def log_prob(self, value): return d.log_prob(value) def tree_flatten(self): - params, treedef = jax.tree_util.tree_flatten(self.base_dist) + params, treedef = jax.tree.flatten(self.base_dist) return params, (treedef, self.detach_sample, self.detach_args) @classmethod def tree_unflatten(cls, aux_data, params): treedef, detach_sample, detach_args = aux_data - base_dist = jax.tree_util.tree_unflatten(treedef, params) + base_dist = jax.tree.unflatten(treedef, params) return cls(base_dist, detach_sample=detach_sample, detach_args=detach_args) diff --git a/coix/oryx.py b/coix/oryx.py index fa687da..9636907 100644 --- a/coix/oryx.py +++ b/coix/oryx.py @@ -218,7 +218,7 @@ def substitute_rule(state, *args, **kwargs): name = kwargs.get("name") if name in state: flat_args = _split_list(args, kwargs["num_consts"]) - _, dist = jax.tree_util.tree_unflatten(kwargs["in_tree"], flat_args) + _, dist = jax.tree.unflatten(kwargs["in_tree"], flat_args) value = state[name] value = primitive.tie_in(flat_args, value) jaxpr, _ = trace_util.stage(identity, dynamic=True)(value, dist) @@ -247,10 +247,10 @@ def distribution_rule(state, *args, **kwargs): name = kwargs.get("name") if name is not None: flat_args = _split_list(args, kwargs["num_consts"]) - _, dist = jax.tree_util.tree_unflatten(kwargs["in_tree"], flat_args) - dist_flat, dist_tree = jax.tree_util.tree_flatten(dist) + _, dist = jax.tree.unflatten(kwargs["in_tree"], flat_args) + dist_flat, dist_tree = jax.tree.flatten(dist) state[name] = {dist_tree: dist_flat} - args = jax.tree_util.tree_map(jax.core.raise_as_much_as_possible, args) + args = jax.tree.map(jax.core.raise_as_much_as_possible, args) return random_variable_p.bind(*args, **kwargs), state @@ -341,20 +341,20 @@ def log_prob(self, value): return jax.lax.stop_gradient(self.base_dist).log_prob(value) def tree_flatten(self): - params, treedef = jax.tree_util.tree_flatten(self.base_dist) + params, treedef = jax.tree.flatten(self.base_dist) return (params, treedef) @classmethod def tree_unflatten(cls, aux_data, children): - base_dist = jax.tree_util.tree_unflatten(aux_data, children) + base_dist = jax.tree.unflatten(aux_data, children) return cls(base_dist) def stl_rule(state, *args, **kwargs): flat_args = _split_list(args, kwargs["num_consts"]) - key, dist = jax.tree_util.tree_unflatten(kwargs["in_tree"], flat_args) + key, dist = jax.tree.unflatten(kwargs["in_tree"], flat_args) stl_dist = STLDistribution(dist) - _, in_tree = jax.tree_util.tree_flatten((key, stl_dist)) + _, in_tree = jax.tree.flatten((key, stl_dist)) kwargs["in_tree"] = in_tree out = random_variable_p.bind(*args, **kwargs) return out, state @@ -411,7 +411,7 @@ def wrapped(*args, **kwargs): trace = {} for name, value in tags[RANDOM_VARIABLE].items(): dist_tree, dist_flat = list(tags[DISTRIBUTION][name].items())[0] - dist = jax.tree_util.tree_unflatten(dist_tree, dist_flat) + dist = jax.tree.unflatten(dist_tree, dist_flat) trace[name] = {"value": value, "log_prob": dist.log_prob(value)} if name in tags[OBSERVED]: trace[name]["is_observed"] = True diff --git a/coix/oryx_test.py b/coix/oryx_test.py index 82911b9..65a490b 100644 --- a/coix/oryx_test.py +++ b/coix/oryx_test.py @@ -81,9 +81,7 @@ def model(x): _, trace, _ = coix.traced_evaluate(model)(1.0) samples = {name: site["value"] for name, site in trace.items()} - jax.tree_util.tree_map( - np.testing.assert_allclose, samples, {"x": 1.0, "y": 2.0} - ) + jax.tree.map(np.testing.assert_allclose, samples, {"x": 1.0, "y": 2.0}) assert "is_observed" not in trace["x"] assert trace["y"]["is_observed"] @@ -145,7 +143,7 @@ def model(key): expected = {"x": 9.0} _, trace, _ = coix.traced_evaluate(model, expected)(random.PRNGKey(0)) actual = {"x": trace["x"]["value"]} - jax.tree_util.tree_map(np.testing.assert_allclose, actual, expected) + jax.tree.map(np.testing.assert_allclose, actual, expected) def test_suffix(): @@ -155,7 +153,7 @@ def model(x): f = coix.oryx.call_and_reap_tags( coix.core.suffix(model), coix.oryx.RANDOM_VARIABLE ) - jax.tree_util.tree_map( + jax.tree.map( np.testing.assert_allclose, f(1.0)[1][coix.oryx.RANDOM_VARIABLE], {"x_PREV_": 1.0}, diff --git a/coix/util.py b/coix/util.py index f356513..52dad25 100644 --- a/coix/util.py +++ b/coix/util.py @@ -138,7 +138,25 @@ def __call__(self, *args, **kwargs): def _skip_update(grad, opt_state, params): del params - return jax.tree_util.tree_map(jnp.zeros_like, grad), opt_state + return jax.tree.map(jnp.zeros_like, grad), opt_state + + +@functools.partial(jax.jit, donate_argnums=(0, 1, 2), static_argnums=(3,)) +def _optimizer_update(params, opt_state, grads, optimizer): + """Updates the parameters and the optimizer state.""" + # Helpful metric to print out during training. + squared_grad_norm = sum(jnp.square(p).sum() for p in jax.tree.leaves(grads)) + grads = jax.tree.map(lambda x, y: x.astype(y.dtype), grads, params) + updates, opt_state = jax.lax.cond( + jnp.isfinite(squared_grad_norm), + optimizer.update, + _skip_update, + grads, + opt_state, + params, + ) + params = jax.tree.map(lambda p, u: p + u, params, updates) + return params, opt_state, squared_grad_norm def train( @@ -161,23 +179,10 @@ def step_fn(params, opt_state, *args, **kwargs): (_, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)( params, *args, **kwargs ) - grads = jax.tree_util.tree_map( - lambda x, y: x.astype(y.dtype), grads, params - ) - # Helpful metric to print out during training. - squared_grad_norm = sum( - jnp.square(p).sum() for p in jax.tree_util.tree_leaves(grads) + params, opt_state, squared_grad_norm = _optimizer_update( + params, opt_state, grads, optimizer ) metrics["squared_grad_norm"] = squared_grad_norm - updates, opt_state = jax.lax.cond( - jnp.isfinite(jax.flatten_util.ravel_pytree(grads)[0]).all(), - optimizer.update, - _skip_update, - grads, - opt_state, - params, - ) - params = jax.tree_util.tree_map(lambda p, u: p + u, params, updates) return params, opt_state, metrics if callable(jit_compile): diff --git a/examples/anneal.py b/examples/anneal.py index a0f761d..b8d9e52 100644 --- a/examples/anneal.py +++ b/examples/anneal.py @@ -93,9 +93,9 @@ def __call__(self, x, index=0): out = vmap_net(name="kernel")( jnp.broadcast_to(x, (self.M - 1,) + x.shape) ) - return jax.tree_util.tree_map(lambda x: x[index], out) + return jax.tree.map(lambda x: x[index], out) params = self.scope.get_variable("params", "kernel") - params_i = jax.tree_util.tree_map(lambda x: x[index], params) + params_i = jax.tree.map(lambda x: x[index], params) return AnnealKernel(name="kernel").apply( flax.core.freeze({"params": params_i}), x ) @@ -197,9 +197,7 @@ def eval_program(seed): _, trace, metrics = jax.vmap(eval_program)(rng_keys) metrics.pop("log_weight") - anneal_metrics = jax.tree_util.tree_map( - lambda x: round(float(jnp.mean(x)), 4), metrics - ) + anneal_metrics = jax.tree.map(lambda x: round(float(jnp.mean(x)), 4), metrics) print(anneal_metrics) plt.figure(figsize=(8, 8)) diff --git a/examples/anneal_oryx.py b/examples/anneal_oryx.py index 759b107..b835d66 100644 --- a/examples/anneal_oryx.py +++ b/examples/anneal_oryx.py @@ -94,9 +94,9 @@ def __call__(self, x, index=0): out = vmap_net(name="kernel")( jnp.broadcast_to(x, (self.M - 1,) + x.shape) ) - return jax.tree_util.tree_map(lambda x: x[index], out) + return jax.tree.map(lambda x: x[index], out) params = self.scope.get_variable("params", "kernel") - params_i = jax.tree_util.tree_map(lambda x: x[index], params) + params_i = jax.tree.map(lambda x: x[index], params) return AnnealKernel(name="kernel").apply( flax.core.freeze({"params": params_i}), x ) @@ -188,9 +188,7 @@ def main(args): )(rng_keys) metrics.pop("log_weight") - anneal_metrics = jax.tree_util.tree_map( - lambda x: round(float(jnp.mean(x)), 4), metrics - ) + anneal_metrics = jax.tree.map(lambda x: round(float(jnp.mean(x)), 4), metrics) print(anneal_metrics) plt.figure(figsize=(8, 8)) diff --git a/examples/dmm_oryx.py b/examples/dmm_oryx.py index 7e925e3..baeefde 100644 --- a/examples/dmm_oryx.py +++ b/examples/dmm_oryx.py @@ -261,7 +261,7 @@ def loss_fn(params, key, batch, num_sweeps, num_particles): # Run the program and get metrics. program = make_dmm(params, num_sweeps) _, _, metrics = jax.vmap(coix.traced_evaluate(program))(rng_keys, batch) - metrics = jax.tree_util.tree_map( + metrics = jax.tree.map( partial(jnp.mean, axis=0), metrics ) # mean across batch return metrics["loss"], metrics diff --git a/examples/gmm.py b/examples/gmm.py index d57a007..9a69156 100644 --- a/examples/gmm.py +++ b/examples/gmm.py @@ -180,7 +180,7 @@ def gmm_kernel_mean_tau(network, inputs): alpha, beta, mu, nu = network.encode_mean_tau(xc) else: alpha, beta, mu, nu = network.encode_initial_mean_tau(inputs["x"]) - alpha, beta, mu, nu = jax.tree_util.tree_map( + alpha, beta, mu, nu = jax.tree.map( lambda x: jnp.expand_dims(x, -3), (alpha, beta, mu, nu) ) tau = numpyro.sample("tau", dist.Gamma(alpha, beta).to_event(2)) diff --git a/examples/gmm_oryx.py b/examples/gmm_oryx.py index 8aa6890..20e5130 100644 --- a/examples/gmm_oryx.py +++ b/examples/gmm_oryx.py @@ -233,7 +233,7 @@ def loss_fn(params, key, batch, num_sweeps, num_particles): # Run the program and get metrics. program = make_gmm(params, num_sweeps) _, _, metrics = jax.vmap(coix.traced_evaluate(program))(rng_keys, batch) - metrics = jax.tree_util.tree_map( + metrics = jax.tree.map( partial(jnp.mean, axis=0), metrics ) # mean across batch return metrics["loss"], metrics diff --git a/notebooks/tutorial_part3_smcs.ipynb b/notebooks/tutorial_part3_smcs.ipynb index 34eeac1..6bae794 100644 --- a/notebooks/tutorial_part3_smcs.ipynb +++ b/notebooks/tutorial_part3_smcs.ipynb @@ -228,9 +228,9 @@ " out = vmap_net(name='kernel')(\n", " jnp.broadcast_to(x, (self.M - 1,) + x.shape)\n", " )\n", - " return jax.tree_util.tree_map(lambda x: x[index], out)\n", + " return jax.tree.map(lambda x: x[index], out)\n", " params = self.scope.get_variable('params', 'kernel')\n", - " params_i = jax.tree_util.tree_map(lambda x: x[index], params)\n", + " params_i = jax.tree.map(lambda x: x[index], params)\n", " return VariationalKernelNetwork(name='kernel').apply(\n", " flax.core.freeze({'params': params_i}), x\n", " )\n", @@ -394,7 +394,7 @@ ")\n", "\n", "metrics.pop(\"log_weight\")\n", - "anneal_metrics = jax.tree_util.tree_map(\n", + "anneal_metrics = jax.tree.map(\n", " lambda x: round(float(jnp.mean(x)), 4), metrics\n", ")\n", "print(anneal_metrics)\n", @@ -546,7 +546,7 @@ " random.PRNGKey(1), trained_params, num_particles=100000\n", ")\n", "\n", - "anneal_metrics = jax.tree_util.tree_map(\n", + "anneal_metrics = jax.tree.map(\n", " lambda x: round(float(jnp.mean(x)), 4), metrics\n", ")\n", "print(anneal_metrics)\n",