From 226b37d07b9e95e3e4759439758c0954d90e5b4c Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Tue, 21 Jan 2025 20:08:53 -0500 Subject: [PATCH] Make ChatAgent an ABC (#5129) --- .../src/autogen_agentchat/base/_chat_agent.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py index 256f752bfa80..36a80efe019c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py @@ -1,5 +1,6 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, AsyncGenerator, Mapping, Protocol, Sequence, runtime_checkable +from typing import Any, AsyncGenerator, Mapping, Sequence from autogen_core import CancellationToken @@ -19,17 +20,18 @@ class Response: or :class:`ChatMessage`.""" -@runtime_checkable -class ChatAgent(TaskRunner, Protocol): +class ChatAgent(ABC, TaskRunner): """Protocol for a chat agent.""" @property + @abstractmethod def name(self) -> str: """The name of the agent. This is used by team to uniquely identify the agent. It should be unique within the team.""" ... @property + @abstractmethod def description(self) -> str: """The description of the agent. This is used by team to make decisions about which agents to use. The description should @@ -37,15 +39,18 @@ def description(self) -> str: ... @property + @abstractmethod def produced_message_types(self) -> Sequence[type[ChatMessage]]: """The types of messages that the agent produces in the :attr:`Response.chat_message` field. They must be :class:`ChatMessage` types.""" ... + @abstractmethod async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response: """Handles incoming messages and returns a response.""" ... + @abstractmethod def on_messages_stream( self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken ) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]: @@ -53,18 +58,22 @@ def on_messages_stream( and the final item is the response.""" ... + @abstractmethod async def on_reset(self, cancellation_token: CancellationToken) -> None: """Resets the agent to its initialization state.""" ... + @abstractmethod async def save_state(self) -> Mapping[str, Any]: """Save agent state for later restoration""" ... + @abstractmethod async def load_state(self, state: Mapping[str, Any]) -> None: """Restore agent from saved state""" ... + @abstractmethod async def close(self) -> None: """Called when the runtime is stopped or any stop method is called""" ...