diff --git a/ai_edge_quantizer/qtyping.py b/ai_edge_quantizer/qtyping.py index 6d5c320..71d83df 100644 --- a/ai_edge_quantizer/qtyping.py +++ b/ai_edge_quantizer/qtyping.py @@ -120,6 +120,8 @@ class UniformQuantParams: zero_point: The zero point of the quantization. symmetric: Whether the quantization is symmetric (force zero_point to be 0). quantized_data: The quantized data. + block_size: The block size for blockwise quantization, block_size=0 meaning + no blockwise quantization. """ num_bits: int @@ -128,6 +130,7 @@ class UniformQuantParams: zero_point: np.ndarray symmetric: bool = True quantized_data: Optional[np.ndarray] = None + block_size: int = 0 @classmethod def from_tfl_tensor_details(cls, tensor_detail) -> 'UniformQuantParams': @@ -170,6 +173,7 @@ def __eq__(self, other): and np.array_equal(self.zero_point, other.zero_point) and self.symmetric == other.symmetric and _compare_array_or_none(self.quantized_data, other.quantized_data) + and self.block_size == other.block_size ) diff --git a/ai_edge_quantizer/transformations/quantize_tensor.py b/ai_edge_quantizer/transformations/quantize_tensor.py index 933d960..6176075 100644 --- a/ai_edge_quantizer/transformations/quantize_tensor.py +++ b/ai_edge_quantizer/transformations/quantize_tensor.py @@ -90,24 +90,99 @@ def _pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray: return flattened_data +def _perform_channelwise_quantization( + transformation_input: transformation_utils.TransformationInput, +) -> schema_py_generated.QuantizationParametersT(): + """Perform channelwise quantization and fill the quantization parameters. + + Args: + transformation_input: Input structure that contains all information needed + for the transformation. + + Returns: + The quantization parameters. + """ + assert isinstance( + transformation_input.quant_params, qtyping.UniformQuantParams + ) + flatbuffer_quantization = schema_py_generated.QuantizationParametersT() + flatbuffer_quantization.scale = list( + transformation_input.quant_params.scale.flatten().astype(np.float32) + ) # Flatbuffer requires scale as list[float]. + if transformation_input.quant_params.zero_point is not None: + flatbuffer_quantization.zeroPoint = list( + transformation_input.quant_params.zero_point.flatten().astype(np.int64) + ) # Flatbuffer requires zeroPoint as list[int64] + if transformation_input.quant_params.quantized_dimension is not None: + flatbuffer_quantization.quantizedDimension = ( + transformation_input.quant_params.quantized_dimension + ) + + return flatbuffer_quantization + + +def _perform_blockwise_quantization( + transformation_input: transformation_utils.TransformationInput, +) -> schema_py_generated.QuantizationParametersT(): + """Perform blockwise quantization and fill the quantization parameters. + + Args: + transformation_input: Input structure that contains all information needed + for the transformation. + + Returns: + The quantization parameters. + """ + assert isinstance( + transformation_input.quant_params, qtyping.UniformQuantParams + ) + flatbuffer_quantization = schema_py_generated.QuantizationParametersT() + flatbuffer_quantization.detailsType = ( + schema_py_generated.QuantizationDetails.BlockwiseQuantization + ) + tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id] + blockwise_details = schema_py_generated.BlockwiseQuantizationT() + scale_tensor_id = transformation_utils.add_new_constant_tensor( + tensor.name + b"_scale", + transformation_input.quant_params.scale, + schema_py_generated.TensorType.FLOAT16, + transformation_input.subgraph, + transformation_input.buffers, + ) + blockwise_details.scale = scale_tensor_id + blockwise_details.blockSize = transformation_input.quant_params.block_size + # blockwise quantization allows optional zero point. + if transformation_input.quant_params.zero_point is not None: + zero_point_tensor_id = transformation_utils.add_new_constant_tensor( + tensor.name + b"_zero_point", + transformation_input.quant_params.zero_point, + schema_py_generated.TensorType.INT32, + transformation_input.subgraph, + transformation_input.buffers, + ) + blockwise_details.zeroPoint = zero_point_tensor_id + flatbuffer_quantization.details = blockwise_details + return flatbuffer_quantization + + def quantize_tensor( transformation_input: transformation_utils.TransformationInput, ) -> qtyping.TransformationInfo: """Quantize the tensor at the tensor_id in the given subgraph. Args: - transformation_input: input structure that contains all information needed + transformation_input: Input structure that contains all information needed for the transformation. Returns: TransformationInfo: - op_id: the producer index for tensor - num_ops_added: the total number of ops inserted by this operation, which - is 0 + op_id: The producer index for tensor. + num_ops_added: The total number of ops inserted by this operation, which + is 0. """ tensor = transformation_input.subgraph.tensors[transformation_input.tensor_id] - # TODO: b/336385820 - suppport quantize buffer directly when quantized_data - # is not provided + # TODO: b/336385820 - Suppport quantize buffer directly when quantized_data + # is not provided. if tensor.buffer: if transformation_input.quant_params.quantized_data is not None: transformation_input.buffers[tensor.buffer].data = _pack_data( @@ -121,29 +196,18 @@ def quantize_tensor( ) if isinstance(transformation_input.quant_params, qtyping.UniformQuantParams): - flatbuffer_quantization = schema_py_generated.QuantizationParametersT() - flatbuffer_quantization.scale = list( - transformation_input.quant_params.scale.flatten().astype(np.float32) - ) # flatbuffer requires scale as list[float] - flatbuffer_quantization.zeroPoint = list( - transformation_input.quant_params.zero_point.flatten().astype(np.int64) - ) # flatbuffer requires zeroPoint as list[int64] - if transformation_input.quant_params.quantized_dimension is not None: - flatbuffer_quantization.quantizedDimension = ( - transformation_input.quant_params.quantized_dimension + if transformation_input.quant_params.block_size == 0: + flatbuffer_quantization = _perform_channelwise_quantization( + transformation_input + ) + else: + flatbuffer_quantization = _perform_blockwise_quantization( + transformation_input ) tensor.quantization = flatbuffer_quantization tensor.type = quant_params_to_tflite_type( transformation_input.quant_params.num_bits ) - - if isinstance( - transformation_input.quant_params, qtyping.NonLinearQuantParams - ): - tensor.type = nonlinear_quant_params_to_tflite_type( - transformation_input.quant_params.num_bits - ) - if isinstance( transformation_input.quant_params, qtyping.NonLinearQuantParams ): diff --git a/ai_edge_quantizer/transformations/quantize_tensor_test.py b/ai_edge_quantizer/transformations/quantize_tensor_test.py index ee048e1..2fe266f 100644 --- a/ai_edge_quantizer/transformations/quantize_tensor_test.py +++ b/ai_edge_quantizer/transformations/quantize_tensor_test.py @@ -42,7 +42,7 @@ def test_quantize_constant_tensor(self): """test quantizing a constant tensor.""" subgraph = self._model.subgraphs[0] model = self._model - data = np.ones([1, 112, 112, 3], dtype=np.int8) + data = np.ones([1, 112, 112, 32], dtype=np.int8) ret = quantize_tensor.quantize_tensor( transformation_utils.TransformationInput( 7, @@ -135,6 +135,42 @@ def test_quantize_tensor_with_nonlinear_quantization(self): subgraph.tensors[4].type, schema_py_generated.TensorType.FLOAT16 ) + def test_blockwise_quantization_with_zero_point(self): + """Test blockwise quantization with explicit zero point.""" + subgraph = self._model.subgraphs[0] + model = self._model + tensor_wh = 112 + test_tensor_id = 7 + data = np.ones([1, tensor_wh, tensor_wh, 32]).astype(np.int8) + quantize_tensor.quantize_tensor( + transformation_utils.TransformationInput( + tensor_id=test_tensor_id, + op_codes=model.operatorCodes, + buffers=model.buffers, + subgraph=subgraph, + producer=1, + consumers=[3], + quant_params=qtyping.UniformQuantParams( + num_bits=8, + quantized_dimension=None, + scale=np.ones([1, tensor_wh, tensor_wh, 1]).astype(np.float16), + zero_point=np.zeros([1, tensor_wh, tensor_wh, 1]), + symmetric=True, + quantized_data=data, + block_size=32, + ), + ) + ) + quant_param = subgraph.tensors[test_tensor_id].quantization + self.assertEqual( + quant_param.detailsType, + schema_py_generated.QuantizationDetails.BlockwiseQuantization, + ) + self.assertEqual(quant_param.details.blockSize, 32) + # Check if the scale and zero point tensors are inserted correctly. + self.assertEqual(quant_param.details.scale, 9) + self.assertEqual(quant_param.details.zeroPoint, 10) + def test_int4_constant_packed_correctly(self): subgraph = self._model.subgraphs[0] model = self._model