Skip to content

Commit

Permalink
Add more rigorous int4 config checks.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 647433191
  • Loading branch information
Google AI Edge authored and copybara-github committed Jun 27, 2024
1 parent 49f2b42 commit 133b893
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 16 deletions.
15 changes: 12 additions & 3 deletions ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,21 @@
_TFLOpName.SOFTMAX,
])

_INT4_DRQ_SRQ_SUPPORTED_OPS = frozenset([
_TFLOpName.FULLY_CONNECTED,
_TFLOpName.CONV_2D,
])

def check_weight_only_config(op_name: _TFLOpName):

def check_weight_only_config(op_name: _TFLOpName) -> None:
"""Checks the op quantization config for weight-only quantization."""
if op_name not in _SUPPORTED_WEIGHT_ONLY_OPS:
raise ValueError(f"Unsupported op for weight-only quantization: {op_name}.")


def check_drq_config(
op_name: _TFLOpName, op_quant_config: qtyping.OpQuantizationConfig
):
) -> None:
"""Checks the op quantization config for dynamic range quantization."""
weight_config = op_quant_config.weight_tensor_config
if op_name not in _SUPPORTED_DRQ_OPS:
Expand All @@ -52,11 +57,13 @@ def check_drq_config(
raise ValueError(
f"Only int4/int8 symmetric DRQ is supported for op {op_name}"
)
if weight_config.num_bits == 4 and op_name not in _INT4_DRQ_SRQ_SUPPORTED_OPS:
raise ValueError(f"Int4 DRQ is not supported for op {op_name}.")


def check_srq_config(
op_name: _TFLOpName, op_quant_config: qtyping.OpQuantizationConfig
):
) -> None:
"""Checks the op quantization config for static range quantization."""
act_config = op_quant_config.activation_tensor_config
weight_config = op_quant_config.weight_tensor_config
Expand All @@ -81,6 +88,8 @@ def check_srq_config(
"Currently only int4/int8 symmetric weight are supported for op"
f" {op_name}."
)
if weight_config.num_bits == 4 and op_name not in _INT4_DRQ_SRQ_SUPPORTED_OPS:
raise ValueError(f"Int4 weight SRQ is not supported for op {op_name}.")


class OpQuantConstraint(enum.Enum):
Expand Down
98 changes: 85 additions & 13 deletions ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,14 @@
# TODO: b/335008966 - increase test coverage.
class MinMaxQuantizeUtilsTest(parameterized.TestCase):

@parameterized.parameters(
(_OpExecutionMode.WEIGHT_ONLY, True, True),
(_OpExecutionMode.WEIGHT_ONLY, False, False),
(_OpExecutionMode.WEIGHT_ONLY, True, False),
(_OpExecutionMode.WEIGHT_ONLY, False, True),
(_OpExecutionMode.DRQ, True, True),
(_OpExecutionMode.DRQ, False, False),
(_OpExecutionMode.DRQ, True, False),
(_OpExecutionMode.DRQ, False, True),
(_OpExecutionMode.SRQ, True, True),
(_OpExecutionMode.SRQ, False, False),
(_OpExecutionMode.SRQ, True, False),
(_OpExecutionMode.SRQ, False, True),
@parameterized.product(
execution_mode=[
_OpExecutionMode.WEIGHT_ONLY,
_OpExecutionMode.DRQ,
_OpExecutionMode.SRQ,
],
is_inbounding_tensor=[True, False],
is_constant=[True, False],
)
def test_get_tensor_transformations(
self, execution_mode, is_inbounding_tensor, is_constant
Expand Down Expand Up @@ -262,6 +257,83 @@ def test_check_srq_config_asym_weight_raise_error(self):
_TFLOpName.FULLY_CONNECTED, op_quant_config
)

@parameterized.product(
op_name=(_TFLOpName.FULLY_CONNECTED, _TFLOpName.CONV_2D),
activation_tensor_config=[
None,
_TensorQuantConfig(num_bits=8, symmetric=False),
_TensorQuantConfig(num_bits=16, symmetric=True),
],
execution_mode=[
_OpExecutionMode.WEIGHT_ONLY,
_OpExecutionMode.DRQ,
_OpExecutionMode.SRQ,
],
)
def test_check_supported_int4_config_succeeds(
self, op_name, activation_tensor_config, execution_mode
):
# Exclude invalid SRQ config.
if (
activation_tensor_config is not None
and execution_mode != _OpExecutionMode.SRQ
) or (
activation_tensor_config is None
and execution_mode == _OpExecutionMode.SRQ
):
return
op_quant_config = _OpQuantConfig(
activation_tensor_config=activation_tensor_config,
weight_tensor_config=_TensorQuantConfig(
num_bits=4, symmetric=True, channel_wise=True
),
execution_mode=execution_mode,
)
# Raise error if the config is not supported.
if execution_mode == _OpExecutionMode.DRQ:
min_max_quantize_utils.check_drq_config(op_name, op_quant_config)
elif execution_mode == _OpExecutionMode.WEIGHT_ONLY:
min_max_quantize_utils.check_weight_only_config(op_name)
elif execution_mode == _OpExecutionMode.SRQ:
min_max_quantize_utils.check_srq_config(op_name, op_quant_config)

@parameterized.product(
op_name=[_TFLOpName.BATCH_MATMUL],
activation_tensor_config=[
None,
_TensorQuantConfig(num_bits=8, symmetric=False),
_TensorQuantConfig(num_bits=16, symmetric=True),
],
execution_mode=[
_OpExecutionMode.DRQ,
_OpExecutionMode.SRQ,
],
)
def test_check_unsupported_int4_config_raise_error(
self, op_name, activation_tensor_config, execution_mode
):
# Exclude invalid SRQ config.
if (
activation_tensor_config is not None
and execution_mode != _OpExecutionMode.SRQ
) or (
activation_tensor_config is None
and execution_mode == _OpExecutionMode.SRQ
):
return
op_quant_config = _OpQuantConfig(
activation_tensor_config=activation_tensor_config,
weight_tensor_config=_TensorQuantConfig(
num_bits=4, symmetric=True, channel_wise=True
),
execution_mode=execution_mode,
)
with self.assertRaises(ValueError):
if execution_mode == _OpExecutionMode.DRQ:
min_max_quantize_utils.check_drq_config(op_name, op_quant_config)
elif execution_mode == _OpExecutionMode.SRQ:
min_max_quantize_utils.check_srq_config(op_name, op_quant_config)


if __name__ == "__main__":
googletest.main()

0 comments on commit 133b893

Please sign in to comment.