Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine Lmdeploy with InternVL2.5-78B-MPO & QVQ-72B-preview #749

Merged
merged 4 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions vlmeval/api/lmdeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def build_prompt(self, line, dataset=None):
'DUDE', 'SLIDEVQA', 'GQA', 'MMLongBench_DOC'], dataset):
prompt = question + '\nAnswer the question using a single word or phrase.'
elif listinstr(['MathVista', 'MathVision', 'VCR', 'MTVQA', 'MMVet', 'MathVerse',
'MMDU', 'CRPE', 'MIA-Bench', 'MM-Math', 'DynaMath', 'QSpatial'], dataset):
'MMDU', 'CRPE', 'MIA-Bench', 'MM-Math', 'DynaMath', 'QSpatial', 'WeMath', 'LogicVista'], dataset):
prompt = question
if os.getenv('USE_COT') == '1':
prompt = build_qa_cot_prompt(line, prompt)
Expand Down Expand Up @@ -154,7 +154,7 @@ class LMDeployWrapper(BaseAPI):
prompt_map = {
'cogvlm2': CogVLM2_PromptUtil(),
'internvl2': InternVL2_PromptUtil(),
'internvl2-8b-mpo-cot': InternVL2_PromptUtil(use_mpo_prompt=True),
'internvl2-mpo-cot': InternVL2_PromptUtil(use_mpo_prompt=True),
}

def __init__(self,
Expand All @@ -170,7 +170,6 @@ def __init__(self,
**kwargs):
self.fail_msg = 'Failed to obtain answer via API. '
self.max_tokens = max_tokens
self.temperature = temperature
self.timeout = timeout

key = os.environ.get('LMDEPLOY_API_KEY', key)
Expand All @@ -188,6 +187,8 @@ def __init__(self,
self.set_prompt_pattern(self.model)
if hasattr(self, 'custom_prompt'):
self.logger.info(f'using custom prompt {self.custom_prompt}')
self.temperature = temperature
self.logger.info(f'Init temperature: {self.temperature}')

def set_dump_image(self, dump_image_func):
if self.custom_prompt in self.prompt_map:
Expand All @@ -212,15 +213,26 @@ def set_prompt_pattern(self, model_name):
self.max_tokens = 2048
self.temperature = 0.0
self.custom_prompt = 'cogvlm2'
if 'InternVL2-'.lower() in model_name.lower():
if 'InternVL2'.lower() in model_name.lower():
self.max_tokens = 1024
self.temperature = 0.0
self.custom_prompt = 'internvl2'
if 'mpo' in model_name.lower():
self.max_tokens = 4096
self.logger.info('Use custom prompt internvl2-mpo-cot')
self.custom_prompt = 'internvl2-mpo-cot'
else:
self.logger.info('Use custom prompt internvl2')
self.custom_prompt = 'internvl2'
if 'internvl2-8b-mpo-cot'.lower() in model_name.lower():
self.use_mpo_prompt = True
self.max_tokens = 1024
self.temperature = 0.0
self.custom_prompt = 'internvl2-8b-mpo-cot'
self.logger.info('Use custom prompt internvl2-mpo-cot')
self.custom_prompt = 'internvl2-mpo-cot'
if 'qvq'.lower() in model_name.lower():
self.max_tokens = 4096
self.temperature = 0.0
self.logger.info('QVQ model detected, do not use custom prompt')

def prepare_itlist(self, inputs):
assert np.all([isinstance(x, dict) for x in inputs])
Expand Down Expand Up @@ -263,6 +275,7 @@ def generate_inner(self, inputs, **kwargs) -> str:
input_msgs = self.prepare_inputs(inputs)

temperature = kwargs.pop('temperature', self.temperature)
self.logger.info(f'Generate temperature: {temperature}')
max_tokens = kwargs.pop('max_tokens', self.max_tokens)

headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.key}'}
Expand Down
2 changes: 2 additions & 0 deletions vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@
'TeleMM': partial(TeleMMAPI, model='TeleAI/TeleMM', temperature=0, retry=10),
# lmdeploy api
'lmdeploy': partial(LMDeployAPI, api_base='http://0.0.0.0:23333/v1/chat/completions', temperature=0, retry=10),
'lmdeploy_internvl_78B_MPO': partial(LMDeployAPI, api_base='http://0.0.0.0:23333/v1/chat/completions', temperature=0, retry=10, timeout=100),
'lmdeploy_qvq_72B_preview': partial(LMDeployAPI, api_base='http://0.0.0.0:23333/v1/chat/completions', temperature=0, retry=10, timeout=300),
# Taichu-VL
'Taichu-VL-2B': partial(TaichuVLAPI, model='Taichu-VL-2B', url='https://platform.wair.ac.cn/api/v1/infer/10381/v1/chat/completions'),
}
Expand Down
3 changes: 2 additions & 1 deletion vlmeval/vlm/internvl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,8 @@ def build_mpo_prompt(message, line, dataset):
if listinstr(['MathVerse', 'MathVision'], dataset):
question_orig = question_orig.split('Question:', 1)[-1].strip()
question_orig = question_orig.replace('Choices:\n', '').strip()

if listinstr(['WeMath'], dataset):
question_orig = question_orig.replace('Regarding the format, please answer following the template below, and be sure to include two <> symbols:\n<Thought process>: <<your thought process>> <Answer>: <<your option>>', '').strip()
options = {
cand: line[cand]
for cand in string.ascii_uppercase
Expand Down
Loading