diff --git a/main_train.py b/main_train.py index c57d561..5011110 100644 --- a/main_train.py +++ b/main_train.py @@ -1,5 +1,6 @@ import os from os.path import join +import random from loguru import logger import torch import torch.nn as nn @@ -11,7 +12,8 @@ from utils.data_process import MultiRoundDataProcess from utils.data_collator import SftDataCollator from train_args.common_args import CommonArgs -from datasets import load_dataset + +os.environ["TOKENIZERS_PARALLELISM"] = "false" def initial_args(): @@ -20,19 +22,11 @@ def initial_args(): if args.train_args_path == "sft_args": parser_b = HfArgumentParser((sft_TrainArgument,)) train_args, = parser_b.parse_args_into_dataclasses(args=remaining_args) - print("Loaded instance sft_args") - # elif args.train_args_path == "dpo_args": - # parser_c = HfArgumentParser((dpo_TrainArgument,)) - # train_args, = parser_c.parse_args_into_dataclasses(args=remaining_args) - # print(f"Loaded instance dpo_args") else: raise ValueError("Invalid train_args_path choice") if not os.path.exists(train_args.output_dir): os.mkdir(train_args.output_dir) - logger.add(join(train_args.output_dir, 'train.log')) - logger.info("train_args:{}".format(train_args)) - logger.info("common_args:{}".format(train_args)) set_seed(train_args.seed) assert sum([train_args.fp16, train_args.bf16]) == 1, "only one of fp16 and bf16 can be True" @@ -54,7 +48,6 @@ def find_all_linear_names(model, train_mode): if 'lm_head' in lora_module_names: # needed for 16-bit lora_module_names.remove('lm_head') lora_module_names = list(lora_module_names) - logger.info(f'LoRA target module names: {lora_module_names}') return lora_module_names @@ -76,14 +69,12 @@ def create_tokenizer(args): assert tokenizer.pad_token_id is not None, "pad_token_id should not be None" assert tokenizer.eos_token_id is not None, "eos_token_id should not be None" - logger.info(f'vocab_size of tokenizer: {tokenizer.vocab_size}') return tokenizer def create_model(args, train_args): - logger.info(f'Loading model from base model: {args.model_name_or_path}') - logger.info(f'Train model with {args.train_mode}') + target_modules = None # 确定训练的精度 torch_dtype = torch.bfloat16 if train_args.bf16 else torch.float32 model_kwargs = dict( @@ -139,16 +130,14 @@ def load_model(model_kwargs): model = get_peft_model(model, peft_config) if not train_args.bf16: cast_mixed_precision_params(model, dtype=torch.float16) - logger.info(f'memory footprint of model: {model.get_memory_footprint() / (1024 * 1024 * 1024)} GB') - model.print_trainable_parameters() - # 计算模型参数量 - total = sum(p.numel() for p in model.parameters()) - logger.info("Total model params: %.2fM" % (total / 1e6)) + # logger.info(f'memory footprint of model: {model.get_memory_footprint() / (1024 * 1024 * 1024)} GB') + # model.print_trainable_parameters() return { 'model': model, 'peft_config': peft_config, + 'target_modules': target_modules } @@ -157,14 +146,6 @@ def load_sft_dataset(args, tokenizer): return train_dataset -def load_dpo_dataset(args, tokenizer): - # trl变动很大 暂时先这样改 - if args.task_type == 'dpo': - train_dataset = load_dataset(data_files=args.train_data_path, path='json') - train_dataset = train_dataset['train'] - return train_dataset - - def create_trainer(args, train_args): tokenizer = create_tokenizer(args) model_dict = create_model(args, train_args) @@ -172,12 +153,12 @@ def create_trainer(args, train_args): # peft_config = model_dict['peft_config'] if args.task_type == 'sft': - logger.info('Train model with sft task') train_dataset = load_sft_dataset(args, tokenizer) data_collator = SftDataCollator(tokenizer, args.max_len) elif args.task_type == 'pretrain': pass + log_out(args, train_args, tokenizer, train_dataset, model, model_dict['target_modules']) # sft or pretrain if args.task_type == 'sft': trainer = Trainer( @@ -192,12 +173,41 @@ def create_trainer(args, train_args): return trainer +def log_out(args, train_args, tokenizer, train_dataset, model, target_modules): + total = sum(p.numel() for p in model.parameters()) + logger.add(join(train_args.output_dir, 'train.log')) + if train_args.local_rank == 0: + logger.info("train_args:{}".format(train_args)) + logger.info("common_args:{}".format(args)) + logger.info(f'vocab_size of tokenizer: {tokenizer.vocab_size}') + logger.info(f'Loading model from base model: {args.model_name_or_path}') + logger.info("Total model params: %.2fM" % (total / 1e6)) + logger.info(f'memory footprint of model: {model.get_memory_footprint() / (1024 * 1024 * 1024)} GB') + trainable_params, all_param = model.get_nb_trainable_parameters() + logger.info( + f"trainable params: {trainable_params:,d} || " + f"all params: {all_param:,d} || " + f"trainable%: {100 * trainable_params / all_param:.4f}" + ) + logger.info(f'Train model with {args.task_type} task') + logger.info(f'Train model with {args.train_mode}') + logger.info(f'LoRA target module names: {target_modules}') + logger.info(f'Loading data: {args.train_data_path}') + logger.info(f"Training dataset samples:{len(train_dataset)}") + for index in random.sample(range(len(train_dataset)), 3): + logger.info( + f"Sample {index} of the training set: {train_dataset[index]['input_ids']}, {train_dataset[index]['target_mask']}.") + logger.info( + f"Sample {index} of the training set: {tokenizer.decode(list(train_dataset[index]['input_ids']))}.") + + def main(): args, train_args = initial_args() # 加载trainer trainer = create_trainer(args, train_args) # 开始训练 - logger.info("*** starting training ***") + if train_args.local_rank == 0: + logger.info("*** starting training ***") train_result = trainer.train() # Transformers 更新了自动保存最后训练结果 # final_save_path = join(train_args.output_dir) diff --git a/run_example.sh b/run_example.sh index 2d4038f..efdd38d 100644 --- a/run_example.sh +++ b/run_example.sh @@ -8,7 +8,7 @@ MODEL_PATH="" # train_args_path: [sft_args,dpo_args] # deepspeed 启动 -deepspeed --include localhost:0,1 main_train.py\ +deepspeed --master_port 29507 --include localhost:0,1 main_train.py\ --train_data_path "$DATA_PATH" \ --model_name_or_path "$MODEL_PATH" \ --max_len 1024 \ diff --git a/utils/data_process.py b/utils/data_process.py index 2ae276f..9a2f215 100644 --- a/utils/data_process.py +++ b/utils/data_process.py @@ -7,10 +7,10 @@ class MultiRoundDataProcess(Dataset): def __init__(self, file, tokenizer, max_length, auto_adapt=True): self.tokenizer = tokenizer self.max_length = max_length - logger.info(f'Loading data: {file}') + # logger.info(f'Loading data: {file}') with open(file, 'r', encoding='utf8') as f: data_list = f.readlines() - logger.info(f"There are {len(data_list)} data in dataset") + # logger.info(f"There are {len(data_list)} data in dataset") self.data_list = data_list self.auto_adapt = auto_adapt