From bd4813045757e4018da554fa001bb14241df5fa8 Mon Sep 17 00:00:00 2001 From: divyanshuaggarwal Date: Sat, 15 Jul 2023 17:05:29 +0530 Subject: [PATCH] Add LoRA fine-tuning Script for T5 XL/XXL (#1926) Co-authored-by: Divyanshu Aggarwal --- .gitignore | 1 + docs/training.md | 43 +++++- fastchat/train/train_lora_t5.py | 226 ++++++++++++++++++++++++++++++++ 3 files changed, 264 insertions(+), 6 deletions(-) create mode 100644 fastchat/train/train_lora_t5.py diff --git a/.gitignore b/.gitignore index 03111a17c..94b6e614d 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ dist .DS_Store wandb output +checkpoints_flant5_3b # Data *.pkl diff --git a/docs/training.md b/docs/training.md index 628557ad5..4bf0e7950 100644 --- a/docs/training.md +++ b/docs/training.md @@ -3,7 +3,7 @@ You can use the following command to train FastChat-T5 with 4 x A100 (40GB). ```bash torchrun --nproc_per_node=4 --master_port=9778 fastchat/train/train_flant5.py \ --model_name_or_path google/flan-t5-xl \ - --data_path /data/dummy.json \ + --data_path ./data/dummy_conversation.json \ --bf16 True \ --output_dir ./checkpoints_flant5_3b \ --num_train_epochs 3 \ @@ -32,17 +32,17 @@ After training, please use our post-processing [function](https://github.com/lm- ### Fine-tuning using (Q)LoRA You can use the following command to train Vicuna-7B using QLoRA using ZeRO2. Note that ZeRO3 is not currently supported with QLoRA but ZeRO3 does support LoRA, which has a reference configuraiton under playground/deepspeed_config_s3.json. To use QLoRA, you must have bitsandbytes>=0.39.0 and transformers>=4.30.0 installed. ```bash -deepspeed train_lora.py \ - --model_name_or_path ~/model_weights/llama-7b \ +deepspeed fastchat/train/train_lora.py \ + --model_name_or_path ~/model_weights/llama-7b \ --lora_r 8 \ --lora_alpha 16 \ --lora_dropout 0.05 \ - --data_path \ + --data_path ./data/dummy_conversation.json \ --bf16 True \ --output_dir ./checkpoints \ --num_train_epochs 3 \ - --per_device_train_batch_size 4 \ - --per_device_eval_batch_size 4 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 1 \ --evaluation_strategy "no" \ --save_strategy "steps" \ @@ -58,3 +58,34 @@ deepspeed train_lora.py \ --q_lora True \ --deepspeed playground/deepspeed_config_s2.json \ ``` + +For T5-XL or XXL + +```bash +deepspeed fastchat/train/train_lora_t5.py \ + --model_name_or_path google/flan-t5-xl \ + --data_path ./data/dummy_conversation.json \ + --bf16 True \ + --output_dir ./checkpoints_flant5_3b \ + --num_train_epochs 3 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 300 \ + --save_total_limit 1 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --model_max_length 2048 \ + --preprocessed_path ./preprocessed_data/processed.json \ + --gradient_checkpointing True \ + --q_lora True \ + --deepspeed playground/deepspeed_config_s2.json + +``` + + diff --git a/fastchat/train/train_lora_t5.py b/fastchat/train/train_lora_t5.py new file mode 100644 index 000000000..399de4820 --- /dev/null +++ b/fastchat/train/train_lora_t5.py @@ -0,0 +1,226 @@ +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +import copy +import os +from dataclasses import dataclass, field +import random +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence, List + +import torch +import torch.distributed as dist + + +from deepspeed import zero +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType + +import transformers +from torch.utils.data import Dataset +from transformers import Trainer, AddedToken, BitsAndBytesConfig, deepspeed + +from fastchat.train.train_flant5 import ( + smart_tokenizer_and_embedding_resize, + make_supervised_data_module, +) + +from fastchat.train.train_lora import get_peft_state_maybe_zero_3 + +from fastchat.model.model_adapter import get_conversation_template + +default_conversation = get_conversation_template("t5") + +# TODO: import and use code from ../data/dataset.py + +IGNORE_INDEX = -100 +DEFAULT_PAD_TOKEN = "[PAD]" +DEFAULT_EOS_TOKEN = "" +DEFAULT_BOS_TOKEN = "" +DEFAULT_UNK_TOKEN = "" + + +@dataclass +class LoraArguments: + lora_r: int = 8 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_target_modules: List[str] = field(default_factory=lambda: ["q", "v"]) + lora_weight_path: str = "" + lora_bias: str = "none" + q_lora: bool = False + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + lazy_preprocess: bool = False + num_data: int = -1 + preprocessed_path: str = field( + default=None, metadata={"help": "Path to the preprocessed training data."} + ) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=2048, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + + +def safe_save_model_for_hf_trainer( + trainer: transformers.Trainer, output_dir: str, state_dict: dict +): + """Collects the state dict and dump to disk.""" + + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def train(): + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments, LoraArguments) + ) + ( + model_args, + data_args, + training_args, + lora_args, + ) = parser.parse_args_into_dataclasses() + + device_map = None + world_size = int(os.environ.get("WORLD_SIZE", 1)) + ddp = world_size != 1 + if lora_args.q_lora: + device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None + if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): + logging.warning( + "FSDP and ZeRO3 are both currently incompatible with QLoRA." + ) + + compute_dtype = ( + torch.float16 + if training_args.fp16 + else (torch.bfloat16 if training_args.bf16 else torch.float32) + ) + + model = transformers.AutoModelForSeq2SeqLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + device_map=device_map, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=compute_dtype, + ) + if lora_args.q_lora + else None, + ) + + lora_config = LoraConfig( + r=lora_args.lora_r, + lora_alpha=lora_args.lora_alpha, + target_modules=lora_args.lora_target_modules, + lora_dropout=lora_args.lora_dropout, + bias=lora_args.lora_bias, + task_type=TaskType.SEQ_2_SEQ_LM, + ) + + if lora_args.q_lora: + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=training_args.gradient_checkpointing + ) + if not ddp and torch.cuda.device_count() > 1: + # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available + model.is_parallelizable = True + model.model_parallel = True + + model = get_peft_model(model, lora_config) + if training_args.deepspeed is not None and training_args.local_rank == 0: + model.print_trainable_parameters() + + if training_args.gradient_checkpointing: + model.enable_input_require_grads() + + # Dacheng: Note we can only use T5Tokenizer, otherwise it will prepend + # a space before special tokens. + tokenizer = transformers.T5Tokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), + other_tokens=["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"], + tokenizer=tokenizer, + model=model, + ) + + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) + + trainer = Trainer( + model=model, tokenizer=tokenizer, args=training_args, **data_module + ) + + if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + # check if zero3 mode enabled + if deepspeed.is_deepspeed_zero3_enabled(): + # use deepspeed engine internal function to gather state dict + # state_dict_zero3 contains whole parameters of base and lora adapters + # we will not extract lora parameters since peft save_pretrained will do that + # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125 + # https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19 + state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() + if training_args.local_rank == 0: + state_dict = state_dict_zero3 + else: + # in other mode we use original code from fastchat team, to make sure our change is minimum + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), lora_args.lora_bias + ) + + if training_args.local_rank == 0: + safe_save_model_for_hf_trainer( + trainer=trainer, output_dir=training_args.output_dir, state_dict=state_dict + ) + + +if __name__ == "__main__": + train()