Skip to content

Commit

Permalink
Add split op support
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 654920880
  • Loading branch information
marialyu authored and copybara-github committed Jul 22, 2024
1 parent 7bdc407 commit 00ccc7d
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 3 deletions.
2 changes: 2 additions & 0 deletions ai_edge_quantizer/algorithm_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class AlgorithmName(str, enum.Enum):
_TFLOpName.RSQRT,
_TFLOpName.CONCATENATION,
_TFLOpName.STRIDED_SLICE,
_TFLOpName.SPLIT,
),
(
naive_min_max_quantize.materialize_fc_conv,
Expand All @@ -98,6 +99,7 @@ class AlgorithmName(str, enum.Enum):
naive_min_max_quantize.materialize_rsqrt,
naive_min_max_quantize.materialize_concatenation,
naive_min_max_quantize.materialize_strided_slice,
naive_min_max_quantize.materialize_split,
),
):
register_quantized_op(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,21 @@ def materialize_concatenation(
)


def materialize_split(
op_info: qtyping.OpInfo,
graph_info: qtyping.GraphInfo,
tensor_name_to_qsv: dict[str, Any],
) -> list[qtyping.TensorTransformationParams]:
"""Materialize tensors in tfl.split."""
return utils.materialize_standard_op(
op_info,
graph_info,
tensor_name_to_qsv,
constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
inputs_to_ignore=[0], # Split dimension does not need to be quantized.
)


# TODO: b/333731147 - Use named tuple to store min/max.
def init_qsvs(
op_info: qtyping.OpInfo,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2024 The AI Edge Quantizer Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import os

from absl.testing import parameterized
import numpy as np

from tensorflow.python.platform import googletest
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
from ai_edge_quantizer.algorithms.uniform_quantize.naive_min_max_quantize_op_tests import test_utils as naive_min_max_test_utils
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils

_TFLOpName = qtyping.TFLOperationName
_OpExecutionMode = qtyping.OpExecutionMode
_TensorQuantConfig = qtyping.TensorQuantizationConfig
_QuantTransformation = qtyping.QuantTransformation
_OpTestInfo = naive_min_max_test_utils.OpTestInfo

_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile(
"../../../tests/models"
)
_DEFAULT_ACTIVATION_QUANT_SETTING = (
naive_min_max_test_utils.DEFAULT_ACTIVATION_QUANT_SETTING
)
_DEFAULT_WEIGHT_QUANT_SETTING = (
naive_min_max_test_utils.DEFAULT_WEIGHT_QUANT_SETTING
)


class SplitTest(naive_min_max_test_utils.NaiveMinMaxQuantizeTest):

def setUp(self):
super().setUp()
np.random.seed(666)
self._test_model_path = os.path.join(
_TEST_DATA_PREFIX_PATH, "single_split.tflite"
)
self._op_test_info = _OpTestInfo(
test_model=tfl_flatbuffer_utils.read_model(self._test_model_path),
op_tensor_names={},
input_range=(np.array([[-10]]), np.array([[8]])),
output_range=(np.array([[10]]), np.array([[88]])),
)
# The test model has one subgraph for now.
self._graph_info = qtyping.GraphInfo(
subgraph_tensors=self._op_test_info.test_model.subgraphs[0].tensors,
buffers=self._op_test_info.test_model.buffers,
)

@parameterized.parameters(
(_DEFAULT_ACTIVATION_QUANT_SETTING),
(
_TensorQuantConfig(
num_bits=16,
symmetric=True,
channel_wise=False,
)
),
)
def test_materialize_split_succeeds(self, activation_tensor_config):
op_quant_config = qtyping.OpQuantizationConfig(
activation_tensor_config=activation_tensor_config,
weight_tensor_config=_DEFAULT_WEIGHT_QUANT_SETTING,
execution_mode=_OpExecutionMode.SRQ,
)
# Read from Model Explorer.
subgraph0 = self._op_test_info.test_model.subgraphs[0]
subgraph_op_id = 0
op = subgraph0.operators[subgraph_op_id]
op_info = qtyping.OpInfo(
op=op,
op_name=qtyping.TFLOperationName.TRANSPOSE,
subgraph_op_index=subgraph_op_id,
op_quant_config=op_quant_config,
)

# Test settings.
op_tensor_names = {}
op_tensor_names["input"] = "serving_default_input_1:0"
op_tensor_names["output"] = "PartitionedCall:0"
op_tensor_names["output2"] = "PartitionedCall:1"
self._op_test_info.op_tensor_names = op_tensor_names
self._test_one_input_two_output_ops(
op_info,
self._graph_info,
self._op_test_info,
naive_min_max_quantize.materialize_split,
same_input_output_params=True,
)


if __name__ == "__main__":
googletest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def _setup_op_test_config(
execution_mode,
op_test_info,
num_inputs=1,
num_outputs=1,
):
"""Helper to set up qsv for op test."""
# SRQ requires QSVs (min/max).
Expand All @@ -93,16 +94,20 @@ def _setup_op_test_config(
"min": output_min,
"max": output_max,
}
self._tensor_name_to_qsv = {
op_test_info.op_tensor_names["output"]: output_qsv,
}
for i in range(num_inputs):
input_name = "input"
if i > 0:
input_name = f"input{i+1}"
self._tensor_name_to_qsv[op_test_info.op_tensor_names[input_name]] = (
input_qsv
)
for i in range(num_outputs):
output_name = "output"
if i > 0:
output_name = f"output{i+1}"
self._tensor_name_to_qsv[op_test_info.op_tensor_names[output_name]] = (
output_qsv
)

def _test_single_input_output_ops(
self,
Expand Down Expand Up @@ -232,6 +237,78 @@ def _test_two_input_one_output_ops(
self.assertEqual(input1_tensor_quant_params, output_tensor_quant_params)
self.assertEqual(input2_tensor_quant_params, output_tensor_quant_params)

def _test_one_input_two_output_ops(
self,
op_info,
graph_info,
op_test_info,
materialization_func,
same_input_output_params=False,
):
"""Tests ops with one input and two outputs.
Can be used for ops such as SPLIT.
Args:
op_info: OpInfo object.
graph_info: GraphInfo object.
op_test_info: OpTestInfo object.
materialization_func: Function to materialize tensor transformation
parameters.
same_input_output_params: Whether the input and output tensor
transformation parameters are the same.
"""
op_quant_config = op_info.op_quant_config
self._setup_op_test_config(
execution_mode=op_quant_config.execution_mode,
op_test_info=op_test_info,
num_outputs=2,
)
tensor_quant_params = materialization_func(
op_info, graph_info, self._tensor_name_to_qsv
)
self.assertLen(tensor_quant_params, 3)

# Test input tensor settings.
transformations = [_QuantTransformation.NO_QUANTIZE]
if op_quant_config.execution_mode == _OpExecutionMode.SRQ:
transformations = [_QuantTransformation.ADD_QUANTIZE]
self._test_tensor_transformation_params(
op_test_info.op_tensor_names["input"],
op_info.subgraph_op_index,
is_inbounding_tensor=True,
tensor_quant_config=op_quant_config.activation_tensor_config,
transformation_params=tensor_quant_params[0],
desired_transformations=transformations,
)
# Test output tensor settings.
transformations = [_QuantTransformation.NO_QUANTIZE]
if op_quant_config.execution_mode == _OpExecutionMode.SRQ:
transformations = [_QuantTransformation.ADD_DEQUANTIZE]
self._test_tensor_transformation_params(
op_test_info.op_tensor_names["output"],
op_info.subgraph_op_index,
is_inbounding_tensor=False,
tensor_quant_config=op_quant_config.activation_tensor_config,
transformation_params=tensor_quant_params[1],
desired_transformations=transformations,
)
self._test_tensor_transformation_params(
op_test_info.op_tensor_names["output2"],
op_info.subgraph_op_index,
is_inbounding_tensor=False,
tensor_quant_config=op_quant_config.activation_tensor_config,
transformation_params=tensor_quant_params[2],
desired_transformations=transformations,
)

if same_input_output_params:
input_tensor_quant_params = tensor_quant_params[0].consumers[0].parameters # pytype: disable=attribute-error
output1_tensor_quant_params = tensor_quant_params[1].producer.parameters # pytype: disable=attribute-error
output2_tensor_quant_params = tensor_quant_params[2].producer.parameters # pytype: disable=attribute-error
self.assertEqual(output1_tensor_quant_params, input_tensor_quant_params)
self.assertEqual(output2_tensor_quant_params, input_tensor_quant_params)

def _test_fc_bmm_conv(
self,
op_info,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
_TFLOpName.RSQRT,
_TFLOpName.CONCATENATION,
_TFLOpName.STRIDED_SLICE,
_TFLOpName.SPLIT,
])

_INT4_SRQ_SUPPORTED_OPS = frozenset([
Expand Down
1 change: 1 addition & 0 deletions ai_edge_quantizer/qtyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class TFLOperationName(str, enum.Enum):
RSQRT = 'RSQRT'
CONCATENATION = 'CONCATENATION'
STRIDED_SLICE = 'STRIDED_SLICE'
SPLIT = 'SPLIT'


# Use same code number as MOJAX for compatibility.
Expand Down
85 changes: 85 additions & 0 deletions ai_edge_quantizer/tests/end_to_end_tests/split_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2024 The AI Edge Quantizer Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""E2E tests for the quantizer for model with transpose."""

from absl.testing import parameterized
import numpy as np

from tensorflow.python.platform import googletest
from ai_edge_quantizer import qtyping
from ai_edge_quantizer import quantizer
from ai_edge_quantizer.utils import test_utils

_OpExecutionMode = qtyping.OpExecutionMode
_OpName = qtyping.TFLOperationName
_TensorQuantConfig = qtyping.TensorQuantizationConfig
_OpQuantConfig = qtyping.OpQuantizationConfig

_RNG = np.random.default_rng(66)


def _get_dummy_data(num_samples):
data = []
for _ in range(num_samples):
data.append(
{'input_1': _RNG.uniform(size=(1, 10, 20, 30)).astype(np.float32)}
)
return data


def _get_calibration_data(num_samples: int = 128):
return _get_dummy_data(num_samples)


def _get_test_data(num_samples: int = 8):
return _get_dummy_data(num_samples)


class SplitTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.float_model_path = test_utils.get_path_to_datafile(
'../models/single_split.tflite'
)
self._quantizer = quantizer.Quantizer(self.float_model_path)

@parameterized.parameters(
'../recipes/default_a8w8_recipe.json',
'../recipes/default_a16w8_recipe.json',
)
def test_split_model_full_integer(self, recipe_path):
recipe_path = test_utils.get_path_to_datafile(recipe_path)
self._quantizer.load_quantization_recipe(recipe_path)
self.assertTrue(self._quantizer.need_calibration)
calibration_result = self._quantizer.calibrate(_get_calibration_data())
_ = self._quantizer.quantize(calibration_result)

comparison_result = self._quantizer.compare(
error_metrics='mse', signature_test_data=_get_test_data()
)
self._check_comparison_result(comparison_result, output_tolerance=1e-4)

# TODO: b/345503484 - Check weight tensor type of the quantized model.
def _check_comparison_result(self, comparison_result, output_tolerance):
output_mse_1 = comparison_result['PartitionedCall:0']
self.assertLess(output_mse_1, output_tolerance)
output_mse_2 = comparison_result['PartitionedCall:1']
self.assertLess(output_mse_2, output_tolerance)


if __name__ == '__main__':
googletest.main()
Binary file not shown.
1 change: 1 addition & 0 deletions ai_edge_quantizer/utils/tfl_flatbuffer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
_TFLOpName.RSQRT: schema_py_generated.BuiltinOperator.RSQRT,
_TFLOpName.CONCATENATION: schema_py_generated.BuiltinOperator.CONCATENATION,
_TFLOpName.STRIDED_SLICE: schema_py_generated.BuiltinOperator.STRIDED_SLICE,
_TFLOpName.SPLIT: schema_py_generated.BuiltinOperator.SPLIT,
})

TFL_OP_CODE_TO_NAME = immutabledict.immutabledict(
Expand Down

0 comments on commit 00ccc7d

Please sign in to comment.