From 98ac3c2e1b9ca09439ceaf1d04ef192fc9dbba3f Mon Sep 17 00:00:00 2001 From: Mengni Wang Date: Thu, 15 Aug 2024 02:16:08 -0700 Subject: [PATCH] fix ut and simplify config Signed-off-by: Mengni Wang --- onnx_neural_compressor/algorithms/utility.py | 12 +- onnx_neural_compressor/constants.py | 2 - .../quantization/algorithm_entry.py | 2 +- onnx_neural_compressor/quantization/config.py | 187 ++++-------------- .../post_training_quant/test_operators.py | 3 + test/quantization/test_algorithm_utility.py | 3 + test/quantization/test_config.py | 22 --- test/utils/test_onnx_model.py | 3 +- 8 files changed, 47 insertions(+), 187 deletions(-) diff --git a/onnx_neural_compressor/algorithms/utility.py b/onnx_neural_compressor/algorithms/utility.py index b38c8e21e..fb8fe8f38 100644 --- a/onnx_neural_compressor/algorithms/utility.py +++ b/onnx_neural_compressor/algorithms/utility.py @@ -222,14 +222,10 @@ def calculate_scale_zp(rmin, rmax, qType, sym, reduce_range=False): dtype = _qType_to_np_type(qType) if isinstance(rmax, np.ndarray): if sym: - mask = abs(rmin) > abs(rmax) - scale = np.ones(rmin.shape).astype(rmin.dtype) - scale[mask] = rmin[mask] - scale[~mask] = rmax[~mask] - abs_max = round((qmax - qmin) / 2) - scale /= abs_max - else: - scale = (rmax - rmin) / (qmax - qmin) + max_range = np.maximum(abs(rmin), abs(rmax)) + rmin = -max_range + rmax = max_range + scale = (rmax - rmin) / (qmax - qmin) scale[abs(scale) < np.finfo(rmax.dtype).tiny] = 1 zero_point = ( np.multiply(np.ones(rmax.shape), np.round((qmax + qmin) / 2.0)).astype(dtype) diff --git a/onnx_neural_compressor/constants.py b/onnx_neural_compressor/constants.py index 54889bda0..ca2b3e594 100644 --- a/onnx_neural_compressor/constants.py +++ b/onnx_neural_compressor/constants.py @@ -20,8 +20,6 @@ # constants for configs GLOBAL = "global" LOCAL = "local" -DEFAULT_WHITE_LIST = "*" -EMPTY_WHITE_LIST = None # config name BASE_CONFIG = "base_config" diff --git a/onnx_neural_compressor/quantization/algorithm_entry.py b/onnx_neural_compressor/quantization/algorithm_entry.py index 560b14292..e58acf2f0 100644 --- a/onnx_neural_compressor/quantization/algorithm_entry.py +++ b/onnx_neural_compressor/quantization/algorithm_entry.py @@ -192,7 +192,7 @@ def smooth_quant_entry( calibration_data_reader, execution_provider=getattr(quant_config, "execution_provider", "CPUExecutionProvider"), ) - smoothed_model = smoother.transform(**quant_config.to_dict()) + smoothed_model = smoother.transform(**quant_config.get_model_params_dict()) with tempfile.TemporaryDirectory(prefix="ort.quant.") as tmp_dir: # ORT quant API requires str input onnx.save_model( diff --git a/onnx_neural_compressor/quantization/config.py b/onnx_neural_compressor/quantization/config.py index 6522d0522..01acd7eaf 100644 --- a/onnx_neural_compressor/quantization/config.py +++ b/onnx_neural_compressor/quantization/config.py @@ -200,7 +200,16 @@ class ExampleAlgorithmConfig: return config_registry.register_config_impl(algo_name=algo_name, priority=priority) - +class Encoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, quantization.QuantType): + return getattr(o, "tensor_type") + if isinstance(o, quantization.QuantFormat): + return getattr(o, "value") + if isinstance(o, quantization.CalibrationMethod): + return getattr(o, "name") + return super().default(o) + class BaseConfig(ABC): """The base config for all algorithm configs.""" @@ -210,43 +219,16 @@ class BaseConfig(ABC): def __init__( self, - white_list: Optional[Union[Union[str, Callable], List[Union[str, Callable]]]] = constants.DEFAULT_WHITE_LIST, ) -> None: self._global_config: Optional[BaseConfig] = None # local config is the collections of operator_type configs and operator configs self._local_config: Dict[str, Optional[BaseConfig]] = {} - self._white_list = white_list self._config_mapping = OrderedDict() - def _post_init(self): - if self.white_list == constants.DEFAULT_WHITE_LIST: - global_config = self.get_init_args() - self._global_config = self.__class__(**global_config, white_list=None) - elif isinstance(self.white_list, list) and len(self.white_list) > 0: - for op_name_or_type in self.white_list: - global_config = self.get_init_args() - tmp_config = self.__class__(**global_config, white_list=None) - self.set_local(op_name_or_type, tmp_config) - elif self.white_list == constants.EMPTY_WHITE_LIST: - return - else: - raise NotImplementedError( - f"The white list should be one of {constants.DEFAULT_WHITE_LIST}, {constants.EMPTY_WHITE_LIST}," - " a not empty list, but got {self.white_list}" - ) - @property def config_mapping(self): return self._config_mapping - @property - def white_list(self): - return self._white_list - - @white_list.setter - def white_list(self, op_name_or_type_list: Optional[List[Union[str, Callable]]]): - self._white_list = op_name_or_type_list - @property def global_config(self): return self._global_config @@ -292,7 +274,7 @@ def get_params_dict(self): def get_init_args(self): result = dict() for param, value in self.__dict__.items(): - if param not in ["_global_config", "_local_config", "_white_list", "_config_mapping"]: + if param not in ["_global_config", "_local_config", "_config_mapping"]: result[param] = value return result @@ -323,7 +305,7 @@ def from_dict(cls, config_dict): operator_config = config_dict.get(constants.LOCAL, {}) if operator_config: for op_name, op_config in operator_config.items(): - config.set_local(op_name, cls(**op_config, white_list=None)) + config.set_local(op_name, cls(**op_config)) return config def get_diff_dict(self, config) -> Dict[str, Any]: @@ -348,7 +330,7 @@ def from_json_file(cls, filename): def to_json_file(self, filename): config_dict = self.to_dict() with open(filename, "w", encoding="utf-8") as file: - json.dump(config_dict, file, indent=4) + json.dump(config_dict, file, cls=Encoder, indent=4) logger.info("Dump the config into %s.", filename) def to_json_string(self, use_diff: bool = False) -> Union[str, Dict]: @@ -367,7 +349,7 @@ def to_json_string(self, use_diff: bool = False) -> Union[str, Dict]: else: config_dict = self.to_dict() try: - return json.dumps(config_dict, indent=2) + "\n" + return json.dumps(config_dict, cls=Encoder, indent=2) + "\n" except Exception as e: logger.error("Failed to serialize the config to JSON string: %s", e) return config_dict @@ -597,7 +579,7 @@ def from_dict(cls, config_dict: OrderedDict[str, Dict], config_registry: Dict[st return config def to_json_string(self, use_diff: bool = False) -> str: - return json.dumps(self.to_dict(), indent=2) + "\n" + return json.dumps(self.to_dict(), cls=Encoder, indent=2) + "\n" def __repr__(self) -> str: return f"{self.__class__.__name__} {self.to_json_string()}" @@ -726,7 +708,6 @@ def __init__( quant_last_matmul: bool = True, quant_format: quantization.QuantFormat = quantization.QuantFormat.QOperator, nodes_to_exclude: list = [], - white_list: List[Union[str, Callable]] = constants.EMPTY_WHITE_LIST, ): """Initialize weight-only quantization config. @@ -744,10 +725,8 @@ def __init__( quant_last_matmul (bool, optional): whether to quantize the last matmul of the model, default is True. quant_format (QuantFormat, optional): use QOperator or QDQ format, default is QOperator. nodes_to_exclude (list, optional): nodes in nodes_to_exclude list will be skipped during quantization. - white_list (list, optional): op in white_list will be applied current config. - Defaults to constants.DEFAULT_WHITE_LIST. """ - super().__init__(white_list=white_list) + super().__init__() self.weight_bits = weight_bits self.weight_dtype = weight_dtype self.weight_group_size = weight_group_size @@ -833,7 +812,7 @@ def __init__( quant_last_matmul: bool = True, quant_format: quantization.QuantFormat = quantization.QuantFormat.QOperator, nodes_to_exclude: list = [], - white_list: List[Union[str, Callable]] = constants.RTN_OP_LIST, + white_list: list = None, ): """Init RTN weight-only quantization config. @@ -855,8 +834,7 @@ def __init__( quant_last_matmul (bool, optional): whether to quantize the last matmul of the model, default is True. quant_format (QuantFormat, optional): use QOperator or QDQ format, default is QOperator. nodes_to_exclude (list, optional): nodes in nodes_to_exclude list will be skipped during quantization. - white_list (list, optional): op in white_list will be applied current config. - Defaults to constants.DEFAULT_WHITE_LIST. + white_list (list, optional): op_name or op_type in white_list will be applied current config. None means all ops. """ super().__init__( weight_bits=weight_bits, @@ -869,23 +847,15 @@ def __init__( quant_last_matmul=quant_last_matmul, quant_format=quant_format, nodes_to_exclude=nodes_to_exclude, - white_list=white_list, ) self.layer_wise_quant = layer_wise_quant self.ratios = ratios - self._post_init() - def _post_init(self): - if self.white_list == constants.RTN_OP_LIST: - global_config = self.get_init_args() - self._global_config = self.__class__(**global_config, white_list=None) - elif isinstance(self.white_list, list) and len(self.white_list) > 0: - for op_name_or_type in self.white_list: + if isinstance(white_list, list) and len(white_list) > 0: + for op_name_or_type in white_list: global_config = self.get_init_args() - tmp_config = self.__class__(**global_config, white_list=None) + tmp_config = self.__class__(**global_config) self.set_local(op_name_or_type, tmp_config) - elif self.white_list == constants.EMPTY_WHITE_LIST: - return @classmethod def register_supported_configs(cls) -> None: @@ -974,7 +944,7 @@ def __init__( quant_last_matmul: bool = True, quant_format: quantization.QuantFormat = quantization.QuantFormat.QOperator, nodes_to_exclude: list = [], - white_list: List[Union[str, Callable]] = constants.GPTQ_OP_LIST, + white_list: list = None, ): """Init GPTQ weight-only quantization config. @@ -1002,8 +972,7 @@ def __init__( quant_last_matmul (bool, optional): whether to quantize the last matmul of the model, default is True. quant_format (QuantFormat, optional): use QOperator or QDQ format, default is QOperator. nodes_to_exclude (list, optional): nodes in nodes_to_exclude list will be skipped during quantization. - white_list (list, optional): op in white_list will be applied current config. - Defaults to constants.DEFAULT_WHITE_LIST. + white_list (list, optional): op_name or op_type in white_list will be applied current config. None means all ops. """ super().__init__( weight_bits=weight_bits, @@ -1016,7 +985,6 @@ def __init__( quant_last_matmul=quant_last_matmul, quant_format=quant_format, nodes_to_exclude=nodes_to_exclude, - white_list=white_list, ) self.percdamp = percdamp self.block_size = block_size @@ -1024,19 +992,12 @@ def __init__( self.mse = mse self.perchannel = perchannel self.layer_wise_quant = layer_wise_quant - self._post_init() - def _post_init(self): - if self.white_list == constants.GPTQ_OP_LIST: - global_config = self.get_init_args() - self._global_config = self.__class__(**global_config, white_list=None) - elif isinstance(self.white_list, list) and len(self.white_list) > 0: - for op_name_or_type in self.white_list: + if isinstance(white_list, list) and len(white_list) > 0: + for op_name_or_type in white_list: global_config = self.get_init_args() - tmp_config = self.__class__(**global_config, white_list=None) + tmp_config = self.__class__(**global_config) self.set_local(op_name_or_type, tmp_config) - elif self.white_list == constants.EMPTY_WHITE_LIST: - return @classmethod def register_supported_configs(cls) -> None: @@ -1126,7 +1087,7 @@ def __init__( quant_last_matmul: bool = True, quant_format: quantization.QuantFormat = quantization.QuantFormat.QOperator, nodes_to_exclude: list = [], - white_list: List[Union[str, Callable]] = constants.AWQ_OP_LIST, + white_list: list = None, ): """Init AWQ weight-only quantization config. @@ -1147,8 +1108,7 @@ def __init__( quant_last_matmul (bool, optional): whether to quantize the last matmul of the model, default is True. quant_format (QuantFormat, optional): use QOperator or QDQ format, default is QOperator. nodes_to_exclude (list, optional): nodes in nodes_to_exclude list will be skipped during quantization. - white_list (list, optional): op in white_list will be applied current config. - Defaults to constants.DEFAULT_WHITE_LIST. + white_list (list, optional): op_name or op_type in white_list will be applied current config. None means all ops. """ super().__init__( weight_bits=weight_bits, @@ -1161,23 +1121,15 @@ def __init__( quant_last_matmul=quant_last_matmul, quant_format=quant_format, nodes_to_exclude=nodes_to_exclude, - white_list=white_list, ) self.enable_auto_scale = enable_auto_scale self.enable_mse_search = enable_mse_search - self._post_init() - def _post_init(self): - if self.white_list == constants.AWQ_OP_LIST: - global_config = self.get_init_args() - self._global_config = self.__class__(**global_config, white_list=None) - elif isinstance(self.white_list, list) and len(self.white_list) > 0: - for op_name_or_type in self.white_list: + if isinstance(white_list, list) and len(white_list) > 0: + for op_name_or_type in white_list: global_config = self.get_init_args() - tmp_config = self.__class__(**global_config, white_list=None) + tmp_config = self.__class__(**global_config) self.set_local(op_name_or_type, tmp_config) - elif self.white_list == constants.EMPTY_WHITE_LIST: - return @classmethod def register_supported_configs(cls) -> List[_OperatorConfig]: @@ -1549,7 +1501,6 @@ def __init__( calibration_sampling_size=100, quant_last_matmul=True, execution_provider=None, - white_list: list = constants.DEFAULT_WHITE_LIST, **kwargs, ): """This is a class for static Quant Configuration. @@ -1619,7 +1570,7 @@ def __init__( else: os.environ["ORT_TENSORRT_UNAVAILABLE"] = "1" - BaseConfig.__init__(self, white_list=self.op_types_to_quantize) + BaseConfig.__init__(self) self.execution_provider = execution_provider self.quant_last_matmul = quant_last_matmul self.calibration_sampling_size = calibration_sampling_size @@ -1629,7 +1580,6 @@ def __init__( self.optypes_to_exclude_output_quant = _extra_options.OpTypesToExcludeOutputQuantization self.dedicated_qdq_pair = _extra_options.DedicatedQDQPair self.add_qdq_pair_to_weight = _extra_options.AddQDQPairToWeight - self.white_list = white_list self._post_init() @staticmethod @@ -1659,11 +1609,6 @@ def _post_init(self): for valid_func in STATIC_CHECK_FUNC_LIST: op_config = valid_func(op_config, op_name_or_type, self.execution_provider, self.quant_format) self.set_local(op_name_or_type, op_config) - if isinstance(self.white_list, list) and len(self.white_list) > 0: - for op_name_or_type in self.white_list: - global_config = self.get_init_args() - tmp_config = self.__class__(**global_config, white_list=None) - self.set_local(op_name_or_type, tmp_config) def to_config_mapping(self, config_list: list = None, model_info: list = None) -> OrderedDict: if config_list is None: @@ -1871,34 +1816,6 @@ def register_supported_configs(cls) -> None: ) cls.supported_configs = supported_configs - def to_dict(self): - result = {} - for key, val in self.__dict__.items(): - if key in ["_global_config", "_config_mapping"]: - continue - if key == "_local_config": - local_result = {} - for name, cfg in val.items(): - local_result[name] = cfg.to_dict() - result[key] = local_result - continue - if not isinstance(val, list): - result[key] = ( - getattr(val, "tensor_type", val) - if isinstance(val, quantization.QuantType) - else getattr(val, "value", val) - ) - else: - result[key] = [ - ( - getattr(item, "tensor_type", item) - if isinstance(item, quantization.QuantType) - else getattr(item, "value", item) - ) - for item in val - ] - return result - ######################## SmoohQuant Config ############################### @@ -1934,7 +1851,6 @@ def __init__( calib_iter: int = 100, scales_per_op: bool = True, auto_alpha_args: dict = {"alpha_min": 0.3, "alpha_max": 0.7, "alpha_step": 0.05, "attn_method": "min"}, - white_list: list = None, **kwargs, ): """Init smooth quant config. @@ -1954,7 +1870,7 @@ def __init__( kwargs (dict): kwargs in below link are supported except calibration_data_reader: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/quantize.py#L78 """ - super().__init__(white_list=white_list, **kwargs) + super().__init__(**kwargs) self.alpha = alpha self.folding = folding self.op_types = op_types @@ -2034,7 +1950,6 @@ def __init__( extra_options: dict = None, quant_last_matmul: bool = True, execution_provider: str = None, - white_list: list = constants.DEFAULT_WHITE_LIST, **kwargs, ): if execution_provider is None: @@ -2056,14 +1971,13 @@ def __init__( use_external_data_format=use_external_data_format, extra_options=extra_options, ) - BaseConfig.__init__(self, white_list=op_types_to_quantize) + BaseConfig.__init__(self) self.execution_provider = execution_provider self.quant_last_matmul = quant_last_matmul self.activation_type = quantization.QuantType.QUInt8 _extra_options = ExtraOptions(**self.extra_options) self.weight_sym = _extra_options.WeightSymmetric self.activation_sym = _extra_options.ActivationSymmetric - self.white_list = white_list self._post_init() @staticmethod @@ -2092,11 +2006,6 @@ def _post_init(self): for valid_func in DYNAMIC_CHECK_FUNC_LIST: op_config = valid_func(op_config, op_name_or_type, self.execution_provider) self.set_local(op_name_or_type, op_config) - if isinstance(self.white_list, list) and len(self.white_list) > 0: - for op_name_or_type in self.white_list: - global_config = self.get_init_args() - tmp_config = self.__class__(**global_config, white_list=None) - self.set_local(op_name_or_type, tmp_config) def to_config_mapping(self, config_list: list = None, model_info: list = None) -> OrderedDict: if config_list is None: @@ -2233,34 +2142,6 @@ def register_supported_configs(cls) -> None: ) cls.supported_configs = supported_configs - def to_dict(self): - result = {} - for key, val in self.__dict__.items(): - if key in ["_global_config", "_config_mapping"]: - continue - if key == "_local_config": - local_result = {} - for name, cfg in val.items(): - local_result[name] = cfg.to_dict() - result[key] = local_result - continue - if not isinstance(val, list): - result[key] = ( - getattr(val, "tensor_type", val) - if isinstance(val, quantization.QuantType) - else getattr(val, "value", val) - ) - else: - result[key] = [ - ( - getattr(item, "tensor_type", item) - if isinstance(item, quantization.QuantType) - else getattr(item, "value", item) - ) - for item in val - ] - return result - ##################### NC Algo Configs End ################################### diff --git a/test/quantization/post_training_quant/test_operators.py b/test/quantization/post_training_quant/test_operators.py index 45c189328..c06759f3c 100644 --- a/test/quantization/post_training_quant/test_operators.py +++ b/test/quantization/post_training_quant/test_operators.py @@ -78,6 +78,9 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): shutil.rmtree("./onnxrt_test", ignore_errors=True) + os.remove("int8.onnx") + os.remove("qdq.onnx") + os.remove("test.onnx") def qlinear_test(self, model, q_config, quantize_params, quantizable_op_types, **kwargs): quant = quantizer.StaticQuantizer( diff --git a/test/quantization/test_algorithm_utility.py b/test/quantization/test_algorithm_utility.py index 4301545c7..28525eeca 100644 --- a/test/quantization/test_algorithm_utility.py +++ b/test/quantization/test_algorithm_utility.py @@ -40,3 +40,6 @@ def test_is_B_transposed(self): beta=0.35, ) self.assertFalse(quant_utils.is_B_transposed(node)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/quantization/test_config.py b/test/quantization/test_config.py index 81ccd245d..7a4d3ba7d 100644 --- a/test/quantization/test_config.py +++ b/test/quantization/test_config.py @@ -329,28 +329,6 @@ def test_static_custom_quant_config(self): self.assertLess(idx, 2) def test_config_white_lst(self): - global_config = config.RTNConfig(weight_bits=4) - # set operator instance - fc_out_config = config.RTNConfig(weight_dtype="fp32", white_list=["/h.4/mlp/fc_out/MatMul"]) - # get model and quantize - fp32_model = self.gptj - qmodel = algos.rtn_quantize_entry(fp32_model, quant_config=global_config + fc_out_config) - self.assertIsNotNone(qmodel) - self.assertEqual(self._count_woq_matmul(qmodel), 29) - self.assertFalse(self._check_node_is_quantized(qmodel, "/h.4/mlp/fc_out/MatMul")) - - def test_config_white_lst2(self): - global_config = config.RTNConfig(weight_dtype="fp32") - # set operator instance - fc_out_config = config.RTNConfig(weight_bits=4, white_list=["/h.4/mlp/fc_out/MatMul"]) - # get model and quantize - fp32_model = self.gptj - qmodel = algos.rtn_quantize_entry(fp32_model, quant_config=global_config + fc_out_config) - self.assertIsNotNone(qmodel) - self.assertEqual(self._count_woq_matmul(qmodel), 1) - self.assertTrue(self._check_node_is_quantized(qmodel, "/h.4/mlp/fc_out/MatMul")) - - def test_config_white_lst3(self): global_config = config.RTNConfig(weight_bits=4) # set operator instance diff --git a/test/utils/test_onnx_model.py b/test/utils/test_onnx_model.py index f27f64e1f..999b0985b 100644 --- a/test/utils/test_onnx_model.py +++ b/test/utils/test_onnx_model.py @@ -88,6 +88,7 @@ def tearDownClass(self): shutil.rmtree("./gptj", ignore_errors=True) shutil.rmtree("./large_model", ignore_errors=True) os.remove("matmul_add.onnx") + os.remove("model1.onnx") def setUp(self): # print the test name @@ -102,7 +103,7 @@ def test_model_atrribute(self): # model_path self.assertEqual(model.model_path, self.matmul_add_model) # framework - self.assertEqual(model.framework(), "onnxruntime") + self.assertEqual(model.framework, "onnxruntime") # q_config quant_config = config.RTNConfig() model.q_config = quant_config