Skip to content

Commit

Permalink
do not use cond in params update
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Jun 24, 2024
1 parent 65b3862 commit b116c9a
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions coix/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,15 @@ def _optimizer_update(params, opt_state, grads, optimizer):
"""Updates the parameters and the optimizer state."""
# Helpful metric to print out during training.
squared_grad_norm = sum(jnp.square(p).sum() for p in jax.tree.leaves(grads))
do_update = jnp.isfinite(squared_grad_norm)
grads = jax.tree.map(lambda x, y: x.astype(y.dtype), grads, params)
updates, opt_state = jax.lax.cond(
jnp.isfinite(squared_grad_norm),
optimizer.update,
_skip_update,
grads,
opt_state,
params,
updates, new_opt_state = optimizer.update(grads, opt_state, params)
opt_state = jax.tree.map(
lambda x, y: jnp.where(do_update, x, y), new_opt_state, opt_state
)
params = jax.tree.map(
lambda p, u: jnp.where(do_update, p + u, u), params, updates
)
params = jax.tree.map(lambda p, u: p + u, params, updates)
return params, opt_state, squared_grad_norm


Expand Down

0 comments on commit b116c9a

Please sign in to comment.