From 18aeb2a03f16ed0765b42c282d3ff5fce6e7e340 Mon Sep 17 00:00:00 2001 From: Vitalii Dziuba Date: Wed, 20 Nov 2024 17:39:41 -0800 Subject: [PATCH] Add full integer quantization for SELECT_V2 in Quantizer PiperOrigin-RevId: 698579890 --- ai_edge_quantizer/algorithm_manager.py | 2 + .../naive_min_max_quantize.py | 17 +++ .../select_v2_test.py | 107 ++++++++++++++++ ai_edge_quantizer/default_policy.py | 6 +- ai_edge_quantizer/qtyping.py | 1 + .../tests/end_to_end_tests/select_v2_test.py | 117 ++++++++++++++++++ .../tests/models/single_select_v2.tflite | Bin 0 -> 1120 bytes .../utils/tfl_flatbuffer_utils.py | 1 + 8 files changed, 249 insertions(+), 2 deletions(-) create mode 100644 ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/select_v2_test.py create mode 100644 ai_edge_quantizer/tests/end_to_end_tests/select_v2_test.py create mode 100644 ai_edge_quantizer/tests/models/single_select_v2.tflite diff --git a/ai_edge_quantizer/algorithm_manager.py b/ai_edge_quantizer/algorithm_manager.py index 5a72a70..046c69e 100644 --- a/ai_edge_quantizer/algorithm_manager.py +++ b/ai_edge_quantizer/algorithm_manager.py @@ -91,6 +91,7 @@ class AlgorithmName(str, enum.Enum): _TFLOpName.LOGISTIC, # Sigmoid _TFLOpName.SLICE, _TFLOpName.SUM, + _TFLOpName.SELECT_V2, ), ( naive_min_max_quantize.materialize_input, @@ -118,6 +119,7 @@ class AlgorithmName(str, enum.Enum): naive_min_max_quantize.materialize_softmax_and_logistic, naive_min_max_quantize.materialize_slice, naive_min_max_quantize.materialize_sum, + naive_min_max_quantize.materialize_select_v2, ), ): register_quantized_op( diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py index 411f77d..cfe6b4c 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py @@ -325,6 +325,23 @@ def materialize_slice( ) +def materialize_select_v2( + op_info: qtyping.OpInfo, + graph_info: qtyping.GraphInfo, + tensor_name_to_qsv: dict[str, Any], +) -> list[qtyping.TensorTransformationParams]: + """Materialize tensors in tfl.select_v2.""" + return utils.materialize_standard_op( + op_info, + graph_info, + tensor_name_to_qsv, + constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE, + inputs_to_ignore=[ + 0, + ], # Condition tensor does not need to be quantized. + ) + + def materialize_sum( op_info: qtyping.OpInfo, graph_info: qtyping.GraphInfo, diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/select_v2_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/select_v2_test.py new file mode 100644 index 0000000..f6c020d --- /dev/null +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/select_v2_test.py @@ -0,0 +1,107 @@ +# Copyright 2024 The AI Edge Quantizer Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.platform import googletest +from ai_edge_quantizer import qtyping +from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize +from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils +from ai_edge_quantizer.utils import test_utils +from ai_edge_quantizer.utils import tfl_flatbuffer_utils + +_TFLOpName = qtyping.TFLOperationName +_ComputePrecision = qtyping.ComputePrecision +_TensorQuantConfig = qtyping.TensorQuantizationConfig +_QuantTransformation = qtyping.QuantTransformation +_OpTestInfo = naive_min_max_test_utils.OpTestInfo + +_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile( + "../../../tests/models" +) + +_DEFAULT_WEIGHT_QUANT_SETTING = ( + naive_min_max_test_utils.DEFAULT_WEIGHT_QUANT_SETTING +) + + +class SelectV2Test(naive_min_max_test_utils.NaiveMinMaxQuantizeTest): + + def setUp(self): + super().setUp() + np.random.seed(666) + self._test_model_path = os.path.join( + _TEST_DATA_PREFIX_PATH, "single_select_v2.tflite" + ) + self._op_test_info = _OpTestInfo( + test_model=tfl_flatbuffer_utils.read_model(self._test_model_path), + op_tensor_names={}, + input_range=(np.array([[-10]]), np.array([[10]])), + output_range=(np.array([[-10]]), np.array([[10]])), + ) + # The test model has one subgraph for now. + self._graph_info = qtyping.GraphInfo( + subgraph_tensors=self._op_test_info.test_model.subgraphs[0].tensors, + buffers=self._op_test_info.test_model.buffers, + ) + + @parameterized.parameters( + 8, + 16, + ) + def test_materialize_select_v2_succeeds(self, num_bits): + activation_tensor_config = _TensorQuantConfig( + num_bits=num_bits, + symmetric=True, + granularity=qtyping.QuantGranularity.TENSORWISE, + ) + op_quant_config = qtyping.OpQuantizationConfig( + activation_tensor_config=activation_tensor_config, + weight_tensor_config=_DEFAULT_WEIGHT_QUANT_SETTING, + compute_precision=_ComputePrecision.INTEGER, # SRQ. + ) + # Read from Model Explorer. + subgraph0 = self._op_test_info.test_model.subgraphs[0] + subgraph_op_id = 0 + op = subgraph0.operators[subgraph_op_id] + op_info = qtyping.OpInfo( + op=op, + op_name=qtyping.TFLOperationName.SELECT_V2, + subgraph_op_index=subgraph_op_id, + op_quant_config=op_quant_config, + ) + + # Test settings. + op_tensor_names = {} + op_tensor_names["input"] = "selectv2_condition_tensor:0" + op_tensor_names["input2"] = "selectv2_t_tensor:0" + op_tensor_names["input3"] = "selectv2_e_tensor:0" + op_tensor_names["output"] = "PartitionedCall:0" + self._op_test_info.op_tensor_names = op_tensor_names + self._test_no_weights_op( + op_info, + self._graph_info, + self._op_test_info, + naive_min_max_quantize.materialize_select_v2, + same_input_output_params=True, + inputs_to_ignore=[0], + ) + + +if __name__ == "__main__": + googletest.main() diff --git a/ai_edge_quantizer/default_policy.py b/ai_edge_quantizer/default_policy.py index 65458f3..c30a0c3 100644 --- a/ai_edge_quantizer/default_policy.py +++ b/ai_edge_quantizer/default_policy.py @@ -166,7 +166,8 @@ "OUTPUT", "SLICE", "EMBEDDING_LOOKUP", - "SUM" + "SUM", + "SELECT_V2" ], "static_wi8_ai8": [ "ADD", @@ -193,7 +194,8 @@ "OUTPUT", "SLICE", "EMBEDDING_LOOKUP", - "SUM" + "SUM", + "SELECT_V2" ], "static_wi4_ai8": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"], "static_wi4_ai16": ["FULLY_CONNECTED", "CONV_2D", "INPUT", "OUTPUT"], diff --git a/ai_edge_quantizer/qtyping.py b/ai_edge_quantizer/qtyping.py index 9dd692c..6d5c320 100644 --- a/ai_edge_quantizer/qtyping.py +++ b/ai_edge_quantizer/qtyping.py @@ -59,6 +59,7 @@ class TFLOperationName(str, enum.Enum): LOGISTIC = 'LOGISTIC' SLICE = 'SLICE' SUM = 'SUM' + SELECT_V2 = 'SELECT_V2' class QuantizeMode(enum.Enum): diff --git a/ai_edge_quantizer/tests/end_to_end_tests/select_v2_test.py b/ai_edge_quantizer/tests/end_to_end_tests/select_v2_test.py new file mode 100644 index 0000000..1dd2c2e --- /dev/null +++ b/ai_edge_quantizer/tests/end_to_end_tests/select_v2_test.py @@ -0,0 +1,117 @@ +# Copyright 2024 The AI Edge Quantizer Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""E2E tests for the quantizer for model with slice.""" + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.platform import googletest +from ai_edge_quantizer import qtyping +from ai_edge_quantizer import quantizer +from ai_edge_quantizer.utils import test_utils +from ai_edge_quantizer.utils import tfl_flatbuffer_utils + +_OpExecutionMode = qtyping.OpExecutionMode +_OpName = qtyping.TFLOperationName +_TensorQuantConfig = qtyping.TensorQuantizationConfig +_OpQuantConfig = qtyping.OpQuantizationConfig + +_RNG = np.random.default_rng(66) + + +def _get_dummy_data(num_samples): + data = [] + for _ in range(num_samples): + data.append({ + 'condition_tensor': _RNG.uniform(size=(1, 16)).astype(np.bool), + 'e_tensor': _RNG.uniform(size=(1, 16)).astype(np.float32), + 't_tensor': _RNG.uniform(size=(1, 16)).astype(np.float32), + }) + return data + + +def _get_calibration_data(num_samples: int = 64): + calibration_samples = _get_dummy_data(num_samples) + calibration_data = {'selectv2': calibration_samples} + return calibration_data + + +def _get_test_data(num_samples: int = 16): + return _get_calibration_data(num_samples) + + +class SelectV2Test(parameterized.TestCase): + + def _custom_setup(self, test_model_file): + super().setUp() + self.float_model_path = test_utils.get_path_to_datafile( + f'../models/{test_model_file}' + ) + self._quantizer = quantizer.Quantizer(self.float_model_path) + + @parameterized.parameters( + ('../../recipes/default_a8w8_recipe.json', 9), # int8. + ('../../recipes/default_a16w8_recipe.json', 7), # int16. + ) + def test_select_v2_model_full_integer(self, recipe_path, tensor_type): + self._custom_setup('single_select_v2.tflite') + recipe_path = test_utils.get_path_to_datafile(recipe_path) + self._quantizer.load_quantization_recipe(recipe_path) + self.assertTrue(self._quantizer.need_calibration) + calibration_result = self._quantizer.calibrate(_get_calibration_data()) + quantization_result = self._quantizer.quantize(calibration_result) + + # Check input/output tensor type. + quantized_model = tfl_flatbuffer_utils.read_model( + quantization_result.quantized_model + ) + self.assertLen(quantized_model.subgraphs, 1) + subgraph = quantized_model.subgraphs[0] + subgraph_tensors = subgraph.tensors + self.assertLen(subgraph.inputs, 3) + condition_tensor = subgraph_tensors[subgraph.inputs[0]] + e_tensor = subgraph_tensors[subgraph.inputs[1]] + t_tensor = subgraph_tensors[subgraph.inputs[2]] + output_tensor = subgraph_tensors[subgraph.outputs[0]] + # See schema_py_generated.py for type code. + self.assertEqual(condition_tensor.type, 6) # bool. + self.assertEqual(e_tensor.type, tensor_type) + self.assertEqual(t_tensor.type, tensor_type) + self.assertEqual(output_tensor.type, tensor_type) + + comparison_result = self._quantizer.validate( + error_metrics='mse', test_data=_get_test_data() + ) + self._check_comparison_result( + comparison_result, + output_tolerance=1e-4, + ) + + # TODO: b/345503484 - Check weight tensor type of the quantized model. + def _check_comparison_result( + self, + comparison_result, + output_tolerance, + ): + # TODO: b/357959309 - Use comparison result directly for testing. + comparison_result = comparison_result.get_all_tensor_results() + # Check final output. + output_mse = comparison_result['PartitionedCall:0'] + self.assertLess(output_mse, output_tolerance) + + +if __name__ == '__main__': + googletest.main() diff --git a/ai_edge_quantizer/tests/models/single_select_v2.tflite b/ai_edge_quantizer/tests/models/single_select_v2.tflite new file mode 100644 index 0000000000000000000000000000000000000000..36b58511d7441216b6a37ef6b9fb9ed10f1c47ab GIT binary patch literal 1120 zcmZ`&Jx?1!5Pb>8SmXpGAY=&&7Zg;MaVQd$#xfuwnGX>w?P2g0C)sz%UZY6l2SiF2 zQBtN%iKzGq`~V7m0Tl&8DFT=G)_2y~3VV9mx!IXFvoo`c0P35y_j4#<5qZpE1`c&O zWGI6=;D#u@0`3XC2Am8qMp3^r1Tbm6?rHqT_43_y z+#bs>`*LyRW6@T0m_yvnjIaA>Mg%SOOS(~g`rdi| zsOEcwVh60VTog|;DM-Ie$dAX-17p3Vyw~U$J0o7R+Gw@bO0kXXF8{-$ywW_MXrtLl z=G{#Ol}zFgfO-4QRIUC?Q{N}3wQDF(-8apDKFXgA`v|{Y^1=)Jx_==?KObaQ`7fFD eoJoD4f9f8ik0%vRe(INFdr5)b>fxzGZvOz>f3GY6 literal 0 HcmV?d00001 diff --git a/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py b/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py index 6058760..665e1e9 100644 --- a/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +++ b/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py @@ -61,6 +61,7 @@ _TFLOpName.LOGISTIC: schema_py_generated.BuiltinOperator.LOGISTIC, _TFLOpName.SLICE: schema_py_generated.BuiltinOperator.SLICE, _TFLOpName.SUM: schema_py_generated.BuiltinOperator.SUM, + _TFLOpName.SELECT_V2: schema_py_generated.BuiltinOperator.SELECT_V2, }) TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(