diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 42166dbd1..a56246871 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -25,6 +25,9 @@ from metagpt.utils.common import OutputParser, general_after_log from metagpt.utils.human_interaction import HumanInteraction from metagpt.utils.sanitize import sanitize +from metagpt.logs import logger +from metagpt.provider.human_provider import HumanProvider +from metagpt.configs.llm_config import LLMConfig class ReviewMode(Enum): @@ -152,11 +155,13 @@ class ActionNode: # Action Output content: str instruct_content: BaseModel + problematic_json_history = [] # Store problematic JSON for feedback # For ActionGraph prevs: List["ActionNode"] # previous nodes nexts: List["ActionNode"] # next nodes - + human_provider : Optional[HumanProvider] = None + def __init__( self, key: str, @@ -176,6 +181,7 @@ def __init__( self.schema = schema self.prevs = [] self.nexts = [] + self.human_provider = HumanProvider(LLMConfig()) def __str__(self): return ( @@ -432,22 +438,79 @@ async def _aask_v1( system_msgs: Optional[list[str]] = None, schema="markdown", # compatible to original format timeout=USE_CONFIG_TIMEOUT, - ) -> (str, BaseModel): + ) -> tuple[str, BaseModel]: """Use ActionOutput to wrap the output of aask""" - content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout) - logger.debug(f"llm raw output:\n{content}") - output_class = self.create_model_class(output_class_name, output_data_mapping) - - if schema == "json": - parsed_data = llm_output_postprocess( - output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]" - ) - else: # using markdown parser - parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) - - logger.debug(f"parsed_data:\n{parsed_data}") - instruct_content = output_class(**parsed_data) - return content, instruct_content + self.problematic_json_history = [] + state = 'llm' + original_prompt = prompt + content = "" + while True: + if state=='llm': + content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout) + state = 'autohumanprompt' + elif state=='autohumanprompt': + # check self instance whether it's write code and development + if self.key=='WriteCodePlanAndChange': + # fix most common failures + prompt='only focus on 1 goal for now. gen json with fields of "Incremental Change" and "Development Plan" json should be wrapped in [CONTENT][/CONTENT], nothing else. The incremental change should be exactly one code patch in a list of string. and dev plan is also a list of string. take care of json syntax, double quotes, escape chars. problematic json as follows:'+','.join(self.problematic_json_history) + content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout) + elif self.key=='WP_ISSUE_TYPE': + prompt = original_prompt+'only focus on 1 goal for now. gen json with fields of "issue_type" json should be wrapped in [CONTENT][/CONTENT], nothing else. The incremental change should be exactly one code patch in a list of string. and dev plan is also a list of string. take care of json syntax, double quotes, escape chars. problematic json as follows:'+','.join(self.problematic_json_history) + content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout) + else: + # other cases + state='humanprompt' + if state=='humanprompt': + content = await self.involve_humanprompt_intervention(content, self.problematic_json_history, original_prompt, system_msgs=system_msgs, images=images, timeout=timeout) + state = 'human' + elif state== 'human': + content = await self.involve_human_intervention(content, self.problematic_json_history, original_prompt, system_msgs=system_msgs, images=images, timeout=timeout) + logger.debug(f"llm raw output:\n{content}") + output_class = self.create_model_class(output_class_name, output_data_mapping) + parsed_data = "" + if schema == "json": + try: + parsed_data = llm_output_postprocess( + output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]" + ) + except Exception as e: + logger.warning(f"Failed to parse JSON content: {str(e)}") + self.problematic_json_history.append(content+str(e)+"please notice any common problem for llm gened json, e.g. escape double quote issues") # Store problematic JSON and error message + else: # using markdown parser + parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) + + logger.debug(f"parsed_data:\n{parsed_data}") + + try: + instruct_content = output_class(**parsed_data) + return content, instruct_content + except Exception as e: + # 如果解析失败,则使用 humanprompt 进行修正 + logger.warning(f"Failed to parse data into {output_class_name} class: {str(e)}") + #prompt = await self.involve_humanprompt_intervention(content, self.problematic_json_history, original_prompt, system_msgs=system_msgs, images=images, timeout=timeout) + + async def involve_humanprompt_intervention(self, + content: str, problematic_json_history: list, + original_prompt:str, + images: Optional[Union[str, list[str]]] = None, + system_msgs: Optional[list[str]] = None, + timeout=USE_CONFIG_TIMEOUT): + """Involve human intervention when all retries fail.""" + logger.error("All attempts to parse JSON content failed. Involving human prompt intervention.") + humanprompt_response = self.human_provider.ask(content) + content = await self.llm.aask(humanprompt_response+f" take care double quotes in json response, and escape chars.output wrapped inside [CONTENT][/CONTENT] like previous example. nothing else. problem json:{','.join(problematic_json_history)}", system_msgs, images=images, timeout=timeout) + return content + + async def involve_human_intervention(self, + content: str, problematic_json_history: list, + original_prompt:str, + images: Optional[Union[str, list[str]]] = None, + system_msgs: Optional[list[str]] = None, + timeout=USE_CONFIG_TIMEOUT): + """Involve human intervention when all retries fail.""" + logger.error("All attempts to parse JSON content failed. Involving human intervention.") + human_response = self.human_provider.ask(content) + return human_response def get(self, key): return self.instruct_content.model_dump()[key] diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index ef034ca49..caf6d40c3 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -25,6 +25,7 @@ class LLMType(Enum): OPEN_LLM = "open_llm" GEMINI = "gemini" METAGPT = "metagpt" + HUMAN = "human" AZURE = "azure" OLLAMA = "ollama" # /chat at ollama api OLLAMA_GENERATE = "ollama.generate" # /generate at ollama api @@ -42,6 +43,28 @@ class LLMType(Enum): def __missing__(self, key): return self.OPENAI +LLMModuleMap = { + LLMType.OPENAI: "metagpt.provider.openai_api", + LLMType.ANTHROPIC: "metagpt.provider.anthropic_api", + LLMType.CLAUDE: "metagpt.provider.anthropic_api", # Same module as Anthropic + LLMType.SPARK: "metagpt.provider.spark_api", + LLMType.ZHIPUAI: "metagpt.provider.zhipuai_api", + LLMType.FIREWORKS: "metagpt.provider.fireworks_api", + LLMType.OPEN_LLM: "metagpt.provider.open_llm_api", + LLMType.GEMINI: "metagpt.provider.google_gemini_api", + LLMType.METAGPT: "metagpt.provider.metagpt_api", + LLMType.HUMAN: "metagpt.provider.human_provider", + LLMType.AZURE: "metagpt.provider.azure_openai_api", + LLMType.OLLAMA: "metagpt.provider.ollama_api", + LLMType.QIANFAN: "metagpt.provider.qianfan_api", # Baidu BCE + LLMType.DASHSCOPE: "metagpt.provider.dashscope_api", # Aliyun LingJi DashScope + LLMType.MOONSHOT: "metagpt.provider.moonshot_api", + LLMType.MISTRAL: "metagpt.provider.mistral_api", + LLMType.YI: "metagpt.provider.yi_api", # lingyiwanwu + LLMType.OPENROUTER: "metagpt.provider.openrouter_api", + LLMType.BEDROCK: "metagpt.provider.bedrock_api", + LLMType.ARK: "metagpt.provider.ark_api", +} class LLMConfig(YamlModel): """Config for LLM diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index c90f5774a..6e57083e5 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -5,33 +5,62 @@ @Author : alexanderwu @File : __init__.py """ +import importlib +from metagpt.configs.llm_config import LLMType, LLMModuleMap -from metagpt.provider.google_gemini_api import GeminiLLM -from metagpt.provider.ollama_api import OllamaLLM -from metagpt.provider.openai_api import OpenAILLM -from metagpt.provider.zhipuai_api import ZhiPuAILLM -from metagpt.provider.azure_openai_api import AzureOpenAILLM -from metagpt.provider.metagpt_api import MetaGPTLLM -from metagpt.provider.human_provider import HumanProvider -from metagpt.provider.spark_api import SparkLLM -from metagpt.provider.qianfan_api import QianFanLLM -from metagpt.provider.dashscope_api import DashScopeLLM -from metagpt.provider.anthropic_api import AnthropicLLM -from metagpt.provider.bedrock_api import BedrockLLM -from metagpt.provider.ark_api import ArkLLM +class LLMFactory: + def __init__(self, module_name, instance_name): + self.module_name = module_name + self.instance_name = instance_name + self._module = None -__all__ = [ - "GeminiLLM", - "OpenAILLM", - "ZhiPuAILLM", - "AzureOpenAILLM", - "MetaGPTLLM", - "OllamaLLM", - "HumanProvider", - "SparkLLM", - "QianFanLLM", - "DashScopeLLM", - "AnthropicLLM", - "BedrockLLM", - "ArkLLM", + def __getattr__(self, name): + if self._module is None: + self._module = importlib.import_module(self.module_name) + return getattr(self._module, name) + + def __instancecheck__(self, instance): + if self._module is None: + self._module = importlib.import_module(self.module_name) + return isinstance(instance, getattr(self._module, self.instance_name)) + + def __call__(self, config): + # Import the module when it鈥檚 called for the first time + if self._module is None: + self._module = importlib.import_module(self.module_name) + + # Create an instance of the specified class from the module with the given config + return getattr(self._module, self.instance_name)(config) + +def create_llm_symbol(llm_configurations): + factories = {name: LLMFactory(LLMModuleMap[llm_type], name) for llm_type, name in llm_configurations} + # Add the factory created llm objects to the global namespace + globals().update(factories) + return factories.keys() + +# List of LLM configurations +llm_configurations = [ + (LLMType.GEMINI, "GeminiLLM"), + (LLMType.OLLAMA, "OllamaLLM"), + (LLMType.OPENAI, "OpenAILLM"), + (LLMType.ZHIPUAI, "ZhiPuAILLM"), + (LLMType.AZURE, "AzureOpenAILLM"), + (LLMType.METAGPT, "MetaGPTLLM"), + (LLMType.HUMAN, "HumanProvider"), + (LLMType.SPARK, "SparkLLM"), + (LLMType.QIANFAN, "QianFanLLM"), + (LLMType.DASHSCOPE, "DashScopeLLM"), + (LLMType.ANTHROPIC, "AnthropicLLM"), + (LLMType.BEDROCK, "BedrockLLM"), + (LLMType.ARK, "ArkLLM"), + (LLMType.FIREWORKS, "FireworksLLM"), + (LLMType.OPEN_LLM, "OpenLLM"), + (LLMType.MOONSHOT, "MoonshotLLM"), + (LLMType.MISTRAL, "MistralLLM"), + (LLMType.YI, "YiLLM"), + (LLMType.OPENROUTER, "OpenRouterLLM"), + (LLMType.CLAUDE, "ClaudeLLM"), ] + +# Create all LLMFactory instances and get created symbols +__all__ = create_llm_symbol(llm_configurations) \ No newline at end of file diff --git a/metagpt/provider/llm_provider_registry.py b/metagpt/provider/llm_provider_registry.py index 7f8618590..a9e739f44 100644 --- a/metagpt/provider/llm_provider_registry.py +++ b/metagpt/provider/llm_provider_registry.py @@ -5,9 +5,9 @@ @Author : alexanderwu @File : llm_provider_registry.py """ -from metagpt.configs.llm_config import LLMConfig, LLMType +from metagpt.configs.llm_config import LLMConfig, LLMType, LLMModuleMap from metagpt.provider.base_llm import BaseLLM - +import importlib class LLMProviderRegistry: def __init__(self): @@ -18,6 +18,10 @@ def register(self, key, provider_cls): def get_provider(self, enum: LLMType): """get provider instance according to the enum""" + if enum not in self.providers: + # Import and register the provider if not already registered + module_name = LLMModuleMap[enum] + importlib.import_module(module_name) return self.providers[enum] diff --git a/metagpt/provider/postprocess/base_postprocess_plugin.py b/metagpt/provider/postprocess/base_postprocess_plugin.py index 48130ede8..259e1e3cb 100644 --- a/metagpt/provider/postprocess/base_postprocess_plugin.py +++ b/metagpt/provider/postprocess/base_postprocess_plugin.py @@ -44,7 +44,14 @@ def run_extract_content_from_output(self, content: str, right_key: str) -> str: def run_retry_parse_json_text(self, content: str) -> Union[dict, list]: """inherited class can re-implement the function""" # logger.info(f"extracted json CONTENT from output:\n{content}") - parsed_data = retry_parse_json_text(output=content) # should use output=content + import tolerantjson as tjson + try: + parsed_data = retry_parse_json_text(output=content) # should use output=content + except: + try: + parsed_data = tjson.tolerate(content) + except: + parsed_data = tjson.tolerate(content[8:-14]) return parsed_data def run(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]: diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 17e095c5f..a1ba1acd6 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -314,7 +314,7 @@ def re_extract_content(cont: str, pattern: str) -> str: pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]" new_content = re_extract_content(raw_content, pattern) - if not new_content.startswith("{"): + if not new_content.startswith("{") and not new_content.startswith("```json"): # TODO find a more general pattern # # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation logger.warning(f"extract_content try another pattern: {pattern}")