Skip to content

Commit

Permalink
chore: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
phil65 committed Feb 11, 2025
1 parent c82f84d commit bb32757
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 29 deletions.
21 changes: 1 addition & 20 deletions src/llmling/tools/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,32 +123,13 @@ def add_container(
self.register(f"{prefix}{name}", func)
logger.debug("Registered callable %s as %s", name, f"{prefix}{name}")

def get_schema(self, name: str) -> py2openai.OpenAIFunctionTool:
"""Get OpenAI function schema for a registered function.
Args:
name: Name of the registered function
Returns:
OpenAI function schema
Raises:
ToolError: If function not found
"""
try:
tool = self.get(name)
return tool.get_schema()
except KeyError as exc:
msg = f"Function {name} not found"
raise ToolError(msg) from exc

def get_schemas(self) -> list[py2openai.OpenAIFunctionTool]:
"""Get schemas for all registered functions.
Returns:
List of OpenAI function schemas
"""
return [self.get_schema(name) for name in self._items]
return [tool.get_schema() for tool in self._items.values()]

async def execute(self, _name: str, **params: Any) -> Any:
"""Execute a registered function.
Expand Down
20 changes: 11 additions & 9 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any

import pytest

from llmling.tools.base import LLMCallableTool
Expand Down Expand Up @@ -73,7 +75,7 @@ async def test_failing_tool():
class TestDynamicTool:
def test_init(self) -> None:
"""Test tool initialization."""
tool = LLMCallableTool.from_callable(
tool = LLMCallableTool[Any].from_callable(
EXAMPLE_IMPORT, name_override="name", description_override="desc"
)
assert tool.name == "name"
Expand All @@ -82,17 +84,17 @@ def test_init(self) -> None:

def test_default_name(self) -> None:
"""Test default name from import path."""
tool = LLMCallableTool.from_callable(EXAMPLE_IMPORT)
tool = LLMCallableTool[Any].from_callable(EXAMPLE_IMPORT)
assert tool.name == "example_tool"

def test_default_description(self) -> None:
"""Test default description from docstring."""
tool = LLMCallableTool.from_callable(EXAMPLE_IMPORT)
tool = LLMCallableTool[Any].from_callable(EXAMPLE_IMPORT)
assert "repeats text" in tool.description.lower()

def test_schema_generation(self) -> None:
"""Test schema generation from function signature."""
tool = LLMCallableTool.from_callable(EXAMPLE_IMPORT)
tool = LLMCallableTool[Any].from_callable(EXAMPLE_IMPORT)
schema = tool.get_schema()

assert schema["function"]["name"] == "example_tool"
Expand All @@ -103,14 +105,14 @@ def test_schema_generation(self) -> None:
@pytest.mark.asyncio
async def test_execution(self) -> None:
"""Test tool execution."""
tool = LLMCallableTool.from_callable(EXAMPLE_IMPORT)
tool = LLMCallableTool[Any].from_callable(EXAMPLE_IMPORT)
result = await tool.execute(text="test", repeat=2)
assert result == "testtest"

@pytest.mark.asyncio
async def test_execution_failure(self) -> None:
"""Test tool execution failure."""
tool = LLMCallableTool.from_callable(FAILING_IMPORT)
tool = LLMCallableTool[Any].from_callable(FAILING_IMPORT)
with pytest.raises(Exception, match="Intentional"):
await tool.execute(text="test")

Expand Down Expand Up @@ -166,10 +168,10 @@ async def test_execute_with_validation(self, registry: ToolRegistry) -> None:
def test_schema_generation(self, registry: ToolRegistry) -> None:
"""Test schema generation for registered tools."""
registry["analyze_ast"] = ANALYZE_IMPORT
schema = registry.get_schema("analyze_ast")
schema = registry["analyze_ast"].get_schema()

assert "code" in schema["function"]["parameters"]["properties"]
assert schema["function"]["parameters"]["required"] == ["code"]
assert schema["function"]["parameters"]["required"] == ["code"] # type: ignore
assert "Analyze Python code AST" in schema["function"]["description"]


Expand All @@ -182,7 +184,7 @@ async def test_tool_integration() -> None:
registry["analyze"] = ANALYZE_IMPORT

# Get schema
schema = registry.get_schema("analyze")
schema = registry["analyze"].get_schema()
assert schema["function"]["name"] == "analyze_ast"
# Execute tool
code = """
Expand Down

0 comments on commit bb32757

Please sign in to comment.