diff --git a/src/llmling/tools/registry.py b/src/llmling/tools/registry.py index 655fc01..d7fbbec 100644 --- a/src/llmling/tools/registry.py +++ b/src/llmling/tools/registry.py @@ -123,32 +123,13 @@ def add_container( self.register(f"{prefix}{name}", func) logger.debug("Registered callable %s as %s", name, f"{prefix}{name}") - def get_schema(self, name: str) -> py2openai.OpenAIFunctionTool: - """Get OpenAI function schema for a registered function. - - Args: - name: Name of the registered function - - Returns: - OpenAI function schema - - Raises: - ToolError: If function not found - """ - try: - tool = self.get(name) - return tool.get_schema() - except KeyError as exc: - msg = f"Function {name} not found" - raise ToolError(msg) from exc - def get_schemas(self) -> list[py2openai.OpenAIFunctionTool]: """Get schemas for all registered functions. Returns: List of OpenAI function schemas """ - return [self.get_schema(name) for name in self._items] + return [tool.get_schema() for tool in self._items.values()] async def execute(self, _name: str, **params: Any) -> Any: """Execute a registered function. diff --git a/tests/test_tools.py b/tests/test_tools.py index 6d84e37..1e75e21 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + import pytest from llmling.tools.base import LLMCallableTool @@ -73,7 +75,7 @@ async def test_failing_tool(): class TestDynamicTool: def test_init(self) -> None: """Test tool initialization.""" - tool = LLMCallableTool.from_callable( + tool = LLMCallableTool[Any].from_callable( EXAMPLE_IMPORT, name_override="name", description_override="desc" ) assert tool.name == "name" @@ -82,17 +84,17 @@ def test_init(self) -> None: def test_default_name(self) -> None: """Test default name from import path.""" - tool = LLMCallableTool.from_callable(EXAMPLE_IMPORT) + tool = LLMCallableTool[Any].from_callable(EXAMPLE_IMPORT) assert tool.name == "example_tool" def test_default_description(self) -> None: """Test default description from docstring.""" - tool = LLMCallableTool.from_callable(EXAMPLE_IMPORT) + tool = LLMCallableTool[Any].from_callable(EXAMPLE_IMPORT) assert "repeats text" in tool.description.lower() def test_schema_generation(self) -> None: """Test schema generation from function signature.""" - tool = LLMCallableTool.from_callable(EXAMPLE_IMPORT) + tool = LLMCallableTool[Any].from_callable(EXAMPLE_IMPORT) schema = tool.get_schema() assert schema["function"]["name"] == "example_tool" @@ -103,14 +105,14 @@ def test_schema_generation(self) -> None: @pytest.mark.asyncio async def test_execution(self) -> None: """Test tool execution.""" - tool = LLMCallableTool.from_callable(EXAMPLE_IMPORT) + tool = LLMCallableTool[Any].from_callable(EXAMPLE_IMPORT) result = await tool.execute(text="test", repeat=2) assert result == "testtest" @pytest.mark.asyncio async def test_execution_failure(self) -> None: """Test tool execution failure.""" - tool = LLMCallableTool.from_callable(FAILING_IMPORT) + tool = LLMCallableTool[Any].from_callable(FAILING_IMPORT) with pytest.raises(Exception, match="Intentional"): await tool.execute(text="test") @@ -166,10 +168,10 @@ async def test_execute_with_validation(self, registry: ToolRegistry) -> None: def test_schema_generation(self, registry: ToolRegistry) -> None: """Test schema generation for registered tools.""" registry["analyze_ast"] = ANALYZE_IMPORT - schema = registry.get_schema("analyze_ast") + schema = registry["analyze_ast"].get_schema() assert "code" in schema["function"]["parameters"]["properties"] - assert schema["function"]["parameters"]["required"] == ["code"] + assert schema["function"]["parameters"]["required"] == ["code"] # type: ignore assert "Analyze Python code AST" in schema["function"]["description"] @@ -182,7 +184,7 @@ async def test_tool_integration() -> None: registry["analyze"] = ANALYZE_IMPORT # Get schema - schema = registry.get_schema("analyze") + schema = registry["analyze"].get_schema() assert schema["function"]["name"] == "analyze_ast" # Execute tool code = """