Skip to content

Commit

Permalink
fix TPOP, report name comflict, query template (#321)
Browse files Browse the repository at this point in the history
* update perf TPOP

* fix app name conflict

* update query template

* add timeout
  • Loading branch information
Yunnglin authored Feb 24, 2025
1 parent 120e967 commit 5bc798a
Show file tree
Hide file tree
Showing 13 changed files with 57 additions and 36 deletions.
2 changes: 2 additions & 0 deletions docs/en/get_started/parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ Run `evalscope eval --help` to get a complete list of parameter descriptions.
- `--template-type`: Model inference template, deprecated, refer to `--chat-template`.
- `--api-url`: (Valid only when `eval-type=service`) Model API endpoint, defaults to `None`; supports passing in local or remote OpenAI API format endpoints, for example, `http://127.0.0.1:8000/v1/chat/completions`.
- `--api-key`: (Valid only when `eval-type=service`) Model API endpoint key, defaults to `EMPTY`.
- `--timeout`: (only effective when `eval-type=service`) The timeout duration for model API requests, defaulting to `60` seconds.

## Dataset Parameters
- `--datasets`: Dataset name, supports inputting multiple datasets separated by spaces, datasets will automatically be downloaded from ModelScope, supported datasets refer to [Dataset List](./supported_dataset.md#supported-datasets).
- `--dataset-args`: Configuration parameters for the evaluation dataset, passed in `json` format, where the key is the dataset name and the value is the parameter, note that it needs to correspond one-to-one with the values in the `--datasets` parameter:
- `dataset_id` (or `local_path`): Local path for the dataset, once specified, it will attempt to load local data.
- `prompt_template`: The prompt template for the evaluation dataset. When specified, it will be used to generate prompts. For example, the template for the `gsm8k` dataset is `Question: {query}\nLet's think step by step\nAnswer:`. The question from the dataset will be filled into the `query` field of the template.
- `query_template`: The query template for the evaluation dataset. When specified, it will be used to generate queries. For example, the template for `general_mcq` is `Question: {question}\n{choices}\nAnswer: {answer}\n\n`. The questions from the dataset will be inserted into the `question` field of the template, options will be inserted into the `choices` field, and answers will be inserted into the `answer` field (answer insertion is only effective for few-shot scenarios).
- `system_prompt`: System prompt for the evaluation dataset.
- `subset_list`: List of subsets for the evaluation dataset, once specified, only subset data will be used.
- `few_shot_num`: Number of few-shots.
Expand Down
7 changes: 4 additions & 3 deletions docs/zh/get_started/parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
- 指定为模型的本地路径,例如`/path/to/model`,将从本地加载模型;
- 评测目标为模型API服务时,需要指定为服务对应的模型id,例如`Qwen2.5-0.5B-Instruct`
- `--model-id`: 被评测的模型的别名,用于报告展示。默认为`model`的最后一部分,例如`Qwen/Qwen2.5-0.5B-Instruct``model-id``Qwen2.5-0.5B-Instruct`
- `--api-url`: (仅在`eval-type=service`时有效) 模型API端点,默认为`None`;支持传入本地或远端的OpenAI API格式端点,例如`http://127.0.0.1:8000/v1/chat/completions`
- `--api-key`: (仅在`eval-type=service`时有效) 模型API端点密钥,默认为`EMPTY`
- `--model-args`: 模型加载参数,以逗号分隔,`key=value`形式,,将解析为字典,默认参数:
- `revision`: 模型版本,默认为`master`
- `precision`: 模型精度,默认为`torch.float16`
Expand All @@ -34,13 +32,16 @@
```
- `--chat-template`: 模型推理模板,默认为`None`,表示使用transformers的`apply_chat_template`;支持传入jinjia模版字符串,来自定义推理模板
- `--template-type`: 模型推理模板,已弃用,参考`--chat-template`

- `--api-url`: (仅在`eval-type=service`时有效) 模型API端点,默认为`None`;支持传入本地或远端的OpenAI API格式端点,例如`http://127.0.0.1:8000/v1/chat/completions`
- `--api-key`: (仅在`eval-type=service`时有效) 模型API端点密钥,默认为`EMPTY`
- `--timeout`: (仅在`eval-type=service`时有效) 模型API请求超时时间,默认为`60`

## 数据集参数
- `--datasets`: 数据集名称,支持输入多个数据集,使用空格分开,数据集将自动从modelscope下载,支持的数据集参考[数据集列表](./supported_dataset.md#支持的数据集)
- `--dataset-args`: 评测数据集的设置参数,以`json`字符串格式传入,将解析为字典,注意需要跟`--datasets`参数中的值对应:
- `dataset_id` (或`local_path`): 可指定数据集本地路径,指定后将尝试从本地加载数据
- `prompt_template`: 评测数据集的prompt模板,指定后将使用模板生成prompt。例如`gsm8k`的模版为`Question: {query}\nLet's think step by step\nAnswer:`,数据集的问题将填充到模板`query`字段中
- `query_template`: 评测数据集的query模板,指定后将使用模板生成query。例如`general_mcq`的模版为`问题:{question}\n{choices}\n答案: {answer}\n\n`,数据集的问题将填充到模板`question`字段中,选项填充到`choices`字段中,答案填充到`answer`字段中(答案填充仅对few-shot生效)
- `system_prompt`: 评测数据集的系统prompt
- `subset_list`: 评测数据子集列表,指定后将只使用子集数据
- `few_shot_num`: few-shot的数量
Expand Down
1 change: 1 addition & 0 deletions evalscope/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def add_argument(parser: argparse.ArgumentParser):
parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility.')
parser.add_argument('--api-key', type=str, default='EMPTY', help='The API key for the remote API model.')
parser.add_argument('--api-url', type=str, default=None, help='The API url for the remote API model.')
parser.add_argument('--timeout', type=float, default=60, help='The timeout for the remote API model.')
# yapf: enable


Expand Down
1 change: 1 addition & 0 deletions evalscope/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class BenchmarkMeta:
eval_split: Optional[str] = None
prompt_template: Optional[str] = None
system_prompt: Optional[str] = None
query_template: Optional[str] = None

def _update(self, args: dict):
if args.get('local_path'):
Expand Down
2 changes: 2 additions & 0 deletions evalscope/benchmarks/data_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self,
eval_split: Optional[str] = None,
prompt_template: Optional[str] = None,
system_prompt: Optional[str] = None,
query_template: Optional[str] = None,
**kwargs):
"""
Data Adapter for the benchmark. You need to implement the following methods:
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(self,
self.eval_split = eval_split
self.prompt_template = prompt_template
self.system_prompt = system_prompt
self.query_template = query_template
self.config_kwargs = kwargs
self.category_map = kwargs.get('category_map', {})

Expand Down
16 changes: 6 additions & 10 deletions evalscope/benchmarks/general_mcq/general_mcq_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
train_split='dev',
eval_split='val',
prompt_template='请回答问题,并选出其中的正确答案\n{query}',
)
query_template='问题:{question}\n{choices}\n答案: {answer}\n\n')
class GeneralMCQAdapter(DataAdapter):

choices = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']
Expand Down Expand Up @@ -115,15 +115,11 @@ def parse_pred_result(self, result: str, raw_input_d: dict = None, eval_type: st
def match(self, gold: str, pred: str) -> float:
return exact_match(gold=gold, pred=pred)

@classmethod
def _format_example(cls, input_d: dict, include_answer=True):
example = '问题:' + input_d['question']
for choice in cls.choices:
if choice in input_d:
example += f'\n{choice}. {input_d[f"{choice}"]}'
def _format_example(self, input_d: dict, include_answer=True):
choices_str = '\n'.join([f'{choice}. {input_d[choice]}' for choice in self.choices if choice in input_d])

if include_answer:
example += '\n答案: ' + input_d['answer'] + '\n\n'
return self.query_template.format(
question=input_d['question'], choices=choices_str, answer=input_d['answer'])
else:
example += '\n答案: '
return example
return self.query_template.format(question=input_d['question'], choices=choices_str, answer='').rstrip()
7 changes: 3 additions & 4 deletions evalscope/benchmarks/general_qa/general_qa_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
few_shot_num=0,
train_split=None,
eval_split='test',
prompt_template='请回答问题\n{query}',
)
class GeneralQAAdapter(DataAdapter):
# TODO: set few_shot_num
Expand Down Expand Up @@ -62,10 +63,8 @@ def gen_prompt(self, input_d: dict, subset_name: str, few_shot_list: list, **kwa
logger.warning('The history is not included in the prompt for GeneralQA. \
To be supported in the future.')

prompt = input_d.get('question', '') or input_d.get('query', '')

# if len(history) > 0:
# prompt = '\n'.join(history) + '\n' + prompt
query = input_d.get('question', '') or input_d.get('query', '')
prompt = self.prompt_template.format(query=query)
return {'data': [prompt], 'system_prompt': self.system_prompt}

def get_gold_answer(self, input_d: dict) -> str:
Expand Down
1 change: 1 addition & 0 deletions evalscope/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class TaskConfig:
seed: Optional[int] = 42
api_url: Optional[str] = None # Only used for server model
api_key: Optional[str] = 'EMPTY' # Only used for server model
timeout: Optional[float] = 60 # Only used for server model

def __post_init__(self):
if (not self.model_id) and self.model:
Expand Down
6 changes: 5 additions & 1 deletion evalscope/models/base_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ def initialize_model_adapter(task_cfg: 'TaskConfig', model_adapter_cls: 'BaseMod
elif task_cfg.eval_type == EvalType.SERVICE:
from evalscope.models import ServerModelAdapter
return ServerModelAdapter(
api_url=task_cfg.api_url, model_id=task_cfg.model, api_key=task_cfg.api_key, seed=task_cfg.seed)
api_url=task_cfg.api_url,
model_id=task_cfg.model,
api_key=task_cfg.api_key,
seed=task_cfg.seed,
timeout=task_cfg.timeout)
else:
return model_adapter_cls(
model=base_model, generation_config=task_cfg.generation_config, chat_template=task_cfg.chat_template)
6 changes: 5 additions & 1 deletion evalscope/models/server_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, api_url: str, model_id: str, api_key: str = 'EMPTY', **kwargs
self.model_id = model_id
self.api_key = api_key
self.seed = kwargs.get('seed', None)
self.timeout = kwargs.get('timeout', 60)
self.model_cfg = {'api_url': api_url, 'model_id': model_id, 'api_key': api_key}
super().__init__(model=None, model_cfg=self.model_cfg, **kwargs)

Expand Down Expand Up @@ -93,7 +94,10 @@ def make_request(self, content: dict, infer_cfg: dict = {}) -> dict:
def send_request(self, request_json: dict, max_retries: int = 3) -> dict:
for attempt in range(max_retries):
response = requests.post(
self.api_url, json=request_json, headers={'Authorization': f'Bearer {self.api_key}'})
self.api_url,
json=request_json,
headers={'Authorization': f'Bearer {self.api_key}'},
timeout=self.timeout)
if response.status_code == 200:
response_data = response.json()
return response_data
Expand Down
8 changes: 6 additions & 2 deletions evalscope/perf/utils/benchmark_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class BenchmarkData:
n_chunks: int = 0
n_chunks_time: float = 0.0
max_gpu_memory_cost = 0
time_per_output_token: float = 0.0

prompt_tokens = None
completion_tokens = None
Expand All @@ -37,6 +38,7 @@ def _calculate_query_stream_metric(self) -> Tuple[float, int, float]:
self.first_chunk_latency = self.query_latency
self.n_chunks = 1
self.n_chunks_time = self.query_latency
self.time_per_output_token = self.query_latency / self.completion_tokens

def _calculate_tokens(self, api_plugin):
self.prompt_tokens, self.completion_tokens = \
Expand All @@ -63,6 +65,7 @@ class BenchmarkMetrics:
start_time: Optional[float] = None
total_time: float = 1.0
n_total_queries: int = 0
n_time_per_output_token: float = 0.0

avg_first_chunk_latency: float = -1
avg_latency: float = -1
Expand Down Expand Up @@ -92,6 +95,7 @@ def update_metrics(self, benchmark_data: BenchmarkData, api_plugin):
self.total_first_chunk_latency += benchmark_data.first_chunk_latency
self.n_total_chunks += benchmark_data.n_chunks
self.total_chunks_time += benchmark_data.n_chunks_time
self.n_time_per_output_token += benchmark_data.time_per_output_token
else:
self.n_failed_queries += 1

Expand All @@ -108,7 +112,7 @@ def calculate_averages(self):
self.avg_prompt_tokens = self.n_total_prompt_tokens / self.n_succeed_queries
self.avg_completion_tokens = self.n_total_completion_tokens / self.n_succeed_queries
self.avg_token_per_seconds = self.n_total_completion_tokens / self.total_time
self.avg_time_per_token = self.total_time / self.n_total_completion_tokens
self.avg_time_per_token = self.n_time_per_output_token / self.n_succeed_queries
self.qps = self.n_succeed_queries / self.total_time
except ZeroDivisionError as e:
logger.exception(e)
Expand All @@ -125,7 +129,7 @@ def create_message(self, default_ndigits=3):
'Average QPS': round(self.qps, default_ndigits),
'Average latency (s)': round(self.avg_latency, default_ndigits),
'Average time to first token (s)': round(self.avg_first_chunk_latency, default_ndigits),
'Average time per output token (s)': round(self.avg_time_per_token, 5),
'Average time per output token (s)': round(self.avg_time_per_token, default_ndigits),
'Average input tokens per request': round(self.avg_prompt_tokens, default_ndigits),
'Average output tokens per request': round(self.avg_completion_tokens, default_ndigits),
'Average package latency (s)': round(self.avg_chunk_time, default_ndigits),
Expand Down
18 changes: 11 additions & 7 deletions evalscope/report/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
logger = get_logger()

PLOTLY_THEME = 'plotly_dark'
REPORT_TOKEN = '@@'
MODEL_TOKEN = '::'
DATASET_TOKEN = ', '


def scan_for_report_folders(root_path):
Expand All @@ -42,18 +45,19 @@ def scan_for_report_folders(root_path):
datasets = []
for dataset_item in glob.glob(os.path.join(model_item, '*.json')):
datasets.append(os.path.basename(dataset_item).split('.')[0])
datasets = ','.join(datasets)
reports.append(f'{os.path.basename(folder)}@{os.path.basename(model_item)}:{datasets}')
datasets = DATASET_TOKEN.join(datasets)
reports.append(
f'{os.path.basename(folder)}{REPORT_TOKEN}{os.path.basename(model_item)}{MODEL_TOKEN}{datasets}')

reports = sorted(reports, reverse=True)
logger.debug(f'reports: {reports}')
return reports


def process_report_name(report_name: str):
prefix, report_name = report_name.split('@')
model_name, datasets = report_name.split(':')
datasets = datasets.split(',')
prefix, report_name = report_name.split(REPORT_TOKEN)
model_name, datasets = report_name.split(MODEL_TOKEN)
datasets = datasets.split(DATASET_TOKEN)
return prefix, model_name, datasets


Expand Down Expand Up @@ -519,8 +523,8 @@ def create_single_model_tab(sidebar: SidebarComponents, lang: str):
outputs=[report_list, task_config, dataset_radio, work_dir, model_name])
def update_single_report_data(root_path, report_name):
report_list, datasets, task_cfg = load_single_report(root_path, report_name)
work_dir = os.path.join(root_path, report_name.split('@')[0])
model_name = report_name.split('@')[1].split(':')[0]
work_dir = os.path.join(root_path, report_name.split(REPORT_TOKEN)[0])
model_name = report_name.split(REPORT_TOKEN)[1].split(MODEL_TOKEN)[0]
return (report_list, task_cfg, gr.update(choices=datasets, value=datasets[0]), work_dir, model_name)

@report_list.change(inputs=[report_list], outputs=[score_plot, score_table, sunburst_plot])
Expand Down
18 changes: 10 additions & 8 deletions tests/cli/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def test_run_custom_task(self):
'local_path': 'custom_eval/text/mcq', # 自定义数据集路径
'subset_list': [
'example' # 评测数据集名称,上述 *_dev.csv 中的 *
]
],
'query_template': 'Question: {question}\n{choices}\nAnswer: {answer}' # 问题模板
},
'general_qa': {
'local_path': 'custom_eval/text/qa', # 自定义数据集路径
Expand Down Expand Up @@ -148,16 +149,16 @@ def test_run_server_model(self):
# 'ifeval',
# 'mmlu',
# 'mmlu_pro',
# 'race',
'race',
# 'trivia_qa',
# 'cmmlu',
# 'humaneval',
# 'gsm8k',
# 'bbh',
'competition_math',
'math_500',
'aime24',
'gpqa',
# 'competition_math',
# 'math_500',
# 'aime24',
# 'gpqa',
# 'arc',
# 'ceval',
# 'hellaswag',
Expand Down Expand Up @@ -200,9 +201,10 @@ def test_run_server_model(self):
debug=True,
generation_config={
'temperature': 0.7,
'n': 5
'n': 1
},
use_cache='/mnt/data/data/user/maoyunlin.myl/eval-scope/outputs/20250212_150525'
# use_cache='/mnt/data/data/user/maoyunlin.myl/eval-scope/outputs/20250212_150525',
timeout=60
)

run_task(task_cfg=task_cfg)
Expand Down

0 comments on commit 5bc798a

Please sign in to comment.