-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unexpected behavior of optax.apply_every
#2
Comments
Hi @vwxyzjn I'm glad you've found this repo useful. And good catch! But can you confirm that |
Hi @jenkspt, thanks for your response. I can confirm w/ Interestingly I found nanoGPT with |
Are you setting the |
That is correct. See https://wandb.ai/costa-huang/cleanrlhf/runs/3g43kqom/logs. Without
|
I also just did a more comprehensive profile in A100. See report How it compares w/ my JAX implementation, which is heavily influenced by your repo :) |
FYI I did an ablation study with openwebtext: I also found a library https://github.com/google/maxtext lately that is well-optimized and could be a useful reference. |
Hi @jenkspt, I have been learning nanoGPT and reproducing it in JAX from scratch. Your repo has been a very helpful reference.
I encountered an issue with
optax.apply_every
and thought you might want to know. It turns outoptax.apply_every
is not equivalent to how nanoGPT updates, which is to accumulate gradient and then clip by grad norm.See this snippet
Empirically, it could have a significant impact on training as well. I am following nanoGPT's setting in a single GPU, which is accumulate gradient 40 times. As shown below,
optax.apply_every
is significantly more unstable thanoptax.MultiStep
.While the training might be more stable if the gradient accumulation steps are fewer, it still feels like an issue...
The text was updated successfully, but these errors were encountered: