Skip to content

Commit

Permalink
Fix maybe_extract_keys logic and move lambda outside of util.train
Browse files Browse the repository at this point in the history
…to avoid recompiling lax.cond.

PiperOrigin-RevId: 568289058
  • Loading branch information
fehiepsi authored and The coix Authors committed Mar 28, 2024
1 parent 8167fe2 commit 178d900
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions coix/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,16 @@ def is_observed_site(site):


def can_extract_key(args):
return args and (
jax.dtypes.issubdtype(args[0].dtype, jax.dtypes.prng_key)
or (
isinstance(args[0], jnp.ndarray)
and (args[0].dtype == jnp.uint32)
and (jnp.ndim(args[0]) >= 1)
and (args[0].shape[-1] == 2)
return (
args
and isinstance(args[0], jnp.ndarray)
and (
jax.dtypes.issubdtype(args[0].dtype, jax.dtypes.prng_key)
or (
(args[0].dtype == jnp.uint32)
and (jnp.ndim(args[0]) >= 1)
and (args[0].shape[-1] == 2)
)
)
)

Expand Down Expand Up @@ -119,6 +122,11 @@ def __call__(self, *args, **kwargs):
return self.module.apply(self.params, *args, **kwargs)


def _skip_update(grad, opt_state, params):
del params
return jax.tree_util.tree_map(jnp.zeros_like, grad), opt_state


def train(
loss_fn,
init_params,
Expand Down Expand Up @@ -150,7 +158,7 @@ def step_fn(params, opt_state, *args, **kwargs):
updates, opt_state = jax.lax.cond(
jnp.isfinite(jax.flatten_util.ravel_pytree(grads)[0]).all(),
optimizer.update,
lambda g, o, p: (jax.tree_util.tree_map(jnp.zeros_like, g), o),
_skip_update,
grads,
opt_state,
params,
Expand Down

0 comments on commit 178d900

Please sign in to comment.