diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_HistoryReducer.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_HistoryReducer.cs new file mode 100644 index 000000000000..6e0816bc8470 --- /dev/null +++ b/dotnet/samples/Concepts/Agents/ChatCompletion_HistoryReducer.cs @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft. All rights reserved. +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.Agents.History; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Agents; + +/// +/// Demonstrate creation of and +/// eliciting its response to three explicit user messages. +/// +public class ChatCompletion_HistoryReducer(ITestOutputHelper output) : BaseTest(output) +{ + private const string TranslatorName = "NumeroTranslator"; + private const string TranslatorInstructions = "Add one to latest user number and spell it in spanish without explanation."; + + /// + /// Demonstrate the use of when directly + /// invoking a . + /// + [Fact] + public async Task TruncatedAgentReductionAsync() + { + // Define the agent + ChatCompletionAgent agent = CreateTruncatingAgent(10, 10); + + await InvokeAgentAsync(agent, 50); + } + + /// + /// Demonstrate the use of when directly + /// invoking a . + /// + [Fact] + public async Task SummarizedAgentReductionAsync() + { + // Define the agent + ChatCompletionAgent agent = CreateSummarizingAgent(10, 10); + + await InvokeAgentAsync(agent, 50); + } + + /// + /// Demonstrate the use of when using + /// to invoke a . + /// + [Fact] + public async Task TruncatedChatReductionAsync() + { + // Define the agent + ChatCompletionAgent agent = CreateTruncatingAgent(10, 10); + + await InvokeChatAsync(agent, 50); + } + + /// + /// Demonstrate the use of when using + /// to invoke a . + /// + [Fact] + public async Task SummarizedChatReductionAsync() + { + // Define the agent + ChatCompletionAgent agent = CreateSummarizingAgent(10, 10); + + await InvokeChatAsync(agent, 50); + } + + // Proceed with dialog by directly invoking the agent and explicitly managing the history. + private async Task InvokeAgentAsync(ChatCompletionAgent agent, int messageCount) + { + ChatHistory chat = []; + + int index = 1; + while (index <= messageCount) + { + // Provide user input + chat.Add(new ChatMessageContent(AuthorRole.User, $"{index}")); + Console.WriteLine($"# {AuthorRole.User}: '{index}'"); + + // Reduce prior to invoking the agent + bool isReduced = await agent.ReduceAsync(chat); + + // Invoke and display assistant response + await foreach (ChatMessageContent message in agent.InvokeAsync(chat)) + { + chat.Add(message); + Console.WriteLine($"# {message.Role} - {message.AuthorName ?? "*"}: '{message.Content}'"); + } + + index += 2; + + // Display the message count of the chat-history for visibility into reduction + Console.WriteLine($"@ Message Count: {chat.Count}\n"); + + // Display summary messages (if present) if reduction has occurred + if (isReduced) + { + int summaryIndex = 0; + while (chat[summaryIndex].Metadata?.ContainsKey(ChatHistorySummarizationReducer.SummaryMetadataKey) ?? false) + { + Console.WriteLine($"\tSummary: {chat[summaryIndex].Content}"); + ++summaryIndex; + } + } + } + } + + // Proceed with dialog with AgentGroupChat. + private async Task InvokeChatAsync(ChatCompletionAgent agent, int messageCount) + { + AgentGroupChat chat = new(); + + int lastHistoryCount = 0; + + int index = 1; + while (index <= messageCount) + { + // Provide user input + chat.AddChatMessage(new ChatMessageContent(AuthorRole.User, $"{index}")); + Console.WriteLine($"# {AuthorRole.User}: '{index}'"); + + // Invoke and display assistant response + await foreach (ChatMessageContent message in chat.InvokeAsync(agent)) + { + Console.WriteLine($"# {message.Role} - {message.AuthorName ?? "*"}: '{message.Content}'"); + } + + index += 2; + + // Display the message count of the chat-history for visibility into reduction + // Note: Messages provided in descending order (newest first) + ChatMessageContent[] history = await chat.GetChatMessagesAsync(agent).ToArrayAsync(); + Console.WriteLine($"@ Message Count: {history.Length}\n"); + + // Display summary messages (if present) if reduction has occurred + if (history.Length < lastHistoryCount) + { + int summaryIndex = history.Length - 1; + while (history[summaryIndex].Metadata?.ContainsKey(ChatHistorySummarizationReducer.SummaryMetadataKey) ?? false) + { + Console.WriteLine($"\tSummary: {history[summaryIndex].Content}"); + --summaryIndex; + } + } + + lastHistoryCount = history.Length; + } + } + + private ChatCompletionAgent CreateSummarizingAgent(int reducerMessageCount, int reducerThresholdCount) + { + Kernel kernel = this.CreateKernelWithChatCompletion(); + return + new() + { + Name = TranslatorName, + Instructions = TranslatorInstructions, + Kernel = kernel, + HistoryReducer = new ChatHistorySummarizationReducer(kernel.GetRequiredService(), reducerMessageCount, reducerThresholdCount), + }; + } + + private ChatCompletionAgent CreateTruncatingAgent(int reducerMessageCount, int reducerThresholdCount) => + new() + { + Name = TranslatorName, + Instructions = TranslatorInstructions, + Kernel = this.CreateKernelWithChatCompletion(), + HistoryReducer = new ChatHistoryTruncationReducer(reducerMessageCount, reducerThresholdCount), + }; +} diff --git a/dotnet/src/Agents/Abstractions/ChatHistoryKernelAgent.cs b/dotnet/src/Agents/Abstractions/ChatHistoryKernelAgent.cs deleted file mode 100644 index 3de87da3de06..000000000000 --- a/dotnet/src/Agents/Abstractions/ChatHistoryKernelAgent.cs +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.Logging; -using Microsoft.SemanticKernel.ChatCompletion; - -namespace Microsoft.SemanticKernel.Agents; - -/// -/// A specialization bound to a . -/// -public abstract class ChatHistoryKernelAgent : KernelAgent, IChatHistoryHandler -{ - /// - protected internal sealed override IEnumerable GetChannelKeys() - { - yield return typeof(ChatHistoryChannel).FullName!; - } - - /// - protected internal sealed override Task CreateChannelAsync(CancellationToken cancellationToken) - { - ChatHistoryChannel channel = - new() - { - Logger = this.LoggerFactory.CreateLogger() - }; - - return Task.FromResult(channel); - } - - /// - public abstract IAsyncEnumerable InvokeAsync( - ChatHistory history, - CancellationToken cancellationToken = default); - - /// - public abstract IAsyncEnumerable InvokeStreamingAsync( - ChatHistory history, - CancellationToken cancellationToken = default); -} diff --git a/dotnet/src/Agents/Abstractions/Extensions/ChatHistoryExtensions.cs b/dotnet/src/Agents/Abstractions/Extensions/ChatHistoryExtensions.cs index a7b2273ece9e..d8ef44a416a1 100644 --- a/dotnet/src/Agents/Abstractions/Extensions/ChatHistoryExtensions.cs +++ b/dotnet/src/Agents/Abstractions/Extensions/ChatHistoryExtensions.cs @@ -8,7 +8,7 @@ namespace Microsoft.SemanticKernel.Agents.Extensions; /// /// Extension methods for /// -internal static class ChatHistoryExtensions +public static class ChatHistoryExtensions { /// /// Enumeration of chat-history in descending order. diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs index 1e9ea3d3208e..aae1776e145d 100644 --- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs +++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs @@ -1,8 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Globalization; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.Agents.History; using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel.Agents; @@ -14,7 +17,7 @@ namespace Microsoft.SemanticKernel.Agents; /// NOTE: Enable OpenAIPromptExecutionSettings.ToolCallBehavior for agent plugins. /// () /// -public sealed class ChatCompletionAgent : ChatHistoryKernelAgent +public sealed class ChatCompletionAgent : KernelAgent, IChatHistoryHandler { /// /// Optional execution settings for the agent. @@ -22,7 +25,10 @@ public sealed class ChatCompletionAgent : ChatHistoryKernelAgent public PromptExecutionSettings? ExecutionSettings { get; set; } /// - public override async IAsyncEnumerable InvokeAsync( + public IChatHistoryReducer? HistoryReducer { get; init; } + + /// + public async IAsyncEnumerable InvokeAsync( ChatHistory history, [EnumeratorCancellation] CancellationToken cancellationToken = default) { @@ -63,7 +69,7 @@ await chatCompletionService.GetChatMessageContentsAsync( } /// - public override async IAsyncEnumerable InvokeStreamingAsync( + public async IAsyncEnumerable InvokeStreamingAsync( ChatHistory history, [EnumeratorCancellation] CancellationToken cancellationToken = default) { @@ -103,6 +109,33 @@ public override async IAsyncEnumerable InvokeStream } } + /// + protected override IEnumerable GetChannelKeys() + { + // Agents with different reducers shall not share the same channel. + // Agents with the same or equivalent reducer shall share the same channel. + if (this.HistoryReducer != null) + { + // Explicitly include the reducer type to eliminate the possibility of hash collisions + // with custom implementations of IChatHistoryReducer. + yield return this.HistoryReducer.GetType().FullName!; + + yield return this.HistoryReducer.GetHashCode().ToString(CultureInfo.InvariantCulture); + } + } + + /// + protected override Task CreateChannelAsync(CancellationToken cancellationToken) + { + ChatHistoryChannel channel = + new() + { + Logger = this.LoggerFactory.CreateLogger() + }; + + return Task.FromResult(channel); + } + private ChatHistory SetupAgentChatHistory(IReadOnlyList history) { ChatHistory chat = []; diff --git a/dotnet/src/Agents/Abstractions/ChatHistoryChannel.cs b/dotnet/src/Agents/Core/ChatHistoryChannel.cs similarity index 84% rename from dotnet/src/Agents/Abstractions/ChatHistoryChannel.cs rename to dotnet/src/Agents/Core/ChatHistoryChannel.cs index 5dcb6b9b0204..970054179577 100644 --- a/dotnet/src/Agents/Abstractions/ChatHistoryChannel.cs +++ b/dotnet/src/Agents/Core/ChatHistoryChannel.cs @@ -5,6 +5,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.SemanticKernel.Agents.Extensions; +using Microsoft.SemanticKernel.Agents.History; using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel.Agents; @@ -12,12 +13,12 @@ namespace Microsoft.SemanticKernel.Agents; /// /// A specialization for that acts upon a . /// -public class ChatHistoryChannel : AgentChannel +public sealed class ChatHistoryChannel : AgentChannel { private readonly ChatHistory _history; /// - protected internal sealed override async IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync( + protected override async IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync( Agent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) { @@ -26,6 +27,9 @@ public class ChatHistoryChannel : AgentChannel throw new KernelException($"Invalid channel binding for agent: {agent.Id} ({agent.GetType().FullName})"); } + // Pre-process history reduction. + await this._history.ReduceAsync(historyHandler.HistoryReducer, cancellationToken).ConfigureAwait(false); + // Capture the current message count to evaluate history mutation. int messageCount = this._history.Count; HashSet mutatedHistory = []; @@ -74,7 +78,7 @@ bool IsMessageVisible(ChatMessageContent message) => } /// - protected internal sealed override Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken) + protected override Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken) { this._history.AddRange(history); @@ -82,7 +86,7 @@ protected internal sealed override Task ReceiveAsync(IEnumerable - protected internal sealed override IAsyncEnumerable GetHistoryAsync(CancellationToken cancellationToken) + protected override IAsyncEnumerable GetHistoryAsync(CancellationToken cancellationToken) { return this._history.ToDescendingAsync(); } diff --git a/dotnet/src/Agents/Core/History/ChatHistoryReducerExtensions.cs b/dotnet/src/Agents/Core/History/ChatHistoryReducerExtensions.cs new file mode 100644 index 000000000000..c884846baafa --- /dev/null +++ b/dotnet/src/Agents/Core/History/ChatHistoryReducerExtensions.cs @@ -0,0 +1,166 @@ +// Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.SemanticKernel.Agents.History; + +/// +/// Discrete operations used when reducing chat history. +/// +/// +/// Allows for improved testability. +/// +internal static class ChatHistoryReducerExtensions +{ + /// + /// Extract a range of messages from the source history. + /// + /// The source history + /// The index of the first message to extract + /// The index of the last message to extract + /// The optional filter to apply to each message + public static IEnumerable Extract(this IReadOnlyList history, int startIndex, int? finalIndex = null, Func? filter = null) + { + int maxIndex = history.Count - 1; + if (startIndex > maxIndex) + { + yield break; + } + + finalIndex ??= maxIndex; + + finalIndex = Math.Min(finalIndex.Value, maxIndex); + + for (int index = startIndex; index <= finalIndex; ++index) + { + if (filter?.Invoke(history[index]) ?? false) + { + continue; + } + + yield return history[index]; + } + } + + /// + /// Identify the index of the first message that is not a summary message, as indicated by + /// the presence of the specified metadata key. + /// + /// The source history + /// The metadata key that identifies a summary message. + public static int LocateSummarizationBoundary(this IReadOnlyList history, string summaryKey) + { + for (int index = 0; index < history.Count; ++index) + { + ChatMessageContent message = history[index]; + + if (!message.Metadata?.ContainsKey(summaryKey) ?? true) + { + return index; + } + } + + return history.Count; + } + + /// + /// Identify the index of the first message at or beyond the specified targetCount that + /// does not orphan sensitive content. + /// Specifically: function calls and results shall not be separated since chat-completion requires that + /// a function-call always be followed by a function-result. + /// In addition, the first user message (if present) within the threshold window will be included + /// in order to maintain context with the subsequent assistant responses. + /// + /// The source history + /// The desired message count, should reduction occur. + /// + /// The threshold, beyond targetCount, required to trigger reduction. + /// History is not reduces it the message count is less than targetCount + thresholdCount. + /// + /// + /// Optionally ignore an offset from the start of the history. + /// This is useful when messages have been injected that are not part of the raw dialog + /// (such as summarization). + /// + /// An index that identifies the starting point for a reduced history that does not orphan sensitive content. + public static int LocateSafeReductionIndex(this IReadOnlyList history, int targetCount, int? thresholdCount = null, int offsetCount = 0) + { + // Compute the index of the truncation threshold + int thresholdIndex = history.Count - (thresholdCount ?? 0) - targetCount; + + if (thresholdIndex <= offsetCount) + { + // History is too short to truncate + return 0; + } + + // Compute the index of truncation target + int messageIndex = history.Count - targetCount; + + // Skip function related content + while (messageIndex >= 0) + { + if (!history[messageIndex].Items.Any(i => i is FunctionCallContent || i is FunctionResultContent)) + { + break; + } + + --messageIndex; + } + + // Capture the earliest non-function related message + int targetIndex = messageIndex; + + // Scan for user message within truncation range to maximize chat cohesion + while (messageIndex >= thresholdIndex) + { + // A user message provides a superb truncation point + if (history[messageIndex].Role == AuthorRole.User) + { + return messageIndex; + } + + --messageIndex; + } + + // No user message found, fallback to the earliest non-function related message + return targetIndex; + } + + /// + /// Process history reduction and mutate the provided history. + /// + /// The source history + /// The target reducer + /// The to monitor for cancellation requests. The default is . + /// True if reduction has occurred. + /// + /// Using the existing for a reduction in collection size eliminates the need + /// for re-allocation (of memory). + /// + public static async Task ReduceAsync(this ChatHistory history, IChatHistoryReducer? reducer, CancellationToken cancellationToken) + { + if (reducer == null) + { + return false; + } + + IEnumerable? reducedHistory = await reducer.ReduceAsync(history, cancellationToken).ConfigureAwait(false); + + if (reducedHistory == null) + { + return false; + } + + // Mutate the history in place + ChatMessageContent[] reduced = reducedHistory.ToArray(); + history.Clear(); + history.AddRange(reduced); + + return true; + } +} diff --git a/dotnet/src/Agents/Core/History/ChatHistorySummarizationReducer.cs b/dotnet/src/Agents/Core/History/ChatHistorySummarizationReducer.cs new file mode 100644 index 000000000000..a45bfa57011d --- /dev/null +++ b/dotnet/src/Agents/Core/History/ChatHistorySummarizationReducer.cs @@ -0,0 +1,166 @@ +// Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.SemanticKernel.Agents.History; + +/// +/// Reduce the chat history by summarizing message past the target message count. +/// +/// +/// Summarization will always avoid orphaning function-content as the presence of +/// a function-call _must_ be followed by a function-result. When a threshold count is +/// is provided (recommended), reduction will scan within the threshold window in an attempt to +/// avoid orphaning a user message from an assistant response. +/// +public class ChatHistorySummarizationReducer : IChatHistoryReducer +{ + /// + /// Metadata key to indicate a summary message. + /// + public const string SummaryMetadataKey = "__summary__"; + + /// + /// The default summarization system instructions. + /// + public const string DefaultSummarizationPrompt = + """ + Provide a concise and complete summarizion of the entire dialog that does not exceed 5 sentences + + This summary must always: + - Consider both user and assistant interactions + - Maintain continuity for the purpose of further dialog + - Include details from any existing summary + - Focus on the most significant aspects of the dialog + + This summary must never: + - Critique, correct, interpret, presume, or assume + - Identify faults, mistakes, misunderstanding, or correctness + - Analyze what has not occurred + - Exclude details from any existing summary + """; + + /// + /// System instructions for summarization. Defaults to . + /// + public string SummarizationInstructions { get; init; } = DefaultSummarizationPrompt; + + /// + /// Flag to indicate if an exception should be thrown if summarization fails. + /// + public bool FailOnError { get; init; } = true; + + /// + /// Flag to indicate summarization is maintained in a single message, or if a series of + /// summations are generated over time. + /// + /// + /// Not using a single summary may ultimately result in a chat history that exceeds the token limit. + /// + public bool UseSingleSummary { get; init; } = true; + + /// + public async Task?> ReduceAsync(IReadOnlyList history, CancellationToken cancellationToken = default) + { + // Identify where summary messages end and regular history begins + int insertionPoint = history.LocateSummarizationBoundary(SummaryMetadataKey); + + // First pass to determine the truncation index + int truncationIndex = history.LocateSafeReductionIndex(this._targetCount, this._thresholdCount, insertionPoint); + + IEnumerable? truncatedHistory = null; + + if (truncationIndex > 0) + { + // Second pass to extract history for summarization + IEnumerable summarizedHistory = + history.Extract( + this.UseSingleSummary ? 0 : insertionPoint, + truncationIndex, + (m) => m.Items.Any(i => i is FunctionCallContent || i is FunctionResultContent)); + + try + { + // Summarize + ChatHistory summarizationRequest = [.. summarizedHistory, new ChatMessageContent(AuthorRole.System, this.SummarizationInstructions)]; + ChatMessageContent summary = await this._service.GetChatMessageContentAsync(summarizationRequest, cancellationToken: cancellationToken).ConfigureAwait(false); + summary.Metadata = new Dictionary { { SummaryMetadataKey, true } }; + + // Assembly the summarized history + truncatedHistory = AssemblySummarizedHistory(summary); + } + catch + { + if (this.FailOnError) + { + throw; + } + } + } + + return truncatedHistory; + + // Inner function to assemble the summarized history + IEnumerable AssemblySummarizedHistory(ChatMessageContent? summary) + { + if (insertionPoint > 0 && !this.UseSingleSummary) + { + for (int index = 0; index <= insertionPoint - 1; ++index) + { + yield return history[index]; + } + } + + if (summary != null) + { + yield return summary; + } + + for (int index = truncationIndex; index < history.Count; ++index) + { + yield return history[index]; + } + } + } + + /// + /// Initializes a new instance of the class. + /// + /// A instance to be used for summarization. + /// The desired number of target messages after reduction. + /// An optional number of messages beyond the 'targetCount' that must be present in order to trigger reduction/ + /// + /// While the 'thresholdCount' is optional, it is recommended to provided so that reduction is not triggered + /// for every incremental addition to the chat history beyond the 'targetCount'. + /// > + public ChatHistorySummarizationReducer(IChatCompletionService service, int targetCount, int? thresholdCount = null) + { + Verify.NotNull(service, nameof(service)); + Verify.True(targetCount > 0, "Target message count must be greater than zero."); + Verify.True(!thresholdCount.HasValue || thresholdCount > 0, "The reduction threshold length must be greater than zero."); + + this._service = service; + this._targetCount = targetCount; + this._thresholdCount = thresholdCount ?? 0; + } + + /// + public override bool Equals(object? obj) + { + ChatHistorySummarizationReducer? other = obj as ChatHistorySummarizationReducer; + return other != null && + this._thresholdCount == other._thresholdCount && + this._targetCount == other._targetCount; + } + + /// + public override int GetHashCode() => HashCode.Combine(nameof(ChatHistorySummarizationReducer), this._thresholdCount, this._targetCount, this.SummarizationInstructions, this.UseSingleSummary); + + private readonly IChatCompletionService _service; + private readonly int _thresholdCount; + private readonly int _targetCount; +} diff --git a/dotnet/src/Agents/Core/History/ChatHistoryTruncationReducer.cs b/dotnet/src/Agents/Core/History/ChatHistoryTruncationReducer.cs new file mode 100644 index 000000000000..be9ca7868f87 --- /dev/null +++ b/dotnet/src/Agents/Core/History/ChatHistoryTruncationReducer.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.SemanticKernel.Agents.History; + +/// +/// Truncate the chat history to the target message count. +/// +/// +/// Truncation will always avoid orphaning function-content as the presence of +/// a function-call _must_ be followed by a function-result. When a threshold count is +/// is provided (recommended), reduction will scan within the threshold window in an attempt to +/// avoid orphaning a user message from an assistant response. +/// +public class ChatHistoryTruncationReducer : IChatHistoryReducer +{ + /// + public Task?> ReduceAsync(IReadOnlyList history, CancellationToken cancellationToken = default) + { + // First pass to determine the truncation index + int truncationIndex = history.LocateSafeReductionIndex(this._targetCount, this._thresholdCount); + + IEnumerable? truncatedHistory = null; + + if (truncationIndex > 0) + { + // Second pass to truncate the history + truncatedHistory = history.Extract(truncationIndex); + } + + return Task.FromResult(truncatedHistory); + } + + /// + /// Initializes a new instance of the class. + /// + /// The desired number of target messages after reduction. + /// An optional number of messages beyond the 'targetCount' that must be present in order to trigger reduction/ + /// + /// While the 'thresholdCount' is optional, it is recommended to provided so that reduction is not triggered + /// for every incremental addition to the chat history beyond the 'targetCount'. + /// > + public ChatHistoryTruncationReducer(int targetCount, int? thresholdCount = null) + { + Verify.True(targetCount > 0, "Target message count must be greater than zero."); + Verify.True(!thresholdCount.HasValue || thresholdCount > 0, "The reduction threshold length must be greater than zero."); + + this._targetCount = targetCount; + + this._thresholdCount = thresholdCount ?? 0; + } + + /// + public override bool Equals(object? obj) + { + ChatHistoryTruncationReducer? other = obj as ChatHistoryTruncationReducer; + return other != null && + this._thresholdCount == other._thresholdCount && + this._targetCount == other._targetCount; + } + + /// + public override int GetHashCode() => HashCode.Combine(nameof(ChatHistoryTruncationReducer), this._thresholdCount, this._targetCount); + + private readonly int _thresholdCount; + private readonly int _targetCount; +} diff --git a/dotnet/src/Agents/Core/History/IChatHistoryReducer.cs b/dotnet/src/Agents/Core/History/IChatHistoryReducer.cs new file mode 100644 index 000000000000..884fbcf42bc1 --- /dev/null +++ b/dotnet/src/Agents/Core/History/IChatHistoryReducer.cs @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.SemanticKernel.Agents.History; + +/// +/// Defines a contract for a reducing chat history. +/// +public interface IChatHistoryReducer +{ + /// + /// Each reducer shall override equality evaluation so that different reducers + /// of the same configuration can be evaluated for equivalency. + /// + bool Equals(object? obj); + + /// + /// Each reducer shall implement custom hash-code generation so that different reducers + /// of the same configuration can be evaluated for equivalency. + /// + int GetHashCode(); + + /// + /// Optionally reduces the chat history. + /// + /// The source history (which may have been previously reduced) + /// The to monitor for cancellation requests. The default is . + /// The reduced history, or 'null' if no reduction has occurred + Task?> ReduceAsync(IReadOnlyList history, CancellationToken cancellationToken = default); +} diff --git a/dotnet/src/Agents/Abstractions/IChatHistoryHandler.cs b/dotnet/src/Agents/Core/IChatHistoryHandler.cs similarity index 83% rename from dotnet/src/Agents/Abstractions/IChatHistoryHandler.cs rename to dotnet/src/Agents/Core/IChatHistoryHandler.cs index 8b7dab748c81..422493a92db7 100644 --- a/dotnet/src/Agents/Abstractions/IChatHistoryHandler.cs +++ b/dotnet/src/Agents/Core/IChatHistoryHandler.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; using System.Threading; +using Microsoft.SemanticKernel.Agents.History; using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel.Agents; @@ -10,6 +11,11 @@ namespace Microsoft.SemanticKernel.Agents; /// public interface IChatHistoryHandler { + /// + /// An optional history reducer to apply to the chat history prior processing. + /// + IChatHistoryReducer? HistoryReducer { get; init; } + /// /// Entry point for calling into an agent from a . /// @@ -26,7 +32,7 @@ IAsyncEnumerable InvokeAsync( /// The chat history at the point the channel is created. /// The to monitor for cancellation requests. The default is . /// Asynchronous enumeration of streaming content. - public abstract IAsyncEnumerable InvokeStreamingAsync( + IAsyncEnumerable InvokeStreamingAsync( ChatHistory history, CancellationToken cancellationToken = default); } diff --git a/dotnet/src/Agents/Core/IChatHistoryHandlerExtensions.cs b/dotnet/src/Agents/Core/IChatHistoryHandlerExtensions.cs new file mode 100644 index 000000000000..f3d23fea7942 --- /dev/null +++ b/dotnet/src/Agents/Core/IChatHistoryHandlerExtensions.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Agents.History; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.SemanticKernel.Agents; + +/// +/// Contract for an agent that utilizes a . +/// +public static class IChatHistoryHandlerExtensions +{ + /// + /// Reduce history for an agent that implements . + /// + /// The target agent + /// The source history + /// The to monitor for cancellation requests. The default is . + /// + public static Task ReduceAsync(this IChatHistoryHandler agent, ChatHistory history, CancellationToken cancellationToken = default) => + history.ReduceAsync(agent.HistoryReducer, cancellationToken); +} diff --git a/dotnet/src/Agents/UnitTests/AgentChatTests.cs b/dotnet/src/Agents/UnitTests/AgentChatTests.cs index 89ff7f02cff2..49c36ae73c53 100644 --- a/dotnet/src/Agents/UnitTests/AgentChatTests.cs +++ b/dotnet/src/Agents/UnitTests/AgentChatTests.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; using System.Linq; -using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.SemanticKernel; @@ -110,50 +109,15 @@ async Task SynchronizedInvokeAsync() private async Task VerifyHistoryAsync(int expectedCount, IAsyncEnumerable history) { - if (expectedCount == 0) - { - Assert.Empty(history); - } - else - { - Assert.NotEmpty(history); - Assert.Equal(expectedCount, await history.CountAsync()); - } + Assert.Equal(expectedCount, await history.CountAsync()); } private sealed class TestChat : AgentChat { - public TestAgent Agent { get; } = new TestAgent(); + public MockAgent Agent { get; } = new() { Response = [new(AuthorRole.Assistant, "sup")] }; public override IAsyncEnumerable InvokeAsync( CancellationToken cancellationToken = default) => this.InvokeAgentAsync(this.Agent, cancellationToken); } - - private sealed class TestAgent : ChatHistoryKernelAgent - { - public int InvokeCount { get; private set; } - - public override async IAsyncEnumerable InvokeAsync( - ChatHistory history, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - await Task.Delay(0, cancellationToken); - - this.InvokeCount++; - - yield return new ChatMessageContent(AuthorRole.Assistant, "sup"); - } - - public override IAsyncEnumerable InvokeStreamingAsync( - ChatHistory history, - CancellationToken cancellationToken = default) - { - this.InvokeCount++; - - StreamingChatMessageContent[] contents = [new(AuthorRole.Assistant, "sup")]; - - return contents.ToAsyncEnumerable(); - } - } } diff --git a/dotnet/src/Agents/UnitTests/AggregatorAgentTests.cs b/dotnet/src/Agents/UnitTests/AggregatorAgentTests.cs index c4a974cbadc9..1a607ea7e6c7 100644 --- a/dotnet/src/Agents/UnitTests/AggregatorAgentTests.cs +++ b/dotnet/src/Agents/UnitTests/AggregatorAgentTests.cs @@ -1,11 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System.Linq; -using System.Threading; using System.Threading.Tasks; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Agents; using Microsoft.SemanticKernel.ChatCompletion; -using Moq; using Xunit; namespace SemanticKernel.Agents.UnitTests; @@ -23,9 +21,9 @@ public class AggregatorAgentTests [InlineData(AggregatorMode.Flat, 2)] public async Task VerifyAggregatorAgentUsageAsync(AggregatorMode mode, int modeOffset) { - Agent agent1 = CreateMockAgent().Object; - Agent agent2 = CreateMockAgent().Object; - Agent agent3 = CreateMockAgent().Object; + Agent agent1 = CreateMockAgent(); + Agent agent2 = CreateMockAgent(); + Agent agent3 = CreateMockAgent(); AgentGroupChat groupChat = new(agent1, agent2, agent3) @@ -81,13 +79,5 @@ public async Task VerifyAggregatorAgentUsageAsync(AggregatorMode mode, int modeO Assert.Equal(5, messages.Length); // Total messages on inner chat once synchronized (agent equivalent) } - private static Mock CreateMockAgent() - { - Mock agent = new(); - - ChatMessageContent[] messages = [new ChatMessageContent(AuthorRole.Assistant, "test agent")]; - agent.Setup(a => a.InvokeAsync(It.IsAny(), It.IsAny())).Returns(() => messages.ToAsyncEnumerable()); - - return agent; - } + private static MockAgent CreateMockAgent() => new() { Response = [new(AuthorRole.Assistant, "test")] }; } diff --git a/dotnet/src/Agents/UnitTests/Core/AgentGroupChatTests.cs b/dotnet/src/Agents/UnitTests/Core/AgentGroupChatTests.cs index 921e0acce016..ad7428f6f0b9 100644 --- a/dotnet/src/Agents/UnitTests/Core/AgentGroupChatTests.cs +++ b/dotnet/src/Agents/UnitTests/Core/AgentGroupChatTests.cs @@ -8,7 +8,6 @@ using Microsoft.SemanticKernel.Agents; using Microsoft.SemanticKernel.Agents.Chat; using Microsoft.SemanticKernel.ChatCompletion; -using Moq; using Xunit; namespace SemanticKernel.Agents.UnitTests.Core; @@ -39,10 +38,10 @@ public void VerifyGroupAgentChatDefaultState() [Fact] public async Task VerifyGroupAgentChatAgentMembershipAsync() { - Agent agent1 = CreateMockAgent().Object; - Agent agent2 = CreateMockAgent().Object; - Agent agent3 = CreateMockAgent().Object; - Agent agent4 = CreateMockAgent().Object; + Agent agent1 = CreateMockAgent(); + Agent agent2 = CreateMockAgent(); + Agent agent3 = CreateMockAgent(); + Agent agent4 = CreateMockAgent(); AgentGroupChat chat = new(agent1, agent2); Assert.Equal(2, chat.Agents.Count); @@ -63,9 +62,9 @@ public async Task VerifyGroupAgentChatAgentMembershipAsync() [Fact] public async Task VerifyGroupAgentChatMultiTurnAsync() { - Agent agent1 = CreateMockAgent().Object; - Agent agent2 = CreateMockAgent().Object; - Agent agent3 = CreateMockAgent().Object; + Agent agent1 = CreateMockAgent(); + Agent agent2 = CreateMockAgent(); + Agent agent3 = CreateMockAgent(); AgentGroupChat chat = new(agent1, agent2, agent3) @@ -162,7 +161,7 @@ public async Task VerifyGroupAgentChatMultiTurnTerminationAsync() [Fact] public async Task VerifyGroupAgentChatDiscreteTerminationAsync() { - Agent agent1 = CreateMockAgent().Object; + Agent agent1 = CreateMockAgent(); AgentGroupChat chat = new() @@ -186,22 +185,14 @@ public async Task VerifyGroupAgentChatDiscreteTerminationAsync() private static AgentGroupChat Create3AgentChat() { - Agent agent1 = CreateMockAgent().Object; - Agent agent2 = CreateMockAgent().Object; - Agent agent3 = CreateMockAgent().Object; + Agent agent1 = CreateMockAgent(); + Agent agent2 = CreateMockAgent(); + Agent agent3 = CreateMockAgent(); return new(agent1, agent2, agent3); } - private static Mock CreateMockAgent() - { - Mock agent = new(); - - ChatMessageContent[] messages = [new ChatMessageContent(AuthorRole.Assistant, "test")]; - agent.Setup(a => a.InvokeAsync(It.IsAny(), It.IsAny())).Returns(() => messages.ToAsyncEnumerable()); - - return agent; - } + private static MockAgent CreateMockAgent() => new() { Response = [new(AuthorRole.Assistant, "test")] }; private sealed class TestTerminationStrategy(bool shouldTerminate) : TerminationStrategy { diff --git a/dotnet/src/Agents/UnitTests/ChatHistoryChannelTests.cs b/dotnet/src/Agents/UnitTests/Core/ChatHistoryChannelTests.cs similarity index 96% rename from dotnet/src/Agents/UnitTests/ChatHistoryChannelTests.cs rename to dotnet/src/Agents/UnitTests/Core/ChatHistoryChannelTests.cs index 7ef624c61ab9..43aae918ad52 100644 --- a/dotnet/src/Agents/UnitTests/ChatHistoryChannelTests.cs +++ b/dotnet/src/Agents/UnitTests/Core/ChatHistoryChannelTests.cs @@ -8,7 +8,7 @@ using Microsoft.SemanticKernel.Agents; using Xunit; -namespace SemanticKernel.Agents.UnitTests; +namespace SemanticKernel.Agents.UnitTests.Core; /// /// Unit testing of . diff --git a/dotnet/src/Agents/UnitTests/Core/History/ChatHistoryReducerExtensionsTests.cs b/dotnet/src/Agents/UnitTests/Core/History/ChatHistoryReducerExtensionsTests.cs new file mode 100644 index 000000000000..a75533474147 --- /dev/null +++ b/dotnet/src/Agents/UnitTests/Core/History/ChatHistoryReducerExtensionsTests.cs @@ -0,0 +1,228 @@ +// Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents.History; +using Microsoft.SemanticKernel.ChatCompletion; +using Moq; +using Xunit; + +namespace SemanticKernel.Agents.UnitTests.Core.History; + +/// +/// Unit testing of . +/// +public class ChatHistoryReducerExtensionsTests +{ + /// + /// Verify the extraction of a set of messages from an input set. + /// + [Theory] + [InlineData(100, 0, 1)] + [InlineData(100, 0, 9)] + [InlineData(100, 0, 99)] + [InlineData(100, 80)] + [InlineData(100, 80, 81)] + [InlineData(100, 0)] + [InlineData(100, int.MaxValue, null, 0)] + [InlineData(100, 0, int.MaxValue, 100)] + public void VerifyChatHistoryExtraction(int messageCount, int startIndex, int? endIndex = null, int? expectedCount = null) + { + ChatHistory history = [.. MockHistoryGenerator.CreateSimpleHistory(messageCount)]; + + ChatMessageContent[] extractedHistory = history.Extract(startIndex, endIndex).ToArray(); + + int finalIndex = endIndex ?? messageCount - 1; + finalIndex = Math.Min(finalIndex, messageCount - 1); + + expectedCount ??= finalIndex - startIndex + 1; + + Assert.Equal(expectedCount, extractedHistory.Length); + + if (extractedHistory.Length > 0) + { + Assert.Contains($"#{startIndex}", extractedHistory[0].Content); + Assert.Contains($"#{finalIndex}", extractedHistory[^1].Content); + } + } + + /// + /// Verify identifying the first non-summary message index. + /// + [Theory] + [InlineData(0, 100)] + [InlineData(1, 100)] + [InlineData(100, 10)] + [InlineData(100, 0)] + public void VerifyGetFinalSummaryIndex(int summaryCount, int regularCount) + { + ChatHistory summaries = [.. MockHistoryGenerator.CreateSimpleHistory(summaryCount)]; + foreach (ChatMessageContent summary in summaries) + { + summary.Metadata = new Dictionary() { { "summary", true } }; + } + + ChatHistory history = [.. summaries, .. MockHistoryGenerator.CreateSimpleHistory(regularCount)]; + + int finalSummaryIndex = history.LocateSummarizationBoundary("summary"); + + Assert.Equal(summaryCount, finalSummaryIndex); + } + + /// + /// Verify a instance is not reduced. + /// + [Fact] + public async Task VerifyChatHistoryNotReducedAsync() + { + ChatHistory history = []; + + bool isReduced = await history.ReduceAsync(null, default); + + Assert.False(isReduced); + Assert.Empty(history); + + Mock mockReducer = new(); + mockReducer.Setup(r => r.ReduceAsync(It.IsAny>(), default)).ReturnsAsync((IEnumerable?)null); + isReduced = await history.ReduceAsync(mockReducer.Object, default); + + Assert.False(isReduced); + Assert.Empty(history); + } + + /// + /// Verify a instance is reduced. + /// + [Fact] + public async Task VerifyChatHistoryReducedAsync() + { + Mock mockReducer = new(); + mockReducer.Setup(r => r.ReduceAsync(It.IsAny>(), default)).ReturnsAsync((IEnumerable?)[]); + + ChatHistory history = [.. MockHistoryGenerator.CreateSimpleHistory(10)]; + + bool isReduced = await history.ReduceAsync(mockReducer.Object, default); + + Assert.True(isReduced); + Assert.Empty(history); + } + + /// + /// Verify starting index (0) is identified when message count does not exceed the limit. + /// + [Theory] + [InlineData(0, 1)] + [InlineData(1, 1)] + [InlineData(1, 2)] + [InlineData(1, int.MaxValue)] + [InlineData(5, 1, 5)] + [InlineData(5, 4, 2)] + [InlineData(5, 5, 1)] + [InlineData(900, 500, 400)] + [InlineData(900, 500, int.MaxValue)] + public void VerifyLocateSafeReductionIndexNone(int messageCount, int targetCount, int? thresholdCount = null) + { + // Shape of history doesn't matter since reduction is not expected + ChatHistory sourceHistory = [.. MockHistoryGenerator.CreateHistoryWithUserInput(messageCount)]; + + int reductionIndex = sourceHistory.LocateSafeReductionIndex(targetCount, thresholdCount); + + Assert.Equal(0, reductionIndex); + } + + /// + /// Verify the expected index ) is identified when message count exceeds the limit. + /// + [Theory] + [InlineData(2, 1)] + [InlineData(3, 2)] + [InlineData(3, 1, 1)] + [InlineData(6, 1, 4)] + [InlineData(6, 4, 1)] + [InlineData(6, 5)] + [InlineData(1000, 500, 400)] + [InlineData(1000, 500, 499)] + public void VerifyLocateSafeReductionIndexFound(int messageCount, int targetCount, int? thresholdCount = null) + { + // Generate history with only assistant messages + ChatHistory sourceHistory = [.. MockHistoryGenerator.CreateSimpleHistory(messageCount)]; + + int reductionIndex = sourceHistory.LocateSafeReductionIndex(targetCount, thresholdCount); + + Assert.True(reductionIndex > 0); + Assert.Equal(targetCount, messageCount - reductionIndex); + } + + /// + /// Verify the expected index ) is identified when message count exceeds the limit. + /// History contains alternating user and assistant messages. + /// + [Theory] + [InlineData(2, 1)] + [InlineData(3, 2)] + [InlineData(3, 1, 1)] + [InlineData(6, 1, 4)] + [InlineData(6, 4, 1)] + [InlineData(6, 5)] + [InlineData(1000, 500, 400)] + [InlineData(1000, 500, 499)] + public void VerifyLocateSafeReductionIndexFoundWithUser(int messageCount, int targetCount, int? thresholdCount = null) + { + // Generate history with alternating user and assistant messages + ChatHistory sourceHistory = [.. MockHistoryGenerator.CreateHistoryWithUserInput(messageCount)]; + + int reductionIndex = sourceHistory.LocateSafeReductionIndex(targetCount, thresholdCount); + + Assert.True(reductionIndex > 0); + + // The reduction length should align with a user message, if threshold is specified + bool hasThreshold = thresholdCount > 0; + int expectedCount = targetCount + (hasThreshold && sourceHistory[^targetCount].Role != AuthorRole.User ? 1 : 0); + + Assert.Equal(expectedCount, messageCount - reductionIndex); + } + + /// + /// Verify the expected index ) is identified when message count exceeds the limit. + /// History contains alternating user and assistant messages along with function + /// related content. + /// + [Theory] + [InlineData(4)] + [InlineData(4, 3)] + [InlineData(5)] + [InlineData(5, 8)] + [InlineData(6)] + [InlineData(6, 7)] + [InlineData(7)] + [InlineData(8)] + [InlineData(9)] + public void VerifyLocateSafeReductionIndexWithFunctionContent(int targetCount, int? thresholdCount = null) + { + // Generate a history with function call on index 5 and 9 and + // function result on index 6 and 10 (total length: 14) + ChatHistory sourceHistory = [.. MockHistoryGenerator.CreateHistoryWithFunctionContent()]; + + ChatHistoryTruncationReducer reducer = new(targetCount, thresholdCount); + + int reductionIndex = sourceHistory.LocateSafeReductionIndex(targetCount, thresholdCount); + + Assert.True(reductionIndex > 0); + + // The reduction length avoid splitting function call and result, regardless of threshold + int expectedCount = targetCount; + + if (sourceHistory[sourceHistory.Count - targetCount].Items.Any(i => i is FunctionCallContent)) + { + expectedCount += 1; + } + else if (sourceHistory[sourceHistory.Count - targetCount].Items.Any(i => i is FunctionResultContent)) + { + expectedCount += 2; + } + + Assert.Equal(expectedCount, sourceHistory.Count - reductionIndex); + } +} diff --git a/dotnet/src/Agents/UnitTests/Core/History/ChatHistorySummarizationReducerTests.cs b/dotnet/src/Agents/UnitTests/Core/History/ChatHistorySummarizationReducerTests.cs new file mode 100644 index 000000000000..f464b6a8214a --- /dev/null +++ b/dotnet/src/Agents/UnitTests/Core/History/ChatHistorySummarizationReducerTests.cs @@ -0,0 +1,207 @@ +// Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents.History; +using Microsoft.SemanticKernel.ChatCompletion; +using Moq; +using Xunit; + +namespace SemanticKernel.Agents.UnitTests.Core.History; + +/// +/// Unit testing of . +/// +public class ChatHistorySummarizationReducerTests +{ + /// + /// Ensure that the constructor arguments are validated. + /// + [Theory] + [InlineData(-1)] + [InlineData(-1, int.MaxValue)] + [InlineData(int.MaxValue, -1)] + public void VerifyChatHistoryConstructorArgumentValidation(int targetCount, int? thresholdCount = null) + { + Mock mockCompletionService = this.CreateMockCompletionService(); + + Assert.Throws(() => new ChatHistorySummarizationReducer(mockCompletionService.Object, targetCount, thresholdCount)); + } + + /// + /// Verify object state after initialization. + /// + [Fact] + public void VerifyChatHistoryInitializationState() + { + Mock mockCompletionService = this.CreateMockCompletionService(); + + ChatHistorySummarizationReducer reducer = new(mockCompletionService.Object, 10); + + Assert.Equal(ChatHistorySummarizationReducer.DefaultSummarizationPrompt, reducer.SummarizationInstructions); + Assert.True(reducer.FailOnError); + + reducer = + new(mockCompletionService.Object, 10) + { + FailOnError = false, + SummarizationInstructions = "instructions", + }; + + Assert.NotEqual(ChatHistorySummarizationReducer.DefaultSummarizationPrompt, reducer.SummarizationInstructions); + Assert.False(reducer.FailOnError); + } + + /// + /// Validate hash-code expresses reducer equivalency. + /// + [Fact] + public void VerifyChatHistoryHasCode() + { + HashSet reducers = []; + + Mock mockCompletionService = this.CreateMockCompletionService(); + + int hashCode1 = GenerateHashCode(3, 4); + int hashCode2 = GenerateHashCode(33, 44); + int hashCode3 = GenerateHashCode(3000, 4000); + int hashCode4 = GenerateHashCode(3000, 4000); + + Assert.NotEqual(hashCode1, hashCode2); + Assert.NotEqual(hashCode2, hashCode3); + Assert.Equal(hashCode3, hashCode4); + Assert.Equal(3, reducers.Count); + + int GenerateHashCode(int targetCount, int thresholdCount) + { + ChatHistorySummarizationReducer reducer = new(mockCompletionService.Object, targetCount, thresholdCount); + + reducers.Add(reducer); + + return reducer.GetHashCode(); + } + } + + /// + /// Validate silent summarization failure when set to 'false'. + /// + [Fact] + public async Task VerifyChatHistoryReductionSilentFailureAsync() + { + Mock mockCompletionService = this.CreateMockCompletionService(throwException: true); + IReadOnlyList sourceHistory = MockHistoryGenerator.CreateSimpleHistory(20).ToArray(); + + ChatHistorySummarizationReducer reducer = new(mockCompletionService.Object, 10) { FailOnError = false }; + IEnumerable? reducedHistory = await reducer.ReduceAsync(sourceHistory); + + Assert.Null(reducedHistory); + } + + /// + /// Validate exception on summarization failure when set to 'true'. + /// + [Fact] + public async Task VerifyChatHistoryReductionThrowsOnFailureAsync() + { + Mock mockCompletionService = this.CreateMockCompletionService(throwException: true); + IReadOnlyList sourceHistory = MockHistoryGenerator.CreateSimpleHistory(20).ToArray(); + + ChatHistorySummarizationReducer reducer = new(mockCompletionService.Object, 10); + await Assert.ThrowsAsync(() => reducer.ReduceAsync(sourceHistory)); + } + + /// + /// Validate history not reduced when source history does not exceed target threshold. + /// + [Fact] + public async Task VerifyChatHistoryNotReducedAsync() + { + Mock mockCompletionService = this.CreateMockCompletionService(); + IReadOnlyList sourceHistory = MockHistoryGenerator.CreateSimpleHistory(20).ToArray(); + + ChatHistorySummarizationReducer reducer = new(mockCompletionService.Object, 20); + IEnumerable? reducedHistory = await reducer.ReduceAsync(sourceHistory); + + Assert.Null(reducedHistory); + } + + /// + /// Validate history reduced when source history exceeds target threshold. + /// + [Fact] + public async Task VerifyChatHistoryReducedAsync() + { + Mock mockCompletionService = this.CreateMockCompletionService(); + IReadOnlyList sourceHistory = MockHistoryGenerator.CreateSimpleHistory(20).ToArray(); + + ChatHistorySummarizationReducer reducer = new(mockCompletionService.Object, 10); + IEnumerable? reducedHistory = await reducer.ReduceAsync(sourceHistory); + + ChatMessageContent[] messages = VerifyReducedHistory(reducedHistory, 11); + VerifySummarization(messages[0]); + } + + /// + /// Validate history re-summarized on second occurrence of source history exceeding target threshold. + /// + [Fact] + public async Task VerifyChatHistoryRereducedAsync() + { + Mock mockCompletionService = this.CreateMockCompletionService(); + IReadOnlyList sourceHistory = MockHistoryGenerator.CreateSimpleHistory(20).ToArray(); + + ChatHistorySummarizationReducer reducer = new(mockCompletionService.Object, 10); + IEnumerable? reducedHistory = await reducer.ReduceAsync(sourceHistory); + reducedHistory = await reducer.ReduceAsync([.. reducedHistory!, .. sourceHistory]); + + ChatMessageContent[] messages = VerifyReducedHistory(reducedHistory, 11); + VerifySummarization(messages[0]); + + reducer = new(mockCompletionService.Object, 10) { UseSingleSummary = false }; + reducedHistory = await reducer.ReduceAsync([.. reducedHistory!, .. sourceHistory]); + + messages = VerifyReducedHistory(reducedHistory, 12); + VerifySummarization(messages[0]); + VerifySummarization(messages[1]); + } + + private static ChatMessageContent[] VerifyReducedHistory(IEnumerable? reducedHistory, int expectedCount) + { + Assert.NotNull(reducedHistory); + ChatMessageContent[] messages = reducedHistory.ToArray(); + Assert.Equal(expectedCount, messages.Length); + + return messages; + } + + private static void VerifySummarization(ChatMessageContent message) + { + Assert.NotNull(message.Metadata); + Assert.True(message.Metadata!.ContainsKey(ChatHistorySummarizationReducer.SummaryMetadataKey)); + } + + private Mock CreateMockCompletionService(bool throwException = false) + { + Mock mock = new(); + var setup = mock.Setup( + s => + s.GetChatMessageContentsAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + default)); + + if (throwException) + { + setup.ThrowsAsync(new HttpOperationException("whoops")); + } + else + { + setup.ReturnsAsync([new(AuthorRole.Assistant, "summary")]); + } + + return mock; + } +} diff --git a/dotnet/src/Agents/UnitTests/Core/History/ChatHistoryTruncationReducerTests.cs b/dotnet/src/Agents/UnitTests/Core/History/ChatHistoryTruncationReducerTests.cs new file mode 100644 index 000000000000..eebcf8fc6136 --- /dev/null +++ b/dotnet/src/Agents/UnitTests/Core/History/ChatHistoryTruncationReducerTests.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents.History; +using Xunit; + +namespace SemanticKernel.Agents.UnitTests.Core.History; + +/// +/// Unit testing of . +/// +public class ChatHistoryTruncationReducerTests +{ + /// + /// Ensure that the constructor arguments are validated. + /// + [Theory] + [InlineData(-1)] + [InlineData(-1, int.MaxValue)] + [InlineData(int.MaxValue, -1)] + public void VerifyChatHistoryConstructorArgumentValidation(int targetCount, int? thresholdCount = null) + { + Assert.Throws(() => new ChatHistoryTruncationReducer(targetCount, thresholdCount)); + } + + /// + /// Validate hash-code expresses reducer equivalency. + /// + [Fact] + public void VerifyChatHistoryHasCode() + { + HashSet reducers = []; + + int hashCode1 = GenerateHashCode(3, 4); + int hashCode2 = GenerateHashCode(33, 44); + int hashCode3 = GenerateHashCode(3000, 4000); + int hashCode4 = GenerateHashCode(3000, 4000); + + Assert.NotEqual(hashCode1, hashCode2); + Assert.NotEqual(hashCode2, hashCode3); + Assert.Equal(hashCode3, hashCode4); + Assert.Equal(3, reducers.Count); + + int GenerateHashCode(int targetCount, int thresholdCount) + { + ChatHistoryTruncationReducer reducer = new(targetCount, thresholdCount); + + reducers.Add(reducer); + + return reducer.GetHashCode(); + } + } + + /// + /// Validate history not reduced when source history does not exceed target threshold. + /// + [Fact] + public async Task VerifyChatHistoryNotReducedAsync() + { + IReadOnlyList sourceHistory = MockHistoryGenerator.CreateSimpleHistory(10).ToArray(); + + ChatHistoryTruncationReducer reducer = new(20); + IEnumerable? reducedHistory = await reducer.ReduceAsync(sourceHistory); + + Assert.Null(reducedHistory); + } + + /// + /// Validate history reduced when source history exceeds target threshold. + /// + [Fact] + public async Task VerifyChatHistoryReducedAsync() + { + IReadOnlyList sourceHistory = MockHistoryGenerator.CreateSimpleHistory(20).ToArray(); + + ChatHistoryTruncationReducer reducer = new(10); + IEnumerable? reducedHistory = await reducer.ReduceAsync(sourceHistory); + + VerifyReducedHistory(reducedHistory, 10); + } + + /// + /// Validate history re-summarized on second occurrence of source history exceeding target threshold. + /// + [Fact] + public async Task VerifyChatHistoryRereducedAsync() + { + IReadOnlyList sourceHistory = MockHistoryGenerator.CreateSimpleHistory(20).ToArray(); + + ChatHistoryTruncationReducer reducer = new(10); + IEnumerable? reducedHistory = await reducer.ReduceAsync(sourceHistory); + reducedHistory = await reducer.ReduceAsync([.. reducedHistory!, .. sourceHistory]); + + VerifyReducedHistory(reducedHistory, 10); + } + + private static void VerifyReducedHistory(IEnumerable? reducedHistory, int expectedCount) + { + Assert.NotNull(reducedHistory); + ChatMessageContent[] messages = reducedHistory.ToArray(); + Assert.Equal(expectedCount, messages.Length); + } +} diff --git a/dotnet/src/Agents/UnitTests/Core/History/MockHistoryGenerator.cs b/dotnet/src/Agents/UnitTests/Core/History/MockHistoryGenerator.cs new file mode 100644 index 000000000000..375b6fc9aa40 --- /dev/null +++ b/dotnet/src/Agents/UnitTests/Core/History/MockHistoryGenerator.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace SemanticKernel.Agents.UnitTests.Core.History; + +/// +/// Factory for generating chat history for various test scenarios. +/// +internal static class MockHistoryGenerator +{ + /// + /// Create a homogeneous list of assistant messages. + /// + public static IEnumerable CreateSimpleHistory(int messageCount) + { + for (int index = 0; index < messageCount; ++index) + { + yield return new ChatMessageContent(AuthorRole.Assistant, $"message #{index}"); + } + } + + /// + /// Create an alternating list of user and assistant messages. + /// + public static IEnumerable CreateHistoryWithUserInput(int messageCount) + { + for (int index = 0; index < messageCount; ++index) + { + yield return + index % 2 == 1 ? + new ChatMessageContent(AuthorRole.Assistant, $"asistant response: {index}") : + new ChatMessageContent(AuthorRole.User, $"user input: {index}"); + } + } + + /// + /// Create an alternating list of user and assistant messages with function content + /// injected at indexes: + /// + /// - 5: function call + /// - 6: function result + /// - 9: function call + /// - 10: function result + /// + /// Total message count: 14 messages. + /// + public static IEnumerable CreateHistoryWithFunctionContent() + { + yield return new ChatMessageContent(AuthorRole.User, "user input: 0"); + yield return new ChatMessageContent(AuthorRole.Assistant, "asistant response: 1"); + yield return new ChatMessageContent(AuthorRole.User, "user input: 2"); + yield return new ChatMessageContent(AuthorRole.Assistant, "asistant response: 3"); + yield return new ChatMessageContent(AuthorRole.User, "user input: 4"); + yield return new ChatMessageContent(AuthorRole.Assistant, [new FunctionCallContent("function call: 5")]); + yield return new ChatMessageContent(AuthorRole.Tool, [new FunctionResultContent("function result: 6")]); + yield return new ChatMessageContent(AuthorRole.Assistant, "asistant response: 7"); + yield return new ChatMessageContent(AuthorRole.User, "user input: 8"); + yield return new ChatMessageContent(AuthorRole.Assistant, [new FunctionCallContent("function call: 9")]); + yield return new ChatMessageContent(AuthorRole.Tool, [new FunctionResultContent("function result: 10")]); + yield return new ChatMessageContent(AuthorRole.Assistant, "asistant response: 11"); + yield return new ChatMessageContent(AuthorRole.User, "user input: 12"); + yield return new ChatMessageContent(AuthorRole.Assistant, "asistant response: 13"); + } +} diff --git a/dotnet/src/Agents/UnitTests/MockAgent.cs b/dotnet/src/Agents/UnitTests/MockAgent.cs new file mode 100644 index 000000000000..dacf7bab78d6 --- /dev/null +++ b/dotnet/src/Agents/UnitTests/MockAgent.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.Agents.History; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace SemanticKernel.Agents.UnitTests; + +/// +/// Mock definition of with a contract. +/// +internal sealed class MockAgent : KernelAgent, IChatHistoryHandler +{ + public int InvokeCount { get; private set; } + + public IReadOnlyList Response { get; set; } = []; + + public IChatHistoryReducer? HistoryReducer { get; init; } + + public IAsyncEnumerable InvokeAsync(ChatHistory history, CancellationToken cancellationToken = default) + { + this.InvokeCount++; + + return this.Response.ToAsyncEnumerable(); + } + + public IAsyncEnumerable InvokeStreamingAsync(ChatHistory history, CancellationToken cancellationToken = default) + { + this.InvokeCount++; + return this.Response.Select(m => new StreamingChatMessageContent(m.Role, m.Content)).ToAsyncEnumerable(); + } + + /// + protected internal override IEnumerable GetChannelKeys() + { + yield return typeof(ChatHistoryChannel).FullName!; + } + + /// + protected internal override Task CreateChannelAsync(CancellationToken cancellationToken) + { + ChatHistoryChannel channel = + new() + { + Logger = this.LoggerFactory.CreateLogger() + }; + + return Task.FromResult(channel); + } +}