-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support gradient checkpointing for functional models via the standard fit method #98
Comments
Grad checkpointing is a very attractive feature. It would be great to have official support for this. Discussion: https://discuss.tensorflow.org/t/support-gradient-checkpointing-in-tensorflow-2/. |
Hi @rivershah -- thanks for the feature request! This is something that would need to be implemented by the TensorFlow team as part of |
@ianstenbit |
Oh interesting -- thanks @innat, I wasn't aware of that. This may be feasible as a feature using |
It sounds like this should be possible (probably through making a subclass of This seems like a nice-to-have feature, but is not currently on our roadmap. I'll leave the issue open for now -- if any contributors are interested in experimenting with this approach, we could see how it goes. |
@sachinprasadhs |
Hi, This was moved as part of migration of Keras 3, if you feel this issue is related to Keras 3, feel free to open a new issue in Keras repo. |
System information.
TensorFlow version (you are using):
2.12.0
Are you willing to contribute it (Yes/No) : No
Describe the feature and the current behavior/state.
Describe the feature clearly here. Be sure to convey here why the requested feature is needed. Any brief description about the use-case would help.
The following feature excerpt is copied from
huggingface
documentation hereIn order to compute the gradients during the backward pass all activations from the forward pass are normally saved. This can create a big memory overhead. Alternatively, one could forget all activations during the forward pass and recompute them on demand during the backward pass. This would however add a significant computational overhead and slow down training.
Gradient checkpointing strikes a compromise between the two approaches and saves strategically selected activations throughout the computational graph so only a fraction of the activations need to be re-computed for the gradients. See this great article explaining the ideas behind gradient checkpointing.
Will this change the current api? How?
Yes, it will introduce a flag in
model.fit
such asgradient_checkpointing=True
Who will benefit from this feature?
Any user with large intermediate activations and tensors that won't fit in device memory. Sequence models can be very memory heavy and this allows scale up without introducing model parallelism complexity.
Contributing
The text was updated successfully, but these errors were encountered: