diff --git a/ai_edge_quantizer/qtyping.py b/ai_edge_quantizer/qtyping.py index 6d5c320..71d83df 100644 --- a/ai_edge_quantizer/qtyping.py +++ b/ai_edge_quantizer/qtyping.py @@ -120,6 +120,8 @@ class UniformQuantParams: zero_point: The zero point of the quantization. symmetric: Whether the quantization is symmetric (force zero_point to be 0). quantized_data: The quantized data. + block_size: The block size for blockwise quantization, block_size=0 meaning + no blockwise quantization. """ num_bits: int @@ -128,6 +130,7 @@ class UniformQuantParams: zero_point: np.ndarray symmetric: bool = True quantized_data: Optional[np.ndarray] = None + block_size: int = 0 @classmethod def from_tfl_tensor_details(cls, tensor_detail) -> 'UniformQuantParams': @@ -170,6 +173,7 @@ def __eq__(self, other): and np.array_equal(self.zero_point, other.zero_point) and self.symmetric == other.symmetric and _compare_array_or_none(self.quantized_data, other.quantized_data) + and self.block_size == other.block_size ) diff --git a/ai_edge_quantizer/transformations/quantize_tensor.py b/ai_edge_quantizer/transformations/quantize_tensor.py index 933d960..094fb3c 100644 --- a/ai_edge_quantizer/transformations/quantize_tensor.py +++ b/ai_edge_quantizer/transformations/quantize_tensor.py @@ -121,29 +121,54 @@ def quantize_tensor( ) if isinstance(transformation_input.quant_params, qtyping.UniformQuantParams): - flatbuffer_quantization = schema_py_generated.QuantizationParametersT() - flatbuffer_quantization.scale = list( - transformation_input.quant_params.scale.flatten().astype(np.float32) - ) # flatbuffer requires scale as list[float] - flatbuffer_quantization.zeroPoint = list( - transformation_input.quant_params.zero_point.flatten().astype(np.int64) - ) # flatbuffer requires zeroPoint as list[int64] - if transformation_input.quant_params.quantized_dimension is not None: - flatbuffer_quantization.quantizedDimension = ( - transformation_input.quant_params.quantized_dimension + if transformation_input.quant_params.block_size == 0: + flatbuffer_quantization = schema_py_generated.QuantizationParametersT() + flatbuffer_quantization.scale = list( + transformation_input.quant_params.scale.flatten().astype(np.float32) + ) # flatbuffer requires scale as list[float] + if transformation_input.quant_params.zero_point is not None: + flatbuffer_quantization.zeroPoint = list( + transformation_input.quant_params.zero_point.flatten().astype( + np.int64 + ) + ) # flatbuffer requires zeroPoint as list[int64] + if transformation_input.quant_params.quantized_dimension is not None: + flatbuffer_quantization.quantizedDimension = ( + transformation_input.quant_params.quantized_dimension + ) + else: + flatbuffer_quantization = schema_py_generated.QuantizationParametersT() + flatbuffer_quantization.detailsType = ( + schema_py_generated.QuantizationDetails.BlockwiseQuantization ) + blockwise_details = ( + schema_py_generated.BlockwiseQuantizationT() + ) + scale_tensor_id = transformation_utils.add_new_constant_tensor( + tensor.name + b"_scale", + transformation_input.quant_params.scale, + schema_py_generated.TensorType.FLOAT16, + transformation_input.subgraph, + transformation_input.buffers, + ) + blockwise_details.scale = scale_tensor_id + blockwise_details.blockSize = ( + transformation_input.quant_params.block_size + ) + if transformation_input.quant_params.zero_point is not None: + zero_point_tensor_id = transformation_utils.add_new_constant_tensor( + tensor.name + b"_zero_point", + transformation_input.quant_params.zero_point, + schema_py_generated.TensorType.INT32, + transformation_input.subgraph, + transformation_input.buffers, + ) + blockwise_details.zeroPoint = zero_point_tensor_id + flatbuffer_quantization.details = blockwise_details tensor.quantization = flatbuffer_quantization tensor.type = quant_params_to_tflite_type( transformation_input.quant_params.num_bits ) - - if isinstance( - transformation_input.quant_params, qtyping.NonLinearQuantParams - ): - tensor.type = nonlinear_quant_params_to_tflite_type( - transformation_input.quant_params.num_bits - ) - if isinstance( transformation_input.quant_params, qtyping.NonLinearQuantParams ): diff --git a/ai_edge_quantizer/transformations/quantize_tensor_test.py b/ai_edge_quantizer/transformations/quantize_tensor_test.py index ee048e1..3cd8c41 100644 --- a/ai_edge_quantizer/transformations/quantize_tensor_test.py +++ b/ai_edge_quantizer/transformations/quantize_tensor_test.py @@ -135,6 +135,39 @@ def test_quantize_tensor_with_nonlinear_quantization(self): subgraph.tensors[4].type, schema_py_generated.TensorType.FLOAT16 ) + def test_blockwise_quantization_with_zero_point(self): + """test quantizing an activation tensor with blockwise quantization.""" + subgraph = self._model.subgraphs[0] + model = self._model + data = np.ones([1, 112, 112, 32]).astype(np.int8) + quantize_tensor.quantize_tensor( + transformation_utils.TransformationInput( + 7, + model.operatorCodes, + model.buffers, + subgraph, + 1, + [3], + qtyping.UniformQuantParams( + 8, + None, + np.ones([1, 112, 112, 1]).astype(np.float16), + np.zeros([1, 112, 112, 1]), + symmetric=True, + quantized_data=data, + block_size=32, + ), + ) + ) + quant_param = subgraph.tensors[7].quantization + self.assertEqual( + quant_param.detailsType, + schema_py_generated.QuantizationDetails.BlockwiseQuantization, + ) + self.assertEqual(quant_param.details.blockSize, 32) + self.assertEqual(quant_param.details.zeroPoint, 10) + self.assertEqual(quant_param.details.scale, 9) + def test_int4_constant_packed_correctly(self): subgraph = self._model.subgraphs[0] model = self._model