Skip to content

Commit

Permalink
fix ut and simplify config
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <[email protected]>
  • Loading branch information
mengniwang95 committed Aug 15, 2024
1 parent 13b69e3 commit 98ac3c2
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 187 deletions.
12 changes: 4 additions & 8 deletions onnx_neural_compressor/algorithms/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,10 @@ def calculate_scale_zp(rmin, rmax, qType, sym, reduce_range=False):
dtype = _qType_to_np_type(qType)
if isinstance(rmax, np.ndarray):
if sym:
mask = abs(rmin) > abs(rmax)
scale = np.ones(rmin.shape).astype(rmin.dtype)
scale[mask] = rmin[mask]
scale[~mask] = rmax[~mask]
abs_max = round((qmax - qmin) / 2)
scale /= abs_max
else:
scale = (rmax - rmin) / (qmax - qmin)
max_range = np.maximum(abs(rmin), abs(rmax))
rmin = -max_range
rmax = max_range
scale = (rmax - rmin) / (qmax - qmin)
scale[abs(scale) < np.finfo(rmax.dtype).tiny] = 1
zero_point = (
np.multiply(np.ones(rmax.shape), np.round((qmax + qmin) / 2.0)).astype(dtype)
Expand Down
2 changes: 0 additions & 2 deletions onnx_neural_compressor/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
# constants for configs
GLOBAL = "global"
LOCAL = "local"
DEFAULT_WHITE_LIST = "*"
EMPTY_WHITE_LIST = None

# config name
BASE_CONFIG = "base_config"
Expand Down
2 changes: 1 addition & 1 deletion onnx_neural_compressor/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def smooth_quant_entry(
calibration_data_reader,
execution_provider=getattr(quant_config, "execution_provider", "CPUExecutionProvider"),
)
smoothed_model = smoother.transform(**quant_config.to_dict())
smoothed_model = smoother.transform(**quant_config.get_model_params_dict())
with tempfile.TemporaryDirectory(prefix="ort.quant.") as tmp_dir:
# ORT quant API requires str input
onnx.save_model(
Expand Down
187 changes: 34 additions & 153 deletions onnx_neural_compressor/quantization/config.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions test/quantization/post_training_quant/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def setUpClass(cls):
@classmethod
def tearDownClass(cls):
shutil.rmtree("./onnxrt_test", ignore_errors=True)
os.remove("int8.onnx")
os.remove("qdq.onnx")
os.remove("test.onnx")

def qlinear_test(self, model, q_config, quantize_params, quantizable_op_types, **kwargs):
quant = quantizer.StaticQuantizer(
Expand Down
3 changes: 3 additions & 0 deletions test/quantization/test_algorithm_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ def test_is_B_transposed(self):
beta=0.35,
)
self.assertFalse(quant_utils.is_B_transposed(node))

if __name__ == "__main__":
unittest.main()
22 changes: 0 additions & 22 deletions test/quantization/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,28 +329,6 @@ def test_static_custom_quant_config(self):
self.assertLess(idx, 2)

def test_config_white_lst(self):
global_config = config.RTNConfig(weight_bits=4)
# set operator instance
fc_out_config = config.RTNConfig(weight_dtype="fp32", white_list=["/h.4/mlp/fc_out/MatMul"])
# get model and quantize
fp32_model = self.gptj
qmodel = algos.rtn_quantize_entry(fp32_model, quant_config=global_config + fc_out_config)
self.assertIsNotNone(qmodel)
self.assertEqual(self._count_woq_matmul(qmodel), 29)
self.assertFalse(self._check_node_is_quantized(qmodel, "/h.4/mlp/fc_out/MatMul"))

def test_config_white_lst2(self):
global_config = config.RTNConfig(weight_dtype="fp32")
# set operator instance
fc_out_config = config.RTNConfig(weight_bits=4, white_list=["/h.4/mlp/fc_out/MatMul"])
# get model and quantize
fp32_model = self.gptj
qmodel = algos.rtn_quantize_entry(fp32_model, quant_config=global_config + fc_out_config)
self.assertIsNotNone(qmodel)
self.assertEqual(self._count_woq_matmul(qmodel), 1)
self.assertTrue(self._check_node_is_quantized(qmodel, "/h.4/mlp/fc_out/MatMul"))

def test_config_white_lst3(self):

global_config = config.RTNConfig(weight_bits=4)
# set operator instance
Expand Down
3 changes: 2 additions & 1 deletion test/utils/test_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def tearDownClass(self):
shutil.rmtree("./gptj", ignore_errors=True)
shutil.rmtree("./large_model", ignore_errors=True)
os.remove("matmul_add.onnx")
os.remove("model1.onnx")

def setUp(self):
# print the test name
Expand All @@ -102,7 +103,7 @@ def test_model_atrribute(self):
# model_path
self.assertEqual(model.model_path, self.matmul_add_model)
# framework
self.assertEqual(model.framework(), "onnxruntime")
self.assertEqual(model.framework, "onnxruntime")
# q_config
quant_config = config.RTNConfig()
model.q_config = quant_config
Expand Down

0 comments on commit 98ac3c2

Please sign in to comment.