Skip to content

Commit

Permalink
.Net: Fixed usage of chat system prompt (microsoft#4994)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

Resolves: microsoft#4377
Resolves: microsoft#4510

When invoking prompt or function from kernel, chat system prompt is
ignored, and prompt is sent to AI as system message instead of user
message.

The example of code that didn't work:
```csharp
var settings = new OpenAIPromptExecutionSettings { ChatSystemPrompt = "Reply \"I don't know\" to every question." };

// Result contains the right answer instead of "I don't know", as it was defined in system prompt.
// That's because the question was set to system message instead of user message and ChatSystemPrompt property was ignored.
var result = await target.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?", new(settings));
```

This fix may potentially change the behavior of applications that rely
on system message input instead of user message. In order to temporarily
resolve the problem, previous behavior could be achieved by using chat
prompt as in the following example:
```csharp
KernelFunction function = KernelFunctionFactory.CreateFromPrompt(@"
    <message role=""system"">Where is the most famous fish market in Seattle, Washington, USA?</message>
");

var result = await kernel.InvokeAsync(function);
```

But this is just a temporary workaround, and valid usage is presented in
first example above.

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

1. Updated `ChatCompletionServiceExtensions` to use prompt as user
message instead of system message.
2. Added unit and integration tests to verify this scenario.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
dmytrostruk authored Feb 13, 2024
1 parent 715d0d7 commit 48b6bb2
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Azure.Core;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
Expand Down Expand Up @@ -560,6 +561,48 @@ public async Task GetStreamingChatMessageContentsWithRequiredFunctionCallAsync()
Assert.Equal("none", secondContentJson.GetProperty("tool_choice").GetString());
}

[Fact]
public async Task GetChatMessageContentsUsesPromptAndSettingsCorrectlyAsync()
{
// Arrange
const string Prompt = "This is test prompt";
const string SystemMessage = "This is test system message";

var service = new AzureOpenAIChatCompletionService("deployment", "https://endpoint", "api-key", "model-id", this._httpClient);
var settings = new OpenAIPromptExecutionSettings() { ChatSystemPrompt = SystemMessage };

this._messageHandlerStub.ResponsesToReturn.Add(new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json"))
});

IKernelBuilder builder = Kernel.CreateBuilder();
builder.Services.AddTransient<IChatCompletionService>((sp) => service);
Kernel kernel = builder.Build();

// Act
var result = await kernel.InvokePromptAsync(Prompt, new(settings));

// Assert
Assert.Equal("Test chat response", result.ToString());

var requestContentByteArray = this._messageHandlerStub.RequestContents[0];

Assert.NotNull(requestContentByteArray);

var requestContent = JsonSerializer.Deserialize<JsonElement>(Encoding.UTF8.GetString(requestContentByteArray));

var messages = requestContent.GetProperty("messages");

Assert.Equal(2, messages.GetArrayLength());

Assert.Equal("This is test system message", messages[0].GetProperty("content").GetString());
Assert.Equal("system", messages[0].GetProperty("role").GetString());

Assert.Equal("This is test prompt", messages[1].GetProperty("content").GetString());
Assert.Equal("user", messages[1].GetProperty("role").GetString());
}

public void Dispose()
{
this._httpClient.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,25 @@ public async Task MultipleServiceLoadPromptConfigTestAsync()
// Assert
Assert.Contains("Pike Place", azureResult.GetValue<string>(), StringComparison.OrdinalIgnoreCase);
}

[Fact]
public async Task ChatSystemPromptIsNotIgnoredAsync()
{
// Arrange
var settings = new OpenAIPromptExecutionSettings { ChatSystemPrompt = "Reply \"I don't know\" to every question." };

this._kernelBuilder.Services.AddSingleton<ILoggerFactory>(this._logger);
var builder = this._kernelBuilder;
this.ConfigureAzureOpenAIChatAsText(builder);
Kernel target = builder.Build();

// Act
var result = await target.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?", new(settings));

// Assert
Assert.Contains("I don't know", result.ToString(), StringComparison.OrdinalIgnoreCase);
}

#region internals

private readonly XunitLogger<Kernel> _logger;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ public static Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsyn
CancellationToken cancellationToken = default)
{
// Try to parse the text as a chat history
if (ChatPromptParser.TryParse(prompt, out var chatHistory))
if (ChatPromptParser.TryParse(prompt, out var chatHistoryFromPrompt))
{
return chatCompletionService.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken);
return chatCompletionService.GetChatMessageContentsAsync(chatHistoryFromPrompt, executionSettings, kernel, cancellationToken);
}

//Otherwise, use the prompt as the chat system message
return chatCompletionService.GetChatMessageContentsAsync(new ChatHistory(prompt), executionSettings, kernel, cancellationToken);
// Otherwise, use the prompt as the chat user message
var chatHistory = new ChatHistory();
chatHistory.AddUserMessage(prompt);

return chatCompletionService.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken);
}

/// <summary>
Expand Down Expand Up @@ -96,12 +99,15 @@ public static IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMess
CancellationToken cancellationToken = default)
{
// Try to parse the text as a chat history
if (ChatPromptParser.TryParse(prompt, out var chatHistory))
if (ChatPromptParser.TryParse(prompt, out var chatHistoryFromPrompt))
{
return chatCompletionService.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken);
return chatCompletionService.GetStreamingChatMessageContentsAsync(chatHistoryFromPrompt, executionSettings, kernel, cancellationToken);
}

//Otherwise, use the prompt as the chat system message
return chatCompletionService.GetStreamingChatMessageContentsAsync(new ChatHistory(prompt), executionSettings, kernel, cancellationToken);
// Otherwise, use the prompt as the chat user message
var chatHistory = new ChatHistory();
chatHistory.AddUserMessage(prompt);

return chatCompletionService.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,61 @@ void MyRenderedHandler(object? sender, PromptRenderedEventArgs e)
// Assert
mockTextCompletion.Verify(m => m.GetTextContentsAsync("Prompt USE SHORT, CLEAR, COMPLETE SENTENCES.", It.IsAny<OpenAIPromptExecutionSettings>(), It.IsAny<Kernel>(), It.IsAny<CancellationToken>()), Times.Once());
}

[Theory]
[InlineData(KernelInvocationType.InvokePrompt)]
[InlineData(KernelInvocationType.InvokePromptStreaming)]
[InlineData(KernelInvocationType.InvokeFunction)]
[InlineData(KernelInvocationType.InvokeFunctionStreaming)]
public async Task ItUsesPromptAsUserMessageAsync(KernelInvocationType invocationType)
{
// Arrange
const string Prompt = "Test prompt as user message";

var fakeService = new FakeChatAsTextService();
IKernelBuilder builder = Kernel.CreateBuilder();
builder.Services.AddTransient<IChatCompletionService>((sp) => fakeService);
Kernel kernel = builder.Build();

var function = KernelFunctionFactory.CreateFromPrompt(Prompt);

// Act
switch (invocationType)
{
case KernelInvocationType.InvokePrompt:
await kernel.InvokePromptAsync(Prompt);
break;
case KernelInvocationType.InvokePromptStreaming:
await foreach (var result in kernel.InvokePromptStreamingAsync(Prompt)) { }
break;
case KernelInvocationType.InvokeFunction:
await kernel.InvokeAsync(function);
break;
case KernelInvocationType.InvokeFunctionStreaming:
await foreach (var result in kernel.InvokeStreamingAsync(function)) { }
break;
}

// Assert
Assert.NotNull(fakeService.ChatHistory);
Assert.Single(fakeService.ChatHistory);

var messageContent = fakeService.ChatHistory[0];

Assert.Equal(AuthorRole.User, messageContent.Role);
Assert.Equal("Test prompt as user message", messageContent.Content);
}

public enum KernelInvocationType
{
InvokePrompt,
InvokePromptStreaming,
InvokeFunction,
InvokeFunctionStreaming
}

#region private

private sealed class FakeChatAsTextService : ITextGenerationService, IChatCompletionService
{
public IReadOnlyDictionary<string, object?> Attributes => throw new NotImplementedException();
Expand Down Expand Up @@ -644,4 +699,6 @@ public Task<IReadOnlyList<TextContent>> GetTextContentsAsync(string prompt, Prom
throw new NotImplementedException();
}
}

#endregion
}

0 comments on commit 48b6bb2

Please sign in to comment.