Skip to content

Commit

Permalink
jit grad update and donate args properly
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Jun 20, 2024
1 parent dd17d17 commit cc95f79
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 54 deletions.
10 changes: 5 additions & 5 deletions coix/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,10 @@ def fn(*args, **kwargs):
maybe_get_along_first_axis = functools.partial(
_maybe_get_along_first_axis, idx=idx, n=n, squeeze=not k
)
out = jax.tree_util.tree_map(
out = jax.tree.map(
maybe_get_along_first_axis, out, is_leaf=lambda x: isinstance(x, list)
)
resample_trace = jax.tree_util.tree_map(
resample_trace = jax.tree.map(
maybe_get_along_first_axis, trace, is_leaf=lambda x: isinstance(x, list)
)
return core.empirical(out, resample_trace, metrics)(*args, **kwargs)
Expand Down Expand Up @@ -504,15 +504,15 @@ def wrapped(*args, **kwargs):
_maybe_get_along_first_axis, idx=idxs, n=num_particles
)
metrics["log_weight"] = maybe_get_along_first_axis(p_log_weight)
out = jax.tree_util.tree_map(
out = jax.tree.map(
maybe_get_along_first_axis, out, is_leaf=lambda x: isinstance(x, list)
)
marginal_trace = jax.tree_util.tree_map(
marginal_trace = jax.tree.map(
maybe_get_along_first_axis,
marginal_trace,
is_leaf=lambda x: isinstance(x, list),
)
metrics["memory"] = jax.tree_util.tree_map(
metrics["memory"] = jax.tree.map(
maybe_get_along_first_axis,
new_memory,
is_leaf=lambda x: isinstance(x, list),
Expand Down
4 changes: 2 additions & 2 deletions coix/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,13 @@ def log_prob(self, value):
return d.log_prob(value)

def tree_flatten(self):
params, treedef = jax.tree_util.tree_flatten(self.base_dist)
params, treedef = jax.tree.flatten(self.base_dist)
return params, (treedef, self.detach_sample, self.detach_args)

@classmethod
def tree_unflatten(cls, aux_data, params):
treedef, detach_sample, detach_args = aux_data
base_dist = jax.tree_util.tree_unflatten(treedef, params)
base_dist = jax.tree.unflatten(treedef, params)
return cls(base_dist, detach_sample=detach_sample, detach_args=detach_args)


Expand Down
18 changes: 9 additions & 9 deletions coix/oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def substitute_rule(state, *args, **kwargs):
name = kwargs.get("name")
if name in state:
flat_args = _split_list(args, kwargs["num_consts"])
_, dist = jax.tree_util.tree_unflatten(kwargs["in_tree"], flat_args)
_, dist = jax.tree.unflatten(kwargs["in_tree"], flat_args)
value = state[name]
value = primitive.tie_in(flat_args, value)
jaxpr, _ = trace_util.stage(identity, dynamic=True)(value, dist)
Expand Down Expand Up @@ -247,10 +247,10 @@ def distribution_rule(state, *args, **kwargs):
name = kwargs.get("name")
if name is not None:
flat_args = _split_list(args, kwargs["num_consts"])
_, dist = jax.tree_util.tree_unflatten(kwargs["in_tree"], flat_args)
dist_flat, dist_tree = jax.tree_util.tree_flatten(dist)
_, dist = jax.tree.unflatten(kwargs["in_tree"], flat_args)
dist_flat, dist_tree = jax.tree.flatten(dist)
state[name] = {dist_tree: dist_flat}
args = jax.tree_util.tree_map(jax.core.raise_as_much_as_possible, args)
args = jax.tree.map(jax.core.raise_as_much_as_possible, args)
return random_variable_p.bind(*args, **kwargs), state


Expand Down Expand Up @@ -341,20 +341,20 @@ def log_prob(self, value):
return jax.lax.stop_gradient(self.base_dist).log_prob(value)

def tree_flatten(self):
params, treedef = jax.tree_util.tree_flatten(self.base_dist)
params, treedef = jax.tree.flatten(self.base_dist)
return (params, treedef)

@classmethod
def tree_unflatten(cls, aux_data, children):
base_dist = jax.tree_util.tree_unflatten(aux_data, children)
base_dist = jax.tree.unflatten(aux_data, children)
return cls(base_dist)


def stl_rule(state, *args, **kwargs):
flat_args = _split_list(args, kwargs["num_consts"])
key, dist = jax.tree_util.tree_unflatten(kwargs["in_tree"], flat_args)
key, dist = jax.tree.unflatten(kwargs["in_tree"], flat_args)
stl_dist = STLDistribution(dist)
_, in_tree = jax.tree_util.tree_flatten((key, stl_dist))
_, in_tree = jax.tree.flatten((key, stl_dist))
kwargs["in_tree"] = in_tree
out = random_variable_p.bind(*args, **kwargs)
return out, state
Expand Down Expand Up @@ -411,7 +411,7 @@ def wrapped(*args, **kwargs):
trace = {}
for name, value in tags[RANDOM_VARIABLE].items():
dist_tree, dist_flat = list(tags[DISTRIBUTION][name].items())[0]
dist = jax.tree_util.tree_unflatten(dist_tree, dist_flat)
dist = jax.tree.unflatten(dist_tree, dist_flat)
trace[name] = {"value": value, "log_prob": dist.log_prob(value)}
if name in tags[OBSERVED]:
trace[name]["is_observed"] = True
Expand Down
8 changes: 3 additions & 5 deletions coix/oryx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def model(x):

_, trace, _ = coix.traced_evaluate(model)(1.0)
samples = {name: site["value"] for name, site in trace.items()}
jax.tree_util.tree_map(
np.testing.assert_allclose, samples, {"x": 1.0, "y": 2.0}
)
jax.tree.map(np.testing.assert_allclose, samples, {"x": 1.0, "y": 2.0})
assert "is_observed" not in trace["x"]
assert trace["y"]["is_observed"]

Expand Down Expand Up @@ -145,7 +143,7 @@ def model(key):
expected = {"x": 9.0}
_, trace, _ = coix.traced_evaluate(model, expected)(random.PRNGKey(0))
actual = {"x": trace["x"]["value"]}
jax.tree_util.tree_map(np.testing.assert_allclose, actual, expected)
jax.tree.map(np.testing.assert_allclose, actual, expected)


def test_suffix():
Expand All @@ -155,7 +153,7 @@ def model(x):
f = coix.oryx.call_and_reap_tags(
coix.core.suffix(model), coix.oryx.RANDOM_VARIABLE
)
jax.tree_util.tree_map(
jax.tree.map(
np.testing.assert_allclose,
f(1.0)[1][coix.oryx.RANDOM_VARIABLE],
{"x_PREV_": 1.0},
Expand Down
37 changes: 21 additions & 16 deletions coix/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,25 @@ def __call__(self, *args, **kwargs):

def _skip_update(grad, opt_state, params):
del params
return jax.tree_util.tree_map(jnp.zeros_like, grad), opt_state
return jax.tree.map(jnp.zeros_like, grad), opt_state


@functools.partial(jax.jit, donate_argnums=(0, 1, 2), static_argnums=(3,))
def _optimizer_update(params, opt_state, grads, optimizer):
"""Updates the parameters and the optimizer state."""
# Helpful metric to print out during training.
squared_grad_norm = sum(jnp.square(p).sum() for p in jax.tree.leaves(grads))
grads = jax.tree.map(lambda x, y: x.astype(y.dtype), grads, params)
updates, opt_state = jax.lax.cond(
jnp.isfinite(squared_grad_norm),
optimizer.update,
_skip_update,
grads,
opt_state,
params,
)
params = jax.tree.map(lambda p, u: p + u, params, updates)
return params, opt_state, squared_grad_norm


def train(
Expand All @@ -161,23 +179,10 @@ def step_fn(params, opt_state, *args, **kwargs):
(_, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(
params, *args, **kwargs
)
grads = jax.tree_util.tree_map(
lambda x, y: x.astype(y.dtype), grads, params
)
# Helpful metric to print out during training.
squared_grad_norm = sum(
jnp.square(p).sum() for p in jax.tree_util.tree_leaves(grads)
params, opt_state, squared_grad_norm = _optimizer_update(
params, opt_state, grads, optimizer
)
metrics["squared_grad_norm"] = squared_grad_norm
updates, opt_state = jax.lax.cond(
jnp.isfinite(jax.flatten_util.ravel_pytree(grads)[0]).all(),
optimizer.update,
_skip_update,
grads,
opt_state,
params,
)
params = jax.tree_util.tree_map(lambda p, u: p + u, params, updates)
return params, opt_state, metrics

if callable(jit_compile):
Expand Down
8 changes: 3 additions & 5 deletions examples/anneal.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ def __call__(self, x, index=0):
out = vmap_net(name="kernel")(
jnp.broadcast_to(x, (self.M - 1,) + x.shape)
)
return jax.tree_util.tree_map(lambda x: x[index], out)
return jax.tree.map(lambda x: x[index], out)
params = self.scope.get_variable("params", "kernel")
params_i = jax.tree_util.tree_map(lambda x: x[index], params)
params_i = jax.tree.map(lambda x: x[index], params)
return AnnealKernel(name="kernel").apply(
flax.core.freeze({"params": params_i}), x
)
Expand Down Expand Up @@ -197,9 +197,7 @@ def eval_program(seed):
_, trace, metrics = jax.vmap(eval_program)(rng_keys)

metrics.pop("log_weight")
anneal_metrics = jax.tree_util.tree_map(
lambda x: round(float(jnp.mean(x)), 4), metrics
)
anneal_metrics = jax.tree.map(lambda x: round(float(jnp.mean(x)), 4), metrics)
print(anneal_metrics)

plt.figure(figsize=(8, 8))
Expand Down
8 changes: 3 additions & 5 deletions examples/anneal_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def __call__(self, x, index=0):
out = vmap_net(name="kernel")(
jnp.broadcast_to(x, (self.M - 1,) + x.shape)
)
return jax.tree_util.tree_map(lambda x: x[index], out)
return jax.tree.map(lambda x: x[index], out)
params = self.scope.get_variable("params", "kernel")
params_i = jax.tree_util.tree_map(lambda x: x[index], params)
params_i = jax.tree.map(lambda x: x[index], params)
return AnnealKernel(name="kernel").apply(
flax.core.freeze({"params": params_i}), x
)
Expand Down Expand Up @@ -188,9 +188,7 @@ def main(args):
)(rng_keys)

metrics.pop("log_weight")
anneal_metrics = jax.tree_util.tree_map(
lambda x: round(float(jnp.mean(x)), 4), metrics
)
anneal_metrics = jax.tree.map(lambda x: round(float(jnp.mean(x)), 4), metrics)
print(anneal_metrics)

plt.figure(figsize=(8, 8))
Expand Down
2 changes: 1 addition & 1 deletion examples/dmm_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def loss_fn(params, key, batch, num_sweeps, num_particles):
# Run the program and get metrics.
program = make_dmm(params, num_sweeps)
_, _, metrics = jax.vmap(coix.traced_evaluate(program))(rng_keys, batch)
metrics = jax.tree_util.tree_map(
metrics = jax.tree.map(
partial(jnp.mean, axis=0), metrics
) # mean across batch
return metrics["loss"], metrics
Expand Down
2 changes: 1 addition & 1 deletion examples/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def gmm_kernel_mean_tau(network, inputs):
alpha, beta, mu, nu = network.encode_mean_tau(xc)
else:
alpha, beta, mu, nu = network.encode_initial_mean_tau(inputs["x"])
alpha, beta, mu, nu = jax.tree_util.tree_map(
alpha, beta, mu, nu = jax.tree.map(
lambda x: jnp.expand_dims(x, -3), (alpha, beta, mu, nu)
)
tau = numpyro.sample("tau", dist.Gamma(alpha, beta).to_event(2))
Expand Down
2 changes: 1 addition & 1 deletion examples/gmm_oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def loss_fn(params, key, batch, num_sweeps, num_particles):
# Run the program and get metrics.
program = make_gmm(params, num_sweeps)
_, _, metrics = jax.vmap(coix.traced_evaluate(program))(rng_keys, batch)
metrics = jax.tree_util.tree_map(
metrics = jax.tree.map(
partial(jnp.mean, axis=0), metrics
) # mean across batch
return metrics["loss"], metrics
Expand Down
8 changes: 4 additions & 4 deletions notebooks/tutorial_part3_smcs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,9 @@
" out = vmap_net(name='kernel')(\n",
" jnp.broadcast_to(x, (self.M - 1,) + x.shape)\n",
" )\n",
" return jax.tree_util.tree_map(lambda x: x[index], out)\n",
" return jax.tree.map(lambda x: x[index], out)\n",
" params = self.scope.get_variable('params', 'kernel')\n",
" params_i = jax.tree_util.tree_map(lambda x: x[index], params)\n",
" params_i = jax.tree.map(lambda x: x[index], params)\n",
" return VariationalKernelNetwork(name='kernel').apply(\n",
" flax.core.freeze({'params': params_i}), x\n",
" )\n",
Expand Down Expand Up @@ -394,7 +394,7 @@
")\n",
"\n",
"metrics.pop(\"log_weight\")\n",
"anneal_metrics = jax.tree_util.tree_map(\n",
"anneal_metrics = jax.tree.map(\n",
" lambda x: round(float(jnp.mean(x)), 4), metrics\n",
")\n",
"print(anneal_metrics)\n",
Expand Down Expand Up @@ -546,7 +546,7 @@
" random.PRNGKey(1), trained_params, num_particles=100000\n",
")\n",
"\n",
"anneal_metrics = jax.tree_util.tree_map(\n",
"anneal_metrics = jax.tree.map(\n",
" lambda x: round(float(jnp.mean(x)), 4), metrics\n",
")\n",
"print(anneal_metrics)\n",
Expand Down

0 comments on commit cc95f79

Please sign in to comment.