Skip to content

Commit

Permalink
Enable packing only for the 4 bits case.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698867595
  • Loading branch information
marialyu authored and copybara-github committed Nov 21, 2024
1 parent 38e3dd3 commit b579b88
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 26 deletions.
16 changes: 8 additions & 8 deletions ai_edge_quantizer/transformations/quantize_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@
from ai_edge_litert import schema_py_generated # pylint: disable=g-direct-tensorflow-import


# TODO: b/335014051 - support distinguishing INT, FLOAT & UINT, BFLOAT
# TODO: b/335014051 - Support distinguishing INT, FLOAT & UINT, BFLOAT.
def quant_params_to_tflite_type(
bitwidth: int,
) -> Optional[schema_py_generated.TensorType]:
"""Given specifications from quant param return the corresponding tflite dtype.
"""Given specifications from quant param return the corresponding TFLite dtype.
Args:
bitwidth: bitwidth from UniformQuantParams
bitwidth: Bit width from UniformQuantParams.
Returns:
the corresponding tflite tensortype
The corresponding TFLite tensor type.
"""
if bitwidth <= 4:
if bitwidth == 4:
return schema_py_generated.TensorType.INT4
elif bitwidth <= 8:
return schema_py_generated.TensorType.INT8
Expand Down Expand Up @@ -70,8 +70,8 @@ def nonlinear_quant_params_to_tflite_type(
def _pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray:
"""Pack the data to the corresponding bitwidth.
If no packing is needed, the original data is returned. Any bitwidth equal or
less than 4 bits will be packed to 4 bits.
If no packing is needed, the original data is returned. Any bitwidth equal to
4 bits will be packed to 4 bits.
Args:
bitwidth: Bitwidth from NonLinearQuantParams.
Expand All @@ -80,7 +80,7 @@ def _pack_data(bitwidth: int, flattened_data: np.ndarray) -> np.ndarray:
Returns:
Packed data.
"""
if bitwidth <= 4:
if bitwidth == 4:
even_data = flattened_data[::2] & 0x0F
odd_data = np.left_shift(flattened_data[1::2], 4).astype(np.uint8)
if odd_data.shape[0] == even_data.shape[0] - 1:
Expand Down
41 changes: 23 additions & 18 deletions ai_edge_quantizer/transformations/quantize_tensor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import numpy as np
from tensorflow.python.platform import googletest
from absl.testing import parameterized
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.transformations import quantize_tensor
from ai_edge_quantizer.transformations import transformation_utils
Expand All @@ -28,7 +29,7 @@
TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile("..")


class QuantizeTensorTest(googletest.TestCase):
class QuantizeTensorTest(parameterized.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -179,40 +180,44 @@ def test_int4_constant_packed_correctly(self):
np.testing.assert_array_equal(quant_param.zeroPoint, [1])
self.assertEqual(quant_param.quantizedDimension, 0)

def test_int5_constant_not_packed(self):
@parameterized.named_parameters(
dict(
testcase_name="int5",
num_bits=5,
),
dict(
testcase_name="int2",
num_bits=2,
),
)
def test_int_constant_not_packed(self, num_bits):
subgraph = self._model.subgraphs[0]
model = self._model
data = np.array(
[
0x0,
0x1,
0x2,
0x3,
0x4,
0x5,
0x6,
0x7,
],
dtype=np.int8,
)
tensor_id = 7
data = np.array([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7], dtype=np.int8)
expected = np.array([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7])
ret = quantize_tensor.quantize_tensor(
transformation_utils.TransformationInput(
tensor_id=7,
tensor_id=tensor_id,
op_codes=model.operatorCodes,
buffers=model.buffers,
subgraph=subgraph,
producer=-1,
consumers=[4],
quant_params=qtyping.UniformQuantParams(
5, None, np.ones(1), np.ones(1), True, data
num_bits=num_bits,
quantized_dimension=None,
scale=np.ones(1),
zero_point=np.ones(1),
symmetric=True,
quantized_data=data,
),
)
)
self.assertEqual(ret.op_id, 0)
self.assertEqual(ret.num_ops_added, 0)
np.testing.assert_array_equal(model.buffers[8].data, expected)
quant_param = subgraph.tensors[7].quantization
quant_param = subgraph.tensors[tensor_id].quantization
np.testing.assert_array_equal(quant_param.scale, [1])
np.testing.assert_array_equal(quant_param.zeroPoint, [1])
self.assertEqual(quant_param.quantizedDimension, 0)
Expand Down

0 comments on commit b579b88

Please sign in to comment.