Skip to content

Commit

Permalink
Add batch matmul test for two input version(No weights)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653042791
  • Loading branch information
ai-edge-bot authored and copybara-github committed Jul 17, 2024
1 parent 27b8d09 commit 31019e3
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@

class BatchMatmulTest(naive_min_max_test_utils.NaiveMinMaxQuantizeTest):

def setUp(self):
super().setUp()
def _custom_setup(self, test_model_file):
np.random.seed(666)
self._test_model_path = os.path.join(_TEST_DATA_PREFIX_PATH, "bmm.tflite")
self._test_model_path = os.path.join(
_TEST_DATA_PREFIX_PATH, test_model_file
)
self._op_test_info = _OpTestInfo(
test_model=tfl_flatbuffer_utils.read_model(self._test_model_path),
op_tensor_names={},
Expand All @@ -44,7 +45,6 @@ def setUp(self):
@parameterized.product(
num_bits_weight=(4, 8),
symmetric_weight=(True, False),
channel_wise_weight=(True, False),
execution_mode=(
_OpExecutionMode.SRQ,
_OpExecutionMode.DRQ,
Expand All @@ -55,10 +55,9 @@ def test_batch_matmul_adjy_false_succeeds(
self,
num_bits_weight,
symmetric_weight,
channel_wise_weight,
execution_mode,
):

self._custom_setup("bmm.tflite")
# Read from Model Explorer.
subgraph0 = self._op_test_info.test_model.subgraphs[0]
subgraph_op_id = 0
Expand All @@ -85,7 +84,7 @@ def test_batch_matmul_adjy_false_succeeds(
weight_tensor_config=_TensorQuantConfig(
num_bits=num_bits_weight,
symmetric=symmetric_weight,
channel_wise=channel_wise_weight,
channel_wise=False,
),
execution_mode=execution_mode,
),
Expand All @@ -100,7 +99,6 @@ def test_batch_matmul_adjy_false_succeeds(
@parameterized.product(
num_bits_weight=(4, 8),
symmetric_weight=(True, False),
channel_wise_weight=(True, False),
execution_mode=(
_OpExecutionMode.SRQ,
_OpExecutionMode.DRQ,
Expand All @@ -111,9 +109,9 @@ def test_batch_matmul_adjy_true_succeeds(
self,
num_bits_weight,
symmetric_weight,
channel_wise_weight,
execution_mode,
):
self._custom_setup("bmm.tflite")

# Read from Model Explorer.
subgraph0 = self._op_test_info.test_model.subgraphs[0]
Expand Down Expand Up @@ -141,7 +139,7 @@ def test_batch_matmul_adjy_true_succeeds(
weight_tensor_config=_TensorQuantConfig(
num_bits=num_bits_weight,
symmetric=symmetric_weight,
channel_wise=channel_wise_weight,
channel_wise=False,
),
execution_mode=execution_mode,
),
Expand All @@ -153,6 +151,110 @@ def test_batch_matmul_adjy_true_succeeds(
naive_min_max_quantize.materialize_batch_matmul,
)

@parameterized.product(
num_bits_weight=(4, 8),
symmetric_weight=(True, False),
execution_mode=(
_OpExecutionMode.SRQ,
),
)
def test_batch_matmul_two_inputs_adjy_false_succeeds(
self,
num_bits_weight,
symmetric_weight,
execution_mode,
):
self._custom_setup("bmm_two_inputs.tflite")
# Read from Model Explorer.
subgraph0 = self._op_test_info.test_model.subgraphs[0]
subgraph_op_id = 0
op = subgraph0.operators[subgraph_op_id]
op_tensor_names = {}
op_tensor_names["input"] = "input1"
op_tensor_names["input2"] = "input2"
op_tensor_names["output"] = (
"BatchMatMulV3;jax2tf_export_func_/PartitionedCall/BatchMatMulV3"
)
self._op_test_info.op_tensor_names = op_tensor_names
self._op_test_info.quantized_dimension = 2

activation_tensor_config = None
if execution_mode == _OpExecutionMode.SRQ:
activation_tensor_config = _DEFAULT_ACTIVATION_QUANT_SETTING

op_info = qtyping.OpInfo(
op=op,
op_name=qtyping.TFLOperationName.BATCH_MATMUL,
subgraph_op_index=subgraph_op_id,
op_quant_config=qtyping.OpQuantizationConfig(
activation_tensor_config=activation_tensor_config,
weight_tensor_config=_TensorQuantConfig(
num_bits=num_bits_weight,
symmetric=symmetric_weight,
channel_wise=False,
),
execution_mode=execution_mode,
),
)
self._test_two_input_one_output_ops(
op_info,
self._graph_info,
self._op_test_info,
naive_min_max_quantize.materialize_batch_matmul,
)

@parameterized.product(
num_bits_weight=(4, 8),
symmetric_weight=(True, False),
execution_mode=(
_OpExecutionMode.SRQ,
),
)
def test_batch_matmul_two_inputs_adjy_true_succeeds(
self,
num_bits_weight,
symmetric_weight,
execution_mode,
):
self._custom_setup("bmm_two_inputs.tflite")

# Read from Model Explorer.
subgraph0 = self._op_test_info.test_model.subgraphs[0]
subgraph_op_id = 1
op = subgraph0.operators[subgraph_op_id]
op_tensor_names = {}
op_tensor_names["input"] = (
"BatchMatMulV3;jax2tf_export_func_/PartitionedCall/BatchMatMulV3"
)
op_tensor_names["input2"] = "input2"
op_tensor_names["output"] = "Identity_1"
self._op_test_info.op_tensor_names = op_tensor_names
self._op_test_info.quantized_dimension = 1

activation_tensor_config = None
if execution_mode == _OpExecutionMode.SRQ:
activation_tensor_config = _DEFAULT_ACTIVATION_QUANT_SETTING

op_info = qtyping.OpInfo(
op=op,
op_name=qtyping.TFLOperationName.BATCH_MATMUL,
subgraph_op_index=subgraph_op_id,
op_quant_config=qtyping.OpQuantizationConfig(
activation_tensor_config=activation_tensor_config,
weight_tensor_config=_TensorQuantConfig(
num_bits=num_bits_weight,
symmetric=symmetric_weight,
channel_wise=False,
),
execution_mode=execution_mode,
),
)
self._test_two_input_one_output_ops(
op_info,
self._graph_info,
self._op_test_info,
naive_min_max_quantize.materialize_batch_matmul,
)

if __name__ == "__main__":
googletest.main()
Binary file added ai_edge_quantizer/tests/models/bmm_two_inputs.tflite
Binary file not shown.

0 comments on commit 31019e3

Please sign in to comment.