diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/pom.xml b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/pom.xml index 363837b0..c2dffa73 100644 --- a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/pom.xml +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/pom.xml @@ -15,11 +15,14 @@ A library for building stateful, multi-agents applications with LLMs - - 17 - 17 - UTF-8 - + + 17 + 17 + UTF-8 + 1.3 + 1.16.0 + 4.4 + @@ -123,6 +126,24 @@ true + + org.apache.commons + commons-collections4 + ${commons-collections4.version} + + + + org.apache.commons + commons-exec + ${commons-exec.version} + + + + commons-codec + commons-codec + ${commons-codec.version} + + diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/CodeExecutor.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/CodeExecutor.java new file mode 100644 index 00000000..2833173e --- /dev/null +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/CodeExecutor.java @@ -0,0 +1,34 @@ +package com.alibaba.cloud.ai.graph.node.code; + +import com.alibaba.cloud.ai.graph.node.code.entity.CodeBlock; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionResult; + +import java.util.List; + +/** + * @author HeYQ + * @since 2024-12-02 17:15 + */ + +public interface CodeExecutor { + + /** + * Execute code blocks and return the result. This method should be implemented by the + * code executor. + * @param codeBlockList The code blocks to execute. + * @param codeExecutionConfig The configuration of the code execution. + * @return CodeExecutionResult The result of the code execution. + * @throws Exception ValueError: Errors in user inputs + */ + + CodeExecutionResult executeCodeBlocks(List codeBlockList, CodeExecutionConfig codeExecutionConfig) + throws Exception; + + /** + * Restart the code executor. This method should be implemented by the code executor. + * This method is called when the agent is reset. + */ + void restart(); + +} diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/CodeExecutorNodeAction.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/CodeExecutorNodeAction.java new file mode 100644 index 00000000..677dd334 --- /dev/null +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/CodeExecutorNodeAction.java @@ -0,0 +1,96 @@ +package com.alibaba.cloud.ai.graph.node.code; + +import com.alibaba.cloud.ai.graph.action.NodeAction; +import com.alibaba.cloud.ai.graph.node.AbstractNode; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeBlock; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionResult; +import com.alibaba.cloud.ai.graph.state.NodeState; +import com.alibaba.fastjson.JSONObject; +import com.alibaba.fastjson.TypeReference; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * @author HeYQ + * @since 2024-11-28 11:47 + */ + +public class CodeExecutorNodeAction extends AbstractNode implements NodeAction { + + private final CodeExecutor codeExecutor; + + private final String codeLanguage; + + private final String code; + + private final CodeExecutionConfig codeExecutionConfig; + + public CodeExecutorNodeAction(CodeExecutor codeExecutor, String codeLanguage, String code, + CodeExecutionConfig config) { + this.codeExecutor = codeExecutor; + this.codeLanguage = codeLanguage; + this.code = code; + this.codeExecutionConfig = config; + } + + @Override + public Map apply(NodeState state) throws Exception { + List codeBlockList = new ArrayList<>(10); + codeBlockList.add(new CodeBlock(codeLanguage, code)); + CodeExecutionResult codeExecutionResult = codeExecutor.executeCodeBlocks(codeBlockList, + this.codeExecutionConfig); + if (codeExecutionResult.exitCode() != 0) { + throw new RuntimeException("code execution failed, exit code: " + codeExecutionResult.exitCode() + + ", logs: " + codeExecutionResult.logs()); + } + return JSONObject.parseObject(codeExecutionResult.logs(), new TypeReference>() { + }); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private CodeExecutor codeExecutor; + + private String codeLanguage; + + private String code; + + private CodeExecutionConfig config; + + public Builder() { + } + + public Builder codeExecutor(CodeExecutor codeExecutor) { + this.codeExecutor = codeExecutor; + return this; + } + + public Builder codeLanguage(String codeLanguage) { + this.codeLanguage = codeLanguage; + return this; + } + + public Builder code(String code) { + this.code = code; + return this; + } + + public Builder config(CodeExecutionConfig config) { + this.config = config; + return this; + } + + public CodeExecutorNodeAction build() { + return new CodeExecutorNodeAction(codeExecutor, codeLanguage, code, config); + } + + } + +} diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/LocalCommandlineCodeExecutor.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/LocalCommandlineCodeExecutor.java new file mode 100644 index 00000000..4e95f382 --- /dev/null +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/LocalCommandlineCodeExecutor.java @@ -0,0 +1,134 @@ +package com.alibaba.cloud.ai.graph.node.code; + +import com.alibaba.cloud.ai.graph.node.code.entity.CodeBlock; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionResult; +import com.alibaba.cloud.ai.graph.utils.FileUtils; +import org.apache.commons.codec.digest.DigestUtils; +import org.apache.commons.exec.CommandLine; +import org.apache.commons.exec.DefaultExecutor; +import org.apache.commons.exec.ExecuteException; +import org.apache.commons.exec.ExecuteWatchdog; +import org.apache.commons.exec.PumpStreamHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +/** + * @author HeYQ + * @since 2024-12-02 17:23 + */ + +public class LocalCommandlineCodeExecutor implements CodeExecutor { + + private static final Logger logger = LoggerFactory.getLogger(LocalCommandlineCodeExecutor.class); + + @Override + public CodeExecutionResult executeCodeBlocks(List codeBlockList, CodeExecutionConfig codeExecutionConfig) + throws Exception { + StringBuilder allLogs = new StringBuilder(); + CodeExecutionResult result; + for (int i = 0; i < codeBlockList.size(); i++) { + CodeBlock codeBlock = codeBlockList.get(i); + String language = codeBlock.language(); + String code = codeBlock.code(); + logger.info("\n>>>>>>>> EXECUTING CODE BLOCK {} (inferred language is {})...", i + 1, language); + + if (Set.of("bash", "shell", "sh", "python").contains(language.toLowerCase())) { + result = executeCode(language, code, codeExecutionConfig); + } + else { + // the language is not supported, then return an error message. + result = new CodeExecutionResult(1, "unknown language " + language); + } + + allLogs.append("\n").append(result.logs()); + if (result.exitCode() != 0) { + return new CodeExecutionResult(result.exitCode(), allLogs.toString()); + } + } + return new CodeExecutionResult(0, allLogs.toString()); + } + + @Override + public void restart() { + + logger.warn("Restarting local command line code executor is not supported. No action is taken."); + } + + public CodeExecutionResult executeCode(String language, String code, CodeExecutionConfig config) throws Exception { + if (Objects.isNull(language) || Objects.isNull(code)) { + throw new Exception("Either language or code must be provided."); + } + String workDir = config.getWorkDir(); + String codeHash = DigestUtils.md5Hex(code); + String fileExt = language.startsWith("python") ? "py" : language; + String filename = String.format("tmp_code_%s.%s", codeHash, fileExt); + + // write the code string to a file specified by the filename. + FileUtils.writeCodeToFile(workDir, filename, code); + + CodeExecutionResult executionResult = executeCodeLocally(language, workDir, filename, config.getTimeout()); + + FileUtils.deleteFile(workDir, filename); + return executionResult; + } + + private CodeExecutionResult executeCodeLocally(String language, String workDir, String filename, int timeout) + throws Exception { + // set up the command based on language + String executable = getExecutableForLanguage(language); + CommandLine commandLine = new CommandLine(executable); + commandLine.addArgument(filename); + + // set up the execution environment + DefaultExecutor executor = new DefaultExecutor(); + executor.setWorkingDirectory(new File(workDir)); + executor.setExitValue(0); + + // set up the streams for the output of the subprocess + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + ByteArrayOutputStream errorStream = new ByteArrayOutputStream(); + PumpStreamHandler streamHandler = new PumpStreamHandler(outputStream, errorStream); + executor.setStreamHandler(streamHandler); + + // set up a watchdog to terminate the process if it exceeds the timeout + ExecuteWatchdog watchdog = new ExecuteWatchdog(TimeUnit.SECONDS.toMillis(timeout)); + executor.setWatchdog(watchdog); + + try { + // execute the command + executor.execute(commandLine); + // process completed before the watchdog terminated it + String output = outputStream.toString(); + return new CodeExecutionResult(0, output.trim()); + } + catch (ExecuteException e) { + // process finished with an exit value (possibly non-zero) + String errorOutput = errorStream.toString().replace(Path.of(workDir).toAbsolutePath() + File.separator, ""); + + return new CodeExecutionResult(e.getExitValue(), errorOutput.trim()); + } + catch (IOException e) { + // returns a special result if the process was killed by the watchdog + throw new Exception("Error executing code.", e); + } + } + + private String getExecutableForLanguage(String language) throws Exception { + return switch (language) { + case "python" -> language; + case "shell", "bash", "sh", "powershell" -> "sh"; + default -> throw new Exception("Language not recognized in code execution:" + language); + }; + } + +} diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeBlock.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeBlock.java new file mode 100644 index 00000000..907faf98 --- /dev/null +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeBlock.java @@ -0,0 +1,8 @@ +package com.alibaba.cloud.ai.graph.node.code.entity; + +/** + * @author HeYQ + * @since 0.0.1 + */ +public record CodeBlock(String language, String code) { +} diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeExecutionConfig.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeExecutionConfig.java new file mode 100644 index 00000000..6569457d --- /dev/null +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeExecutionConfig.java @@ -0,0 +1,43 @@ +package com.alibaba.cloud.ai.graph.node.code.entity; + +import lombok.Builder; +import lombok.Data; + +/** + * Config for the code execution. + * + * @author HeYQ + * @since 0.0.1 + */ +@Data +@Builder +public class CodeExecutionConfig { + + /** + * the working directory for the code execution. + */ + @Builder.Default + private String workDir = "extensions"; + + /** + * the docker image to use for code execution. + */ + private String docker; + + /** + * the maximum execution time in seconds. + */ + @Builder.Default + private int timeout = 600; + + /** + * the number of messages to look back for code execution. default value is 1, and -1 + * indicates auto mode. + */ + @Builder.Default + private int lastMessagesNumber = 1; + + @Builder.Default + private int codeMaxDepth = 5; + +} diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeExecutionResult.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeExecutionResult.java new file mode 100644 index 00000000..b8aef06f --- /dev/null +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/node/code/entity/CodeExecutionResult.java @@ -0,0 +1,18 @@ +package com.alibaba.cloud.ai.graph.node.code.entity; + +/** + * Represents the result of code execution. + * + * @param exitCode 0 if the code executes successfully. + * @param logs the error message if the code fails to execute, the stdout otherwise. + * @param extra commandLine code_file or the docker image name after container run when + * docker is used. + * @author HeYQ + * @since 0.0.1 + */ +public record CodeExecutionResult(int exitCode, String logs, String extra) { + + public CodeExecutionResult(int exitCode, String logs) { + this(exitCode, logs, null); + } +} diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/utils/FileUtils.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/utils/FileUtils.java new file mode 100644 index 00000000..4611edab --- /dev/null +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/main/java/com/alibaba/cloud/ai/graph/utils/FileUtils.java @@ -0,0 +1,53 @@ +package com.alibaba.cloud.ai.graph.utils; + +import lombok.SneakyThrows; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + +/** + * @author HeYQ + * @since 2024-11-28 11:47 + */ +public class FileUtils { + + private FileUtils() { + throw new IllegalStateException("Utility class"); + } + + /** + * Writes the given code string to a file specified by the filename. The file is + * created in the provided working directory. Intermediate directories in the path + * will be created if they do not exist. + * @param workDir The working directory where the file needs to be created. + * @param filename The name of the file to write the code to. + * @param code The code to write to the file. Must not be null. + */ + @SneakyThrows(IOException.class) + public static void writeCodeToFile(String workDir, String filename, String code) { + if (code == null) { + throw new IllegalArgumentException("Code must not be null"); + } + Path filepath = Path.of(workDir, filename); + // ensure the parent directory exists + Path fileDir = filepath.getParent(); + if (fileDir != null && !Files.exists(fileDir)) { + Files.createDirectories(fileDir); + } + // write the code to the file + Files.writeString(filepath, code); + } + + /** + * Deletes the file specified by the filename from the provided working directory. + * @param workDir The working directory where the file to be deleted is located. + * @param filename The name of the file to be deleted. + */ + @SneakyThrows(IOException.class) + public static void deleteFile(String workDir, String filename) { + Path filepath = Path.of(workDir, filename); + Files.deleteIfExists(filepath); + } + +} diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/CodeActionTest.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/CodeActionTest.java new file mode 100644 index 00000000..9e49e00e --- /dev/null +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-core/src/test/java/com/alibaba/cloud/ai/graph/CodeActionTest.java @@ -0,0 +1,70 @@ +package com.alibaba.cloud.ai.graph; + +import com.alibaba.cloud.ai.graph.action.NodeAction; +import com.alibaba.cloud.ai.graph.node.code.CodeExecutorNodeAction; +import com.alibaba.cloud.ai.graph.node.code.LocalCommandlineCodeExecutor; +import com.alibaba.cloud.ai.graph.node.code.entity.CodeExecutionConfig; +import com.alibaba.cloud.ai.graph.state.NodeState; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +/** + * @author HeYQ + * @since 2024-11-28 11:49 + */ + +public class CodeActionTest { + + private CodeExecutionConfig config; + + @TempDir + Path tempDir; + + @BeforeEach + void setUp() { + // set up the configuration for each test + config = CodeExecutionConfig.builder().workDir(tempDir.toString()).build(); + } + + private LLMNodeActionTest.MockState mockState() { + Map initData = new HashMap<>(); + return new LLMNodeActionTest.MockState(initData); + } + + @Test + void testExecutePythonSuccessfully() throws Exception { + String code = """ + print({"result":'Hello, Python!'}) + """; + + NodeAction codeNode = CodeExecutorNodeAction.builder() + .codeExecutor(new LocalCommandlineCodeExecutor()) + .code(code) + .codeLanguage("python") + .config(config) + .build(); + + Map stateData = codeNode.apply(mockState()); + + // assertThat(result.exitCode()).isZero(); + // assertThat(result.logs()).contains("Hello, Python!"); + System.out.println(stateData); + } + + static class MockState extends NodeState { + + /** + * Constructs an AgentState with the given initial data. + * @param initData the initial data for the agent state + */ + public MockState(Map initData) { + super(initData); + } + + } + +} diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/model/workflow/NodeType.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/model/workflow/NodeType.java index 94029b34..f8a81a9d 100644 --- a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/model/workflow/NodeType.java +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/model/workflow/NodeType.java @@ -16,6 +16,8 @@ public enum NodeType { AGGREGATOR("AGGREGATOR", "variable-aggregator"), + QUESTION_CLASSIFIER("QUESTION_CLASSIFIER", "question-classifier"), + HUMAN("HUMAN", "unsupported"),; private String value; diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/model/workflow/nodedata/QuestionClassifierNodeData.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/model/workflow/nodedata/QuestionClassifierNodeData.java new file mode 100644 index 00000000..3b369346 --- /dev/null +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/model/workflow/nodedata/QuestionClassifierNodeData.java @@ -0,0 +1,47 @@ +package com.alibaba.cloud.ai.model.workflow.nodedata; + +import com.alibaba.cloud.ai.model.Variable; +import com.alibaba.cloud.ai.model.VariableSelector; +import com.alibaba.cloud.ai.model.VariableType; +import com.alibaba.cloud.ai.model.workflow.NodeData; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.experimental.Accessors; + +import java.util.List; + +/** + * @author HeYQ + * @since 2024-12-12 21:26 + */ +@EqualsAndHashCode(callSuper = true) +@Data +@Accessors(chain = true) +public class QuestionClassifierNodeData extends NodeData { + + public static final Variable DEFAULT_OUTPUT_SCHEMA = new Variable("class_name", VariableType.STRING.value()); + + private LLMNodeData.ModelConfig model; + + private LLMNodeData.MemoryConfig memoryConfig; + + private String instruction; + + private List classes; + + public QuestionClassifierNodeData(List inputs, List outputs) { + super(inputs, outputs); + } + + @Data + @AllArgsConstructor + public static class ClassConfig { + + private String id; + + private String text; + + } + +} diff --git a/spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/service/dsl/nodes/QuestionClassifyNodeDataConverter.java b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/service/dsl/nodes/QuestionClassifyNodeDataConverter.java new file mode 100644 index 00000000..3956b921 --- /dev/null +++ b/spring-ai-alibaba-graph/spring-ai-alibaba-graph-studio/src/main/java/com/alibaba/cloud/ai/service/dsl/nodes/QuestionClassifyNodeDataConverter.java @@ -0,0 +1,134 @@ +package com.alibaba.cloud.ai.service.dsl.nodes; + +import com.alibaba.cloud.ai.model.VariableSelector; +import com.alibaba.cloud.ai.model.workflow.NodeType; +import com.alibaba.cloud.ai.model.workflow.nodedata.LLMNodeData; +import com.alibaba.cloud.ai.model.workflow.nodedata.QuestionClassifierNodeData; +import com.alibaba.cloud.ai.service.dsl.NodeDataConverter; +import com.alibaba.cloud.ai.utils.StringTemplateUtil; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import org.apache.commons.collections4.CollectionUtils; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * @author HeYQ + * @since 2024-12-12 23:54 + */ + +public class QuestionClassifyNodeDataConverter implements NodeDataConverter { + + @Override + public Boolean supportType(String nodeType) { + return NodeType.QUESTION_CLASSIFIER.value().equals(nodeType); + } + + @Override + public QuestionClassifierNodeData parseDifyData(Map data) { + List inputs = Optional.ofNullable((List) data.get("query_variable_selector")) + .filter(CollectionUtils::isNotEmpty) + .map(variables -> Collections.singletonList(new VariableSelector(variables.get(0), variables.get(1)))) + .orElse(Collections.emptyList()); + + // convert model config + Map modelData = (Map) data.get("model"); + ObjectMapper objectMapper = new ObjectMapper(); + objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + objectMapper.setPropertyNamingStrategy(PropertyNamingStrategies.LOWER_CASE); + LLMNodeData.ModelConfig modelConfig = new LLMNodeData.ModelConfig().setMode((String) modelData.get("mode")) + .setName((String) modelData.get("name")) + .setProvider((String) modelData.get("provider")) + .setCompletionParams( + objectMapper.convertValue(modelData.get("completion_params"), LLMNodeData.CompletionParams.class)); + + QuestionClassifierNodeData nodeData = new QuestionClassifierNodeData(inputs, + List.of(QuestionClassifierNodeData.DEFAULT_OUTPUT_SCHEMA)) + .setModel(modelConfig); + + // covert instructions + String instruction = (String) data.get("instructions"); + if (instruction != null && !instruction.isBlank()) { + nodeData.setInstruction(instruction); + } + + // covert classes + if (data.containsKey("classes")) { + List> classes = (List>) data.get("classes"); + nodeData.setClasses(classes.stream() + .map(item -> new QuestionClassifierNodeData.ClassConfig((String) item.get("id"), + (String) item.get("text"))) + .toList()); + } + + // convert memory config + if (data.containsKey("memory")) { + Map memoryData = (Map) data.get("memory"); + String lastMessageTemplate = (String) memoryData.get("query_prompt_template"); + Map window = (Map) memoryData.get("window"); + Boolean windowEnabled = (Boolean) window.get("enabled"); + Integer windowSize = (Integer) window.get("size"); + LLMNodeData.MemoryConfig memory = new LLMNodeData.MemoryConfig().setWindowEnabled(windowEnabled) + .setWindowSize(windowSize) + .setLastMessageTemplate(lastMessageTemplate) + .setIncludeLastMessage(false); + nodeData.setMemoryConfig(memory); + } + + return nodeData; + } + + @Override + public Map dumpDifyData(QuestionClassifierNodeData nodeData) { + Map data = new HashMap<>(); + ObjectMapper objectMapper = new ObjectMapper(); + objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + objectMapper.setPropertyNamingStrategy(PropertyNamingStrategies.LOWER_CASE); + objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); + + // put memory + LLMNodeData.MemoryConfig memory = nodeData.getMemoryConfig(); + if (memory != null) { + data.put("memory", + Map.of("query_prompt_template", StringTemplateUtil.toDifyTmpl(memory.getLastMessageTemplate()), + "role_prefix", Map.of("assistant", "", "user", ""), "window", + Map.of("enabled", memory.getWindowEnabled(), "size", memory.getWindowSize()))); + } + + // put model + LLMNodeData.ModelConfig model = nodeData.getModel(); + data.put("model", Map.of("mode", model.getMode(), "name", model.getName(), "provider", model.getProvider(), + "completion_params", objectMapper.convertValue(model.getCompletionParams(), Map.class))); + + // put query_variable_selector + List inputs = nodeData.getInputs(); + Optional.ofNullable(inputs) + .filter(CollectionUtils::isNotEmpty) + .map(inputList -> inputList.stream() + .findFirst() + .map(input -> List.of(input.getNamespace(), input.getName())) + .orElse(Collections.emptyList())) + .ifPresent(variables -> data.put("query_variable_selector", variables)); + + // put instructions + data.put("instructions", + nodeData.getInstruction() != null ? nodeData.getInstruction() : ""); + + // put Classes + if (!CollectionUtils.isEmpty(nodeData.getClasses())) { + data.put("classes", + nodeData.getClasses() + .stream() + .map(item -> Map.of("id", item.getId(), "text", item.getText())) + .toList()); + } + + return data; + } + +}