Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636683030
  • Loading branch information
zichuan-wei authored and copybara-github committed May 23, 2024
1 parent 9f7767c commit bdcfbd3
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 0 deletions.
41 changes: 41 additions & 0 deletions quantization_toolkit/transformation_instruction_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,50 @@ def _quant_params_to_transformation_insts(
# Adding other consumers rules
transformations += other_consumer_transformations
tensor_trans_insts.instructions = transformations
# Check the generated transformation instructions are valid, the function
# will raise an error if the instructions are not valid
self._check_tensor_transformation_instructions_valid(tensor_trans_insts)

return tensor_trans_insts

def _check_tensor_transformation_instructions_valid(
self, instructions: qtyping.TensorTransformationInsts
):
"""Check if the tensor transformation instructions are valid.
Args:
instructions: Transformation instructions for a tensor.
Raises:
ValueError: If the instructions are not valid.
"""
is_tensor_unquantized = False
is_tensor_quantized = False
is_operator_emulated = False
for instruction in instructions.instructions:
transform_type = instruction.transformation
if transform_type == qtyping.QuantTransformation.NO_QUANTIZE:
is_tensor_unquantized = True
elif (
transform_type == qtyping.QuantTransformation.ADD_QUANTIZE
or transform_type == qtyping.QuantTransformation.QUANTIZE_TENSOR
or transform_type == qtyping.QuantTransformation.ADD_DEQUANTIZE
):
is_tensor_quantized = True
elif transform_type == qtyping.QuantTransformation.EMULATED_SUBCHANNEL:
is_operator_emulated = True
if is_tensor_unquantized and is_tensor_quantized:
raise ValueError(
"Tensor %s can not be both quantized and unquantized"
% instructions.tensor_name
)
if is_operator_emulated and len(instructions.instructions) > 1:
raise ValueError(
"Tensor %s : op replacement transformation can not be combined with"
" other transformations."
% instructions.tensor_name
)

def quant_params_to_transformation_insts(
self,
params: dict[str, qtyping.TensorTransformationParams],
Expand Down
64 changes: 64 additions & 0 deletions quantization_toolkit/transformation_instruction_generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,70 @@ def test_generate_instruction_for_single_fc_bias(self):
instructions["StatefulPartitionedCall:0"], output_transformation
)

def test_raise_error_on_op_replacement_transformation_is_not_unique(self):
test_model_path = os.path.join(
TEST_DATA_PREFIX_PATH, "test_models/insert_dequant_test.tflite"
)
quant_parameters = {}
quant_parameters["tfl.quantize"] = qtyping.TensorTransformationParams(
"tfl.quantize",
qtyping.OpToTensorParams(
subgraph_op_id=0,
transformations=[
qtyping.QuantTransformation.ADD_DEQUANTIZE,
qtyping.QuantTransformation.EMULATED_SUBCHANNEL,
],
parameters=qtyping.UniformQuantParams(
8, None, np.array([1]), np.array([0])
),
),
[],
)
ins_gen = instruction_generator.TransformationInstructionsGenerator(
test_model_path
)
with self.assertRaisesRegex(
ValueError, "op replacement transformation can not be combined"
):
ins_gen.quant_params_to_transformation_insts(quant_parameters)

def test_raise_error_on_no_quant_conflict(self):
test_model_path = os.path.join(
TEST_DATA_PREFIX_PATH, "test_models/insert_dequant_test.tflite"
)
quant_parameters = {}
quant_parameters["tfl.quantize"] = qtyping.TensorTransformationParams(
"tfl.quantize",
qtyping.OpToTensorParams(
subgraph_op_id=0,
transformations=[qtyping.QuantTransformation.ADD_DEQUANTIZE],
parameters=qtyping.UniformQuantParams(
8, None, np.array([1]), np.array([0])
),
),
[
qtyping.OpToTensorParams(
subgraph_op_id=1,
transformations=[qtyping.QuantTransformation.ADD_QUANTIZE],
parameters=qtyping.UniformQuantParams(
8, None, np.array([1]), np.array([0])
),
),
qtyping.OpToTensorParams(
subgraph_op_id=2,
transformations=[qtyping.QuantTransformation.NO_QUANTIZE],
parameters=None,
),
],
)
ins_gen = instruction_generator.TransformationInstructionsGenerator(
test_model_path
)
with self.assertRaisesRegex(
ValueError, "can not be both quantized and unquantized"
):
ins_gen.quant_params_to_transformation_insts(quant_parameters)

def test_generate_instruction_for_branching(self):
"""test horizontal and vertial optimization on a graph with multi branch."""
test_model_path = os.path.join(
Expand Down

0 comments on commit bdcfbd3

Please sign in to comment.