Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable calibration and model validation with XNNPACK #175

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ai_edge_quantizer/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Calibrator:
def __init__(
self,
float_tflite: Union[str, bytes],
num_threads: int = 16,
):
self._flatbuffer_model = tfl_flatbuffer_utils.read_model(float_tflite)

Expand All @@ -50,7 +51,7 @@ def __init__(
" the model (e.g., if it is already quantized)."
)
self._tfl_interpreter = tfl_interpreter_utils.create_tfl_interpreter(
float_tflite
float_tflite, use_xnnpack=True, num_threads=num_threads
)
# Tensor name to tensor content.
self._tensor_content_map: dict[str, Any] = {}
Expand Down
25 changes: 16 additions & 9 deletions ai_edge_quantizer/model_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ def _setup_validation_interpreter(
model: bytes,
signature_input: dict[str, Any],
signature_key: Optional[str],
use_reference_kernel: bool,
use_xnnpack: bool,
num_threads: int,
) -> tuple[Any, int, dict[str, Any]]:
"""Setup the interpreter for validation given a signature key.
Expand All @@ -216,15 +217,15 @@ def _setup_validation_interpreter(
signature_input: A dictionary of input tensor name and its value.
signature_key: The signature key to be used for invoking the models. If the
model only has one signature, this can be set to None.
use_reference_kernel: Whether to use the reference kernel for the
interpreter.
use_xnnpack: Whether to use xnnpack for the interpreter.
num_threads: The number of threads to use for the interpreter.
Returns:
A tuple of interpreter, subgraph_index and tensor_name_to_details.
"""

interpreter = utils.create_tfl_interpreter(
tflite_model=model, use_reference_kernel=use_reference_kernel
tflite_model=model, use_xnnpack=use_xnnpack, num_threads=num_threads
)
utils.invoke_interpreter_signature(
interpreter, signature_input, signature_key
Expand All @@ -247,7 +248,8 @@ def compare_model(
test_data: dict[str, Iterable[dict[str, Any]]],
error_metric: str,
compare_fn: Callable[[Any, Any], float],
use_reference_kernel: bool = False,
use_xnnpack: bool = True,
num_threads: int = 16,
) -> ComparisonResult:
"""Compares model tensors over a model signature using the compare_fn.
Expand All @@ -266,8 +268,8 @@ def compare_model(
compare_fn: a comparison function to be used for calculating the statistics,
this function must be taking in two ArrayLike strcuture and output a
single float value.
use_reference_kernel: Whether to use the reference kernel for the
interpreter.
use_xnnpack: Whether to use xnnpack for the interpreter.
num_threads: The number of threads to use for the interpreter.
Returns:
A ComparisonResult object.
Expand All @@ -282,12 +284,17 @@ def compare_model(
reference_model,
signature_input,
signature_key,
use_reference_kernel,
use_xnnpack,
num_threads,
)
)
targ_interpreter, targ_subgraph_index, targ_tensor_name_to_details = (
_setup_validation_interpreter(
target_model, signature_input, signature_key, use_reference_kernel
target_model,
signature_input,
signature_key,
use_xnnpack,
num_threads,
)
)
# Compare the cached tensor values.
Expand Down
13 changes: 9 additions & 4 deletions ai_edge_quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,15 @@ def calibrate(
self,
calibration_data: dict[str, Iterable[_SignatureInput]],
previous_calibration_result: Optional[_CalibrationResult] = None,
num_threads: int = 16,
) -> _CalibrationResult:
"""Calibrates the float model (required by static range quantization).
Args:
calibration_data: Calibration data for a model signature.
previous_calibration_result: Previous calibration result to be loaded. The
calibration process will be resumed from the previous result.
num_threads: Number of threads to use for calibration.
Returns:
Calibration result ({tensor_name: tensor QSVs (e.g.,min/max)}).
Expand All @@ -233,7 +235,7 @@ def calibrate(
if not self.need_calibration:
return {}

calib = calibrator.Calibrator(self.float_model)
calib = calibrator.Calibrator(self.float_model, num_threads=num_threads)
if previous_calibration_result is not None:
calib.load_model_qsvs(previous_calibration_result)
calib.calibrate(calibration_data, self._recipe_manager)
Expand Down Expand Up @@ -297,7 +299,8 @@ def validate(
self,
test_data: Optional[dict[str, Iterable[_SignatureInput]]] = None,
error_metrics: str = 'mse',
use_reference_kernel: bool = False,
use_xnnpack: bool = True,
num_threads: int = 16,
) -> model_validator.ComparisonResult:
"""Numerical validation of the quantized model for a model signature.
Expand All @@ -314,7 +317,8 @@ def validate(
data that will be used for validation. If set to None, random normal
distributed data will be used for all signatures in the model.
error_metrics: Error metrics to be used for comparison.
use_reference_kernel: Whether to use the reference kernel for validation.
use_xnnpack: Whether to use the xnnpack library for validation.
num_threads: Number of threads to use for validation.
Returns:
The comparison result.
Expand All @@ -330,7 +334,8 @@ def validate(
test_data,
error_metrics,
validation_utils.get_validation_func(error_metrics),
use_reference_kernel=use_reference_kernel,
use_xnnpack=use_xnnpack,
num_threads=num_threads,
)

def _get_quantization_params(
Expand Down
12 changes: 7 additions & 5 deletions ai_edge_quantizer/utils/tfl_interpreter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@
def create_tfl_interpreter(
tflite_model: Union[str, bytes],
allocate_tensors: bool = True,
use_reference_kernel: bool = False,
use_xnnpack: bool = True,
num_threads: int = 16,
) -> tfl.Interpreter:
"""Creates a TFLite interpreter from a model file.
Args:
tflite_model: Model file path or bytes.
allocate_tensors: Whether to allocate tensors.
use_reference_kernel: Whether to use the reference kernel for the
interpreter.
use_xnnpack: Whether to use the XNNPACK delegate for the interpreter.
num_threads: The number of threads to use for the interpreter.
Returns:
A TFLite interpreter.
Expand All @@ -47,12 +48,13 @@ def create_tfl_interpreter(
with gfile.GFile(tflite_model, "rb") as f:
tflite_model = f.read()

if use_reference_kernel:
op_resolver = tfl.OpResolverType.BUILTIN_REF
if use_xnnpack:
op_resolver = tfl.OpResolverType.BUILTIN
else:
op_resolver = tfl.OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES
tflite_interpreter = tfl.Interpreter(
model_content=bytes(tflite_model),
num_threads=num_threads,
experimental_op_resolver_type=op_resolver,
experimental_preserve_all_tensors=True,
)
Expand Down
Loading