Skip to content

Commit

Permalink
Check if calibration result has sufficient QSV before quantization.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 692266950
  • Loading branch information
v-dziuba authored and copybara-github committed Nov 1, 2024
1 parent 9e35400 commit eeb967e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
33 changes: 33 additions & 0 deletions ai_edge_quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ai_edge_quantizer import recipe_manager
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
from ai_edge_quantizer.utils import tfl_interpreter_utils
from ai_edge_quantizer.utils import validation_utils
from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import

Expand Down Expand Up @@ -225,6 +226,9 @@ def calibrate(
Returns:
Calibration result ({tensor_name: tensor QSVs (e.g.,min/max)}).
Raises:
ValueError: If the calibration result is insufficient.
"""
if not self.need_calibration:
return {}
Expand All @@ -235,6 +239,32 @@ def calibrate(
calib.calibrate(calibration_data, self._recipe_manager)
return calib.get_model_qsvs()

def _ensure_model_qsv_sufficient(
self, calibration_result: _CalibrationResult
):
"""Checks if the calibration result has sufficient QSV."""

# Find all tensor names with empty entries.
empty_qsvs = [key for key, value in calibration_result.items() if not value]

# Go over every signature and check if empty entry tensor belongs to it.
tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
self.float_model
)
for signature_key in tfl_interpreter.get_signature_list():
subgraph_idx = tfl_interpreter_utils.get_signature_main_subgraph_index(
tfl_interpreter, signature_key
)

for tensor_detail in tfl_interpreter.get_tensor_details(subgraph_idx):
tensor_name = tensor_detail['name']
if tensor_name in empty_qsvs:
raise ValueError(
f'Missing QSVs (min/max) for tensor {tensor_name} in Signature'
f" '{signature_key}'. Please check if Signature"
f' {signature_key} has been calibrated.'
)

def quantize(
self, calibration_result: Optional[_CalibrationResult] = None
) -> QuantizationResult:
Expand All @@ -251,6 +281,9 @@ def quantize(
RuntimeError: If quantization recipe is empty.
"""

if calibration_result is not None:
self._ensure_model_qsv_sufficient(calibration_result)

if not self.get_quantization_recipe():
raise RuntimeError('Can not quantize without a quantization recipe.')
quant_params = self._get_quantization_params(calibration_result)
Expand Down
17 changes: 17 additions & 0 deletions ai_edge_quantizer/quantizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,23 @@ def test_recipe_conflict_raises_error(self):
):
qt.quantize(calib_result)

def test_quantization_with_insufficient_calibration(self):
# Run calibration for one signature only.
scarce_calibration_dataset = {
'add': [{'x': np.array([2.0], dtype=np.float32)}],
}
calib_result = self._quantizer.calibrate(scarce_calibration_dataset)

# Quantize and expect an error about missing signature in calibration data.
error_message = (
'Missing QSVs (min/max) for tensor multiply_x:0 in Signature'
" 'multiply'."
)
with self.assertRaisesWithPredicateMatch(
ValueError, lambda err: error_message in str(err)
):
self._quantizer.quantize(calib_result)


class QuantizerToyGemma2Test(parameterized.TestCase):

Expand Down

0 comments on commit eeb967e

Please sign in to comment.