Skip to content

Commit

Permalink
Refine Lmdeploy with InternVL2.5-78B-MPO & QVQ-72B-preview (#749)
Browse files Browse the repository at this point in the history
* change config

* modify LmdeployWrapper

* modify build_mpo_prompt.py with WeMath
  • Loading branch information
PhoenixZ810 authored Jan 23, 2025
1 parent 61f6b3d commit 4f2d6ad
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
25 changes: 19 additions & 6 deletions vlmeval/api/lmdeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def build_prompt(self, line, dataset=None):
'DUDE', 'SLIDEVQA', 'GQA', 'MMLongBench_DOC'], dataset):
prompt = question + '\nAnswer the question using a single word or phrase.'
elif listinstr(['MathVista', 'MathVision', 'VCR', 'MTVQA', 'MMVet', 'MathVerse',
'MMDU', 'CRPE', 'MIA-Bench', 'MM-Math', 'DynaMath', 'QSpatial'], dataset):
'MMDU', 'CRPE', 'MIA-Bench', 'MM-Math', 'DynaMath', 'QSpatial', 'WeMath', 'LogicVista'], dataset):
prompt = question
if os.getenv('USE_COT') == '1':
prompt = build_qa_cot_prompt(line, prompt)
Expand Down Expand Up @@ -154,7 +154,7 @@ class LMDeployWrapper(BaseAPI):
prompt_map = {
'cogvlm2': CogVLM2_PromptUtil(),
'internvl2': InternVL2_PromptUtil(),
'internvl2-8b-mpo-cot': InternVL2_PromptUtil(use_mpo_prompt=True),
'internvl2-mpo-cot': InternVL2_PromptUtil(use_mpo_prompt=True),
}

def __init__(self,
Expand All @@ -170,7 +170,6 @@ def __init__(self,
**kwargs):
self.fail_msg = 'Failed to obtain answer via API. '
self.max_tokens = max_tokens
self.temperature = temperature
self.timeout = timeout

key = os.environ.get('LMDEPLOY_API_KEY', key)
Expand All @@ -188,6 +187,8 @@ def __init__(self,
self.set_prompt_pattern(self.model)
if hasattr(self, 'custom_prompt'):
self.logger.info(f'using custom prompt {self.custom_prompt}')
self.temperature = temperature
self.logger.info(f'Init temperature: {self.temperature}')

def set_dump_image(self, dump_image_func):
if self.custom_prompt in self.prompt_map:
Expand All @@ -212,15 +213,26 @@ def set_prompt_pattern(self, model_name):
self.max_tokens = 2048
self.temperature = 0.0
self.custom_prompt = 'cogvlm2'
if 'InternVL2-'.lower() in model_name.lower():
if 'InternVL2'.lower() in model_name.lower():
self.max_tokens = 1024
self.temperature = 0.0
self.custom_prompt = 'internvl2'
if 'mpo' in model_name.lower():
self.max_tokens = 4096
self.logger.info('Use custom prompt internvl2-mpo-cot')
self.custom_prompt = 'internvl2-mpo-cot'
else:
self.logger.info('Use custom prompt internvl2')
self.custom_prompt = 'internvl2'
if 'internvl2-8b-mpo-cot'.lower() in model_name.lower():
self.use_mpo_prompt = True
self.max_tokens = 1024
self.temperature = 0.0
self.custom_prompt = 'internvl2-8b-mpo-cot'
self.logger.info('Use custom prompt internvl2-mpo-cot')
self.custom_prompt = 'internvl2-mpo-cot'
if 'qvq'.lower() in model_name.lower():
self.max_tokens = 4096
self.temperature = 0.0
self.logger.info('QVQ model detected, do not use custom prompt')

def prepare_itlist(self, inputs):
assert np.all([isinstance(x, dict) for x in inputs])
Expand Down Expand Up @@ -263,6 +275,7 @@ def generate_inner(self, inputs, **kwargs) -> str:
input_msgs = self.prepare_inputs(inputs)

temperature = kwargs.pop('temperature', self.temperature)
self.logger.info(f'Generate temperature: {temperature}')
max_tokens = kwargs.pop('max_tokens', self.max_tokens)

headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.key}'}
Expand Down
2 changes: 2 additions & 0 deletions vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@
'TeleMM': partial(TeleMMAPI, model='TeleAI/TeleMM', temperature=0, retry=10),
# lmdeploy api
'lmdeploy': partial(LMDeployAPI, api_base='http://0.0.0.0:23333/v1/chat/completions', temperature=0, retry=10),
'lmdeploy_internvl_78B_MPO': partial(LMDeployAPI, api_base='http://0.0.0.0:23333/v1/chat/completions', temperature=0, retry=10, timeout=100),
'lmdeploy_qvq_72B_preview': partial(LMDeployAPI, api_base='http://0.0.0.0:23333/v1/chat/completions', temperature=0, retry=10, timeout=300),
# Taichu-VL
'Taichu-VL-2B': partial(TaichuVLAPI, model='Taichu-VL-2B', url='https://platform.wair.ac.cn/api/v1/infer/10381/v1/chat/completions'),
}
Expand Down
3 changes: 2 additions & 1 deletion vlmeval/vlm/internvl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,8 @@ def build_mpo_prompt(message, line, dataset):
if listinstr(['MathVerse', 'MathVision'], dataset):
question_orig = question_orig.split('Question:', 1)[-1].strip()
question_orig = question_orig.replace('Choices:\n', '').strip()

if listinstr(['WeMath'], dataset):
question_orig = question_orig.replace('Regarding the format, please answer following the template below, and be sure to include two <> symbols:\n<Thought process>: <<your thought process>> <Answer>: <<your option>>', '').strip()
options = {
cand: line[cand]
for cand in string.ascii_uppercase
Expand Down

0 comments on commit 4f2d6ad

Please sign in to comment.