Skip to content

Commit

Permalink
Enable packing only for the 4 bits case.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698964660
  • Loading branch information
marialyu authored and copybara-github committed Nov 22, 2024
1 parent 38e3dd3 commit 3c0f35b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 28 deletions.
20 changes: 10 additions & 10 deletions ai_edge_quantizer/transformations/quantize_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import


# TODO: b/335014051 - support distinguishing INT, FLOAT & UINT, BFLOAT
# TODO: b/335014051 - Support distinguishing INT, FLOAT & UINT, BFLOAT.
def quant_params_to_tflite_type(
bitwidth: int,
) -> Optional[schema_py_generated.TensorType]:
"""Given specifications from quant param return the corresponding tflite dtype.
"""Given specifications from quant param return the corresponding TFLite dtype.
Args:
bitwidth: bitwidth from UniformQuantParams
bitwidth: Bit width from UniformQuantParams.
Returns:
the corresponding tflite tensortype
The corresponding TFLite tensor type.
"""
if bitwidth <= 4:
if bitwidth == 4:
return schema_py_generated.TensorType.INT4
elif bitwidth <= 8:
return schema_py_generated.TensorType.INT8
Expand Down Expand Up @@ -68,19 +68,19 @@ def nonlinear_quant_params_to_tflite_type(


def _pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray:
"""Pack the data to the corresponding bitwidth.
"""Pack the data to the corresponding bit width.
If no packing is needed, the original data is returned. Any bitwidth equal or
less than 4 bits will be packed to 4 bits.
Currently only support 4 bits. If no packing is needed, the original data is
returned.
Args:
bitwidth: Bitwidth from NonLinearQuantParams.
bitwidth: Bit width from NonLinearQuantParams.
flattened_data: The data to be packed.
Returns:
Packed data.
"""
if bitwidth <= 4:
if bitwidth == 4:
even_data = flattened_data[::2] & 0x0F
odd_data = np.left_shift(flattened_data[1::2], 4).astype(np.uint8)
if odd_data.shape[0] == even_data.shape[0] - 1:
Expand Down
41 changes: 23 additions & 18 deletions ai_edge_quantizer/transformations/quantize_tensor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import numpy as np
from tensorflow.python.platform import googletest
from absl.testing import parameterized
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.transformations import quantize_tensor
from ai_edge_quantizer.transformations import transformation_utils
Expand All @@ -28,7 +29,7 @@
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("..")


class QuantizeTensorTest(googletest.TestCase):
class QuantizeTensorTest(parameterized.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -179,40 +180,44 @@ def test_int4_constant_packed_correctly(self):
np.testing.assert_array_equal(quant_param.zeroPoint, [1])
self.assertEqual(quant_param.quantizedDimension, 0)

def test_int5_constant_not_packed(self):
@parameterized.named_parameters(
dict(
testcase_name="int5",
num_bits=5,
),
dict(
testcase_name="int2",
num_bits=2,
),
)
def test_int_constant_not_packed(self, num_bits):
subgraph = self._model.subgraphs[0]
model = self._model
data = np.array(
[
0x0,
0x1,
0x2,
0x3,
0x4,
0x5,
0x6,
0x7,
],
dtype=np.int8,
)
tensor_id = 7
data = np.array([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7], dtype=np.int8)
expected = np.array([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7])
ret = quantize_tensor.quantize_tensor(
transformation_utils.TransformationInput(
tensor_id=7,
tensor_id=tensor_id,
op_codes=model.operatorCodes,
buffers=model.buffers,
subgraph=subgraph,
producer=-1,
consumers=[4],
quant_params=qtyping.UniformQuantParams(
5, None, np.ones(1), np.ones(1), True, data
num_bits=num_bits,
quantized_dimension=None,
scale=np.ones(1),
zero_point=np.ones(1),
symmetric=True,
quantized_data=data,
),
)
)
self.assertEqual(ret.op_id, 0)
self.assertEqual(ret.num_ops_added, 0)
np.testing.assert_array_equal(model.buffers[8].data, expected)
quant_param = subgraph.tensors[7].quantization
quant_param = subgraph.tensors[tensor_id].quantization
np.testing.assert_array_equal(quant_param.scale, [1])
np.testing.assert_array_equal(quant_param.zeroPoint, [1])
self.assertEqual(quant_param.quantizedDimension, 0)
Expand Down

0 comments on commit 3c0f35b

Please sign in to comment.