diff --git a/ai_edge_quantizer/calibrator.py b/ai_edge_quantizer/calibrator.py index c91066d..bc3b792 100644 --- a/ai_edge_quantizer/calibrator.py +++ b/ai_edge_quantizer/calibrator.py @@ -281,7 +281,13 @@ def _initialize_model_qsvs( algorithm_name, op_key ) op_info = qtyping.OpInfo(op, op_key, subgraph_op_id, op_quant_config) - op_qsvs = qsv_init_func(op_info, graph_info) + # Ignore the input tensors where any dimension of the shape is 0. + inputs_to_ignore = [ + idx + for idx in op.inputs + if not np.all(graph_info.subgraph_tensors[idx].shape) + ] + op_qsvs = qsv_init_func(op_info, graph_info, inputs_to_ignore) # Step3: initialize tensor qsvs. for tensor_name, qsv in op_qsvs.items(): if tensor_name not in self._model_qsvs: diff --git a/ai_edge_quantizer/calibrator_test.py b/ai_edge_quantizer/calibrator_test.py index 454889e..38c4159 100644 --- a/ai_edge_quantizer/calibrator_test.py +++ b/ai_edge_quantizer/calibrator_test.py @@ -228,6 +228,18 @@ def test_calibrate_unsupported_ops_success(self): ) self.assertLen(test_calibrator.get_cached_output(), 10) + def test_calibrate_reshape_with_empty_shape_success(self): + test_model_path = os.path.join( + TEST_DATA_PREFIX_PATH, "tests/models/reshape_with_empty_shape.tflite" + ) + test_calibrator = calibrator.Calibrator(test_model_path) + _add_default_int8xint8_integer_recipe(self._recipe_manager) + calib_data = test_utils.create_random_normal_input_data( + test_model_path, num_samples=4 + ) + test_calibrator.calibrate(calib_data, self._recipe_manager) + self.assertNotEmpty(test_calibrator.get_model_qsvs()) + class CalibratorAlreadyQuantizedModelTest(googletest.TestCase): diff --git a/ai_edge_quantizer/tests/models/reshape_with_empty_shape.tflite b/ai_edge_quantizer/tests/models/reshape_with_empty_shape.tflite new file mode 100644 index 0000000..157ec77 Binary files /dev/null and b/ai_edge_quantizer/tests/models/reshape_with_empty_shape.tflite differ diff --git a/ai_edge_quantizer/utils/tfl_interpreter_utils.py b/ai_edge_quantizer/utils/tfl_interpreter_utils.py index a27a326..e356374 100644 --- a/ai_edge_quantizer/utils/tfl_interpreter_utils.py +++ b/ai_edge_quantizer/utils/tfl_interpreter_utils.py @@ -188,9 +188,14 @@ def get_tensor_name_to_content_map( """ tensors = {} for tensor_detail in tflite_interpreter.get_tensor_details(subgraph_index): - # Don't return temporary, unnamed tensors + # Don't return temporary, unnamed tensors. if not tensor_detail["name"]: continue + + # Don't return tensors where any dimension of the shape is 0. + if not np.all(tensor_detail["shape"]): + continue + tensors[tensor_detail["name"]] = get_tensor_data( tflite_interpreter, tensor_detail, subgraph_index, dequantize )