Skip to content

Commit

Permalink
update CPO、SimPO、CPO-SimPO
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Jul 18, 2024
1 parent 318d88e commit 96644e3
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 76 deletions.
8 changes: 3 additions & 5 deletions rlhf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@
- ✅ Reward模型的训练
- ✅ RLOO
- ✅ PPO
- SimPO、KTO及其他(待更新)
- ✅ SimPO
- ✅ CPO
- ✅ CPO-SimPO



## Quick Star

### 数据格式要求
**PPO、RLOO:**

数据格式要求有如下三个字段:
- prompt
Expand All @@ -44,9 +45,6 @@ huggingface上也有很多数据集,例如:```trl-internal-testing/hh-rlhf-h

数据格式为jsonl,具体可见示例数据:```rlhf/data_example/data.jsonl```

**KTO:**



### Step1 训练Reward Model

Expand Down
8 changes: 6 additions & 2 deletions rlhf/common_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
class TrainArgPath(Enum):
PPO_ARGS = 'rlhf_args/ppo_config.py'
RLOO_ARGS = 'rlhf_args/rloo_config.py'
CPO_ARGS = 'rlhf_args/cpo_config.py'
SimPO_ARGS = 'rlhf_args/simpo_config.py'
CPOSimPO_ARGS = 'rlhf_args/cpo-simpo_config.py'


@dataclass
Expand All @@ -14,12 +17,13 @@ class CommonArgs:
一些常用的自定义参数
"""
train_args_path: TrainArgPath = field(default=TrainArgPath.RLOO_ARGS.value,
metadata={"help": "当前模式训练参数,目前支持 [PPO,RLOO]"})
metadata={"help": "当前模式训练参数,目前支持 [PPO,RLOO,CPO,SimPO,CPOSimPO]"})
# 微调方法相关选择与配置
train_mode: str = field(default='lora', metadata={"help": "选择采用的训练方式:[qlora, lora, full]"})
use_dora: bool = field(default=False,
metadata={"help": "仅在train_mode==lora时可以使用。是否使用Dora(一个基于Lora的变体)"})
rlhf_type: str = field(default="RLOO", metadata={"help": "选择使用的RLHF方法,目前支持[PPO,RLOO]"})
rlhf_type: str = field(default="RLOO",
metadata={"help": "选择使用的RLHF方法,目前支持[PPO,RLOO,CPO,SimPO,CPOSimPO]"})

# lora相关配置
lora_rank: Optional[int] = field(default=64, metadata={"help": "lora rank"})
Expand Down
16 changes: 16 additions & 0 deletions rlhf/rlhf_args/cpo-simpo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dataclasses import dataclass
from typing import Literal
from cpo_config import CPOConfig


@dataclass
class CPOSimPOConfig(CPOConfig):
"""
基于CPOConfig,只需修改
"""
loss_type: Literal["sigmoid", "hinge", "ipo", "simpo"] = "simpo"
"""The type of loss to use."""
cpo_alpha: float = 0.5
"""combined use of CPO and SimPO, which enables more stable training and improved performance.A non-zero
cpo_alpha"""

18 changes: 16 additions & 2 deletions rlhf/rlhf_args/cpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,37 @@
@dataclass
class CPOConfig(TrainingArguments):
max_length: Optional[int] = None
"""The maximum length of the sequences in the batch."""
max_prompt_length: Optional[int] = None
max_completion_length: Optional[int] = None
"""The maximum length of the prompt."""
max_target_length: Optional[int] = None
"""The maximum length of the target."""

beta: float = 0.1
"""The beta factor in CPO loss."""
label_smoothing: float = 0
"""The label smoothing factor. This argument is required if you want to use the default data collator."""
loss_type: Literal["sigmoid", "hinge", "ipo", "simpo"] = "sigmoid"
"""The type of loss to use."""
disable_dropout: bool = True
"""Whether or not to disable dropouts in `model`."""
cpo_alpha: float = 1.0
"""A hyperparameter that controls the strength of the BC regularizer in CPO training."""
simpo_gamma: float = 0.5
"""A target reward margin for the SimPO loss, used only when the "simpo" option is enabled."""

label_pad_token_id: int = -100
"""The label pad token id."""
padding_value: int = None
"""The padding value if it is different to the tokenizer's pad_token_id."""
truncation_mode: str = "keep_end"
"""The truncation mode to use, either `keep_end` or `keep_start`."""
generate_during_eval: bool = False
"""Whether to sample and log generations during evaluation step."""
is_encoder_decoder: Optional[bool] = None

"""If no model is provided, we need to know if the model_init returns an encoder-decoder."""
model_init_kwargs: Optional[Dict] = None
"""Dict of Optional kwargs to pass when instantiating the model from a string"""

dataset_num_proc: Optional[int] = None
"""The number of workers to use to tokenize the data. Defaults to None."""
16 changes: 16 additions & 0 deletions rlhf/rlhf_args/simpo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dataclasses import dataclass
from typing import Literal
from cpo_config import CPOConfig


@dataclass
class SimPOConfig(CPOConfig):
"""
基于CPOConfig,只需修改
"""
loss_type: Literal["sigmoid", "hinge", "ipo", "simpo"] = "simpo"
"""The type of loss to use."""
cpo_alpha: float = 0
"""A hyperparameter that controls the strength of the BC regularizer in CPO training."""
simpo_gamma: float = 0.5
"""A target reward margin for the SimPO loss, used only when the "simpo" option is enabled."""
35 changes: 22 additions & 13 deletions rlhf/rlhf_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,16 @@ def tokenize(element):
return train_dataset, eval_dataset


def load_data_chosen_rej():
pass
def load_data_all(tokenizer, train_data_path, eval_samples):
raw_datasets = pd.read_json(train_data_path, lines=True)
for i in range(len(raw_datasets)):
raw_datasets.loc[i, 'prompt'] = tokenizer.apply_chat_template(raw_datasets['prompt'][i], tokenize=False)
raw_datasets.loc[i, 'chosen'] = tokenizer.apply_chat_template(raw_datasets['chosen'][i], tokenize=False)
raw_datasets.loc[i, 'rejected'] = tokenizer.apply_chat_template(raw_datasets['rejected'][i], tokenize=False)
raw_datasets = Dataset.from_pandas(raw_datasets, preserve_index=False)
train_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples))
eval_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples, len(raw_datasets)))
return train_dataset, eval_dataset


def main():
Expand Down Expand Up @@ -187,21 +195,22 @@ def main():
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# elif args.rlhf_type == 'CPO':
# from trl import CPOTrainer
# trainer = CPOTrainer(
# policy,
# args=cpo_args,
# train_dataset=train_dataset,
# eval_dataset=eval_dataset,
# tokenizer=tokenizer,
# peft_config=get_peft_config(model_config),
# )
# todo: 更优雅的方式实现?
elif args.rlhf_type in ['CPO', 'SimPO', 'CPOSimPO']:
from trl import CPOTrainer
train_dataset, eval_dataset = load_data_all(tokenizer, config.train_data_path, config.eval_samples)
trainer = CPOTrainer(
policy,
args=config,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
)
else:
raise Exception
trainer.train()
trainer.save_model(config.output_dir)
trainer.generate_completions()
# trainer.generate_completions()


if __name__ == "__main__":
Expand Down
6 changes: 2 additions & 4 deletions train_args/sft/lora_qlora/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ class TrainArgument(TrainingArguments):
save_total_limit: Optional[int] = field(default=2, metadata={"help": "If a value is passed, will limit the total "
"amount of checkpoints. Deletes the older "
"checkpoints in"})
lr_scheduler_type: Union[SchedulerType, str] = field(default="constant_with_warmup",
lr_scheduler_type: Union[SchedulerType, str] = field(default="cosine",
metadata={"help": "The scheduler type to use."})
warmup_steps: int = field(default=10, metadata={"help": "Linear warmup over warmup_steps."})
optim: Union[OptimizerNames, str] = field(default='paged_adamw_32bit', metadata={"help": "The optimizer to use."})
optim: Union[OptimizerNames, str] = field(default='adamw_torch', metadata={"help": "The optimizer to use."})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
report_to: Optional[List[str]] = field(default='tensorboard', metadata={
"help": "The list of integrations to report the results and logs to."})
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
remove_unused_columns: Optional[bool] = field(default=False, metadata={
"help": "Remove columns not required by the model when using an nlp.Dataset."})
bf16: bool = field(default=True, metadata={
Expand Down
50 changes: 0 additions & 50 deletions utils/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,56 +3,6 @@
from loguru import logger


class MyDataCollator(object):
def __init__(self, tokenizer, max_len):
self.tokenizer = tokenizer
self.max_len = max_len
self.pad_token_id = tokenizer.pad_token_id

def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
# batch 中的最大长度
lengths = [len(x['input_ids']) for x in batch if x['input_ids'] is not None]
# 进行截断
batch_length = min(self.max_len, max(lengths))

input_ids_batch, attention_mask_batch, target_mask_batch = [], [], []

for x in batch:
print(x)
input_ids = x["input_ids"]
attention_mask = x['attention_mask']
target_mask = x['labels']
if input_ids is None:
logger.info('some input_ids is None,and now continue')
continue
padding_len = batch_length - len(input_ids)
# 开始padding
input_ids = input_ids + [self.pad_token_id] * padding_len
attention_mask = attention_mask + [0] * padding_len
target_mask = target_mask + [0] * padding_len
# 开始截断
input_ids = input_ids[:self.max_len]
attention_mask = attention_mask[:self.max_len]
target_mask = target_mask[:self.max_len]
# 将本批次全部加入列表
input_ids_batch.append(input_ids)
attention_mask_batch.append(attention_mask)
target_mask_batch.append(target_mask)

# 将list转换为tensor,得到最终的的模型输入
input_ids_batch = torch.tensor(input_ids_batch, dtype=torch.long)
attention_mask_batch = torch.tensor(attention_mask_batch, dtype=torch.long)
target_mask_batch = torch.tensor(target_mask_batch, dtype=torch.long)

labels = torch.where(target_mask_batch == 1, input_ids_batch, -100)

return {
'input_ids': input_ids_batch,
'attention_mask': attention_mask_batch,
'labels': labels
}


class SftDataCollator:
def __init__(self, tokenizer, max_length):
self.tokenizer = tokenizer
Expand Down

0 comments on commit 96644e3

Please sign in to comment.