Skip to content

Commit

Permalink
Add LoRA fine-tuning Script for T5 XL/XXL (lm-sys#1926)
Browse files Browse the repository at this point in the history
Co-authored-by: Divyanshu Aggarwal <[email protected]>
  • Loading branch information
divyanshuaggarwal and Divyanshu Aggarwal authored Jul 15, 2023
1 parent 62459e9 commit bd48130
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dist
.DS_Store
wandb
output
checkpoints_flant5_3b

# Data
*.pkl
Expand Down
43 changes: 37 additions & 6 deletions docs/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ You can use the following command to train FastChat-T5 with 4 x A100 (40GB).
```bash
torchrun --nproc_per_node=4 --master_port=9778 fastchat/train/train_flant5.py \
--model_name_or_path google/flan-t5-xl \
--data_path /data/dummy.json \
--data_path ./data/dummy_conversation.json \
--bf16 True \
--output_dir ./checkpoints_flant5_3b \
--num_train_epochs 3 \
Expand Down Expand Up @@ -32,17 +32,17 @@ After training, please use our post-processing [function](https://github.com/lm-
### Fine-tuning using (Q)LoRA
You can use the following command to train Vicuna-7B using QLoRA using ZeRO2. Note that ZeRO3 is not currently supported with QLoRA but ZeRO3 does support LoRA, which has a reference configuraiton under playground/deepspeed_config_s3.json. To use QLoRA, you must have bitsandbytes>=0.39.0 and transformers>=4.30.0 installed.
```bash
deepspeed train_lora.py \
--model_name_or_path ~/model_weights/llama-7b \
deepspeed fastchat/train/train_lora.py \
--model_name_or_path ~/model_weights/llama-7b \
--lora_r 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--data_path <path-to-data> \
--data_path ./data/dummy_conversation.json \
--bf16 True \
--output_dir ./checkpoints \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
Expand All @@ -58,3 +58,34 @@ deepspeed train_lora.py \
--q_lora True \
--deepspeed playground/deepspeed_config_s2.json \
```

For T5-XL or XXL

```bash
deepspeed fastchat/train/train_lora_t5.py \
--model_name_or_path google/flan-t5-xl \
--data_path ./data/dummy_conversation.json \
--bf16 True \
--output_dir ./checkpoints_flant5_3b \
--num_train_epochs 3 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 4 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 300 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--model_max_length 2048 \
--preprocessed_path ./preprocessed_data/processed.json \
--gradient_checkpointing True \
--q_lora True \
--deepspeed playground/deepspeed_config_s2.json

```


226 changes: 226 additions & 0 deletions fastchat/train/train_lora_t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
import copy
import os
from dataclasses import dataclass, field
import random
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List

import torch
import torch.distributed as dist


from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType

import transformers
from torch.utils.data import Dataset
from transformers import Trainer, AddedToken, BitsAndBytesConfig, deepspeed

from fastchat.train.train_flant5 import (
smart_tokenizer_and_embedding_resize,
make_supervised_data_module,
)

from fastchat.train.train_lora import get_peft_state_maybe_zero_3

from fastchat.model.model_adapter import get_conversation_template

default_conversation = get_conversation_template("t5")

# TODO: import and use code from ../data/dataset.py

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "</s>"


@dataclass
class LoraArguments:
lora_r: int = 8
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_target_modules: List[str] = field(default_factory=lambda: ["q", "v"])
lora_weight_path: str = ""
lora_bias: str = "none"
q_lora: bool = False


@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")


@dataclass
class DataArguments:
data_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
lazy_preprocess: bool = False
num_data: int = -1
preprocessed_path: str = field(
default=None, metadata={"help": "Path to the preprocessed training data."}
)


@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=2048,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)


def safe_save_model_for_hf_trainer(
trainer: transformers.Trainer, output_dir: str, state_dict: dict
):
"""Collects the state dict and dump to disk."""

if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa


def train():
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
)
(
model_args,
data_args,
training_args,
lora_args,
) = parser.parse_args_into_dataclasses()

device_map = None
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if lora_args.q_lora:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
logging.warning(
"FSDP and ZeRO3 are both currently incompatible with QLoRA."
)

compute_dtype = (
torch.float16
if training_args.fp16
else (torch.bfloat16 if training_args.bf16 else torch.float32)
)

model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
device_map=device_map,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
)
if lora_args.q_lora
else None,
)

lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=lora_args.lora_target_modules,
lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias,
task_type=TaskType.SEQ_2_SEQ_LM,
)

if lora_args.q_lora:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=training_args.gradient_checkpointing
)
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True

model = get_peft_model(model, lora_config)
if training_args.deepspeed is not None and training_args.local_rank == 0:
model.print_trainable_parameters()

if training_args.gradient_checkpointing:
model.enable_input_require_grads()

# Dacheng: Note we can only use T5Tokenizer, otherwise it will prepend
# a space before special tokens.
tokenizer = transformers.T5Tokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)

smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
other_tokens=["<", "{", "\n", "}", "`", " ", "\\", "^", "\t"],
tokenizer=tokenizer,
model=model,
)

data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)

trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args, **data_module
)

if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
# check if zero3 mode enabled
if deepspeed.is_deepspeed_zero3_enabled():
# use deepspeed engine internal function to gather state dict
# state_dict_zero3 contains whole parameters of base and lora adapters
# we will not extract lora parameters since peft save_pretrained will do that
# https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/peft_model.py#L125
# https://github.com/huggingface/peft/blob/3714aa2fff158fdfa637b2b65952580801d890b2/src/peft/utils/save_and_load.py#L19
state_dict_zero3 = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
if training_args.local_rank == 0:
state_dict = state_dict_zero3
else:
# in other mode we use original code from fastchat team, to make sure our change is minimum
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), lora_args.lora_bias
)

if training_args.local_rank == 0:
safe_save_model_for_hf_trainer(
trainer=trainer, output_dir=training_args.output_dir, state_dict=state_dict
)


if __name__ == "__main__":
train()

0 comments on commit bd48130

Please sign in to comment.