diff --git a/coix/util.py b/coix/util.py index 52dad25..5807ec6 100644 --- a/coix/util.py +++ b/coix/util.py @@ -146,16 +146,15 @@ def _optimizer_update(params, opt_state, grads, optimizer): """Updates the parameters and the optimizer state.""" # Helpful metric to print out during training. squared_grad_norm = sum(jnp.square(p).sum() for p in jax.tree.leaves(grads)) + do_update = jnp.isfinite(squared_grad_norm) grads = jax.tree.map(lambda x, y: x.astype(y.dtype), grads, params) - updates, opt_state = jax.lax.cond( - jnp.isfinite(squared_grad_norm), - optimizer.update, - _skip_update, - grads, - opt_state, - params, + updates, new_opt_state = optimizer.update(grads, opt_state, params) + opt_state = jax.tree.map( + lambda x, y: jnp.where(do_update, x, y), new_opt_state, opt_state + ) + params = jax.tree.map( + lambda p, u: jnp.where(do_update, p + u, u), params, updates ) - params = jax.tree.map(lambda p, u: p + u, params, updates) return params, opt_state, squared_grad_norm