Skip to content

Commit

Permalink
update sft log
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Nov 6, 2024
1 parent bf33956 commit 97383a8
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 31 deletions.
66 changes: 38 additions & 28 deletions main_train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from os.path import join
import random
from loguru import logger
import torch
import torch.nn as nn
Expand All @@ -11,7 +12,8 @@
from utils.data_process import MultiRoundDataProcess
from utils.data_collator import SftDataCollator
from train_args.common_args import CommonArgs
from datasets import load_dataset

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def initial_args():
Expand All @@ -20,19 +22,11 @@ def initial_args():
if args.train_args_path == "sft_args":
parser_b = HfArgumentParser((sft_TrainArgument,))
train_args, = parser_b.parse_args_into_dataclasses(args=remaining_args)
print("Loaded instance sft_args")
# elif args.train_args_path == "dpo_args":
# parser_c = HfArgumentParser((dpo_TrainArgument,))
# train_args, = parser_c.parse_args_into_dataclasses(args=remaining_args)
# print(f"Loaded instance dpo_args")
else:
raise ValueError("Invalid train_args_path choice")

if not os.path.exists(train_args.output_dir):
os.mkdir(train_args.output_dir)
logger.add(join(train_args.output_dir, 'train.log'))
logger.info("train_args:{}".format(train_args))
logger.info("common_args:{}".format(train_args))
set_seed(train_args.seed)

assert sum([train_args.fp16, train_args.bf16]) == 1, "only one of fp16 and bf16 can be True"
Expand All @@ -54,7 +48,6 @@ def find_all_linear_names(model, train_mode):
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
lora_module_names = list(lora_module_names)
logger.info(f'LoRA target module names: {lora_module_names}')
return lora_module_names


Expand All @@ -76,14 +69,12 @@ def create_tokenizer(args):

assert tokenizer.pad_token_id is not None, "pad_token_id should not be None"
assert tokenizer.eos_token_id is not None, "eos_token_id should not be None"
logger.info(f'vocab_size of tokenizer: {tokenizer.vocab_size}')

return tokenizer


def create_model(args, train_args):
logger.info(f'Loading model from base model: {args.model_name_or_path}')
logger.info(f'Train model with {args.train_mode}')
target_modules = None
# 确定训练的精度
torch_dtype = torch.bfloat16 if train_args.bf16 else torch.float32
model_kwargs = dict(
Expand Down Expand Up @@ -139,16 +130,14 @@ def load_model(model_kwargs):
model = get_peft_model(model, peft_config)
if not train_args.bf16:
cast_mixed_precision_params(model, dtype=torch.float16)
logger.info(f'memory footprint of model: {model.get_memory_footprint() / (1024 * 1024 * 1024)} GB')
model.print_trainable_parameters()

# 计算模型参数量
total = sum(p.numel() for p in model.parameters())
logger.info("Total model params: %.2fM" % (total / 1e6))
# logger.info(f'memory footprint of model: {model.get_memory_footprint() / (1024 * 1024 * 1024)} GB')
# model.print_trainable_parameters()

return {
'model': model,
'peft_config': peft_config,
'target_modules': target_modules
}


Expand All @@ -157,27 +146,19 @@ def load_sft_dataset(args, tokenizer):
return train_dataset


def load_dpo_dataset(args, tokenizer):
# trl变动很大 暂时先这样改
if args.task_type == 'dpo':
train_dataset = load_dataset(data_files=args.train_data_path, path='json')
train_dataset = train_dataset['train']
return train_dataset


def create_trainer(args, train_args):
tokenizer = create_tokenizer(args)
model_dict = create_model(args, train_args)
model = model_dict['model']
# peft_config = model_dict['peft_config']

if args.task_type == 'sft':
logger.info('Train model with sft task')
train_dataset = load_sft_dataset(args, tokenizer)
data_collator = SftDataCollator(tokenizer, args.max_len)
elif args.task_type == 'pretrain':
pass

log_out(args, train_args, tokenizer, train_dataset, model, model_dict['target_modules'])
# sft or pretrain
if args.task_type == 'sft':
trainer = Trainer(
Expand All @@ -192,12 +173,41 @@ def create_trainer(args, train_args):
return trainer


def log_out(args, train_args, tokenizer, train_dataset, model, target_modules):
total = sum(p.numel() for p in model.parameters())
logger.add(join(train_args.output_dir, 'train.log'))
if train_args.local_rank == 0:
logger.info("train_args:{}".format(train_args))
logger.info("common_args:{}".format(args))
logger.info(f'vocab_size of tokenizer: {tokenizer.vocab_size}')
logger.info(f'Loading model from base model: {args.model_name_or_path}')
logger.info("Total model params: %.2fM" % (total / 1e6))
logger.info(f'memory footprint of model: {model.get_memory_footprint() / (1024 * 1024 * 1024)} GB')
trainable_params, all_param = model.get_nb_trainable_parameters()
logger.info(
f"trainable params: {trainable_params:,d} || "
f"all params: {all_param:,d} || "
f"trainable%: {100 * trainable_params / all_param:.4f}"
)
logger.info(f'Train model with {args.task_type} task')
logger.info(f'Train model with {args.train_mode}')
logger.info(f'LoRA target module names: {target_modules}')
logger.info(f'Loading data: {args.train_data_path}')
logger.info(f"Training dataset samples:{len(train_dataset)}")
for index in random.sample(range(len(train_dataset)), 3):
logger.info(
f"Sample {index} of the training set: {train_dataset[index]['input_ids']}, {train_dataset[index]['target_mask']}.")
logger.info(
f"Sample {index} of the training set: {tokenizer.decode(list(train_dataset[index]['input_ids']))}.")


def main():
args, train_args = initial_args()
# 加载trainer
trainer = create_trainer(args, train_args)
# 开始训练
logger.info("*** starting training ***")
if train_args.local_rank == 0:
logger.info("*** starting training ***")
train_result = trainer.train()
# Transformers 更新了自动保存最后训练结果
# final_save_path = join(train_args.output_dir)
Expand Down
2 changes: 1 addition & 1 deletion run_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ MODEL_PATH=""
# train_args_path: [sft_args,dpo_args]

# deepspeed 启动
deepspeed --include localhost:0,1 main_train.py\
deepspeed --master_port 29507 --include localhost:0,1 main_train.py\
--train_data_path "$DATA_PATH" \
--model_name_or_path "$MODEL_PATH" \
--max_len 1024 \
Expand Down
4 changes: 2 additions & 2 deletions utils/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ class MultiRoundDataProcess(Dataset):
def __init__(self, file, tokenizer, max_length, auto_adapt=True):
self.tokenizer = tokenizer
self.max_length = max_length
logger.info(f'Loading data: {file}')
# logger.info(f'Loading data: {file}')
with open(file, 'r', encoding='utf8') as f:
data_list = f.readlines()
logger.info(f"There are {len(data_list)} data in dataset")
# logger.info(f"There are {len(data_list)} data in dataset")
self.data_list = data_list
self.auto_adapt = auto_adapt

Expand Down

0 comments on commit 97383a8

Please sign in to comment.