-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Instances of nn.AdaptiveAvgPool2d
are always decomposed into primitive floating-point operations
#399
Comments
This is probably similar to #53 |
I tried tfl full int8 quantization to see if it would work better... it doesn't have the DQ/Q nodes but it still decomposes which I think will have the same issue you are seeing: import tensorflow as tf
import torch
import ai_edge_torch
import numpy as np
from torch.export import export_for_training
from ai_edge_torch.quantize.pt2e_quantizer import (
PT2EQuantizer,
get_symmetric_quantization_config,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from ai_edge_torch.quantize.quant_config import QuantConfig
def aiet_export(model: torch.nn.Module, fname: str):
sample_input = torch.rand((1, 8, 8, 3))
model = ai_edge_torch.to_channel_last_io(model, args=[0], outputs=[0]).eval()
model(sample_input)
model = export_for_training(model, (sample_input,)).module()
aiet_quantizer = PT2EQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=False, is_dynamic=False)
)
model = prepare_pt2e(model, aiet_quantizer)
with torch.no_grad():
model(sample_input)
model = convert_pt2e(model, fold_quantize=False)
torch.ao.quantization.move_exported_model_to_eval(model)
aiet_tc = ai_edge_torch.convert(
model, (sample_input,), quant_config=QuantConfig(pt2e_quantizer=aiet_quantizer)
)
aiet_tc.export(fname)
class TinyPool1(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 8, 3)
self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.conv(x)
y = self.pool(y)
return y
class TinyPool2(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 8, 3)
self.pool = torch.nn.AvgPool2d((6, 6))
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.conv(x)
y = self.pool(y)
return y
tp1 = TinyPool1()
aiet_export(tp1.eval(), "tinypool1.tflite")
tp2 = TinyPool2()
aiet_export(tp2.eval(), "tinypool2.tflite")
def representative_dataset():
for _ in range(100):
data = np.random.rand(1, 3, 32, 32)
yield [data.astype(np.float32)]
sample_args = (torch.randn(1, 3, 32, 32),)
tfl_converter_flags = {
'optimizations': [tf.lite.Optimize.DEFAULT],
'representative_dataset': representative_dataset,
'target_spec.supported_ops': [tf.lite.OpsSet.TFLITE_BUILTINS_INT8],
'inference_input_type': tf.int8,
'inference_output_type': tf.int8,
}
tfl_quant_model = ai_edge_torch.convert(
tp1, sample_args, _ai_edge_converter_flags=tfl_converter_flags
)
tfl_quant_model.export("tinypool1_tflq.tflite") |
Yes, the operations are already decomposed early on in the call to I will try to remove the average pooling from the decompositions table passed to PyTorch and see what happens. |
Description of the bug:
When converting CNN architectures, any instance of
nn.AdaptiveAvgPool2d
is decomposed into asum
andmul
floating-point operator. Since these are common in CNNs, this leads to frequent dequantization/quantization steps inserted into the output graph which degrades performance on backends that rely on full quantization. At least for static dimensions, converting these to AvgPool2d operators should be trivial.Minimal working example:
Actual vs expected behavior:
Expected:
Fully-quantized graph for both operators
Actual
Suboptimal graph with Dequantize/Quantize nodes and float-operations inserted for TinyPool1
Fully-quantized graph for TinyPool2
Any other information you'd like to share?
A not very pretty workaround is to trace the input/output dimensions of
nn.AdaptiveAvgPool2d
using hooks before execution and replacing them with functionally equivalentnn.AvgPool2d
instances before converting.I wonder whether there is a cleaner way of doing this as part of the conversion, though.
The decomposition of the
aten.avgpool2d
node happens during the first canonicalization pass:ai-edge-torch/ai_edge_torch/fx_pass_base.py
Lines 107 to 109 in d4e358e
Environment
AI Edge Torch: 0.3.0 (current head of
main
)Python: 3.10.5
PyTorch: 2.5.0
TF: 2.19.0-dev20241127
OS: Ubuntu 24.04.1 LTS in WSL2
The text was updated successfully, but these errors were encountered: