Skip to content

Commit

Permalink
Advancing Tool Support - Part 1
Browse files Browse the repository at this point in the history
* Defined new APIs consolidating the “tool” naming as opposed to the current “function”, aligning with the industry and solving the confusion between “function tool” and “Java Function”: ToolCallback and ToolCallingChatOptions. They extend the current ones to ensure backward compatibility, but FunctionCallback and FunctionCallingOptions can be considered deprecated.
* Enhanced support for methods as tools, introducing support for declarative Tool-annotated methods via MethodToolCallback and MethodToolCallbackProvider (deprecating the existing MethodInvokingFunctionCallback).
* Improved tool execution logic with granular support for returning the result directly to the client and exception handling.
* Improved JSON Schema generation and parsing logic, consolidating the usage of the victools/jsonschema-generator library and dropping the non-maintained Jackson JSON Schema Module. This makes it possible to use tools with input lists/arrays, which the latter library was not supporting.
* Extended ChatClient API with new methods tools() and toolCallbacks(). The existing functions() methods can be considered deprecated.

Relates to spring-projectsgh-2049

Signed-off-by: Thomas Vitale <[email protected]>
  • Loading branch information
ThomasVitale committed Jan 12, 2025
1 parent 329e6c0 commit df3a3b4
Show file tree
Hide file tree
Showing 48 changed files with 4,609 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.aot;

import org.springframework.ai.tool.execution.DefaultToolCallResultConverter;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;

/**
* Registers runtime hints for the tool calling APIs.
*
* @author Thomas Vitale
*/
public class ToolRuntimeHints implements RuntimeHintsRegistrar {

@Override
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
var mcs = MemberCategory.values();
hints.reflection().registerType(DefaultToolCallResultConverter.class, mcs);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -215,6 +215,12 @@ interface ChatClientRequestSpec {

<T extends ChatOptions> ChatClientRequestSpec options(T options);

ChatClientRequestSpec tools(String... toolNames);

ChatClientRequestSpec tools(Object... toolObjects);

ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks);

/**
* @deprecated use {@link #functions(FunctionCallback...)} instead.
*/
Expand Down Expand Up @@ -293,6 +299,12 @@ interface Builder {

Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer);

Builder defaultTools(String... toolNames);

Builder defaultTools(Object... toolObjects);

Builder defaultToolCallbacks(FunctionCallback... toolCallbacks);

/**
* @deprecated use {@link #defaultFunctions(FunctionCallback...)} instead.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,6 +32,7 @@
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.springframework.ai.tool.ToolCallbacks;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;

Expand Down Expand Up @@ -782,10 +783,9 @@ public Builder mutate() {
builder.defaultOptions(this.chatOptions);
}

// workaround to set the missing fields.
builder.defaultRequest.getMessages().addAll(this.messages);
builder.defaultRequest.getFunctionCallbacks().addAll(this.functionCallbacks);
builder.defaultRequest.getToolContext().putAll(this.toolContext);
builder.addMessages(this.messages);
builder.addToolCallbacks(this.functionCallbacks);
builder.addToolContext(this.toolContext);

return builder;
}
Expand Down Expand Up @@ -836,6 +836,30 @@ public <T extends ChatOptions> ChatClientRequestSpec options(T options) {
return this;
}

@Override
public ChatClientRequestSpec tools(String... toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
this.functionNames.addAll(List.of(toolNames));
return this;
}

@Override
public ChatClientRequestSpec tools(Object... toolObjects) {
Assert.notNull(toolObjects, "toolObjects cannot be null");
Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements");
this.functionCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects)));
return this;
}

@Override
public ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
this.functionCallbacks.addAll(Arrays.asList(toolCallbacks));
return this;
}

@Override
public <I, O> ChatClientRequestSpec function(String name, String description,
java.util.function.Function<I, O> function) {
Expand Down Expand Up @@ -888,10 +912,7 @@ public <I, O> ChatClientRequestSpec function(String name, String description, @N
}

public ChatClientRequestSpec functions(String... functionBeanNames) {
Assert.notNull(functionBeanNames, "functionBeanNames cannot be null");
Assert.noNullElements(functionBeanNames, "functionBeanNames cannot contain null elements");
this.functionNames.addAll(List.of(functionBeanNames));
return this;
return tools(functionBeanNames);
}

public ChatClientRequestSpec functions(FunctionCallback... functionCallbacks) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,10 +30,12 @@
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.tool.ToolCallbacks;
import org.springframework.core.io.Resource;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -147,6 +149,24 @@ public Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer) {
return this;
}

@Override
public Builder defaultTools(String... toolNames) {
this.defaultRequest.functions(toolNames);
return this;
}

@Override
public Builder defaultTools(Object... toolObjects) {
this.defaultRequest.functions(ToolCallbacks.from(toolObjects));
return this;
}

@Override
public Builder defaultToolCallbacks(FunctionCallback... toolCallbacks) {
this.defaultRequest.functions(toolCallbacks);
return this;
}

public <I, O> Builder defaultFunction(String name, String description, java.util.function.Function<I, O> function) {
this.defaultRequest.function(name, description, function);
return this;
Expand All @@ -173,4 +193,17 @@ public Builder defaultToolContext(Map<String, Object> toolContext) {
return this;
}

void addMessages(List<Message> messages) {
this.defaultRequest.messages(messages);
}

void addToolCallbacks(List<FunctionCallback> toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
this.defaultRequest.toolCallbacks(toolCallbacks.toArray(FunctionCallback[]::new));
}

void addToolContext(Map<String, Object> toolContext) {
this.defaultRequest.toolContext(toolContext);
}

}
Loading

0 comments on commit df3a3b4

Please sign in to comment.