Skip to content

Commit

Permalink
Java: Add example of a multi turnaround chat with tool calls. Fixes b…
Browse files Browse the repository at this point in the history
…ug with serializing tool call arguments (microsoft#6372)

### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
johnoliver authored May 22, 2024
1 parent 72f0510 commit 8b55f2b
Show file tree
Hide file tree
Showing 5 changed files with 340 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ private static ChatRequestMessage getChatRequestMessage(

if (id == null) {
throw new SKException(
"Require to create a tool call message, but not tool call id is available");
"Require to create a tool call message, but no tool call id is available");
}
return new ChatRequestToolMessage(content, id);
default:
Expand Down Expand Up @@ -918,7 +918,7 @@ private static ChatRequestAssistantMessage formAssistantMessage(
StringEscapeUtils.escapeJson(entry.getKey()),
StringEscapeUtils.escapeJson(
entry.getValue().toPromptString())))
.collect(Collectors.joining("{", "}", ","))
.collect(Collectors.joining(",", "{", "}"))
: "{}";

String prefix = "";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.aiservices.openai.chatcompletion;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsFunctionToolCall;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatRequestAssistantMessage;
import com.azure.core.http.HttpHeaders;
import com.azure.core.http.HttpRequest;
import com.azure.core.http.rest.Response;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.microsoft.semantickernel.implementation.EmbeddedResourceLoader;
import com.microsoft.semantickernel.orchestration.FunctionResultMetadata;
import com.microsoft.semantickernel.semanticfunctions.KernelFunctionArguments;
import com.microsoft.semantickernel.services.chatcompletion.AuthorRole;
import com.microsoft.semantickernel.services.chatcompletion.ChatHistory;
import java.nio.charset.Charset;
import java.util.Arrays;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import reactor.core.publisher.Mono;

public class OpenAiChatCompletionTest {

@Test
public void serializesToolCallsCorrectly() {
OpenAIAsyncClient client = Mockito.mock(OpenAIAsyncClient.class);
OpenAIChatCompletion chatCompletion = mockClient(client);

ChatHistory chatHistory = new ChatHistory();

chatHistory.addUserMessage(
"What is the name of the pet with id ca2fc6bc-1307-4da6-a009-d7bf88dec37b?");

chatHistory.addMessage(new OpenAIChatMessageContent(
AuthorRole.ASSISTANT,
"",
"test",
null,
Charset.defaultCharset(),
null,
Arrays.asList(
new OpenAIFunctionToolCall(
"a-tool-id",
"pluginName",
"funcName",
KernelFunctionArguments.builder()
.withVariable("id", "ca2fc6bc-1307-4da6-a009-d7bf88dec37b")
.build()))));
chatHistory.addMessage(new OpenAIChatMessageContent(
AuthorRole.TOOL,
"Snuggles",
"test",
null,
Charset.defaultCharset(),
FunctionResultMetadata.build("a-tool-id"),
null));

chatCompletion
.getChatMessageContentsAsync(chatHistory, null, null).block();

Mockito.verify(client, Mockito.times(1))
.getChatCompletionsWithResponse(
Mockito.any(),
Mockito.<ChatCompletionsOptions>argThat(options -> {
ChatRequestAssistantMessage message = ((ChatRequestAssistantMessage) options
.getMessages()
.get(1));
ChatCompletionsFunctionToolCall toolcall = ((ChatCompletionsFunctionToolCall) message
.getToolCalls()
.get(0));
return toolcall
.getFunction()
.getArguments()
.equals("{\"id\": \"ca2fc6bc-1307-4da6-a009-d7bf88dec37b\"}");
}),
Mockito.any());
}

private static OpenAIChatCompletion mockClient(OpenAIAsyncClient client) {
Mockito.when(client.getChatCompletionsWithResponse(Mockito.any(),
Mockito.<ChatCompletionsOptions>any(), Mockito.any()))
.thenReturn(Mono.just(
new Response<ChatCompletions>() {
@Override
public int getStatusCode() {
return 200;
}

@Override
public HttpHeaders getHeaders() {
return new HttpHeaders();
}

@Override
public HttpRequest getRequest() {
return null;
}

@Override
public ChatCompletions getValue() {
try {
String message = EmbeddedResourceLoader.readFile("chatCompletion.txt",
OpenAiChatCompletionTest.class);

return new ObjectMapper()
.readValue(String.format(message, "Snuggles"),
ChatCompletions.class);

} catch (Exception e) {
throw new RuntimeException(e);
}
}
}));
return new OpenAIChatCompletion(
client,
"test",
"test",
"test");
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{
"choices" : [
{
"content_filter_results" : {
"hate" : {
"filtered" : false,
"severity" : "safe"
},
"self_harm" : {
"filtered" : false,
"severity" : "safe"
},
"sexual" : {
"filtered" : false,
"severity" : "safe"
},
"violence" : {
"filtered" : false,
"severity" : "safe"
}
},
"finish_reason" : "stop",
"index" : 0,
"message" : {
"content" : "%s",
"role" : "assistant"
}
}
],
"created" : 1707253039,
"id" : "chatcmpl-xxx",
"prompt_filter_results" : [
{
"content_filter_results" : {
"hate" : {
"filtered" : false,
"severity" : "safe"
},
"self_harm" : {
"filtered" : false,
"severity" : "safe"
},
"sexual" : {
"filtered" : false,
"severity" : "safe"
},
"violence" : {
"filtered" : false,
"severity" : "safe"
}
},
"prompt_index" : 0
}
],
"usage" : {
"completion_tokens" : 131,
"prompt_tokens" : 26,
"total_tokens" : 157
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.microsoft.semantickernel.aiservices.openai.chatcompletion.OpenAIFunctionToolCall;
import com.microsoft.semantickernel.implementation.CollectionUtil;
import com.microsoft.semantickernel.orchestration.InvocationContext;
import com.microsoft.semantickernel.orchestration.InvocationReturnMode;
import com.microsoft.semantickernel.orchestration.ToolCallBehavior;
import com.microsoft.semantickernel.plugin.KernelPlugin;
import com.microsoft.semantickernel.semanticfunctions.KernelFunction;
Expand Down Expand Up @@ -86,12 +87,14 @@ public void nonAutoInvokedIsNotCalled() throws NoSuchMethodException {
.build())
.block();

List<OpenAIFunctionToolCall> toolCalls = ((OpenAIChatMessageContent<?>) CollectionUtil.getLastOrNull(result))
List<OpenAIFunctionToolCall> toolCalls = ((OpenAIChatMessageContent<?>) CollectionUtil.getLastOrNull(
result))
.getToolCall();

Assertions.assertNotNull(toolCalls);
Assertions.assertEquals(1, toolCalls.size());
Assertions.assertEquals("apluginname", CollectionUtil.getLastOrNull(toolCalls).getPluginName());
Assertions.assertEquals("apluginname",
CollectionUtil.getLastOrNull(toolCalls).getPluginName());
Assertions.assertEquals("doIt", CollectionUtil.getLastOrNull(toolCalls).getFunctionName());
Assertions.assertEquals("call_abc123", CollectionUtil.getLastOrNull(toolCalls).getId());

Expand Down Expand Up @@ -135,7 +138,8 @@ public void toolIsInvoked() throws NoSuchMethodException {
.build())
.block();

Assertions.assertTrue(CollectionUtil.getLastOrNull(result).getContent().contains("tool call done"));
Assertions.assertTrue(
CollectionUtil.getLastOrNull(result).getContent().contains("tool call done"));
Mockito.verify(testPlugin, Mockito.times(1)).doIt();

result = chatCompletionService
Expand All @@ -162,10 +166,66 @@ public void toolIsInvoked() throws NoSuchMethodException {
.build())
.block();

Assertions.assertTrue(CollectionUtil.getLastOrNull(result).getContent().contains("tool call done"));
Assertions.assertTrue(
CollectionUtil.getLastOrNull(result).getContent().contains("tool call done"));
Mockito.verify(testPlugin, Mockito.times(3)).doIt();
}


@Test
public void toolCallingHistoryPassed() throws NoSuchMethodException {
ChatCompletionService chatCompletionService = getChatCompletionService();

TestPlugin testPlugin = Mockito.spy(new TestPlugin());
KernelFunction<String> method = KernelFunctionFromMethod.<String>builder()
.withFunctionName("doIt")
.withMethod(TestPlugin.class.getMethod("doIt"))
.withTarget(testPlugin)
.withPluginName("apluginname")
.build();

Kernel kernel = Kernel.builder()
.withAIService(ChatCompletionService.class, chatCompletionService)
.withPlugin(
new KernelPlugin(
"apluginname",
"A plugin description",
Map.of("doIt", method)))
.build();

ChatHistory messages = new ChatHistory();
messages.addMessage(
new ChatMessageContent<>(
AuthorRole.USER,
"Call A function"));

List<ChatMessageContent<?>> result = chatCompletionService
.getChatMessageContentsAsync(
messages,
kernel,
InvocationContext.builder()
.withToolCallBehavior(
ToolCallBehavior.allowAllKernelFunctions(true))
.withReturnMode(InvocationReturnMode.FULL_HISTORY)
.build())
.block();

ChatHistory newHistory = new ChatHistory(result);
newHistory.addMessage(AuthorRole.USER, "do something else");

List<ChatMessageContent<?>> result2 = chatCompletionService
.getChatMessageContentsAsync(
newHistory,
kernel,
InvocationContext.builder()
.withToolCallBehavior(
ToolCallBehavior.allowAllKernelFunctions(true))
.withReturnMode(InvocationReturnMode.FULL_HISTORY)
.build())
.block();
Assertions.assertTrue(result2.size() == 6);
}

private ChatCompletionService getChatCompletionService() {
wm
.addStubMapping(
Expand Down Expand Up @@ -221,6 +281,7 @@ private ChatCompletionService getChatCompletionService() {
.build();
}


public static MappingBuilder buildTextResponse(String bodyMatcher, String responseBody) {
return post(urlEqualTo(
"//openai/deployments/gpt-35-turbo-2/completions?api-version=2024-03-01-preview"))
Expand Down
Loading

0 comments on commit 8b55f2b

Please sign in to comment.