Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instances of nn.AdaptiveAvgPool2d are always decomposed into primitive floating-point operations #399

Open
etrommer opened this issue Nov 29, 2024 · 3 comments
Assignees
Labels
status:awaiting ai-edge-developer type:performance An issue with performance, primarily inference latency

Comments

@etrommer
Copy link

etrommer commented Nov 29, 2024

Description of the bug:

When converting CNN architectures, any instance of nn.AdaptiveAvgPool2d is decomposed into a sum and mul floating-point operator. Since these are common in CNNs, this leads to frequent dequantization/quantization steps inserted into the output graph which degrades performance on backends that rely on full quantization. At least for static dimensions, converting these to AvgPool2d operators should be trivial.

Minimal working example:

import torch
import ai_edge_torch
from torch.export import export_for_training
from ai_edge_torch.quantize.pt2e_quantizer import (
    PT2EQuantizer,
    get_symmetric_quantization_config,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from ai_edge_torch.quantize.quant_config import QuantConfig

def aiet_export(model: torch.nn.Module, fname: str):
    sample_input = torch.rand((1, 8, 8, 3))
    model = ai_edge_torch.to_channel_last_io(model, args=[0], outputs=[0]).eval()
    model(sample_input)
    model = export_for_training(model, (sample_input,)).module()
    aiet_quantizer = PT2EQuantizer().set_global(
        get_symmetric_quantization_config(is_per_channel=False, is_dynamic=False)
    )
    model = prepare_pt2e(model, aiet_quantizer)
    with torch.no_grad():
        model(sample_input)
    model = convert_pt2e(model, fold_quantize=False)
    torch.ao.quantization.move_exported_model_to_eval(model)
    aiet_tc = ai_edge_torch.convert(
        model, (sample_input,), quant_config=QuantConfig(pt2e_quantizer=aiet_quantizer)
    )
    aiet_tc.export(fname)


class TinyPool1(EvalModule):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 8, 3)
        self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.conv(x)
        y = self.pool(y)
        return y


class TinyPool2(EvalModule):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 8, 3)
        self.pool = torch.nn.AvgPool2d((6, 6))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.conv(x)
        y = self.pool(y)
        return y


tp1 = TinyPool1()
aiet_export(tp1, "tinypool1.tflite")

tp2 = TinyPool2()
aiet_export(tp2, "tinypool2.tflite")

Actual vs expected behavior:

Expected:
Fully-quantized graph for both operators

Actual
Suboptimal graph with Dequantize/Quantize nodes and float-operations inserted for TinyPool1
gh_aiet_tp1

Fully-quantized graph for TinyPool2
gh_aiet_tp2

Any other information you'd like to share?

A not very pretty workaround is to trace the input/output dimensions of nn.AdaptiveAvgPool2d using hooks before execution and replacing them with functionally equivalent nn.AvgPool2d instances before converting.

I wonder whether there is a cleaner way of doing this as part of the conversion, though.

The decomposition of the aten.avgpool2d node happens during the first canonicalization pass:

exported_program = exported_program.run_decompositions(
self._DUMMY_DECOMP_TABLE
)

Environment
AI Edge Torch: 0.3.0 (current head of main)
Python: 3.10.5
PyTorch: 2.5.0
TF: 2.19.0-dev20241127
OS: Ubuntu 24.04.1 LTS in WSL2

@etrommer etrommer added the type:bug Bug label Nov 29, 2024
@etrommer
Copy link
Author

etrommer commented Dec 2, 2024

This is probably similar to #53

@pkgoogle pkgoogle self-assigned this Dec 3, 2024
@pkgoogle
Copy link
Contributor

pkgoogle commented Dec 3, 2024

I tried tfl full int8 quantization to see if it would work better... it doesn't have the DQ/Q nodes but it still decomposes which I think will have the same issue you are seeing:

import tensorflow as tf
import torch
import ai_edge_torch
import numpy as np
from torch.export import export_for_training
from ai_edge_torch.quantize.pt2e_quantizer import (
    PT2EQuantizer,
    get_symmetric_quantization_config,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from ai_edge_torch.quantize.quant_config import QuantConfig

def aiet_export(model: torch.nn.Module, fname: str):
    sample_input = torch.rand((1, 8, 8, 3))
    model = ai_edge_torch.to_channel_last_io(model, args=[0], outputs=[0]).eval()
    model(sample_input)
    model = export_for_training(model, (sample_input,)).module()
    aiet_quantizer = PT2EQuantizer().set_global(
        get_symmetric_quantization_config(is_per_channel=False, is_dynamic=False)
    )
    model = prepare_pt2e(model, aiet_quantizer)
    with torch.no_grad():
        model(sample_input)
    model = convert_pt2e(model, fold_quantize=False)
    torch.ao.quantization.move_exported_model_to_eval(model)
    aiet_tc = ai_edge_torch.convert(
        model, (sample_input,), quant_config=QuantConfig(pt2e_quantizer=aiet_quantizer)
    )
    aiet_tc.export(fname)


class TinyPool1(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 8, 3)
        self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.conv(x)
        y = self.pool(y)
        return y


class TinyPool2(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 8, 3)
        self.pool = torch.nn.AvgPool2d((6, 6))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.conv(x)
        y = self.pool(y)
        return y


tp1 = TinyPool1()
aiet_export(tp1.eval(), "tinypool1.tflite")

tp2 = TinyPool2()
aiet_export(tp2.eval(), "tinypool2.tflite")

def representative_dataset():
    for _ in range(100):
        data = np.random.rand(1, 3, 32, 32)
        yield [data.astype(np.float32)]

sample_args = (torch.randn(1, 3, 32, 32),)

tfl_converter_flags = {
    'optimizations': [tf.lite.Optimize.DEFAULT],
    'representative_dataset': representative_dataset,
    'target_spec.supported_ops': [tf.lite.OpsSet.TFLITE_BUILTINS_INT8],
    'inference_input_type': tf.int8,
    'inference_output_type': tf.int8,
}

tfl_quant_model = ai_edge_torch.convert(
    tp1, sample_args, _ai_edge_converter_flags=tfl_converter_flags
)

tfl_quant_model.export("tinypool1_tflq.tflite")

image

@pkgoogle pkgoogle added type:performance An issue with performance, primarily inference latency status:awaiting ai-edge-developer and removed type:bug Bug labels Dec 3, 2024
@etrommer
Copy link
Author

etrommer commented Dec 3, 2024

Yes, the operations are already decomposed early on in the call to run_decompositions(), before they hit the TFlite conversion.

I will try to remove the average pooling from the decompositions table passed to PyTorch and see what happens.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
status:awaiting ai-edge-developer type:performance An issue with performance, primarily inference latency
Projects
None yet
Development

No branches or pull requests

2 participants