Skip to content

Commit

Permalink
Replicate messages before agent hooks (#265)
Browse files Browse the repository at this point in the history
change message replication
  • Loading branch information
braisedpork1964 authored Oct 29, 2024
1 parent cd34d8d commit e8af4cc
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 16 deletions.
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,9 @@ class PrefixedMessageHook(Hook):
self.senders = senders or []

def before_agent(self, agent, messages, session_id):
for i, message in enumerate(messages):
for message in messages:
if message.sender in self.senders:
message = message.copy(deep=True)
message.content = self.prefix + message.content
messages[i] = message
return messages

class AsyncBlogger(AsyncAgent):
def __init__(self, model_path, writer_prompt, critic_prompt, critic_prefix='', max_turn=3):
Expand Down
5 changes: 1 addition & 4 deletions docs/en/get_started/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,9 @@ class PrefixedMessageHook(Hook):
self.senders = senders or []

def before_agent(self, agent, messages, session_id):
for i, message in enumerate(messages):
for message in messages:
if message.sender in self.senders:
message = message.copy(deep=True)
message.content = self.prefix + message.content
messages[i] = message
return messages

class AsyncBlogger(AsyncAgent):
def __init__(self, model_path, writer_prompt, critic_prompt, critic_prefix='', max_turn=3):
Expand Down
14 changes: 6 additions & 8 deletions lagent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,10 @@ def __call__(
) -> AgentMessage:
# message.receiver = self.name
message = [
AgentMessage(sender='user', content=m) if isinstance(m, str) else m
for m in message
AgentMessage(sender='user', content=m)
if isinstance(m, str) else copy.deepcopy(m) for m in message
]
for hook in self._hooks.values():
message = copy.deepcopy(message)
result = hook.before_agent(self, message, session_id)
if result:
message = result
Expand All @@ -87,8 +86,8 @@ def __call__(
content=response_message,
)
self.update_memory(response_message, session_id=session_id)
response_message = copy.deepcopy(response_message)
for hook in self._hooks.values():
response_message = copy.deepcopy(response_message)
result = hook.after_agent(self, response_message, session_id)
if result:
response_message = result
Expand Down Expand Up @@ -177,11 +176,10 @@ async def __call__(self,
session_id=0,
**kwargs) -> AgentMessage:
message = [
AgentMessage(sender='user', content=m) if isinstance(m, str) else m
for m in message
AgentMessage(sender='user', content=m)
if isinstance(m, str) else copy.deepcopy(m) for m in message
]
for hook in self._hooks.values():
message = copy.deepcopy(message)
result = hook.before_agent(self, message, session_id)
if result:
message = result
Expand All @@ -194,8 +192,8 @@ async def __call__(self,
content=response_message,
)
self.update_memory(response_message, session_id=session_id)
response_message = copy.deepcopy(response_message)
for hook in self._hooks.values():
response_message = copy.deepcopy(response_message)
result = hook.after_agent(self, response_message, session_id)
if result:
response_message = result
Expand Down

0 comments on commit e8af4cc

Please sign in to comment.