Skip to content

Commit

Permalink
Merge branch 'feat/enhance-multi-modal-support' into release/0.10.0-beta
Browse files Browse the repository at this point in the history
  • Loading branch information
laipz8200 committed Oct 15, 2024
2 parents 6e9129a + a36ef84 commit 24d06eb
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 20 deletions.
4 changes: 2 additions & 2 deletions api/core/app/entities/task_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class Data(BaseModel):
created_by: Optional[dict] = None
created_at: int
finished_at: int
files: Optional[list[dict]] = []
files: Optional[Sequence[Mapping[str, Any]]] = []

event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
workflow_run_id: str
Expand Down Expand Up @@ -298,7 +298,7 @@ class Data(BaseModel):
execution_metadata: Optional[dict] = None
created_at: int
finished_at: int
files: Optional[list[dict]] = []
files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion api/core/prompt/advanced_prompt_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _get_chat_model_prompt_messages(
for k, v in inputs.items():
if k.startswith("#"):
vp.add(k[1:-1].split("."), v)
raw_prompt.replace("{{#context#}}", context or "")
raw_prompt = raw_prompt.replace("{{#context#}}", context or "")
prompt = vp.convert_template(raw_prompt).text
else:
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
Expand Down
2 changes: 1 addition & 1 deletion api/core/workflow/nodes/llm/llm_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]:
return []
raise ValueError(f"Invalid variable type: {type(variable)}")

def _fetch_context(self, node_data: LLMNodeData) -> Generator[RunEvent, None, None]:
def _fetch_context(self, node_data: LLMNodeData):
if not node_data.context.enabled:
return

Expand Down
52 changes: 36 additions & 16 deletions api/core/workflow/nodes/tool/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from os import path
from typing import Any

from sqlalchemy import select
from sqlalchemy.orm import Session

from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.models import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
Expand All @@ -14,6 +17,8 @@
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from enums import NodeType
from extensions.ext_database import db
from models import ToolFile
from models.workflow import WorkflowNodeExecutionStatus


Expand Down Expand Up @@ -167,45 +172,59 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage])
result = []
for response in tool_response:
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
url = response.message
ext = path.splitext(url)[1]
mimetype = response.meta.get("mime_type", "image/jpeg")
tool_file_id = response.save_as or url.split("/")[-1]
url = str(response.message) if response.message else None
ext = path.splitext(url)[1] if url else ".bin"
tool_file_id = response.save_as or str(url).split("/")[-1].split(".")[0]
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)

# get tool file id
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")

result.append(
File(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
remote_url=url,
related_id=tool_file_id,
filename=tool_file_id,
related_id=tool_file.id,
filename=tool_file.name,
extension=ext,
mime_type=mimetype,
mime_type=tool_file.mimetype,
size=tool_file.size,
)
)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
result.append(
File(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=response.save_as,
related_id=tool_file.id,
filename=tool_file.name,
extension=path.splitext(response.save_as)[1],
mime_type=response.meta.get("mime_type", "application/octet-stream"),
mime_type=tool_file.mimetype,
size=tool_file.size,
)
)
elif response.type == ToolInvokeMessage.MessageType.LINK:
url = str(response.message)
transfer_method = FileTransferMethod.TOOL_FILE
mimetype = response.meta.get("mime_type", "application/octet-stream")
tool_file_id = url.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
if "." in url:
extension = "." + url.split("/")[-1].split(".")[1]
else:
Expand All @@ -215,10 +234,11 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage])
type=FileType(response.save_as),
transfer_method=transfer_method,
remote_url=url,
filename=tool_file_id,
related_id=tool_file_id,
filename=tool_file.name,
related_id=tool_file.id,
extension=extension,
mime_type=mimetype,
mime_type=tool_file.mimetype,
size=tool_file.size,
)
result.append(file)

Expand Down

0 comments on commit 24d06eb

Please sign in to comment.