diff --git a/assets/wechat.jpg b/assets/wechat.jpg index 3318c9ecd6..79cdc21a99 100644 Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ diff --git a/assets/wechat_npu.jpg b/assets/wechat_npu.jpg index 2449e5ee31..5104d61ccc 100644 Binary files a/assets/wechat_npu.jpg and b/assets/wechat_npu.jpg differ diff --git a/examples/README.md b/examples/README.md index d79518df90..cc2afc463a 100644 --- a/examples/README.md +++ b/examples/README.md @@ -132,8 +132,8 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml #### Supervised Fine-Tuning on Multiple Nodes ```bash -FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml -FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml +FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml +FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml ``` #### Multimodal Supervised Fine-Tuning diff --git a/examples/README_zh.md b/examples/README_zh.md index 6fa935fe3a..b41d7ab8a8 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -132,8 +132,8 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml #### 在多机上进行指令监督微调 ```bash -FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml -FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml +FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml +FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml ``` #### 多模态指令监督微调 diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index dcfa117db8..a000374127 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -169,6 +169,14 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, ) }, ) + combination_type: Optional[str] = field( + default=None, + metadata={"help": "The merging type can be one of ['cat','svd','linear']"}, + ) + combination_weights: Optional[float] = field( + default=None, + metadata={"help": "List of weights for each adapter. "}, + ) adapter_folder: Optional[str] = field( default=None, metadata={"help": "The folder containing the adapter weights to load."}, diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index 9edd87dd2c..728d6a3de8 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING import torch -from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model +from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model, PeftModelForCausalLM from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled @@ -177,10 +177,60 @@ def _setup_lora_tuning( "token": model_args.hf_hub_token, } - for adapter in adapter_to_merge: - model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs) + if len(adapter_to_merge) > 1 and model_args.combination_type is not None: + if model_args.combination_weights is None : + raise ValueError(f"Combination_weights must be provided, if you use '{model_args.combination_type}' to merge lora adapters.") + elif len(model_args.combination_weights) != len(adapter_to_merge): + raise ValueError(f"The number of combination_weights must be consistent with the number of adapters") + + weights = model_args.combination_weights + index = 0 + adapter_names = [] + for idx, adapter in enumerate(adapter_to_merge): + adapter_name = 'ad_' + str(index) + print(adapter_name) + if idx == 0: + model = PeftModelForCausalLM.from_pretrained(model, adapter , adapter_name="ad_0") + else: + model.load_adapter(adapter, adapter_name) + adapter_names.append(adapter_name) + index += 1 + # Since the merge_and_unload() operation will be performed according to the original structure after the LoRA is merged. + # The LoRA weight will be scaled at that step. + # So the weight will be adjusted during the merge to eliminate the impact of the scaling operation during the adapter merge on the merge weight. + adapter_scaling = [] + for adapter_name in adapter_names: + adapter_config = model.peft_config[adapter_name] + adapter_scaling.append(adapter_config.lora_alpha / adapter_config.r) + + weighted_adapter_name = "merged_weighted_ad" + if model_args.combination_type in ['cat','svd']: + weights = [wi / ai for ai, wi in zip(adapter_scaling, weights)] + model.add_weighted_adapter( + adapters = adapter_names, + weights = weights, + adapter_name = weighted_adapter_name, + combination_type = model_args.combination_type, + ) + elif model_args.combination_type in ['linear']: + weights = [wi**2/ai for ai, wi in zip(adapter_scaling, weights)] + print(weights) + model.add_weighted_adapter( + adapters = adapter_names, + weights = weights, + adapter_name = weighted_adapter_name, + combination_type = model_args.combination_type, + ) + model.set_adapter(weighted_adapter_name) + for name in adapter_names: + model.delete_adapter(name) model = model.merge_and_unload() + else: + for adapter in adapter_to_merge: + model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs) + model = model.merge_and_unload() + if len(adapter_to_merge) > 0: logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")