diff --git a/src/codegen/extensions/langchain/agent.py b/src/codegen/extensions/langchain/agent.py index 458903c24..dc1a1b47d 100644 --- a/src/codegen/extensions/langchain/agent.py +++ b/src/codegen/extensions/langchain/agent.py @@ -3,7 +3,7 @@ from langchain import hub from langchain.agents import AgentExecutor from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent -from langchain_core.chat_history import ChatMessageHistory +from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_openai import ChatOpenAI @@ -20,6 +20,7 @@ RevealSymbolTool, SearchTool, SemanticEditTool, + SemanticSearchTool, ViewFileTool, ) @@ -59,6 +60,7 @@ def create_codebase_agent( MoveSymbolTool(codebase), RevealSymbolTool(codebase), SemanticEditTool(codebase), + SemanticSearchTool(codebase), CommitTool(codebase), ] diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index fcfcd2997..f00b193a3 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -312,20 +312,18 @@ def _run( return json.dumps(result, indent=2) +class SemanticSearchInput(BaseModel): + query: str = Field(..., description="The natural language search query") + k: int = Field(default=5, description="Number of results to return") + preview_length: int = Field(default=200, description="Length of content preview in characters") + + class SemanticSearchTool(BaseTool): """Tool for semantic code search.""" name: ClassVar[str] = "semantic_search" description: ClassVar[str] = "Search the codebase using natural language queries and semantic similarity" - args_schema: ClassVar[type[BaseModel]] = type( - "SemanticSearchInput", - (BaseModel,), - { - "query": (str, Field(..., description="The natural language search query")), - "k": (int, Field(default=5, description="Number of results to return")), - "preview_length": (int, Field(default=200, description="Length of content preview in characters")), - }, - ) + args_schema: ClassVar[type[BaseModel]] = SemanticSearchInput codebase: Codebase = Field(exclude=True) def __init__(self, codebase: Codebase) -> None: