From 008492d471aaf9658aef94da03a81540b18256b0 Mon Sep 17 00:00:00 2001 From: "Wang, Chang" Date: Thu, 16 May 2024 10:08:44 +0800 Subject: [PATCH] Support ipex cpu WOQ backend (#1546) * support ipex cpu woq Signed-off-by: changwangss --------- Signed-off-by: changwangss Signed-off-by: Dong, Bo Co-authored-by: Dong, Bo --- .../text-generation/quantization/README.md | 2 +- .../quantization/run_generation_cpu_woq.py | 9 +- .../transformers/llm/quantization/utils.py | 109 +++++++++++---- .../transformers/modeling/modeling_auto.py | 131 ++++++++++++++---- .../transformers/utils/config.py | 5 + tests/CI/test_weight_only.py | 21 +++ 6 files changed, 222 insertions(+), 55 deletions(-) diff --git a/examples/huggingface/pytorch/text-generation/quantization/README.md b/examples/huggingface/pytorch/text-generation/quantization/README.md index d36c2ec0e75..6fe9d66aa70 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/README.md +++ b/examples/huggingface/pytorch/text-generation/quantization/README.md @@ -104,7 +104,7 @@ pip install -r requirements_cpu_woq.txt > ``` ### Run -We provide compression technologies such as `WeightOnlyQuant` with `Rtn/Awq/Teq/GPTQ/AutoRound` algorithms and `BitsandBytes`, `load_in_4bit` and `load_in_8bit` work on CPU device, besides we provided use [neural-speed](https://github.com/intel/neural-speed) by `--use_neural_speed` to accelerate the optimized model, [here](https://github.com/intel/neural-speed/blob/main/docs/supported_models.md) is neural-speed supported list. +We provide compression technologies such as `WeightOnlyQuant` with `Rtn/Awq/Teq/GPTQ/AutoRound` algorithms and `BitsandBytes`, `load_in_4bit` and `load_in_8bit` work on CPU device, besides we provide use ipex by `--use_ipex` to use intel extension for pytorch to accelerate the model, also provided use [neural-speed](https://github.com/intel/neural-speed) by `--use_neural_speed` to accelerate the optimized model, [here](https://github.com/intel/neural-speed/blob/main/docs/supported_models.md) is neural-speed supported list. The followings are command to show how to use it. #### Performance ```shell diff --git a/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py b/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py index 21bed936218..a373f36e848 100644 --- a/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py +++ b/examples/huggingface/pytorch/text-generation/quantization/run_generation_cpu_woq.py @@ -28,6 +28,7 @@ "--max_new_tokens", default=32, type=int, help="output max new tokens" ) parser.add_argument("--output_dir", nargs="?", default="./saved_results") +parser.add_argument("--use_ipex", action="store_true") # ============Benchmark configs============== parser.add_argument("--benchmark", action="store_true") parser.add_argument("--iters", default=100, type=int, help="num iter") @@ -207,7 +208,6 @@ if args.woq: if args.woq_algo == "Rtn": quantization_config = RtnConfig( - tokenizer=tokenizer, bits=args.bits, sym=True if args.scheme == "sym" else False, group_size=args.group_size, @@ -215,6 +215,7 @@ scale_dtype=args.scale_dtype, weight_dtype=args.weight_dtype, layer_wise=args.layer_wise, + use_ipex=args.use_ipex, ) elif args.woq_algo == "Awq": quantization_config = AwqConfig( @@ -228,6 +229,7 @@ scale_dtype=args.scale_dtype, weight_dtype=args.weight_dtype, calib_iters=args.calib_iters, + use_ipex=args.use_ipex, ) elif args.woq_algo == "Teq": quantization_config = TeqConfig( @@ -241,6 +243,7 @@ scale_dtype=args.scale_dtype, weight_dtype=args.weight_dtype, calib_iters=args.calib_iters, + use_ipex=args.use_ipex, ) elif args.woq_algo == "GPTQ": quantization_config = GPTQConfig( @@ -260,6 +263,7 @@ weight_dtype=args.weight_dtype, calib_iters=args.calib_iters, layer_wise=args.layer_wise, + use_ipex=args.use_ipex, ) elif args.woq_algo == "AutoRound": quantization_config = AutoRoundConfig( @@ -277,6 +281,7 @@ lr=args.lr, minmax_lr=args.minmax_lr, use_quant_input=args.use_quant_input, + use_ipex=args.use_ipex, ) else: assert False, "Please set the correct '--woq_algo'" @@ -388,6 +393,8 @@ model_args += ",model_format=neural_speed" args = LMEvalParser(model = "hf", model_args=model_args, + #user_model=user_model, + #tokenizer=tokenizer, tasks = args.tasks, device = "cpu", batch_size = args.batch_size) diff --git a/intel_extension_for_transformers/transformers/llm/quantization/utils.py b/intel_extension_for_transformers/transformers/llm/quantization/utils.py index bbf38d7fdd7..86c909ac4f9 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/utils.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/utils.py @@ -183,34 +183,79 @@ def _replace_linear( or device == torch.device("cpu") or device == "auto" ): - from .nn.modules import ( - QuantizedLinearQBits, - ) # TODO: QuantizedLinearINT4, QuantizedLinearINT8 - - use_optimum_format = getattr(module, "use_optimum_format", False) or \ - quantization_config.weight_dtype not in [ - "fp8_e5m2", - "fp8_e4m3", - "fp4", - "nf4", - "int4_fullrange", - ] - - model._modules[name] = QuantizedLinearQBits( - in_features, - out_features, - module.bias is not None, - compute_dtype=quantization_config.compute_dtype, - compress_statistics=False, - weight_dtype=quantization_config.weight_dtype, - scale_dtype=quantization_config.scale_dtype, - blocksize=quantization_config.group_size, - scheme=quantization_config.scheme, - compression_dtype=getattr(module, "compression_dtype", torch.int32), - compression_dim=getattr(module, "compression_dim", 1), - device=device, - use_optimum_format=use_optimum_format, - ) + if is_ipex_available() and quantization_config.use_ipex: + from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_linear + from intel_extension_for_pytorch.utils.weight_only_quantization import \ + _convert_optimum_format_to_desired + + qweight, scales, qzeros = _convert_optimum_format_to_desired(module.qweight, + module.scales, + module.qzeros) + + weight_dtype = { + 4: ipex.quantization.WoqWeightDtype.INT4, + 8: ipex.quantization.WoqWeightDtype.INT8, + } + compute_dtype = { + "fp32": ipex.quantization.WoqLowpMode.NONE, # follow the activation datatype. + "bf16": ipex.quantization.WoqLowpMode.BF16, + "fp16": ipex.quantization.WoqLowpMode.FP16, + "int8": ipex.quantization.WoqLowpMode.INT8, + + } + + ipex_qconfig_mapping = ( + ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype[quantization_config.bits], + lowp_mode=compute_dtype[quantization_config.compute_dtype], + act_quant_mode=ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, + group_size=quantization_config.group_size, + ) + ) + tmp_linear = torch.nn.Linear( + in_features, + out_features, + True if hasattr(module, "bias") else False + ) + tmp_linear.qconfig = ipex_qconfig_mapping.global_qconfig + model._modules[name] = ipex_linear.from_float_and_int4_weight( + mod = tmp_linear, + qweight = qweight, + scales = scales, + zero_points = qzeros, + bias = module.bias if hasattr(module, "bias") else None, + group_size = quantization_config.group_size, + g_idx = module.g_idx if hasattr(module, "g_idx") else None, + ) + else: + from .nn.modules import ( + QuantizedLinearQBits, + ) # TODO: QuantizedLinearINT4, QuantizedLinearINT8 + + use_optimum_format = getattr(module, "use_optimum_format", False) or \ + quantization_config.weight_dtype not in [ + "fp8_e5m2", + "fp8_e4m3", + "fp4", + "nf4", + "int4_fullrange", + ] + + model._modules[name] = QuantizedLinearQBits( + in_features, + out_features, + module.bias is not None, + compute_dtype=quantization_config.compute_dtype, + compress_statistics=False, + weight_dtype=quantization_config.weight_dtype, + scale_dtype=quantization_config.scale_dtype, + blocksize=quantization_config.group_size, + scheme=quantization_config.scheme, + compression_dtype=getattr(module, "compression_dtype", torch.int32), + compression_dim=getattr(module, "compression_dim", 1), + device=device, + use_optimum_format=use_optimum_format, + ) elif device == "xpu" or device == torch.device("xpu"): from intel_extension_for_pytorch.nn.utils._quantize_convert \ import WeightOnlyQuantizedLinear as ipex_linear # pylint: disable=E0401 @@ -265,7 +310,9 @@ def _replace_linear( model._modules[name].source_cls = type(module) # Force requires grad to False to avoid unexpected errors model._modules[name].requires_grad_(False) - if device == "cpu" or device == torch.device("cpu") or device == "auto": + if quantization_config.use_ipex: + pass + elif (device == "cpu" or device == torch.device("cpu") or device == "auto"): if quantization_config.weight_dtype in [ "fp8_e5m2", "fp8_e4m3", @@ -560,7 +607,11 @@ def default_calib_func(model): if config.weight_dtype not in ["nf4", "fp4", "int4_fullrange"]: inc_model = inc_model.export_compressed_model(use_optimum_format=True) inc_model.eval() + if config.use_ipex: + optimum_format_state_dict = inc_model.state_dict() q_model = replace_linear(inc_model, None, None, config, device=device) + if config.use_ipex: + setattr(q_model, "optimum_format_state_dict", optimum_format_state_dict) else: q_model = replace_linear( inc_model.model, None, None, config, device=device diff --git a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py index ede4c684427..381a7a0f8bd 100644 --- a/intel_extension_for_transformers/transformers/modeling/modeling_auto.py +++ b/intel_extension_for_transformers/transformers/modeling/modeling_auto.py @@ -186,6 +186,8 @@ def convert_model_to_public(model): module.qweight.data = module.qweight.t_().contiguous() module.scales.data = module.scales.t_().contiguous() module.weight_transposed = False + elif model.quantization_config.use_ipex: + pass elif model.quantization_config.weight_dtype not in [ "fp8_e5m2", "fp8_e4m3", @@ -195,7 +197,6 @@ def convert_model_to_public(model): ]: model = recover_export_model(model) - def make_contiguous(model): for param in model.parameters(): if param.data.ndimension() > 1: @@ -223,6 +224,7 @@ def save_low_bit( os.path.abspath(os.path.expanduser(save_directory)), WEIGHTS_NAME) torch.save(self.quantized_state_dict(), weights_file) return + convert_model_to_public(self) os.makedirs(save_directory, exist_ok=True) # use transformers original `save_pretrained` function @@ -231,6 +233,33 @@ def save_low_bit( self.save_pretrained( save_directory=save_directory, push_to_hub=push_to_hub, **kwargs ) + + if self.quantization_config.use_ipex: + def save_linear_parameters(model, save_directory): + # only can save to pytorch model.bin due to ipex. + weights_file = os.path.join( + os.path.abspath(os.path.expanduser(save_directory)), SAFE_WEIGHTS_NAME) + os.remove(weights_file) + weights_file = os.path.join( + os.path.abspath(os.path.expanduser(save_directory)), WEIGHTS_NAME) + linear_parameters = {} + from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_cpu_linear + for name, module in model.named_modules(): + if isinstance(module, ipex_cpu_linear): + linear_parameters[name + ".ipex_scales"] = module._op_context.get_scales().contiguous() + linear_parameters[name + ".ipex_weight"] = \ + module._op_context.to_public(module._op_context.get_weight()).contiguous() + linear_parameters[name + ".ipex_zeros"] = module._op_context.get_zero_points().contiguous() + if module._op_context.get_bias() is not None: + linear_parameters[name + ".ipex_bias"] = module._op_context.get_bias().contiguous() + if module._op_context.get_g_idx() is not None: + linear_parameters[name + ".ipex_g_idx"] = module._op_context.get_g_idx().contiguous() + others_parameters = model.state_dict() + linear_parameters.update(others_parameters) + + torch.save(linear_parameters, weights_file) + + save_linear_parameters(self, save_directory) self.save_pretrained = types.MethodType(save_low_bit, self) # We conveniently save all the keys of the model to have them on hand, # so that when using 'low_cpumem load', @@ -1814,42 +1843,96 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs): # restore default dtype if dtype_orig is not None: torch.set_default_dtype(dtype_orig) - ( - model, - missing_keys, - unexpected_keys, - mismatched_keys, - offload_index, - error_msgs, - ) = model_class._load_pretrained_model( - model, - None, - loaded_state_dict_keys, # XXX: rename? - resolved_archive_file, - pretrained_model_name_or_path, - sharded_metadata=sharded_metadata, - _fast_init=_fast_init, - low_cpu_mem_usage=True, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - dtype=torch_dtype, - keep_in_fp32_modules=[], - ) + + if is_ipex_available() and quantization_config.use_ipex: + import intel_extension_for_pytorch as ipex + from intel_extension_for_pytorch.nn.modules import WeightOnlyQuantizedLinear as ipex_linear + def replace_ipex_cpu_woq_linear(model, current_name=[]): + for name, module in model.named_children(): + current_name.append(name) + if isinstance(module, WeightOnlyLinear): + weight_dtype = { + 4: ipex.quantization.WoqWeightDtype.INT4, + 8: ipex.quantization.WoqWeightDtype.INT8, + } + compute_dtype = { + "fp32": ipex.quantization.WoqLowpMode.NONE, # follow the activation datatype. + "bf16": ipex.quantization.WoqLowpMode.BF16, + "fp16": ipex.quantization.WoqLowpMode.FP16, + "int8": ipex.quantization.WoqLowpMode.INT8, + + } + + ipex_qconfig_mapping = ( + ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype[quantization_config.bits], + lowp_mode=compute_dtype[quantization_config.compute_dtype], + act_quant_mode=ipex.quantization.WoqActQuantMode.PER_IC_BLOCK, + group_size=quantization_config.group_size, + ) + ) + tmp_linear = torch.nn.Linear( + module.in_features, + module.out_features, + True if hasattr(module, "bias") else False + ) + tmp_linear.qconfig = ipex_qconfig_mapping.global_qconfig + target_linear = ipex_linear.from_float_and_int4_weight( + mod = tmp_linear, + qweight = state_dict.pop('.'.join(current_name) + ".ipex_weight"), + scales = state_dict.pop('.'.join(current_name) + ".ipex_scales"), + zero_points = state_dict.pop('.'.join(current_name) + ".ipex_zeros"), + bias = state_dict.pop('.'.join(current_name) + ".ipex_bias") \ + if '.'.join(current_name) + ".ipex_bias" in state_dict else None, + group_size = quantization_config.group_size, + g_idx = state_dict.pop('.'.join(current_name) + ".ipex_g_idx") \ + if '.'.join(current_name) + ".ipex_g_idx" in state_dict else None, + ) + setattr(model, name, target_linear) + else: + replace_ipex_cpu_woq_linear(module, current_name) + current_name.pop() + + replace_ipex_cpu_woq_linear(model) + model.load_state_dict(state_dict, strict=False, assign=True) + else: + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = model_class._load_pretrained_model( + model, + None, + loaded_state_dict_keys, # XXX: rename? + resolved_archive_file, + pretrained_model_name_or_path, + sharded_metadata=sharded_metadata, + _fast_init=_fast_init, + low_cpu_mem_usage=True, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + keep_in_fp32_modules=[], + ) # make sure token embedding weights are still tied if needed model.tie_weights() # Set model in evaluation mode to deactivate DropOut modules by default model.eval() + if quantization_config.weight_dtype not in [ "fp8_e5m2", "fp8_e4m3", "nf4", "fp4", "int4_fullrange", - ]: + ] and not quantization_config.use_ipex: model = replace_linear( - model.float(), + model, quantization_config=quantization_config, device="cpu" if device_map == "auto" else device_map, empty_weights=True, diff --git a/intel_extension_for_transformers/transformers/utils/config.py b/intel_extension_for_transformers/transformers/utils/config.py index 9376a0a9d8b..c7d0ae80ed4 100644 --- a/intel_extension_for_transformers/transformers/utils/config.py +++ b/intel_extension_for_transformers/transformers/utils/config.py @@ -775,6 +775,7 @@ def __init__( self.dataset = None self.calib_func = None self.calib_iters = None + self.use_ipex = kwargs.pop("use_ipex", False) def to_diff_dict(self) -> Dict[str, Any]: """Removes all attributes from config which correspond to the default config attributes @@ -874,6 +875,7 @@ def __init__( ) else: self.double_quant_scale_dtype = double_quant_scale_dtype + self.use_ipex = kwargs.pop("use_ipex", False) self.post_init_gptq() def post_init_gptq(self): @@ -952,6 +954,7 @@ def __init__( self.calib_iters = kwargs.get("calib_iters", 100) self.scheme = "asym" if self.zero_point else "sym" self.sym = True if not self.zero_point else False + self.use_ipex = kwargs.pop("use_ipex", False) def to_diff_dict(self) -> Dict[str, Any]: """Removes all attributes from config which correspond to the default config attributes @@ -1013,6 +1016,7 @@ def __init__( self.calib_dataloader = kwargs.get("calib_dataloader", None) self.calib_func = kwargs.get("calib_func", None) self.calib_iters = kwargs.get("calib_iters", 100) + self.use_ipex = kwargs.pop("use_ipex", False) def to_diff_dict(self) -> Dict[str, Any]: """Removes all attributes from config which correspond to the default config attributes @@ -1106,6 +1110,7 @@ def __init__( ) else: self.double_quant_scale_dtype = double_quant_scale_dtype + self.use_ipex = kwargs.pop("use_ipex", False) def to_diff_dict(self) -> Dict[str, Any]: """Removes all attributes from config which correspond to the default config attributes diff --git a/tests/CI/test_weight_only.py b/tests/CI/test_weight_only.py index 1952d8e5e69..2632d8c07ad 100644 --- a/tests/CI/test_weight_only.py +++ b/tests/CI/test_weight_only.py @@ -83,6 +83,7 @@ def setUpClass(cls): def tearDownClass(cls) -> None: shutil.rmtree(cls.workspace, ignore_errors=True) shutil.rmtree('tmp', ignore_errors=True) + shutil.rmtree('saved_results', ignore_errors=True) def test_woq_config(self): config = RtnConfig( @@ -154,6 +155,26 @@ def test_int4(self): print(output_quant) assert torch.allclose(output, output_quant, rtol=0.01) + def test_woq_with_ipex_cpu(self): + model_name_or_path = "facebook/opt-125m" + config = RtnConfig(bits=4, use_ipex=True) + model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + quantization_config=config, + use_llm_runtime=False + ) + input_ids = model.dummy_inputs["input_ids"] + output = model(input_ids) + model.save_pretrained("./saved_results") + model = AutoModelForCausalLM.from_pretrained( + "saved_results", + use_llm_runtime=False + ) + output_loading = model(input_ids) + assert torch.allclose(output.logits, output_loading.logits, rtol=0.01) + + + def test_auto_model(self): model = AutoModelForCausalLM.from_pretrained( llama_model_path, load_in_4bit=True, use_neural_speed=False)