Skip to content

Commit

Permalink
Support MUL op in quantization tool.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653779941
  • Loading branch information
ai-edge-bot authored and copybara-github committed Jul 18, 2024
1 parent 14f0713 commit f06f712
Show file tree
Hide file tree
Showing 9 changed files with 309 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ai_edge_quantizer/algorithm_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class AlgorithmName(str, enum.Enum):
_TFLOpName.GELU,
_TFLOpName.ADD,
_TFLOpName.SUB,
_TFLOpName.MUL,
),
(
naive_min_max_quantize.materialize_fc_conv,
Expand All @@ -88,6 +89,7 @@ class AlgorithmName(str, enum.Enum):
naive_min_max_quantize.materialize_gelu,
naive_min_max_quantize.materialize_add,
naive_min_max_quantize.materialize_sub,
naive_min_max_quantize.materialize_mul,
),
):
register_quantized_op(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ def materialize_sub(
)


def materialize_mul(
op_info: qtyping.OpInfo,
graph_info: qtyping.GraphInfo,
tensor_name_to_qsv: dict[str, Any],
) -> list[qtyping.TensorTransformationParams]:
"""Materialize tensors in tfl.mul."""
return utils.materialize_standard_op(
op_info,
graph_info,
tensor_name_to_qsv,
)


def materialize_softmax(
op_info: qtyping.OpInfo,
graph_info: qtyping.GraphInfo,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2024 The AI Edge Quantizer Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import os

from absl.testing import parameterized
import numpy as np

from tensorflow.python.platform import googletest
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils

_TFLOpName = qtyping.TFLOperationName
_OpExecutionMode = qtyping.OpExecutionMode
_TensorQuantConfig = qtyping.TensorQuantizationConfig
_QuantTransformation = qtyping.QuantTransformation
_OpTestInfo = naive_min_max_test_utils.OpTestInfo

_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile(
"../../../tests/models"
)


class MulTest(naive_min_max_test_utils.NaiveMinMaxQuantizeTest):

def _custom_setup(self, test_model_file: str):
np.random.seed(666)
self._test_model_path = os.path.join(
_TEST_DATA_PREFIX_PATH, test_model_file
)
self._op_test_info = _OpTestInfo(
test_model=tfl_flatbuffer_utils.read_model(self._test_model_path),
op_tensor_names={},
input_range=(np.array([[-10]]), np.array([[8]])),
output_range=(np.array([[10]]), np.array([[88]])),
)
# The test model has one subgraph for now.
self._graph_info = qtyping.GraphInfo(
subgraph_tensors=self._op_test_info.test_model.subgraphs[0].tensors,
buffers=self._op_test_info.test_model.buffers,
)

@parameterized.named_parameters(
("int8_nonsymmetric", 8, False),
("int16_symmetric", 16, True),
)
def test_materialize_srq_mul_succeeds(
self,
activation_num_bits: int,
activation_symmetric: bool,
):
self._custom_setup("single_mul.tflite")
# Read from Model Explorer.
subgraph0 = self._op_test_info.test_model.subgraphs[0]
subgraph_op_id = 0
op = subgraph0.operators[subgraph_op_id]
op_tensor_names = {}
op_tensor_names["input"] = "serving_default_input_1:0"
op_tensor_names["input2"] = "serving_default_input_2:0"
op_tensor_names["output"] = "PartitionedCall:0"
self._op_test_info.op_tensor_names = op_tensor_names

activation_tensor_config = _TensorQuantConfig(
num_bits=activation_num_bits,
symmetric=activation_symmetric,
channel_wise=False,
)
op_info = qtyping.OpInfo(
op=op,
op_name=qtyping.TFLOperationName.MUL,
subgraph_op_index=subgraph_op_id,
op_quant_config=qtyping.OpQuantizationConfig(
activation_tensor_config=activation_tensor_config,
weight_tensor_config=activation_tensor_config,
execution_mode=_OpExecutionMode.SRQ,
),
)
self._test_two_input_one_output_ops(
op_info,
self._graph_info,
self._op_test_info,
naive_min_max_quantize.materialize_mul,
)

@parameterized.named_parameters(
("int8_nonsymmetric", 8, False),
("int16_symmetric", 16, True),
)
def test_materialize_srq_mul2_constant_input_succeeds(
self,
activation_num_bits: int,
activation_symmetric: bool,
):
"""Tests the case where one of the MUL inputs is a constant tensor."""
self._custom_setup("single_mul2_constant_input.tflite")
# Read from Model Explorer.
subgraph0 = self._op_test_info.test_model.subgraphs[0]
subgraph_op_id = 0
op = subgraph0.operators[subgraph_op_id]
op_tensor_names = {}
op_tensor_names["input"] = "serving_default_input_1:0"
op_tensor_names["weight"] = "model/multiply/ExpandDims"
op_tensor_names["output"] = "PartitionedCall:0"
self._op_test_info.op_tensor_names = op_tensor_names

activation_tensor_config = _TensorQuantConfig(
num_bits=activation_num_bits,
symmetric=activation_symmetric,
channel_wise=False,
)
op_info = qtyping.OpInfo(
op=op,
op_name=qtyping.TFLOperationName.MUL,
subgraph_op_index=subgraph_op_id,
op_quant_config=qtyping.OpQuantizationConfig(
activation_tensor_config=activation_tensor_config,
weight_tensor_config=activation_tensor_config,
execution_mode=_OpExecutionMode.SRQ,
),
)
# We re-use the fc_bmm_conv helper test function here because the constant
# tensor is treated as a weight tensor.
self._test_fc_bmm_conv(
op_info,
self._graph_info,
self._op_test_info,
naive_min_max_quantize.materialize_mul,
)

if __name__ == "__main__":
googletest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
_TFLOpName.GELU,
_TFLOpName.ADD,
_TFLOpName.SUB,
_TFLOpName.MUL,
])

_INT4_SRQ_SUPPORTED_OPS = frozenset([
Expand Down
1 change: 1 addition & 0 deletions ai_edge_quantizer/qtyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class TFLOperationName(str, enum.Enum):
GELU = 'GELU'
ADD = 'ADD'
SUB = 'SUB'
MUL = 'MUL'


# Use same code number as MOJAX for compatibility.
Expand Down
145 changes: 145 additions & 0 deletions ai_edge_quantizer/tests/end_to_end_tests/mul_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2024 The AI Edge Quantizer Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""E2E tests for the quantizer for model with mul."""

from absl.testing import parameterized
import numpy as np

from tensorflow.python.platform import googletest
from ai_edge_quantizer import qtyping
from ai_edge_quantizer import quantizer
from ai_edge_quantizer.utils import test_utils
from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import

_OpExecutionMode = qtyping.OpExecutionMode
_OpName = qtyping.TFLOperationName
_TensorQuantConfig = qtyping.TensorQuantizationConfig
_OpQuantConfig = qtyping.OpQuantizationConfig

_RNG = np.random.default_rng(66)


def _get_dummy_data(num_inputs, num_samples):
data = []
for _ in range(num_samples):
data.append({
f'input_{i+1}': _RNG.uniform(size=(1, 32, 32)).astype(np.float32)
for i in range(num_inputs)
})
return data


def _get_calibration_data(num_inputs, num_samples: int = 512):
return _get_dummy_data(num_inputs, num_samples)


def _get_test_data(num_inputs, num_samples: int = 8):
return _get_dummy_data(num_inputs, num_samples)


class MulTest(parameterized.TestCase):

def _custom_setup(self, test_model_file):
super().setUp()
self.float_model_path = test_utils.get_path_to_datafile(
f'../models/{test_model_file}'
)
self._quantizer = quantizer.Quantizer(self.float_model_path)

@parameterized.parameters(
'../recipes/default_a8w8_recipe.json',
'../recipes/default_a16w8_recipe.json',
)
def test_mul_model_full_integer(self, recipe_path):
self._custom_setup('single_mul.tflite')
recipe_path = test_utils.get_path_to_datafile(recipe_path)
self._quantizer.load_quantization_recipe(recipe_path)
self.assertTrue(self._quantizer.need_calibration)
calibration_result = self._quantizer.calibrate(
_get_calibration_data(num_inputs=2)
)
_ = self._quantizer.quantize(calibration_result)
# Skip model size check because the quantized model doesn't decrease as
# there are no weights in the model file.

comparion_result = self._quantizer.compare(
error_metrics='mse', signature_test_data=_get_test_data(num_inputs=2)
)
self._check_comparion_result(
comparion_result,
output_tolerance=1e-4,
)

@parameterized.parameters(
'../recipes/default_a8w8_recipe.json',
'../recipes/default_a16w8_recipe.json',
)
def test_mul2_constant_input_model_full_integer(self, recipe_path):
self._custom_setup('single_mul2_constant_input.tflite')
recipe_path = test_utils.get_path_to_datafile(recipe_path)
self._quantizer.load_quantization_recipe(recipe_path)
self.assertTrue(self._quantizer.need_calibration)
calibration_result = self._quantizer.calibrate(
_get_calibration_data(num_inputs=1)
)
quant_result = self._quantizer.quantize(calibration_result)
# Check model size.
with gfile.GFile(self.float_model_path, 'rb') as f:
float_model_bytearray = bytearray(f.read())
self.assertLess(
len(quant_result.quantized_model), len(float_model_bytearray)
)

comparion_result = self._quantizer.compare(
error_metrics='mse', signature_test_data=_get_test_data(num_inputs=1)
)
self._check_comparion_result(
comparion_result,
output_tolerance=1e-4,
)

@parameterized.named_parameters(
('drq', _OpExecutionMode.DRQ),
('weight_only', _OpExecutionMode.WEIGHT_ONLY),
)
def test_mul2_fail(self, execution_mode):
self._custom_setup('single_mul.tflite')
with self.assertRaisesRegex(ValueError, 'Unsupported op for .*: MUL'):
self._quantizer.update_quantization_recipe(
regex='.*',
operation_name='MUL',
op_config=_OpQuantConfig(
weight_tensor_config=_TensorQuantConfig(
num_bits=8, symmetric=False
),
execution_mode=execution_mode,
),
algorithm_key='min_max_uniform_quantize',
)

# TODO: b/345503484 - Check weight tensor type of the quantized model.
def _check_comparion_result(
self,
comparion_result,
output_tolerance,
):
# Check final output.
output_mse = comparion_result['PartitionedCall:0']
self.assertLess(output_mse, output_tolerance)


if __name__ == '__main__':
googletest.main()
Binary file added ai_edge_quantizer/tests/models/single_mul.tflite
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions ai_edge_quantizer/utils/tfl_flatbuffer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
_TFLOpName.GELU: schema_py_generated.BuiltinOperator.GELU,
_TFLOpName.ADD: schema_py_generated.BuiltinOperator.ADD,
_TFLOpName.SUB: schema_py_generated.BuiltinOperator.SUB,
_TFLOpName.MUL: schema_py_generated.BuiltinOperator.MUL,
})

TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
Expand Down

0 comments on commit f06f712

Please sign in to comment.