Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gradient checkpointing for functional models via the standard fit method #98

Open
rivershah opened this issue May 11, 2023 · 7 comments

Comments

@rivershah
Copy link

System information.

TensorFlow version (you are using): 2.12.0
Are you willing to contribute it (Yes/No) : No

Describe the feature and the current behavior/state.

Describe the feature clearly here. Be sure to convey here why the requested feature is needed. Any brief description about the use-case would help.

The following feature excerpt is copied from huggingface documentation here

In order to compute the gradients during the backward pass all activations from the forward pass are normally saved. This can create a big memory overhead. Alternatively, one could forget all activations during the forward pass and recompute them on demand during the backward pass. This would however add a significant computational overhead and slow down training.

Gradient checkpointing strikes a compromise between the two approaches and saves strategically selected activations throughout the computational graph so only a fraction of the activations need to be re-computed for the gradients. See this great article explaining the ideas behind gradient checkpointing.

Will this change the current api? How?

Yes, it will introduce a flag in model.fit such as gradient_checkpointing=True

Who will benefit from this feature?

Any user with large intermediate activations and tensors that won't fit in device memory. Sequence models can be very memory heavy and this allows scale up without introducing model parallelism complexity.

Contributing

  • Do you want to contribute a PR? (yes/no): No
  • If yes, please read this page for instructions
  • Briefly describe your candidate solution(if contributing):
@innat
Copy link

innat commented May 11, 2023

Grad checkpointing is a very attractive feature. It would be great to have official support for this. Discussion: https://discuss.tensorflow.org/t/support-gradient-checkpointing-in-tensorflow-2/.

@ianstenbit
Copy link
Contributor

Hi @rivershah -- thanks for the feature request!

This is something that would need to be implemented by the TensorFlow team as part of tf.GradientTape before we could support such a feature in Keras. Feel free to open a feature request at https://github.com/tensorflow/tensorflow

@innat
Copy link

innat commented May 11, 2023

@ianstenbit
(not 100% sure), From TF, there is a method recompute_grad for this. In EfficientDet official repo, they used it for gradient-checkpiont, code.

https://github.com/pidajay/addons/blob/grad_checkpointing_eager/docs/tutorials/training_gradient_checkpointing.ipynb

@ianstenbit ianstenbit reopened this May 11, 2023
@ianstenbit
Copy link
Contributor

ianstenbit commented May 11, 2023

Oh interesting -- thanks @innat, I wasn't aware of that. This may be feasible as a feature using recompute_grad. Let me sync with some folks that have some more expertise on this area.

@ianstenbit
Copy link
Contributor

ianstenbit commented May 11, 2023

It sounds like this should be possible (probably through making a subclass of Model which uses a custom train_step and decorates call with this recompute_grad decorator.)

This seems like a nice-to-have feature, but is not currently on our roadmap. I'll leave the issue open for now -- if any contributors are interested in experimenting with this approach, we could see how it goes.

@sachinprasadhs sachinprasadhs transferred this issue from keras-team/keras Sep 22, 2023
@innat-asj
Copy link

@sachinprasadhs
Could you please explain why this issue moved from keras to here? As it is a feature request, and other framework like torch support gradient checkpoint. That, IMO, makes it appropriate feature for keras.

@sachinprasadhs
Copy link
Collaborator

Hi, This was moved as part of migration of Keras 3, if you feel this issue is related to Keras 3, feel free to open a new issue in Keras repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants