diff --git a/ai_edge_quantizer/quantizer.py b/ai_edge_quantizer/quantizer.py index 66506dc..cc1765c 100644 --- a/ai_edge_quantizer/quantizer.py +++ b/ai_edge_quantizer/quantizer.py @@ -30,6 +30,7 @@ from ai_edge_quantizer import recipe_manager from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils +from ai_edge_quantizer.utils import tfl_interpreter_utils from ai_edge_quantizer.utils import validation_utils from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import @@ -225,6 +226,9 @@ def calibrate( Returns: Calibration result ({tensor_name: tensor QSVs (e.g.,min/max)}). + + Raises: + ValueError: If the calibration result is insufficient. """ if not self.need_calibration: return {} @@ -235,6 +239,32 @@ def calibrate( calib.calibrate(calibration_data, self._recipe_manager) return calib.get_model_qsvs() + def _ensure_model_qsv_sufficient( + self, calibration_result: _CalibrationResult + ): + """Checks if the calibration result has sufficient QSV.""" + + # Find all tensor names with empty entries. + empty_qsvs = [key for key, value in calibration_result.items() if not value] + + # Go over every signature and check if empty entry tensor belongs to it. + tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter( + self.float_model + ) + for signature_key in tfl_interpreter.get_signature_list(): + subgraph_idx = tfl_interpreter_utils.get_signature_main_subgraph_index( + tfl_interpreter, signature_key + ) + + for tensor_detail in tfl_interpreter.get_tensor_details(subgraph_idx): + tensor_name = tensor_detail['name'] + if tensor_name in empty_qsvs: + raise ValueError( + f'Missing QSVs (min/max) for tensor {tensor_name} in Signature' + f" '{signature_key}'. Please check if Signature" + f' {signature_key} has been calibrated.' + ) + def quantize( self, calibration_result: Optional[_CalibrationResult] = None ) -> QuantizationResult: @@ -251,6 +281,9 @@ def quantize( RuntimeError: If quantization recipe is empty. """ + if calibration_result is not None: + self._ensure_model_qsv_sufficient(calibration_result) + if not self.get_quantization_recipe(): raise RuntimeError('Can not quantize without a quantization recipe.') quant_params = self._get_quantization_params(calibration_result) diff --git a/ai_edge_quantizer/quantizer_test.py b/ai_edge_quantizer/quantizer_test.py index e6a8966..7123c37 100644 --- a/ai_edge_quantizer/quantizer_test.py +++ b/ai_edge_quantizer/quantizer_test.py @@ -450,6 +450,23 @@ def test_recipe_conflict_raises_error(self): ): qt.quantize(calib_result) + def test_quantization_with_insufficient_calibration(self): + # Run calibration for one signature only. + scarce_calibration_dataset = { + 'add': [{'x': np.array([2.0], dtype=np.float32)}], + } + calib_result = self._quantizer.calibrate(scarce_calibration_dataset) + + # Quantize and expect an error about missing signature in calibration data. + error_message = ( + 'Missing QSVs (min/max) for tensor multiply_x:0 in Signature' + " 'multiply'." + ) + with self.assertRaisesWithPredicateMatch( + ValueError, lambda err: error_message in str(err) + ): + self._quantizer.quantize(calib_result) + class QuantizerToyGemma2Test(parameterized.TestCase):