Skip to content

Commit

Permalink
Fix comments and docstring formatting in params_generator.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653311734
  • Loading branch information
marialyu authored and copybara-github committed Jul 17, 2024
1 parent 27b8d09 commit 65fd572
Showing 1 changed file with 25 additions and 24 deletions.
49 changes: 25 additions & 24 deletions ai_edge_quantizer/params_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def generate_quantization_parameters(
"""Generate the quantization parameters for the model.
Args:
model_recipe_manager: the recipe manager for the model.
model_qsvs: quantization statistics values (qsvs) for the model. This is
model_recipe_manager: The recipe manager for the model.
model_qsvs: Quantization statistics values (QSVs) for the model. This is
obtained through calibration process.
Returns:
model_quant_results: the quantization parameters for tensors in the model.
model_quant_results: The quantization parameters for tensors in the model.
Raises:
RuntimeError: if the calibration dataset is required but not provided.
RuntimeError: If the calibration dataset is required but not provided.
"""
if model_recipe_manager.need_calibration() and not model_qsvs:
raise RuntimeError(
Expand Down Expand Up @@ -95,7 +95,7 @@ def _post_process_results(self) -> None:
"""Post process the quantization results.
Raises:
RuntimeError: if the tensors sharing the same buffer have different
RuntimeError: If the tensors sharing the same buffer have different
quantization settings.
"""
self._check_buffer_sharing()
Expand All @@ -110,10 +110,10 @@ def _update_model_quant_results(
"""Update the op quantization results to the final output.
Args:
op_tensor_results: list of tensor level quantization params for the op.
op_tensor_results: Tensor level quantization params for the op.
Raises:
RuntimeError: if the same tensor has multiple quantization configs.
RuntimeError: If the same tensor has multiple quantization configs.
"""

for op_tensor_result in op_tensor_results:
Expand All @@ -124,7 +124,7 @@ def _update_model_quant_results(
tensor_params = self.model_quant_results[tensor_name]
# Set source op.
if op_tensor_result.producer is not None:
# src params must be unique (a tensor can only be produced by one op).
# Src params must be unique (a tensor can only be produced by one op).
if tensor_params.producer is not None:
raise RuntimeError(
'Tensor %s received multiple quantization parameters from the'
Expand All @@ -148,11 +148,11 @@ def _get_op_scope(self, op: Any, subgraph_tensors: list[Any]) -> str:
Explorer).
Args:
op: the op that need to be parsed.
subgraph_tensors: list of tensors in the subgraph.
op: The op that needs to be parsed.
subgraph_tensors: Tensors in the subgraph.
Returns:
scope: scope for the op.
Scope for the op.
"""
scope = ''
# Op scope is determined by output tensors.
Expand All @@ -161,7 +161,7 @@ def _get_op_scope(self, op: Any, subgraph_tensors: list[Any]) -> str:
scope += tfl_flatbuffer_utils.get_tensor_name(
subgraph_tensors[output_tensor_idx]
)
scope += ';' # split names
scope += ';' # Split names.
return scope

def _get_params_for_no_quant_op(
Expand All @@ -173,12 +173,12 @@ def _get_params_for_no_quant_op(
"""Get the quantization parameters for ops require no quantization.
Args:
subgraph_op_id: the op id in the subgraph.
op: the op that need to be parsed.
subgraph_tensors: list of tensors in the subgraph.
subgraph_op_id: The op id in the subgraph.
op: The op that needs to be parsed.
subgraph_tensors: Tensors in the subgraph.
Returns:
tensor_params: list of tensor level quantization params for the op.
Tensor level quantization params for the op.
"""

def no_quant_tensor_params():
Expand Down Expand Up @@ -211,8 +211,8 @@ def _check_buffer_sharing(self) -> None:
"""Check if tensors sharing the same buffer have the same quantization.
Raises:
RuntimeError: if the tensors sharing the same buffer have different
quantization settings.
RuntimeError: If the tensors sharing the same buffer have different
quantization settings.
"""
for tensors in self.buffer_to_tensors.values():
first_tensor = tensors[0]
Expand Down Expand Up @@ -240,14 +240,15 @@ def _modify_io_tensor_transformations(
) -> None:
"""Modify quantization information for I/O tensors.
This will not be trigged by weight-only/drq because they do not quantize
This will not be trigged by weight-only/DRQ because they do not quantize
activation tensors.
Selective srq & emulated srq will be okay because only the I/O tensors will
Selective SRQ & emulated SRQ will be okay because only the I/O tensors will
be left as quantized, if applicable. This is the intended behavior if user
choose to SRQ ops contain I/O tensors.
choose SRQ ops to contain I/O tensors.
Args:
tensor_params: tensor level quantization params for the tensor.
tensor_params: Tensor level quantization params for the tensor.
"""
# Change ADD_QUANTIZE to QUANTIZE_TENSOR for unique input/constant tensors.
if (
Expand Down Expand Up @@ -293,7 +294,7 @@ def _compatible_tensor_transformation_params(
if params1.consumers != params2.consumers:
return False
else:
# check all consumers within eah params are compatible.
# Check all consumers within each params are compatible.
for params1_consumer in params1.consumers:
if not _compatible_tensor_params(params1_consumer, params1.consumers[0]):
return False
Expand Down Expand Up @@ -322,7 +323,7 @@ def _compatible_tensor_params(
]
if params1.parameters != params2.parameters:
return False
# we only need to check the first transformation because transformations are
# We only need to check the first transformation because transformations are
# applied in order, and as long as the one that's immediately after the tensor
# is the same, it's compatible.
if (
Expand Down

0 comments on commit 65fd572

Please sign in to comment.