Skip to content

Commit

Permalink
Enable quantization when op may contain tensors with zero-size array
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723093115
  • Loading branch information
v-dziuba authored and copybara-github committed Feb 4, 2025
1 parent a3318da commit 83a275d
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 2 deletions.
8 changes: 7 additions & 1 deletion ai_edge_quantizer/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,13 @@ def _initialize_model_qsvs(
algorithm_name, op_key
)
op_info = qtyping.OpInfo(op, op_key, subgraph_op_id, op_quant_config)
op_qsvs = qsv_init_func(op_info, graph_info)
# Ignore the input tensors where any dimension of the shape is 0.
inputs_to_ignore = [
idx
for idx in op.inputs
if not np.all(graph_info.subgraph_tensors[idx].shape)
]
op_qsvs = qsv_init_func(op_info, graph_info, inputs_to_ignore)
# Step3: initialize tensor qsvs.
for tensor_name, qsv in op_qsvs.items():
if tensor_name not in self._model_qsvs:
Expand Down
12 changes: 12 additions & 0 deletions ai_edge_quantizer/calibrator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,18 @@ def test_calibrate_unsupported_ops_success(self):
)
self.assertLen(test_calibrator.get_cached_output(), 10)

def test_calibrate_reshape_with_empty_shape_success(self):
test_model_path = os.path.join(
TEST_DATA_PREFIX_PATH, "tests/models/reshape_with_empty_shape.tflite"
)
test_calibrator = calibrator.Calibrator(test_model_path)
_add_default_int8xint8_integer_recipe(self._recipe_manager)
calib_data = test_utils.create_random_normal_input_data(
test_model_path, num_samples=4
)
test_calibrator.calibrate(calib_data, self._recipe_manager)
self.assertNotEmpty(test_calibrator.get_model_qsvs())


class CalibratorAlreadyQuantizedModelTest(googletest.TestCase):

Expand Down
Binary file not shown.
7 changes: 6 additions & 1 deletion ai_edge_quantizer/utils/tfl_interpreter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,14 @@ def get_tensor_name_to_content_map(
"""
tensors = {}
for tensor_detail in tflite_interpreter.get_tensor_details(subgraph_index):
# Don't return temporary, unnamed tensors
# Don't return temporary, unnamed tensors.
if not tensor_detail["name"]:
continue

# Don't return tensors where any dimension of the shape is 0.
if not np.all(tensor_detail["shape"]):
continue

tensors[tensor_detail["name"]] = get_tensor_data(
tflite_interpreter, tensor_detail, subgraph_index, dequantize
)
Expand Down

0 comments on commit 83a275d

Please sign in to comment.