Skip to content

Commit

Permalink
fix: Qwen 模型无法正确适配 & 修复同步模式下 openai 流式输出问题 (#289)
Browse files Browse the repository at this point in the history
  • Loading branch information
wojiaoyishang authored Jan 8, 2025
1 parent c337aa8 commit 0cd5df8
Showing 1 changed file with 15 additions and 47 deletions.
62 changes: 15 additions & 47 deletions lagent/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,7 @@ def stream_chat(
# mapping to role that openai supports
messages = self.template_parser(inputs)
for text in self._stream_chat(messages, **gen_params):
if self.model_type.lower().startswith('qwen'):
resp = text
else:
resp += text
resp += text
if not resp:
continue
# remove stop_words
Expand Down Expand Up @@ -270,16 +267,11 @@ def streaming(raw_response):
# Context exceeds maximum length
yield ''
return
if self.model_type.lower().startswith('qwen'):
choice = response['output']['choices'][0]
yield choice['message']['content']
if choice['finish_reason'] == 'stop':
return
else:
choice = response['choices'][0]
if choice['finish_reason'] == 'stop':
return
yield choice['delta'].get('content', '')

choice = response['choices'][0]
if choice['finish_reason'] == 'stop':
return
yield choice['delta'].get('content', '')
except Exception as exc:
msg = f'response {decoded} lead to exception of {str(exc)}'
self.logger.error(msg)
Expand Down Expand Up @@ -316,7 +308,7 @@ def streaming(raw_response):

response = dict()
try:
raw_response = requests.post(self.url, headers=header, data=json.dumps(data), proxies=self.proxies)
raw_response = requests.post(self.url, headers=header, data=json.dumps(data), proxies=self.proxies, stream=True)
return streaming(raw_response)
except requests.ConnectionError:
errmsg = 'Got connection error ' + str(traceback.format_exc())
Expand Down Expand Up @@ -384,7 +376,7 @@ def generate_request_data(self, model_type, messages, gen_params, json_mode=Fals

# Model-specific processing
data = {}
if model_type.lower().startswith('gpt'):
if model_type.lower().startswith('gpt') or model_type.lower().startswith('qwen'):
if 'top_k' in gen_params:
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning)
gen_params.pop('top_k')
Expand All @@ -397,14 +389,6 @@ def generate_request_data(self, model_type, messages, gen_params, json_mode=Fals
data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params}
if json_mode:
data['response_format'] = {'type': 'json_object'}
elif model_type.lower().startswith('qwen'):
header['X-DashScope-SSE'] = 'enable'
gen_params.pop('skip_special_tokens', None)
gen_params.pop('session_id', None)
if 'frequency_penalty' in gen_params:
gen_params['repetition_penalty'] = gen_params.pop('frequency_penalty')
gen_params['result_format'] = 'message'
data = {'model': model_type, 'input': {'messages': messages}, 'parameters': {**gen_params}}
else:
raise NotImplementedError(f'Model type {model_type} is not supported')

Expand Down Expand Up @@ -550,10 +534,7 @@ async def stream_chat(
# mapping to role that openai supports
messages = self.template_parser(inputs)
async for text in self._stream_chat(messages, **gen_params):
if self.model_type.lower().startswith('qwen'):
resp = text
else:
resp += text
resp += text
if not resp:
continue
# remove stop_words
Expand Down Expand Up @@ -679,16 +660,11 @@ async def streaming(raw_response):
# Context exceeds maximum length
yield ''
return
if self.model_type.lower().startswith('qwen'):
choice = response['output']['choices'][0]
yield choice['message']['content']
if choice['finish_reason'] == 'stop':
return
else:
choice = response['choices'][0]
if choice['finish_reason'] == 'stop':
return
yield choice['delta'].get('content', '')

choice = response['choices'][0]
if choice['finish_reason'] == 'stop':
return
yield choice['delta'].get('content', '')
except Exception as exc:
msg = f'response {decoded} lead to exception of {str(exc)}'
self.logger.error(msg)
Expand Down Expand Up @@ -798,7 +774,7 @@ def generate_request_data(self, model_type, messages, gen_params, json_mode=Fals

# Model-specific processing
data = {}
if model_type.lower().startswith('gpt'):
if model_type.lower().startswith('gpt') or model_type.lower().startswith('qwen'):
if 'top_k' in gen_params:
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning)
gen_params.pop('top_k')
Expand All @@ -812,14 +788,6 @@ def generate_request_data(self, model_type, messages, gen_params, json_mode=Fals
data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params}
if json_mode:
data['response_format'] = {'type': 'json_object'}
elif model_type.lower().startswith('qwen'):
header['X-DashScope-SSE'] = 'enable'
gen_params.pop('skip_special_tokens', None)
gen_params.pop('session_id', None)
if 'frequency_penalty' in gen_params:
gen_params['repetition_penalty'] = gen_params.pop('frequency_penalty')
gen_params['result_format'] = 'message'
data = {'model': model_type, 'input': {'messages': messages}, 'parameters': {**gen_params}}
elif model_type.lower().startswith('o1'):
data = {'model': model_type, 'messages': messages, 'n': 1}
else:
Expand Down

0 comments on commit 0cd5df8

Please sign in to comment.