Skip to content

Commit

Permalink
流式调用 30%
Browse files Browse the repository at this point in the history
  • Loading branch information
Asankilp committed Mar 7, 2025
1 parent a61d134 commit 780df08
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 33 deletions.
50 changes: 44 additions & 6 deletions nonebot_plugin_marshoai/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
current_matcher,
)
from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai import AsyncOpenAI, AsyncStream
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage

from .config import config
from .constants import SUPPORT_IMAGE_MODELS
Expand Down Expand Up @@ -96,7 +96,8 @@ async def handle_single_chat(
model_name: str,
tools_list: list,
tool_message: Optional[list] = None,
) -> ChatCompletion:
stream: bool = False,
) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
"""
处理单条聊天
"""
Expand All @@ -109,20 +110,24 @@ async def handle_single_chat(
msg=context_msg + [UserMessage(content=user_message).as_dict()] + (tool_message if tool_message else []), # type: ignore
model_name=model_name,
tools=tools_list if tools_list else None,
stream=stream,
)
return response

async def handle_function_call(
self,
completion: ChatCompletion,
completion: Union[ChatCompletion, AsyncStream[ChatCompletionChunk]],
user_message: Union[str, list],
model_name: str,
tools_list: list,
):
# function call
# 需要获取额外信息,调用函数工具
tool_msg = []
choice = completion.choices[0]
if isinstance(completion, ChatCompletion):
choice = completion.choices[0]
else:
raise ValueError("Unexpected completion type")
# await UniMessage(str(response)).send()
tool_calls = choice.message.tool_calls
# try:
Expand Down Expand Up @@ -198,7 +203,10 @@ async def handle_common_chat(
tools_list=tools_list,
tool_message=tool_message,
)
choice = response.choices[0]
if isinstance(response, ChatCompletion):
choice = response.choices[0]
else:
raise ValueError("Unexpected response type")
# Sprint(choice)
# 当tool_calls非空时,将finish_reason设置为TOOL_CALLS
if choice.message.tool_calls is not None and config.marshoai_fix_toolcalls:
Expand Down Expand Up @@ -240,3 +248,33 @@ async def handle_common_chat(
else:
await UniMessage(f"意外的完成原因:{choice.finish_reason}").send()
return None

async def handle_stream_request(
self, user_message: Union[str, list], model_name: str, tools_list: list
):
"""
处理流式请求
"""
response = await self.handle_single_chat(
user_message=user_message,
model_name=model_name,
tools_list=tools_list,
stream=True,
)

if isinstance(response, AsyncStream):
reasoning_contents = ""
answer_contents = ""
async for chunk in response:
if not chunk.choices:
logger.info("Usage:", chunk.usage)
else:
delta = chunk.choices[0].delta
if (
hasattr(delta, "reasoning_content")
and delta.reasoning_content is not None
):
reasoning_contents += delta.reasoning_content
else:
if delta.content is not None:
answer_contents += delta.content
33 changes: 6 additions & 27 deletions nonebot_plugin_marshoai/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import mimetypes
import re
import uuid
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import aiofiles # type: ignore
import httpx
Expand All @@ -15,8 +15,8 @@
from nonebot_plugin_alconna import Image as ImageMsg
from nonebot_plugin_alconna import Text as TextMsg
from nonebot_plugin_alconna import UniMessage
from openai import AsyncOpenAI, NotGiven
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai import AsyncOpenAI, AsyncStream, NotGiven
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
from zhDateTime import DateTime

from ._types import DeveloperMessage
Expand Down Expand Up @@ -109,35 +109,13 @@ async def get_image_b64(url: str, timeout: int = 10) -> Optional[str]:
return None


async def make_chat(
client: ChatCompletionsClient,
msg: list,
model_name: str,
tools: Optional[list] = None,
):
"""
调用ai获取回复
参数:
client: 用于与AI模型进行通信
msg: 消息内容
model_name: 指定AI模型名
tools: 工具列表
"""
return await client.complete(
messages=msg,
model=model_name,
tools=tools,
**config.marshoai_model_args,
)


async def make_chat_openai(
client: AsyncOpenAI,
msg: list,
model_name: str,
tools: Optional[list] = None,
) -> ChatCompletion:
stream: bool = False,
) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
"""
使用 Openai SDK 调用ai获取回复
Expand All @@ -152,6 +130,7 @@ async def make_chat_openai(
model=model_name,
tools=tools or NOT_GIVEN,
timeout=config.marshoai_timeout,
stream=stream,
**config.marshoai_model_args,
)

Expand Down

0 comments on commit 780df08

Please sign in to comment.