From 83a275db4a93da2a53cfe14130dba2ef5163aacd Mon Sep 17 00:00:00 2001 From: Vitalii Dziuba Date: Tue, 4 Feb 2025 08:19:20 -0800 Subject: [PATCH] Enable quantization when op may contain tensors with zero-size array PiperOrigin-RevId: 723093115 --- ai_edge_quantizer/calibrator.py | 8 +++++++- ai_edge_quantizer/calibrator_test.py | 12 ++++++++++++ .../tests/models/reshape_with_empty_shape.tflite | Bin 0 -> 852 bytes ai_edge_quantizer/utils/tfl_interpreter_utils.py | 7 ++++++- 4 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 ai_edge_quantizer/tests/models/reshape_with_empty_shape.tflite diff --git a/ai_edge_quantizer/calibrator.py b/ai_edge_quantizer/calibrator.py index c91066d..bc3b792 100644 --- a/ai_edge_quantizer/calibrator.py +++ b/ai_edge_quantizer/calibrator.py @@ -281,7 +281,13 @@ def _initialize_model_qsvs( algorithm_name, op_key ) op_info = qtyping.OpInfo(op, op_key, subgraph_op_id, op_quant_config) - op_qsvs = qsv_init_func(op_info, graph_info) + # Ignore the input tensors where any dimension of the shape is 0. + inputs_to_ignore = [ + idx + for idx in op.inputs + if not np.all(graph_info.subgraph_tensors[idx].shape) + ] + op_qsvs = qsv_init_func(op_info, graph_info, inputs_to_ignore) # Step3: initialize tensor qsvs. for tensor_name, qsv in op_qsvs.items(): if tensor_name not in self._model_qsvs: diff --git a/ai_edge_quantizer/calibrator_test.py b/ai_edge_quantizer/calibrator_test.py index 454889e..38c4159 100644 --- a/ai_edge_quantizer/calibrator_test.py +++ b/ai_edge_quantizer/calibrator_test.py @@ -228,6 +228,18 @@ def test_calibrate_unsupported_ops_success(self): ) self.assertLen(test_calibrator.get_cached_output(), 10) + def test_calibrate_reshape_with_empty_shape_success(self): + test_model_path = os.path.join( + TEST_DATA_PREFIX_PATH, "tests/models/reshape_with_empty_shape.tflite" + ) + test_calibrator = calibrator.Calibrator(test_model_path) + _add_default_int8xint8_integer_recipe(self._recipe_manager) + calib_data = test_utils.create_random_normal_input_data( + test_model_path, num_samples=4 + ) + test_calibrator.calibrate(calib_data, self._recipe_manager) + self.assertNotEmpty(test_calibrator.get_model_qsvs()) + class CalibratorAlreadyQuantizedModelTest(googletest.TestCase): diff --git a/ai_edge_quantizer/tests/models/reshape_with_empty_shape.tflite b/ai_edge_quantizer/tests/models/reshape_with_empty_shape.tflite new file mode 100644 index 0000000000000000000000000000000000000000..157ec77808a050240763f887f3b5dac676573db0 GIT binary patch literal 852 zcmZuvJxfAS7=AS?v5N_bIY>A(#wGh0f)+6a!H=*@t4Rvo!DHw(>K{ZygE%-gwAJ7s zv^4cQBAWUc-Ja)ouZ8sLbKi5`?~jub5vx$zjZ0cml9af_#DFX;A)Ey=kz3#scnFA$ z0a?(0;V}k?M6?%uBDm3?zI#9jXg6#HrtNtjxvm=o=b84!W=lD?b5qa!+I)SxF7=Cm ztuFfxIDPz>?}$MiZPtKXwQ`!TofNB;dO2@pkFr)4Js7%=uxIDnwry4K_K~vPs$aFb zJ-ef12{y051b7670CUo(i(KSqz_ee^CpwEw*(M-}!Wl3Y?McRLrXf9-LE+_TYvXQ! z5$b+uNBbQ!b6{?Ze0aGSGh^;kG@}t_@?VZI0XA=i$Xn`~a>HX@39! literal 0 HcmV?d00001 diff --git a/ai_edge_quantizer/utils/tfl_interpreter_utils.py b/ai_edge_quantizer/utils/tfl_interpreter_utils.py index a27a326..e356374 100644 --- a/ai_edge_quantizer/utils/tfl_interpreter_utils.py +++ b/ai_edge_quantizer/utils/tfl_interpreter_utils.py @@ -188,9 +188,14 @@ def get_tensor_name_to_content_map( """ tensors = {} for tensor_detail in tflite_interpreter.get_tensor_details(subgraph_index): - # Don't return temporary, unnamed tensors + # Don't return temporary, unnamed tensors. if not tensor_detail["name"]: continue + + # Don't return tensors where any dimension of the shape is 0. + if not np.all(tensor_detail["shape"]): + continue + tensors[tensor_detail["name"]] = get_tensor_data( tflite_interpreter, tensor_detail, subgraph_index, dequantize )