Skip to content

Commit

Permalink
Optimize args
Browse files Browse the repository at this point in the history
  • Loading branch information
mst272 committed Apr 28, 2024
1 parent afd2a05 commit fddb1de
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions utils/args.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Union
from enum import Enum


# 目前支持的template 类型
class TemplateName(Enum):
QWEN = 'qwen'
YI = 'Yi'


@dataclass
Expand All @@ -10,7 +17,7 @@ class CommonArgs:
max_len: int = field(metadata={"help": "最大输入长度"})
train_data_path: Optional[str] = field(metadata={"help": "训练集路径"})
model_name_or_path: str = field(metadata={"help": "下载的所需模型路径"})
template_name: str = field(default="", metadata={"help": "sft时的数据格式"})
template_name: TemplateName = field(default=TemplateName.QWEN, metadata={"help": "sft时的数据格式"})
train_mode: str = field(default="qlora", metadata={"help": "选择采用的训练方式:[qlora, lora]"})
task_type: str = field(default="sft", metadata={"help": "预训练任务:[pretrain, sft, dpo]"})

Expand Down

0 comments on commit fddb1de

Please sign in to comment.