Skip to content

Commit

Permalink
Fix Moonshot Chat model toolcalling token usage
Browse files Browse the repository at this point in the history
 - Accumulate the token usage when toolcalling is invoked
   - Fix both call() and stream() methods
     - Add `usage` field to the Chat completion choice as the usage is returned via Choice
 - Add Mootshot chatmodel ITs for functioncalling tests
  • Loading branch information
ilayaperumalg committed Dec 13, 2024
1 parent 218c967 commit de9a356
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
Expand Down Expand Up @@ -74,6 +76,7 @@
* MoonshotChatModel is a {@link ChatModel} implementation that uses the Moonshot
*
* @author Geng Rong
* @author Ilayaperumal Gopinathan
*/
public class MoonshotChatModel extends AbstractToolCallSupport implements ChatModel, StreamingChatModel {

Expand Down Expand Up @@ -179,6 +182,10 @@ private static Generation buildGeneration(Choice choice, Map<String, Object> met

@Override
public ChatResponse call(Prompt prompt) {
return this.internalCall(prompt, null);
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
ChatCompletionRequest request = createRequest(prompt, false);

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
Expand Down Expand Up @@ -217,8 +224,11 @@ public ChatResponse call(Prompt prompt) {
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();

ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));
MoonshotApi.Usage usage = completionEntity.getBody().usage();
Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage();
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
ChatResponse chatResponse = new ChatResponse(generations,
from(completionEntity.getBody(), cumulativeUsage));

observationContext.setResponse(chatResponse);

Expand All @@ -231,7 +241,7 @@ && isToolCall(response, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
}
return response;
}
Expand All @@ -243,6 +253,10 @@ public ChatOptions getDefaultOptions() {

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return this.internalStream(prompt, null);
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);

Expand Down Expand Up @@ -286,8 +300,11 @@ public Flux<ChatResponse> stream(Prompt prompt) {
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();
MoonshotApi.Usage usage = chatCompletion2.usage();
Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage();
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);

return new ChatResponse(generations, from(chatCompletion2));
return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage));
}
catch (Exception e) {
logger.error("Error processing chat completion", e);
Expand All @@ -302,7 +319,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response);
}
return Flux.just(response);
})
Expand All @@ -324,6 +341,16 @@ private ChatResponseMetadata from(ChatCompletion result) {
.build();
}

private ChatResponseMetadata from(ChatCompletion result, Usage usage) {
Assert.notNull(result, "Moonshot ChatCompletionResult must not be null");
return ChatResponseMetadata.builder()
.withId(result.id() != null ? result.id() : "")
.withUsage(usage)
.withModel(result.model() != null ? result.model() : "")
.withKeyValue("created", result.created() != null ? result.created() : 0L)
.build();
}

/**
* Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
* @param chunk the ChatCompletionChunk to convert
Expand All @@ -335,10 +362,11 @@ private ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) {
if (delta == null) {
delta = new ChatCompletionMessage("", ChatCompletionMessage.Role.ASSISTANT);
}
return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason());
return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason(), cc.usage());
}).toList();

return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null);
// Get the usage from the latest choice
MoonshotApi.Usage usage = choices.get(choices.size() - 1).usage();
return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, usage);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ public record Choice(
// @formatter:off
@JsonProperty("index") Integer index,
@JsonProperty("message") ChatCompletionMessage message,
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) {
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
@JsonProperty("usage") Usage usage) {
// @formatter:on
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
: previous.finishReason());
Integer index = (current.index() != null ? current.index() : previous.index());

MoonshotApi.Usage usage = current.usage() != null ? current.usage() : previous.usage();

ChatCompletionMessage message = merge(previous.delta(), current.delta());
return new ChunkChoice(index, message, finishReason, null);
return new ChunkChoice(index, message, finishReason, usage);
}

private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.moonshot;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.moonshot.api.MockWeatherService;
import org.springframework.ai.moonshot.api.MoonshotApi;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.springframework.util.StringUtils;

import static org.assertj.core.api.Assertions.assertThat;

/**
* @author Ilayaperumal Gopinathan
*/
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "MOONSHOT_API_KEY", matches = ".+")
public class MoonShotChatModelIT {

@Autowired
private MoonshotChatModel chatModel;

private static final MoonshotApi.FunctionTool FUNCTION_TOOL = new MoonshotApi.FunctionTool(
MoonshotApi.FunctionTool.Type.FUNCTION, new MoonshotApi.FunctionTool.Function(
"Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """
{
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"lat": {
"type": "number",
"description": "The city latitude"
},
"lon": {
"type": "number",
"description": "The city longitude"
},
"unit": {
"type": "string",
"enum": ["C", "F"]
}
},
"required": ["location", "lat", "lon", "unit"]
}
"""));

@Test
public void toolFunctionCall() {
var promptOptions = MoonshotChatOptions.builder()
.withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
.withTools(Arrays.asList(FUNCTION_TOOL))
.withFunctionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
.inputType(MockWeatherService.Request.class)
.build()))
.build();
Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.",
promptOptions);

ChatResponse chatResponse = this.chatModel.call(prompt);
assertThat(chatResponse).isNotNull();
assertThat(chatResponse.getResult().getOutput());
assertThat(chatResponse.getResult().getOutput().getText()).contains("San Francisco");
assertThat(chatResponse.getResult().getOutput().getText()).contains("30.0");
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280);
}

@Test
public void testStreamFunctionCall() {
var promptOptions = MoonshotChatOptions.builder()
.withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
.withTools(Arrays.asList(FUNCTION_TOOL))
.withFunctionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
.inputType(MockWeatherService.Request.class)
.build()))
.build();
Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.",
promptOptions);

Flux<ChatResponse> chatResponse = this.chatModel.stream(prompt);
String content = chatResponse.collectList()
.block()
.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getText)
.collect(Collectors.joining());
assertThat(content).contains("San Francisco");
assertThat(content).contains("30.0");
}

@Test
public void testStreamFunctionCallUsage() {
var promptOptions = MoonshotChatOptions.builder()
.withModel(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
.withTools(Arrays.asList(FUNCTION_TOOL))
.withFunctionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
.inputType(MockWeatherService.Request.class)
.build()))
.build();
Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.",
promptOptions);

ChatResponse chatResponse = this.chatModel.stream(prompt).blockLast();
assertThat(chatResponse).isNotNull();
assertThat(chatResponse.getMetadata()).isNotNull();
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280);
}

@SpringBootConfiguration
public static class Config {

@Bean
public MoonshotApi moonshotApi() {
return new MoonshotApi(getApiKey());
}

private String getApiKey() {
String apiKey = System.getenv("MOONSHOT_API_KEY");
if (!StringUtils.hasText(apiKey)) {
throw new IllegalArgumentException(
"You must provide an API key. Put it in an environment variable under the name MOONSHOT_API_KEY");
}
return apiKey;
}

@Bean
public MoonshotChatModel moonshotChatModel(MoonshotApi api) {
return new MoonshotChatModel(api);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public void beforeEach() {
public void moonshotChatTransientError() {

var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
ChatCompletionFinishReason.STOP);
ChatCompletionFinishReason.STOP, null);
ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789L, "model",
List.of(choice), new MoonshotApi.Usage(10, 10, 10));

Expand Down

0 comments on commit de9a356

Please sign in to comment.