Skip to content

Commit

Permalink
.Net Agents - ChatHistory Reducer Pattern (#7570)
Browse files Browse the repository at this point in the history
### Motivation and Context
<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
Introduce ability to reduce the length of the chat-history when using
the _Agent Framework_.

A prolonged agent dialog can exceed the allowed token limit. As such,
introducing a technique for managing chat-history is critical. Such a
technique must be available for both `AgentChat` and _No-Chat_
invocation modes.

> Note: This consideration only applies to `ChatCompletionAgent` as the
Assistant API manages this internally for `OpenAIAssistantAgent`.

### Description
<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

- Define `IChatHistoryReducer` contract
- Add `ChatCompletionAgent.HistoryReducer` nullable/optional property
- Update `ChatCompletionAgent.GetChannelKeys()` to assign agents to
dedicted channels based on reducer.
- Implement "Top N" truncating reducer: `ChatHistoryTruncationReducer`
- Implement summarizing reducer: `ChatHistorySummarizationReducer`
- Allow prompt customization for `ChatHistorySummarizationReducer`
- Update `ChatHistoryChannel` to make use of agent-reducer, if present
- Add `ChatCompletion_HistoryReducer` sample under _Concepts_

### Contribution Checklist
<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄
  • Loading branch information
crickman authored Aug 7, 2024
1 parent 4365415 commit 3b0d086
Show file tree
Hide file tree
Showing 20 changed files with 1,363 additions and 125 deletions.
173 changes: 173 additions & 0 deletions dotnet/samples/Concepts/Agents/ChatCompletion_HistoryReducer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Copyright (c) Microsoft. All rights reserved.
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.Agents.History;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Agents;

/// <summary>
/// Demonstrate creation of <see cref="ChatCompletionAgent"/> and
/// eliciting its response to three explicit user messages.
/// </summary>
public class ChatCompletion_HistoryReducer(ITestOutputHelper output) : BaseTest(output)
{
private const string TranslatorName = "NumeroTranslator";
private const string TranslatorInstructions = "Add one to latest user number and spell it in spanish without explanation.";

/// <summary>
/// Demonstrate the use of <see cref="ChatHistoryTruncationReducer"/> when directly
/// invoking a <see cref="ChatCompletionAgent"/>.
/// </summary>
[Fact]
public async Task TruncatedAgentReductionAsync()
{
// Define the agent
ChatCompletionAgent agent = CreateTruncatingAgent(10, 10);

await InvokeAgentAsync(agent, 50);
}

/// <summary>
/// Demonstrate the use of <see cref="ChatHistorySummarizationReducer"/> when directly
/// invoking a <see cref="ChatCompletionAgent"/>.
/// </summary>
[Fact]
public async Task SummarizedAgentReductionAsync()
{
// Define the agent
ChatCompletionAgent agent = CreateSummarizingAgent(10, 10);

await InvokeAgentAsync(agent, 50);
}

/// <summary>
/// Demonstrate the use of <see cref="ChatHistoryTruncationReducer"/> when using
/// <see cref="AgentGroupChat"/> to invoke a <see cref="ChatCompletionAgent"/>.
/// </summary>
[Fact]
public async Task TruncatedChatReductionAsync()
{
// Define the agent
ChatCompletionAgent agent = CreateTruncatingAgent(10, 10);

await InvokeChatAsync(agent, 50);
}

/// <summary>
/// Demonstrate the use of <see cref="ChatHistorySummarizationReducer"/> when using
/// <see cref="AgentGroupChat"/> to invoke a <see cref="ChatCompletionAgent"/>.
/// </summary>
[Fact]
public async Task SummarizedChatReductionAsync()
{
// Define the agent
ChatCompletionAgent agent = CreateSummarizingAgent(10, 10);

await InvokeChatAsync(agent, 50);
}

// Proceed with dialog by directly invoking the agent and explicitly managing the history.
private async Task InvokeAgentAsync(ChatCompletionAgent agent, int messageCount)
{
ChatHistory chat = [];

int index = 1;
while (index <= messageCount)
{
// Provide user input
chat.Add(new ChatMessageContent(AuthorRole.User, $"{index}"));
Console.WriteLine($"# {AuthorRole.User}: '{index}'");

// Reduce prior to invoking the agent
bool isReduced = await agent.ReduceAsync(chat);

// Invoke and display assistant response
await foreach (ChatMessageContent message in agent.InvokeAsync(chat))
{
chat.Add(message);
Console.WriteLine($"# {message.Role} - {message.AuthorName ?? "*"}: '{message.Content}'");
}

index += 2;

// Display the message count of the chat-history for visibility into reduction
Console.WriteLine($"@ Message Count: {chat.Count}\n");

// Display summary messages (if present) if reduction has occurred
if (isReduced)
{
int summaryIndex = 0;
while (chat[summaryIndex].Metadata?.ContainsKey(ChatHistorySummarizationReducer.SummaryMetadataKey) ?? false)
{
Console.WriteLine($"\tSummary: {chat[summaryIndex].Content}");
++summaryIndex;
}
}
}
}

// Proceed with dialog with AgentGroupChat.
private async Task InvokeChatAsync(ChatCompletionAgent agent, int messageCount)
{
AgentGroupChat chat = new();

int lastHistoryCount = 0;

int index = 1;
while (index <= messageCount)
{
// Provide user input
chat.AddChatMessage(new ChatMessageContent(AuthorRole.User, $"{index}"));
Console.WriteLine($"# {AuthorRole.User}: '{index}'");

// Invoke and display assistant response
await foreach (ChatMessageContent message in chat.InvokeAsync(agent))
{
Console.WriteLine($"# {message.Role} - {message.AuthorName ?? "*"}: '{message.Content}'");
}

index += 2;

// Display the message count of the chat-history for visibility into reduction
// Note: Messages provided in descending order (newest first)
ChatMessageContent[] history = await chat.GetChatMessagesAsync(agent).ToArrayAsync();
Console.WriteLine($"@ Message Count: {history.Length}\n");

// Display summary messages (if present) if reduction has occurred
if (history.Length < lastHistoryCount)
{
int summaryIndex = history.Length - 1;
while (history[summaryIndex].Metadata?.ContainsKey(ChatHistorySummarizationReducer.SummaryMetadataKey) ?? false)
{
Console.WriteLine($"\tSummary: {history[summaryIndex].Content}");
--summaryIndex;
}
}

lastHistoryCount = history.Length;
}
}

private ChatCompletionAgent CreateSummarizingAgent(int reducerMessageCount, int reducerThresholdCount)
{
Kernel kernel = this.CreateKernelWithChatCompletion();
return
new()
{
Name = TranslatorName,
Instructions = TranslatorInstructions,
Kernel = kernel,
HistoryReducer = new ChatHistorySummarizationReducer(kernel.GetRequiredService<IChatCompletionService>(), reducerMessageCount, reducerThresholdCount),
};
}

private ChatCompletionAgent CreateTruncatingAgent(int reducerMessageCount, int reducerThresholdCount) =>
new()
{
Name = TranslatorName,
Instructions = TranslatorInstructions,
Kernel = this.CreateKernelWithChatCompletion(),
HistoryReducer = new ChatHistoryTruncationReducer(reducerMessageCount, reducerThresholdCount),
};
}
42 changes: 0 additions & 42 deletions dotnet/src/Agents/Abstractions/ChatHistoryKernelAgent.cs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Microsoft.SemanticKernel.Agents.Extensions;
/// <summary>
/// Extension methods for <see cref="ChatMessageContent"/>
/// </summary>
internal static class ChatHistoryExtensions
public static class ChatHistoryExtensions
{
/// <summary>
/// Enumeration of chat-history in descending order.
Expand Down
39 changes: 36 additions & 3 deletions dotnet/src/Agents/Core/ChatCompletionAgent.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Globalization;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Agents.History;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Microsoft.SemanticKernel.Agents;
Expand All @@ -14,15 +17,18 @@ namespace Microsoft.SemanticKernel.Agents;
/// NOTE: Enable OpenAIPromptExecutionSettings.ToolCallBehavior for agent plugins.
/// (<see cref="ChatCompletionAgent.ExecutionSettings"/>)
/// </remarks>
public sealed class ChatCompletionAgent : ChatHistoryKernelAgent
public sealed class ChatCompletionAgent : KernelAgent, IChatHistoryHandler
{
/// <summary>
/// Optional execution settings for the agent.
/// </summary>
public PromptExecutionSettings? ExecutionSettings { get; set; }

/// <inheritdoc/>
public override async IAsyncEnumerable<ChatMessageContent> InvokeAsync(
public IChatHistoryReducer? HistoryReducer { get; init; }

/// <inheritdoc/>
public async IAsyncEnumerable<ChatMessageContent> InvokeAsync(
ChatHistory history,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -63,7 +69,7 @@ await chatCompletionService.GetChatMessageContentsAsync(
}

/// <inheritdoc/>
public override async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
public async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
ChatHistory history,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
Expand Down Expand Up @@ -103,6 +109,33 @@ public override async IAsyncEnumerable<StreamingChatMessageContent> InvokeStream
}
}

/// <inheritdoc/>
protected override IEnumerable<string> GetChannelKeys()
{
// Agents with different reducers shall not share the same channel.
// Agents with the same or equivalent reducer shall share the same channel.
if (this.HistoryReducer != null)
{
// Explicitly include the reducer type to eliminate the possibility of hash collisions
// with custom implementations of IChatHistoryReducer.
yield return this.HistoryReducer.GetType().FullName!;

yield return this.HistoryReducer.GetHashCode().ToString(CultureInfo.InvariantCulture);
}
}

/// <inheritdoc/>
protected override Task<AgentChannel> CreateChannelAsync(CancellationToken cancellationToken)
{
ChatHistoryChannel channel =
new()
{
Logger = this.LoggerFactory.CreateLogger<ChatHistoryChannel>()
};

return Task.FromResult<AgentChannel>(channel);
}

private ChatHistory SetupAgentChatHistory(IReadOnlyList<ChatMessageContent> history)
{
ChatHistory chat = [];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.Agents.Extensions;
using Microsoft.SemanticKernel.Agents.History;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Microsoft.SemanticKernel.Agents;

/// <summary>
/// A <see cref="AgentChannel"/> specialization for that acts upon a <see cref="IChatHistoryHandler"/>.
/// </summary>
public class ChatHistoryChannel : AgentChannel
public sealed class ChatHistoryChannel : AgentChannel
{
private readonly ChatHistory _history;

/// <inheritdoc/>
protected internal sealed override async IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(
protected override async IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(
Agent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
Expand All @@ -26,6 +27,9 @@ public class ChatHistoryChannel : AgentChannel
throw new KernelException($"Invalid channel binding for agent: {agent.Id} ({agent.GetType().FullName})");
}

// Pre-process history reduction.
await this._history.ReduceAsync(historyHandler.HistoryReducer, cancellationToken).ConfigureAwait(false);

// Capture the current message count to evaluate history mutation.
int messageCount = this._history.Count;
HashSet<ChatMessageContent> mutatedHistory = [];
Expand Down Expand Up @@ -74,15 +78,15 @@ bool IsMessageVisible(ChatMessageContent message) =>
}

/// <inheritdoc/>
protected internal sealed override Task ReceiveAsync(IEnumerable<ChatMessageContent> history, CancellationToken cancellationToken)
protected override Task ReceiveAsync(IEnumerable<ChatMessageContent> history, CancellationToken cancellationToken)
{
this._history.AddRange(history);

return Task.CompletedTask;
}

/// <inheritdoc/>
protected internal sealed override IAsyncEnumerable<ChatMessageContent> GetHistoryAsync(CancellationToken cancellationToken)
protected override IAsyncEnumerable<ChatMessageContent> GetHistoryAsync(CancellationToken cancellationToken)
{
return this._history.ToDescendingAsync();
}
Expand Down
Loading

0 comments on commit 3b0d086

Please sign in to comment.