From 4f2d6ad78979229d7a34e9c486365803f06cd646 Mon Sep 17 00:00:00 2001 From: Xiangyu Zhao <98592339+PhoenixZ810@users.noreply.github.com> Date: Thu, 23 Jan 2025 21:38:24 +0800 Subject: [PATCH] Refine Lmdeploy with InternVL2.5-78B-MPO & QVQ-72B-preview (#749) * change config * modify LmdeployWrapper * modify build_mpo_prompt.py with WeMath --- vlmeval/api/lmdeploy.py | 25 +++++++++++++++++++------ vlmeval/config.py | 2 ++ vlmeval/vlm/internvl/utils.py | 3 ++- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/vlmeval/api/lmdeploy.py b/vlmeval/api/lmdeploy.py index 72b10236..f6112d39 100644 --- a/vlmeval/api/lmdeploy.py +++ b/vlmeval/api/lmdeploy.py @@ -57,7 +57,7 @@ def build_prompt(self, line, dataset=None): 'DUDE', 'SLIDEVQA', 'GQA', 'MMLongBench_DOC'], dataset): prompt = question + '\nAnswer the question using a single word or phrase.' elif listinstr(['MathVista', 'MathVision', 'VCR', 'MTVQA', 'MMVet', 'MathVerse', - 'MMDU', 'CRPE', 'MIA-Bench', 'MM-Math', 'DynaMath', 'QSpatial'], dataset): + 'MMDU', 'CRPE', 'MIA-Bench', 'MM-Math', 'DynaMath', 'QSpatial', 'WeMath', 'LogicVista'], dataset): prompt = question if os.getenv('USE_COT') == '1': prompt = build_qa_cot_prompt(line, prompt) @@ -154,7 +154,7 @@ class LMDeployWrapper(BaseAPI): prompt_map = { 'cogvlm2': CogVLM2_PromptUtil(), 'internvl2': InternVL2_PromptUtil(), - 'internvl2-8b-mpo-cot': InternVL2_PromptUtil(use_mpo_prompt=True), + 'internvl2-mpo-cot': InternVL2_PromptUtil(use_mpo_prompt=True), } def __init__(self, @@ -170,7 +170,6 @@ def __init__(self, **kwargs): self.fail_msg = 'Failed to obtain answer via API. ' self.max_tokens = max_tokens - self.temperature = temperature self.timeout = timeout key = os.environ.get('LMDEPLOY_API_KEY', key) @@ -188,6 +187,8 @@ def __init__(self, self.set_prompt_pattern(self.model) if hasattr(self, 'custom_prompt'): self.logger.info(f'using custom prompt {self.custom_prompt}') + self.temperature = temperature + self.logger.info(f'Init temperature: {self.temperature}') def set_dump_image(self, dump_image_func): if self.custom_prompt in self.prompt_map: @@ -212,15 +213,26 @@ def set_prompt_pattern(self, model_name): self.max_tokens = 2048 self.temperature = 0.0 self.custom_prompt = 'cogvlm2' - if 'InternVL2-'.lower() in model_name.lower(): + if 'InternVL2'.lower() in model_name.lower(): self.max_tokens = 1024 self.temperature = 0.0 - self.custom_prompt = 'internvl2' + if 'mpo' in model_name.lower(): + self.max_tokens = 4096 + self.logger.info('Use custom prompt internvl2-mpo-cot') + self.custom_prompt = 'internvl2-mpo-cot' + else: + self.logger.info('Use custom prompt internvl2') + self.custom_prompt = 'internvl2' if 'internvl2-8b-mpo-cot'.lower() in model_name.lower(): self.use_mpo_prompt = True self.max_tokens = 1024 self.temperature = 0.0 - self.custom_prompt = 'internvl2-8b-mpo-cot' + self.logger.info('Use custom prompt internvl2-mpo-cot') + self.custom_prompt = 'internvl2-mpo-cot' + if 'qvq'.lower() in model_name.lower(): + self.max_tokens = 4096 + self.temperature = 0.0 + self.logger.info('QVQ model detected, do not use custom prompt') def prepare_itlist(self, inputs): assert np.all([isinstance(x, dict) for x in inputs]) @@ -263,6 +275,7 @@ def generate_inner(self, inputs, **kwargs) -> str: input_msgs = self.prepare_inputs(inputs) temperature = kwargs.pop('temperature', self.temperature) + self.logger.info(f'Generate temperature: {temperature}') max_tokens = kwargs.pop('max_tokens', self.max_tokens) headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.key}'} diff --git a/vlmeval/config.py b/vlmeval/config.py index 6e1234a4..121e9955 100644 --- a/vlmeval/config.py +++ b/vlmeval/config.py @@ -109,6 +109,8 @@ 'TeleMM': partial(TeleMMAPI, model='TeleAI/TeleMM', temperature=0, retry=10), # lmdeploy api 'lmdeploy': partial(LMDeployAPI, api_base='http://0.0.0.0:23333/v1/chat/completions', temperature=0, retry=10), + 'lmdeploy_internvl_78B_MPO': partial(LMDeployAPI, api_base='http://0.0.0.0:23333/v1/chat/completions', temperature=0, retry=10, timeout=100), + 'lmdeploy_qvq_72B_preview': partial(LMDeployAPI, api_base='http://0.0.0.0:23333/v1/chat/completions', temperature=0, retry=10, timeout=300), # Taichu-VL 'Taichu-VL-2B': partial(TaichuVLAPI, model='Taichu-VL-2B', url='https://platform.wair.ac.cn/api/v1/infer/10381/v1/chat/completions'), } diff --git a/vlmeval/vlm/internvl/utils.py b/vlmeval/vlm/internvl/utils.py index 7e3a917f..ad3cefa4 100644 --- a/vlmeval/vlm/internvl/utils.py +++ b/vlmeval/vlm/internvl/utils.py @@ -338,7 +338,8 @@ def build_mpo_prompt(message, line, dataset): if listinstr(['MathVerse', 'MathVision'], dataset): question_orig = question_orig.split('Question:', 1)[-1].strip() question_orig = question_orig.replace('Choices:\n', '').strip() - + if listinstr(['WeMath'], dataset): + question_orig = question_orig.replace('Regarding the format, please answer following the template below, and be sure to include two <> symbols:\n: <> : <>', '').strip() options = { cand: line[cand] for cand in string.ascii_uppercase