Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to get the flops of TensorFlow.js model #7427

Open
lebron8dong opened this issue Feb 28, 2023 · 10 comments
Open

How to get the flops of TensorFlow.js model #7427

lebron8dong opened this issue Feb 28, 2023 · 10 comments

Comments

@lebron8dong
Copy link

I can't get flops through tf.profile like TensorFlow, the api of tf.profile is different from tf.

@gbaned gbaned self-assigned this Mar 1, 2023
@lebron8dong
Copy link
Author

Is there an API to get Gflops directly?

@gaikwadrahul8
Copy link
Contributor

Hi, @lebron8dong

Apologize for the delay and As far I know, there is no straight forward API available in the TFJS but I think you can use it tf.profiler API on Tensorflow model before converting model from Keras .h5 format or tensorflow default model format SavedModel to tensorflowjs_converter format model.json but after converting model intomodel.jsonformat I don't think so we have API like tf.profiler, Please correct me If I have missed something here? Thank you!

@lebron8dong
Copy link
Author

@gaikwadrahul8 thanks,Does TensorFlow.js have an API that can calculate flops?
Can I only calculate the flops of each layer one by one through formula?

@lebron8dong
Copy link
Author

@gaikwadrahul8 Is there any other way to get the flops of TensorFlow.js model?

@gbaned gbaned assigned gaikwadrahul8 and unassigned gbaned Mar 7, 2023
@gaikwadrahul8
Copy link
Contributor

gaikwadrahul8 commented Mar 7, 2023

Hi, @lebron8dong

Apologize for the delayed response and as far I know there is no API available in the Tensorflow.js at the moment to calculate flops directly and It seems like you'll have to try it from your end with some formula and I have added FLOPS calculation with tf.keras for calculating the flops for Tensorflow keras layers not for Tensorflow.js layers so please have a look into that notebook which may help you to find out flops in Tensorflow.js layers. Thank you!

You'll have to do something like below in TFJS which is already done for tensorflow model:

import tensorflow as tf
from tensorflow.python.profiler.model_analyzer import profile
from tensorflow.python.profiler.option_builder import ProfileOptionBuilder
print('TensorFlow:', tf.__version__)

model = tf.keras.applications.ResNet50()

forward_pass = tf.function(
    model.call,
    input_signature=[tf.TensorSpec(shape=(1,) + model.input_shape[1:])])

graph_info = profile(forward_pass.get_concrete_function().graph,
                        options=ProfileOptionBuilder.float_operation())

# The //2 is necessary since `profile` counts multiply and accumulate
# as two flops, here we report the total number of multiply accumulate ops
flops = graph_info.total_float_ops // 2
print('Flops: {:,}'.format(flops))

I would suggest you to please refer these 02 issues #32809, #17273 which will help you to calculate flops in the TFJS and if you're looking FLOPS calculation API in TFJS then we'll consider this issue as feature request because at the moment we don't have straight forward API like keras or Tensorflow in TFJS so Would you like to consider this issue as feature request or not? Thank you!

@lebron8dong
Copy link
Author

@gaikwadrahul8 Yes,I would like to consider this issue as feature request.

@gaikwadrahul8
Copy link
Contributor

@lebron8dong

Thank you for the confirmation and we'll consider this issue as feature request and I'll forward this issue to respective team so they'll take appropriate action from their end. Thank you!

@lebron8dong
Copy link
Author

@gaikwadrahul8
Thank you very much. I think 'How does TensorFlow.js obtain the execution time of each layer during model inference' also can be a feature request. And can you help me forward this issue to respective team so they'll take appropriate action from their end. Thank you very much!

@gaikwadrahul8
Copy link
Contributor

Hi, @Linchenn

Could you please look into this issue ? Thank you!

@moonsh
Copy link

moonsh commented May 20, 2023

@gaikwadrahul8

I have a question about the script you shared.

It seems like getting MACs? Why do we divide by 2 for getting FLOPs by the way?

flops = graph_info.total_float_ops // 2
print('Flops: {:,}'.format(flops))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants