Skip to content

Commit

Permalink
Fix a bug in CONV2D tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653291428
  • Loading branch information
ai-edge-bot authored and copybara-github committed Jul 17, 2024
1 parent 8b13d26 commit 27b8d09
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile(
"../../../tests/models"
)
_DEFAULT_ACTIVATION_QUANT_SETTING = (
naive_min_max_test_utils.DEFAULT_ACTIVATION_QUANT_SETTING
)


class Conv2dTest(naive_min_max_test_utils.NaiveMinMaxQuantizeTest):
Expand All @@ -37,50 +34,48 @@ def setUp(self):
op_tensor_names={},
input_range=(np.array([[-10]]), np.array([[8]])),
output_range=(np.array([[10]]), np.array([[88]])),
quantized_dimension=0,
)
# The test model has one subgraph for now.
self._graph_info = qtyping.GraphInfo(
subgraph_tensors=self._op_test_info.test_model.subgraphs[0].tensors,
buffers=self._op_test_info.test_model.buffers,
)
self._set_op_tensor_names()

def _set_op_tensor_names(self):
op_tensor_names = {}
op_tensor_names["weight"] = "sequential/conv2d/Conv2D"
op_tensor_names["bias"] = (
"sequential/conv2d/Relu;sequential/conv2d/BiasAdd;sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp"
)
op_tensor_names["input"] = "serving_default_conv2d_input:0"
op_tensor_names["output"] = (
"sequential/conv2d/Relu;sequential/conv2d/BiasAdd;sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp1"
)
self._op_test_info.op_tensor_names = op_tensor_names

# TODO(rewu): add int16 tests.
@parameterized.product(
num_bits_weight=(4, 8),
symmetric_weight=(True, False),
channel_wise_weight=(True, False),
execution_mode=(
_OpExecutionMode.WEIGHT_ONLY,
_OpExecutionMode.DRQ,
_OpExecutionMode.SRQ,
),
)
def test_materialize_conv2d_succeeds(
def test_materialize_weight_only_drq_conv2d_succeeds(
self,
num_bits_weight,
symmetric_weight,
channel_wise_weight,
execution_mode,
):

# Read from Model Explorer.
subgraph0 = self._op_test_info.test_model.subgraphs[0]
subgraph_op_id = 0
op = subgraph0.operators[subgraph_op_id]
op_tensor_names = {}
op_tensor_names["weight"] = "sequential/conv2d/Conv2D"
op_tensor_names["bias"] = (
"sequential/conv2d/Relu;sequential/conv2d/BiasAdd;sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp"
)
op_tensor_names["input"] = "serving_default_conv2d_input:0"
op_tensor_names["output"] = (
"sequential/conv2d/Relu;sequential/conv2d/BiasAdd;sequential/conv2d/Conv2D;sequential/conv2d/BiasAdd/ReadVariableOp1"
)
self._op_test_info.op_tensor_names = op_tensor_names

activation_tensor_config = None
if execution_mode == _OpExecutionMode.SRQ:
activation_tensor_config = _DEFAULT_ACTIVATION_QUANT_SETTING
op_info = qtyping.OpInfo(
op=op,
op_name=qtyping.TFLOperationName.CONV_2D,
Expand All @@ -95,7 +90,53 @@ def test_materialize_conv2d_succeeds(
execution_mode=execution_mode,
),
)
self._test_fc_bmm_conv(
op_info,
self._graph_info,
self._op_test_info,
naive_min_max_quantize.materialize_fc_conv,
)

@parameterized.product(
activation_num_bits=(8, 16),
weight_num_bits=(4, 8),
)
def test_materialize_srq_conv2d_succeeds(
self,
activation_num_bits,
weight_num_bits,
):
# Read from Model Explorer.
subgraph0 = self._op_test_info.test_model.subgraphs[0]
subgraph_op_id = 0
op = subgraph0.operators[subgraph_op_id]

if activation_num_bits == 8:
activation_tensor_config = _TensorQuantConfig(
num_bits=8,
symmetric=False,
channel_wise=False,
)
else:
activation_tensor_config = _TensorQuantConfig(
num_bits=16,
symmetric=True,
channel_wise=False,
)
op_info = qtyping.OpInfo(
op=op,
op_name=qtyping.TFLOperationName.CONV_2D,
subgraph_op_index=subgraph_op_id,
op_quant_config=qtyping.OpQuantizationConfig(
activation_tensor_config=activation_tensor_config,
weight_tensor_config=_TensorQuantConfig(
num_bits=weight_num_bits,
symmetric=True,
channel_wise=True,
),
execution_mode=_OpExecutionMode.SRQ,
),
)
self._test_fc_bmm_conv(
op_info,
self._graph_info,
Expand Down
13 changes: 10 additions & 3 deletions ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
_TFLOpName.DEPTHWISE_CONV_2D,
])

_INT4_DRQ_SUPPORTED_OPS = frozenset([
_TFLOpName.FULLY_CONNECTED,
# TODO: b/353365054 - Uncomment after int4 DRQ is supported for
# conv2d.
# _TFLOpName.CONV_2D,
])

_SUPPORTED_SRQ_OPS = frozenset([
_TFLOpName.FULLY_CONNECTED,
_TFLOpName.CONV_2D,
Expand All @@ -40,7 +47,7 @@
_TFLOpName.ADD,
])

_INT4_DRQ_SRQ_SUPPORTED_OPS = frozenset([
_INT4_SRQ_SUPPORTED_OPS = frozenset([
_TFLOpName.FULLY_CONNECTED,
_TFLOpName.CONV_2D,
_TFLOpName.EMBEDDING_LOOKUP,
Expand All @@ -66,7 +73,7 @@ def check_drq_config(
raise ValueError(
f"Only int4/int8 symmetric DRQ is supported for op {op_name}"
)
if weight_config.num_bits == 4 and op_name not in _INT4_DRQ_SRQ_SUPPORTED_OPS:
if weight_config.num_bits == 4 and op_name not in _INT4_DRQ_SUPPORTED_OPS:
raise ValueError(f"Int4 DRQ is not supported for op {op_name}.")


Expand Down Expand Up @@ -97,7 +104,7 @@ def check_srq_config(
"Currently only int4/int8 symmetric weight are supported for op"
f" {op_name}."
)
if weight_config.num_bits == 4 and op_name not in _INT4_DRQ_SRQ_SUPPORTED_OPS:
if weight_config.num_bits == 4 and op_name not in _INT4_SRQ_SUPPORTED_OPS:
raise ValueError(f"Int4 weight SRQ is not supported for op {op_name}.")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def test_check_weight_only_config_fails(self, op_name):
def test_check_drq_config_succeeds(
self, op_name, weight_num_bits, weight_channel_wise
):
# TODO: b/353365054 - Remove this check after int4 DRQ is supported for
# conv2d.
if op_name == _TFLOpName.CONV_2D and weight_num_bits == 4:
return
op_quant_config = _OpQuantConfig(
weight_tensor_config=_TensorQuantConfig(
num_bits=weight_num_bits, channel_wise=weight_channel_wise
Expand Down Expand Up @@ -259,7 +263,10 @@ def test_check_srq_config_asym_weight_raise_error(self):
)

@parameterized.product(
op_name=(_TFLOpName.FULLY_CONNECTED, _TFLOpName.CONV_2D),
op_name=[
_TFLOpName.FULLY_CONNECTED,
_TFLOpName.CONV_2D,
],
activation_tensor_config=[
None,
_TensorQuantConfig(num_bits=8, symmetric=False),
Expand All @@ -283,6 +290,10 @@ def test_check_supported_int4_config_succeeds(
and execution_mode == _OpExecutionMode.SRQ
):
return
# TODO: b/353365054 - Remove this check after int4 DRQ is supported for
# conv2d.
if execution_mode == _OpExecutionMode.DRQ and op_name == _TFLOpName.CONV_2D:
return
op_quant_config = _OpQuantConfig(
activation_tensor_config=activation_tensor_config,
weight_tensor_config=_TensorQuantConfig(
Expand Down

0 comments on commit 27b8d09

Please sign in to comment.