diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ad41914..2c7a29b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,7 +25,8 @@ jobs: - name: Build and test with JDK ${{ matrix.java_version }} run: | mvn -B -U -T4C clean test \ - -Dmaven.javadoc.skip=true surefire-report:report + -Dmaven.javadoc.skip=true \ + -DskipITs=true surefire-report:report - name: Upload Surefire Reports with JDK ${{ matrix.java_version }} uses: actions/upload-artifact@v4 with: diff --git a/README.md b/README.md index 1e0ec9a..a6baf80 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# JavAI Workflow πŸ¦œπŸ”€: Build programmatically custom agentic workflows, AI Agents, RAG systems for java +# JavAI Workflow πŸ¦œπŸ”‚β˜•: Build programmatically custom agentic workflows, AI Agents, RAG systems for java [![Build Status](https://github.com/czelabueno/langchain4j-workflow/actions/workflows/ci.yaml/badge.svg)](https://github.com/czelabueno/langchain4j-workflow/actions/workflows/ci.yaml) [![Maven Central](https://maven-badges.herokuapp.com/maven-central/dev.langchain4j/langchain4j-workflow/badge.svg)](https://maven-badges.herokuapp.com/maven-central/dev.langchain4j/langchain4j-workflow) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) @@ -6,6 +6,20 @@ An open-source Java library to build, package, integrate, orchestrate and monitor agentic AI systems for java developers πŸ’‘ ![Workflow Image](docs/jai-worflow-anatomy.png) +
+ Node, Module, and Workflow Definitions + +### Node +A `Node` represents a single unit of work within the workflow. It encapsulates a specific function or task that processes the stateful bean and updates it. Nodes can be synchronous or asynchronous (streaming). + +### Module +A `Module` is a collection of nodes grouped together to perform a higher-level function. Modules can be reused across different workflows, providing modularity and reusability. + +### Workflow +A `Workflow` is a directed graph of nodes, modules and edges that defines the sequence of operations to be performed. It manages the state transitions and execution flow, ensuring that each node processes the stateful bean in the correct order. + +
+ > 🌟 **Starring me**: If you find this repository beneficial, don't forget to give it a star! 🌟 It's a simple way to show your appreciation and help this project grow! @@ -42,20 +56,23 @@ jAI Workflow is influenced by [LangFlow](https://github.com/langflow-ai/langflow - **Integration**: - Model Context Protocol (MCP) integration as server and client. - Define remote module as MCP server. -- **Observability**: - - OpenTelemetry integration (metrics and traces). - - Debugging mode logging structure. +- **API**: + - Publish workflow as API (SSE for streaming runs and REST for sync runs). ### πŸ—ΊοΈ Future Features - **Deployment Model**: - Dockerize workflow - Kubernetes deployment - Cloud deployment -- **API**: - - Publish workflow as API (SSE for streaming runs and REST for sync runs). +- **Observability**: + - OpenTelemetry integration (metrics and traces). + - Debugging mode logging structure. - **Playground**: - Web-based playground to add, test and run jAI workflows APIs. - Chatbot Q&A viewer. - Graph tracing visualization for debugging + - + ![jai-workflow-playground-prototype](docs/jai-workflow-playground.gif) + > _This is a prototype, the final version will be available soon. Open an issue if you want to share your ideas or contribute to this feature._ ## Architecture jAI Workflow is designed with a modular architecture, enabling you to define custom workflows, modules, or agents to build RAG systems as LEGO-like. A module can be decoupled and integrated into any other workflow. @@ -88,93 +105,105 @@ If you would want to use jAI workflow without LangChain4j or with other framewor 0.2.0 ``` -### Example +### langChain4j-workflow example Define a stateful bean with fields that will be used to store the state of the workflow: ```java // Define a stateful bean -public class MyStatefulBean { - int value = 0; +public class MyStatefulBean extends AbstractStatefulBean { + private List relevantDocuments; + private String webSearchResponse; + // other fields you need } ``` - -Create a simple workflow with 4 nodes and conditional edges: +Define functions that determines statefulBean state. To simplify this, you can use a java class with static methods: +```java +public class MyStatefulBeanFunctions { + public static MyStatefulBean searchWeb(MyStatefulBean statefulBean) { + // This is a simple example, you can use LangChain4j to search the web using any WebSearchEngine. + statefulBean.webSearchResponse = "Web search response"; + return statefulBean; + }; + public static MyStatefulBean extractRelevantDocuments(MyStatefulBean statefulBean, String... uris) { + // This is a simple example, you can use LangChain4j to extract relevant content of the URIs using any RAG pattern. + statefulBean.relevantDocuments = Arrays.asList("Relevant Content 1", "Relevant Content 2"); + return statefulBean; + }; + public static UserMessage generateUserMessageUsingPrompt(MyStatefulBean statefulBean) { + return UserMessage.from(answerPrompt(statefulBean).text()); + } + private static Prompt answerPrompt(MyStatefulBean statefulBean) { + String question = statefulBean.getQuestion(); + String context = String.join("\n\n", statefulBean.getRelevantDocuments()); + MyStructuredPrompt generateAnswerPrompt = new MyStructuredPrompt(question, context); + return StructuredPromptProcessor.toPrompt(generateAnswerPrompt); + } +} +``` +Create a simple workflow with 3 nodes and conditional edges: ```java public class Example { public static void main(String[] args) { MyStatefulBean myStatefulBean = new MyStatefulBean(); - - // Define functions that determines statefulBean state - Function node1Func = obj -> { - obj.value +=1; - System.out.println("Node 1: [" + obj.value + "]"); - return "Node1: function proceed"; - }; - Function node2Func = obj -> { - obj.value +=2; - System.out.println("Node 2: [" + obj.value + "]"); - return "Node2: function proceed"; + String[] documents = new String[]{ + "https://lilianweng.github.io/posts/2023-06-23-agent/", + "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/" }; - Function node3Func = obj -> { - obj.value +=3; - System.out.println("Node 3: [" + obj.value + "]"); - return "Node3: function proceed"; - }; - Function node4Func = obj -> { - obj.value +=4; - System.out.println("Node 4: [" + obj.value + "]"); - return "Node4: function proceed"; - }; - + + StreamingChatLanguageModel streamingModel = MistralAiStreamingChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .modelName(MistralAiChatModelName.MISTRAL_LARGE_LATEST) + .temperature(0.0) + .build(); + // Create the nodes and associate them with the functions to be used during execution. - Node node1 = Node.from("node1", node1Func); - Node node2 = Node.from("node2", node2Func); - Node node3 = Node.from("node3", node3Func); - Node node4 = Node.from("node4", node4Func); - + Node retrieveNode = Node.from( + "Retrieve Node", + obj -> MyStatefulBeanFunctions.extractRelevantDocuments(obj, documents)); + Node webSearchNode = Node.from( + "Web Searching Node", + obj -> MyStatefulBeanFunctions.searchWeb(obj)); + StreamingNode generateAnswerNode = StreamingNode.from( + "Generation Node", + obj -> MyStatefulBeanFunctions.generateUserMessageUsingPrompt(obj), + streamingModel); // Create workflow - StateWorkflow workflow = DefaultStateWorkflow.builder() - .statefulBean(myStatefulBean) - .addNodes(Arrays.asList(node1, node2, node3)) + DefaultJAiWorkflow workflow = DefaultJAiWorkflow.builder() + .statefulBean(statefulBean) + .runStream(true) + .nodes(Arrays.asList(retrieveNode, webSearchNode, generateAnswerNode)) .build(); + StateWorkflow stateWorkflow = workflow.workflow(); + // You can add more nodes after workflow build. E.g. node4 - workflow.addNode(node4); + stateWorkflow.addNode(node4); // Define edges - workflow.putEdge(node1, node2); - workflow.putEdge(node2, node3); + stateWorkflow.putEdge(retrieveNode, webSearchNode); // Conditional edge - workflow.putEdge(node3, Conditional.eval(obj -> { - System.out.println("Stateful Value [" + obj.value + "]"); - if (obj.value > 6) { - return node4; + stateWorkflow.putEdge(webSearchNode, Conditional.eval(obj -> { + if (obj.webSearchResponse != null) { + return generateAnswerNode; } else { - return node2; + return retrieveNode; } })); - workflow.putEdge(node4, WorkflowStateName.END); + stateWorkflow.putEdge(generateAnswerNode, WorkflowStateName.END); // Define which node to start - workflow.startNode(node1); + stateWorkflow.startNode(retrieveNode); - // Run workflow normally - workflow.run(); - // OR - // Run workflow in streaming mode - workflow.runStream(node -> { - System.out.println("Processing node: " + node.getName()); - }); + // Start conversation with the workflow in streaming mode + String question = "Summarizes the importance of building agents with LLMs"; + Flux tokens = workflow.answerStream(question); + tokens.subscribe(System.out::println); // Print all computed transitions - String transitions = workflow.prettyTransitions(); + String transitions = stateWorkflow.prettyTransitions(); System.out.println("Transitions: \n"); System.out.println(transitions); - - // Generate workflow image - workflow.generateWorkflowImage("image/my-workflow.svg"); - // workflow.generateWorkflowImage(); // if you use this method, it'll use by default the root path and default image name. } } ``` @@ -182,22 +211,64 @@ Now you can check the output of the workflow execution. ```shell STARTING workflow in stream mode.. -Processing node: node1 -Node 1: [1] -Processing node: node2 -Node 2: [3] -Processing node: node3 -Node 3: [6] -Stateful Value [6] -Processing node: node2 -Node 2: [8] -Processing node: node3 -Node 3: [11] -Stateful Value [11] -Processing node: node4 -Node 4: [15] +Processing node: Retrieve Node +Retrieve Node: processed +Processing node: Web Searching Node +Web Searching Node: processed +Processing node: Retrieve Node +Retrieve Node: processed +Processing node: Web Searching Node +Web Searching Node: processed +Processing node: Generation Node +Generation Node: processed Reached END state ``` +The LLM answer will be printed by tokens in the console: +```shell +Building +agen +ts with +LLMs +is +important +for three +key reasons. +Firstly, +LLMs serve as +a powerful +general problem +solver, +extending +their capabilities +beyond +just +generating text. +Secondly, they +act as +the brain + of an +autonomous +agent system, + enabling tasks +like planning +and task +decomposition. +Lastly, +proof-of-concept +demos like +AutoGPT +and +BabyAGI +showcase the + potential of +LLM-powered +agents in +handling + complex +tasks + efficiently. +``` + You can print all computed transitions: ```shell @@ -210,6 +281,8 @@ You can generate a workflow image with all computed transitions: ``` ![Workflow Image](jai-workflow-core/image/my-workflow.svg) +Check the full example in the [langchain4j-worflow tests](https://github.com/czelabueno/jai-workflow/blob/main/jai-workflow-core/src/test/java/com/github/czelabueno/jai/workflow/langchain4j/JAiWorkflowIT.java) + ## LLM examples You can check all examples in the [langchain4j-worflow-examples](https://github.com/czelabueno/langchain4j-workflow-examples) repository where show you how-to implement multiple RAG patterns, agent architectures and AI papers using LangChain4j and jAI Workflow. diff --git a/docs/jai-workflow-playground.gif b/docs/jai-workflow-playground.gif new file mode 100644 index 0000000..3ab1fe1 Binary files /dev/null and b/docs/jai-workflow-playground.gif differ diff --git a/jai-workflow-core/pom.xml b/jai-workflow-core/pom.xml index 33a2715..5a8f406 100644 --- a/jai-workflow-core/pom.xml +++ b/jai-workflow-core/pom.xml @@ -4,7 +4,7 @@ com.github.czelabueno jai-workflow-parent - ${revision} + 0.2.0 jai-workflow-core diff --git a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/DefaultStateWorkflow.java b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/DefaultStateWorkflow.java index 9492f4b..254131e 100644 --- a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/DefaultStateWorkflow.java +++ b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/DefaultStateWorkflow.java @@ -5,6 +5,7 @@ import com.github.czelabueno.jai.workflow.transition.Transition; import com.github.czelabueno.jai.workflow.graph.GraphImageGenerator; import com.github.czelabueno.jai.workflow.graph.graphviz.GraphvizImageGenerator; +import com.github.czelabueno.jai.workflow.transition.TransitionState; import lombok.Builder; import lombok.NonNull; import lombok.Singular; @@ -23,7 +24,7 @@ public class DefaultStateWorkflow implements StateWorkflow { private static final Logger log = LoggerFactory.getLogger(DefaultStateWorkflow.class); - private final Map, List> adjList; + private final Map, List> adjList; private volatile Node startNode; private final T statefulBean; private final List transitions; @@ -82,6 +83,22 @@ public void startNode(Node startNode){ this.startNode = startNode; } + @Override + public Node getLastNode() { + if (adjList.isEmpty() || adjList == null) + throw new IllegalStateException("No nodes added to the workflow"); + + return adjList.entrySet() + .stream() + .filter(entry -> entry.getValue().contains(WorkflowStateName.END)) + .>map(Map.Entry::getKey) + .findFirst() + .orElseGet(() -> adjList.keySet() + .>stream() + .reduce((first, second) -> second) + .orElseThrow()); + } + @Override public T run() { transitions.clear(); // clean previous transitions @@ -98,11 +115,11 @@ private void runNode(Node node) { synchronized (statefulBean){ node.execute(statefulBean); } - List nextNodes; + List nextNodes; synchronized (adjList) { nextNodes = adjList.get(node); } - for (Object nextNode : nextNodes) { + for (TransitionState nextNode : nextNodes) { if (nextNode instanceof WorkflowStateName) { WorkflowStateName next = (WorkflowStateName) nextNode; if (next == WorkflowStateName.END) { @@ -127,11 +144,11 @@ private void runNode(Node node) { public T runStream(Consumer> eventConsumer) { transitions.clear(); // clean previous transitions log.debug("STARTING workflow in stream mode.."); - Queue queue = new LinkedBlockingQueue<>(); + Queue queue = new LinkedBlockingQueue<>(); queue.add(startNode); transitions.add(Transition.from(WorkflowStateName.START, startNode)); while (!queue.isEmpty()) { - Object current = queue.poll(); + TransitionState current = queue.poll(); if (current instanceof Node) { Node currentNode = (Node) current; //eventConsumer.accept(currentNode); @@ -139,12 +156,12 @@ public T runStream(Consumer> eventConsumer) { currentNode.execute(statefulBean); } eventConsumer.accept(currentNode); - List nextNodes; + List nextNodes; synchronized (adjList) { nextNodes = adjList.get(currentNode); } if (nextNodes != null) { - for (Object next : nextNodes) { + for (TransitionState next : nextNodes) { if (next instanceof WorkflowStateName) { WorkflowStateName nextState = (WorkflowStateName) next; if (nextState == WorkflowStateName.END) { @@ -180,13 +197,13 @@ public String prettyTransitions() { StringBuilder sb = new StringBuilder(); Object lastTo = null; for (Transition transition : transitions) { - if (transition.getFrom().equals(lastTo)) { - sb.append(" -> ").append(transition.getTo() instanceof Node ? ((Node) transition.getTo()).getName() : transition.getTo().toString()); + if (transition.from().equals(lastTo)) { + sb.append(" -> ").append(transition.to() instanceof Node ? ((Node) transition.to()).getName() : transition.to().toString()); } else { if (sb.length() > 0) sb.append(" "); - sb.append(transition.getFrom() instanceof Node ? ((Node)transition.getFrom()).getName() : transition.getFrom().toString()).append(" -> ").append(transition.getTo() instanceof Node ? ((Node) transition.getTo()).getName() : transition.getTo().toString()); + sb.append(transition.from() instanceof Node ? ((Node)transition.from()).getName() : transition.from().toString()).append(" -> ").append(transition.to() instanceof Node ? ((Node) transition.to()).getName() : transition.to().toString()); } - lastTo = transition.getTo() instanceof Node ? (Node) transition.getTo() : transition.getTo(); + lastTo = transition.to() instanceof Node ? (Node) transition.to() : transition.to(); } return sb.toString(); } diff --git a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/StateWorkflow.java b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/StateWorkflow.java index f319fee..47013d0 100644 --- a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/StateWorkflow.java +++ b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/StateWorkflow.java @@ -8,25 +8,93 @@ import java.util.List; import java.util.function.Consumer; +/** + * Interface representing a state workflow. + * + * @param the type of the stateful bean used in the workflow + */ public interface StateWorkflow { + + /** + * Adds a node to the workflow. + * + * @param node the node to add + */ void addNode(Node node); + /** + * Creates an edge between two nodes in the workflow. + * + * @param from the starting node of the edge + * @param to the ending node of the edge + */ void putEdge(Node from, Node to); + /** + * Creates an edge between a node and a conditional node in the workflow. + * + * @param from the starting node of the edge + * @param conditional the conditional node to evaluate + */ void putEdge(Node from, Conditional conditional); + /** + * Creates an edge between a node and a workflow state in the workflow. + * + * @param from the starting node of the edge + * @param state the workflow state to transition to + */ void putEdge(Node from, WorkflowStateName state); + /** + * Sets the starting node of the workflow. + * + * @param startNode the starting node + */ void startNode(Node startNode); + /** + * Returns the last node defined in the workflow. + * + * @return the last node defined in the workflow + */ + Node getLastNode(); + + /** + * Runs the workflow synchronously. + * + * @return the stateful bean after the workflow execution + */ T run(); + /** + * Runs the workflow in stream mode, consuming events with the specified consumer. + * + * @param eventConsumer the consumer to process node events + * @return the stateful bean after the workflow execution + */ T runStream(Consumer> eventConsumer); + /** + * Returns the list of computed transitions in the workflow. + * + * @return the list of computed transitions + */ List getComputedTransitions(); + /** + * Generates an image of the workflow and saves it to the specified output path. + * + * @param outputPath the path to save the workflow image + * @throws IOException if an I/O error occurs + */ void generateWorkflowImage(String outputPath) throws IOException; + /** + * Generates an image of the workflow and saves it to the default path "workflow-image.svg". + * + * @throws IOException if an I/O error occurs + */ default void generateWorkflowImage() throws IOException { generateWorkflowImage("workflow-image.svg"); } diff --git a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/WorkflowStateName.java b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/WorkflowStateName.java index 52a3e30..98b3cf0 100644 --- a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/WorkflowStateName.java +++ b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/WorkflowStateName.java @@ -1,6 +1,20 @@ package com.github.czelabueno.jai.workflow; -public enum WorkflowStateName { +import com.github.czelabueno.jai.workflow.transition.TransitionState; + +/** + * Enum representing the possible states in a workflow. + *

+ * This class implements the {@link TransitionState} interface. + */ +public enum WorkflowStateName implements TransitionState { + /** + * The starting state of the workflow. + */ START, + + /** + * The ending state of the workflow. + */ END } diff --git a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/graph/GraphImageGenerator.java b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/graph/GraphImageGenerator.java index 8dc15cb..3456e20 100644 --- a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/graph/GraphImageGenerator.java +++ b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/graph/GraphImageGenerator.java @@ -6,10 +6,28 @@ import java.io.IOException; import java.util.List; +/** + * Interface for generating graph images from workflow transitions computed. + */ public interface GraphImageGenerator { + /** + * Generates a graph image from the given list of transitions and saves it to the default output path. + * + * @param transitions the list of transitions to generate the graph image from + * @throws IOException if an I/O error occurs during image generation + */ default void generateImage(List transitions) throws IOException { generateImage(transitions, "workflow-image.svg"); } + + /** + * Generates a graph image from the given list of transitions and saves it to the specified output path. + * + * @param transitions the list of transitions to generate the graph image from + * @param outputPath the path to save the generated graph image + * @throws IOException if an I/O error occurs during image generation + * @throws IllegalArgumentException if the output path is null or empty + */ void generateImage(List transitions, @NonNull String outputPath) throws IOException; } diff --git a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/graph/graphviz/GraphvizImageGenerator.java b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/graph/graphviz/GraphvizImageGenerator.java index 3e6e076..0c0e0cb 100644 --- a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/graph/graphviz/GraphvizImageGenerator.java +++ b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/graph/graphviz/GraphvizImageGenerator.java @@ -12,6 +12,9 @@ import java.io.IOException; import java.util.List; +/** + * Implementation of {@link GraphImageGenerator} that uses Graphviz java library and DOT language to generate workflow images. + */ public class GraphvizImageGenerator implements GraphImageGenerator { private static final Logger log = LoggerFactory.getLogger(GraphvizImageGenerator.class); @@ -23,10 +26,23 @@ private GraphvizImageGenerator(GraphvizImageGeneratorBuilder builder) { this.dotFormat = builder.dotFormat; } + /** + * Returns a new builder instance for creating a {@link GraphvizImageGenerator}. + * + * @return a new {@link GraphvizImageGeneratorBuilder} instance + */ public static GraphvizImageGeneratorBuilder builder() { return new GraphvizImageGeneratorBuilder(); } + /** + * Generates a graph image from the given list of transitions and saves it to the specified output path. + * + * @param transitions the list of transitions to generate the graph image from + * @param outputPath the path to save the generated graph image + * @throws IOException if an I/O error occurs during image generation + * @throws IllegalArgumentException if the output path is null or empty + */ @Override public void generateImage(List transitions, String outputPath) throws IOException { if (outputPath == null || outputPath.isEmpty()) { @@ -34,7 +50,7 @@ public void generateImage(List transitions, String outputPath) throw } // Generate image using Graphviz from dot format log.debug("Generating workflow image.."); - Graphviz.useEngine(new GraphvizJdkEngine()); + Graphviz.useEngine(new GraphvizJdkEngine()); // Use GraalJS as the default engine log.debug("Using default image format: " + DEFAULT_IMAGE_FORMAT); if (dotFormat == null) { if (transitions == null || transitions.isEmpty()) { @@ -50,19 +66,39 @@ public void generateImage(List transitions, String outputPath) throw log.debug("Workflow image saved to: " + outputPath); } + /** + * Builder class for {@link GraphvizImageGenerator}. + */ public static class GraphvizImageGeneratorBuilder { private String dotFormat; + /** + * Sets the dot format for the graph image. + * + * @param dotFormat the dot format string + * @return the current {@link GraphvizImageGeneratorBuilder} instance + */ public GraphvizImageGeneratorBuilder dotFormat(String dotFormat) { this.dotFormat = dotFormat; return this; } + /** + * Builds and returns a new {@link GraphvizImageGenerator} instance. + * + * @return a new {@link GraphvizImageGenerator} instance + */ public GraphvizImageGenerator build() { return new GraphvizImageGenerator(this); } } + /** + * Generates the default dot format string from the given list of transitions. + * + * @param transitions the list of transitions + * @return the generated dot format string + */ private String defaultDotFormat(List transitions) { StringBuilder sb = new StringBuilder(); sb.append("digraph workflow {").append(System.lineSeparator()); @@ -71,26 +107,26 @@ private String defaultDotFormat(List transitions) { sb.append(" ").append("beautify=true").append(System.lineSeparator()); sb.append(System.lineSeparator()); for (Transition transition : transitions) { - if (transition.getTo() instanceof Node) { + if (transition.to() instanceof Node) { sb.append(" ") // NodeFrom -> NodeTo - .append(transition.getFrom() instanceof Node ? - sanitizeNodeName(((Node) transition.getFrom()).getName()) : - transition.getFrom().toString().toLowerCase()) + .append(transition.from() instanceof Node ? + sanitizeNodeName(((Node) transition.from()).getName()) : + transition.from().toString().toLowerCase()) .append(" -> ") - .append(sanitizeNodeName(((Node) transition.getTo()).getName())).append(";") + .append(sanitizeNodeName(((Node) transition.to()).getName())).append(";") .append(System.lineSeparator()); - } else if (transition.getTo() == WorkflowStateName.END && transition.getFrom() instanceof Node) { + } else if (transition.to() == WorkflowStateName.END && transition.from() instanceof Node) { sb.append(" ") // NodeFrom -> END - .append(sanitizeNodeName(((Node) transition.getFrom()).getName())) + .append(sanitizeNodeName(((Node) transition.from()).getName())) .append(" -> ") - .append(((WorkflowStateName) transition.getTo()).toString().toLowerCase()).append(";") + .append(((WorkflowStateName) transition.to()).toString().toLowerCase()).append(";") .append(System.lineSeparator()) .append(System.lineSeparator()); } else { sb.append(" ") // NodeFrom -> NodeTo - .append(sanitizeNodeName(transition.getFrom().toString().toLowerCase())) + .append(sanitizeNodeName(transition.from().toString().toLowerCase())) .append(" -> ") - .append(sanitizeNodeName(transition.getTo().toString().toLowerCase())).append(";") + .append(sanitizeNodeName(transition.to().toString().toLowerCase())).append(";") .append(System.lineSeparator()); } } @@ -104,6 +140,12 @@ private String defaultDotFormat(List transitions) { return sb.toString(); } + /** + * Sanitizes the node name by removing special characters and converting it to camel case. + * + * @param nodeName the node name to sanitize + * @return the sanitized node name + */ private static String sanitizeNodeName(String nodeName) { // Remove special characters String sanitized = nodeName.replaceAll("[^a-zA-Z0-9 ]", ""); diff --git a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/node/Conditional.java b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/node/Conditional.java index 2d2ef25..90bc910 100644 --- a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/node/Conditional.java +++ b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/node/Conditional.java @@ -1,23 +1,51 @@ package com.github.czelabueno.jai.workflow.node; +import com.github.czelabueno.jai.workflow.transition.TransitionState; import lombok.NonNull; import java.util.Objects; import java.util.function.Function; -public class Conditional { +/** + * Represents a conditional node in a workflow that evaluates a condition function. + *

+ * Implements the {@link TransitionState} interface. + * + * @param the stateful bean POJO defined by the user. It is used to store the state of the workflow. + */ +public class Conditional implements TransitionState { private final Function> condition; + /** + * Constructs a Conditional with the specified condition function. + * + * @param condition the condition function to evaluate + * @throws NullPointerException if the condition function is null + */ public Conditional(@NonNull Function> condition) { this.condition = Objects.requireNonNull(condition, "Condition function cannot be null"); } + /** + * Evaluates the condition function with the given stateful bean. + * + * @param input the stateful bean as input to the condition function + * @return the resulting Node from the condition function + * @throws NullPointerException if the input is null + */ public Node evaluate(T input) { Objects.requireNonNull(input, "Function Input cannot be null"); return condition.apply(input); } + /** + * Creates a new Conditional with the specified condition function. + * + * @param condition the condition function to evaluate + * @param the stateful bean as input to the condition function + * @return a new Conditional instance + */ public static Conditional eval(Function> condition) { return new Conditional<>(condition); } diff --git a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/node/Node.java b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/node/Node.java index 55c103d..1266229 100644 --- a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/node/Node.java +++ b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/node/Node.java @@ -1,12 +1,21 @@ package com.github.czelabueno.jai.workflow.node; +import com.github.czelabueno.jai.workflow.transition.TransitionState; import lombok.Getter; import lombok.NonNull; import java.util.Objects; import java.util.function.Function; -public class Node { +/** + * Represents a node in a workflow that executes a function with a given input and produces an output. + *

+ * This class implements the {@link TransitionState} interface. + * + * @param the type of the input to the function. Normally a stateful bean POJO defined by the user. + * @param the type of the output from the function. Normally a stateful bean POJO defined by the user. + */ +public class Node implements TransitionState { @Getter private final String name; @@ -16,6 +25,14 @@ public class Node { @Getter private R functionOutput; + /** + * Constructs a Node with the specified name and function. + * + * @param name the name of the node + * @param function the function to execute + * @throws IllegalArgumentException if the node name is empty + * @throws NullPointerException if the name or function is null + */ public Node(@NonNull String name, @NonNull Function function) { if (name.trim().isEmpty()) { throw new IllegalArgumentException("Node name cannot be empty"); @@ -24,6 +41,13 @@ public Node(@NonNull String name, @NonNull Function function) { this.function = function; } + /** + * Executes the function with the given input and stores the input and output. + * + * @param input the input to the function + * @return the output from the function + * @throws IllegalArgumentException if the input is null + */ public R execute(T input) { if (input == null) { throw new IllegalArgumentException("Function input cannot be null"); @@ -34,6 +58,15 @@ public R execute(T input) { return output; } + /** + * Creates a new Node with the specified name and function. + * + * @param name the name of the node + * @param function the function to execute + * @param the type of the input to the function + * @param the type of the output from the function + * @return a new Node instance + */ public static Node from(String name, Function function) { return new Node<>(name, function); } diff --git a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/transition/Transition.java b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/transition/Transition.java index 7787176..955b790 100644 --- a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/transition/Transition.java +++ b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/transition/Transition.java @@ -2,15 +2,25 @@ import com.github.czelabueno.jai.workflow.node.Node; import com.github.czelabueno.jai.workflow.WorkflowStateName; -import lombok.Getter; import lombok.NonNull; -@Getter -public class Transition { - private final Object from; // // Can be Node or WorflowState - private final Object to; // Can be Node or WorflowState +/** + * Represents a transition between two states in a workflow. + * The states can be instances of {@link Node} or {@link WorkflowStateName}. + */ +public record Transition(TransitionState from, TransitionState to) { - public Transition(@NonNull Object from, @NonNull Object to) { + /** + * Constructs a Transition with the specified from and to states. + * + * @param from the starting state of the transition, must be an instance of {@link Node} or {@link WorkflowStateName} + * @param to the ending state of the transition, must be an instance of {@link Node} or {@link WorkflowStateName} + * @throws IllegalArgumentException if the from state is {@link WorkflowStateName#END}, + * if the to state is {@link WorkflowStateName#START}, + * or if the transition is from {@link WorkflowStateName#START} to {@link WorkflowStateName#END} + * @throws NullPointerException if the from or to state is null + */ + public Transition(@NonNull TransitionState from, @NonNull TransitionState to) { if (from == WorkflowStateName.END) { throw new IllegalArgumentException("Cannot transition from an END state"); } @@ -24,10 +34,22 @@ public Transition(@NonNull Object from, @NonNull Object to) { this.to = to; } - public static Transition from(Object from, Object to) { + /** + * Creates a new Transition with the specified from and to states. + * + * @param from the starting state of the transition, must be an instance of {@link Node} or {@link WorkflowStateName} + * @param to the ending state of the transition, must be an instance of {@link Node} or {@link WorkflowStateName} + * @return a new Transition instance + */ + public static Transition from(TransitionState from, TransitionState to) { return new Transition(from, to); } + /** + * Returns a string representation of the transition. + * + * @return a string representation of the transition in the format "from -> to" + */ @Override public String toString() { String transition = ""; diff --git a/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/transition/TransitionState.java b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/transition/TransitionState.java new file mode 100644 index 0000000..6ebd35c --- /dev/null +++ b/jai-workflow-core/src/main/java/com/github/czelabueno/jai/workflow/transition/TransitionState.java @@ -0,0 +1,10 @@ +package com.github.czelabueno.jai.workflow.transition; + +/** + * Marker interface representing a state in a workflow transition. + *

+ * Classes implementing this interface can be used as states in a {@link Transition}. + *

+ */ +public interface TransitionState{ +} diff --git a/jai-workflow-core/src/test/java/com/github/czelabueno/jai/workflow/StateWorkflowTest.java b/jai-workflow-core/src/test/java/com/github/czelabueno/jai/workflow/StateWorkflowTest.java index 3c63365..fc3fd0c 100644 --- a/jai-workflow-core/src/test/java/com/github/czelabueno/jai/workflow/StateWorkflowTest.java +++ b/jai-workflow-core/src/test/java/com/github/czelabueno/jai/workflow/StateWorkflowTest.java @@ -4,7 +4,6 @@ import com.github.czelabueno.jai.workflow.node.Node; import lombok.SneakyThrows; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import java.io.IOException; diff --git a/jai-workflow-core/src/test/java/com/github/czelabueno/jai/workflow/graph/graphviz/GraphvizImageGeneratorTest.java b/jai-workflow-core/src/test/java/com/github/czelabueno/jai/workflow/graph/graphviz/GraphvizImageGeneratorTest.java index f7d706c..b826fe7 100644 --- a/jai-workflow-core/src/test/java/com/github/czelabueno/jai/workflow/graph/graphviz/GraphvizImageGeneratorTest.java +++ b/jai-workflow-core/src/test/java/com/github/czelabueno/jai/workflow/graph/graphviz/GraphvizImageGeneratorTest.java @@ -1,5 +1,6 @@ package com.github.czelabueno.jai.workflow.graph.graphviz; +import com.github.czelabueno.jai.workflow.node.Node; import com.github.czelabueno.jai.workflow.transition.Transition; import lombok.SneakyThrows; import org.junit.jupiter.api.BeforeEach; @@ -36,7 +37,7 @@ void test_builder_and_doFormat() { } @Test - void test_generateImage_invalid_transitions_and_doFormat() { + void test_generateImage_invalid_transitions() { // given List transitions = Collections.EMPTY_LIST; // when @@ -51,7 +52,12 @@ void test_generateImage_invalid_transitions_and_doFormat() { @Test void test_generate_Image_invalid_outputPath() { // given - List transitions = Arrays.asList(new Transition("a", "b")); + List transitions = Arrays.asList( + Transition.from( + Node.from("a", s -> s + "1"), + Node.from("b", s -> s + "2") + ) + ); // when GraphvizImageGenerator generator = builder.build(); // then @@ -93,7 +99,12 @@ void test_generate_Image_with_custom_outputPath() { @Test void test_generate_Image_with_transition_and_default_dotFormat() { // given - List transitions = Arrays.asList(new Transition("a", "b")); + List transitions = Arrays.asList( + Transition.from( + Node.from("a", s -> s + "1"), + Node.from("b", s -> s + "2") + ) + ); GraphvizImageGenerator generator = builder.build(); // when assertThat(generator).isNotNull(); @@ -107,7 +118,12 @@ void test_generate_Image_with_transition_and_default_dotFormat() { @Test void test_generate_Image_is_SVG_format() { // given - List transitions = Arrays.asList(new Transition("a", "b")); + List transitions = Arrays.asList( + Transition.from( + Node.from("a", s -> s + "1"), + Node.from("b", s -> s + "2") + ) + ); GraphvizImageGenerator generator = builder.build(); // when assertThat(generator).isNotNull(); diff --git a/jai-workflow-core/src/test/java/com/github/czelabueno/jai/workflow/transition/TransitionTest.java b/jai-workflow-core/src/test/java/com/github/czelabueno/jai/workflow/transition/TransitionTest.java index 7ae60e8..ec53d21 100644 --- a/jai-workflow-core/src/test/java/com/github/czelabueno/jai/workflow/transition/TransitionTest.java +++ b/jai-workflow-core/src/test/java/com/github/czelabueno/jai/workflow/transition/TransitionTest.java @@ -4,6 +4,9 @@ import com.github.czelabueno.jai.workflow.node.Node; import org.junit.jupiter.api.Test; +import java.util.Arrays; +import java.util.List; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -11,10 +14,13 @@ class TransitionTest { @Test void should_build_transition_using_from() { - Transition transition = Transition.from("from", "to"); + Transition transition = Transition.from( + Node.from("from", s -> s + "1"), + Node.from("to", s -> s + "2") + ); - assertThat(transition.getFrom()).isEqualTo("from"); - assertThat(transition.getTo()).isEqualTo("to"); + assertThat(((Node)transition.from()).getName()).isEqualTo("from"); + assertThat(((Node)transition.to()).getName()).isEqualTo("to"); assertThat(transition).hasToString("from -> to"); } @@ -27,8 +33,8 @@ void should_build_transition_using_nodes() { // when Transition transition = Transition.from(from, to); // then - assertThat(transition.getFrom()).isEqualTo(from); - assertThat(transition.getTo()).isEqualTo(to); + assertThat(transition.from()).isEqualTo(from); + assertThat(transition.to()).isEqualTo(to); assertThat(transition).hasToString("node1 -> node2"); } @@ -41,8 +47,8 @@ void should_build_transition_using_node_and_workflowState() { // when Transition transition = Transition.from(from, to); // then - assertThat(transition.getFrom()).isEqualTo(from); - assertThat(transition.getTo()).isEqualTo(to); + assertThat(transition.from()).isEqualTo(from); + assertThat(transition.to()).isEqualTo(to); assertThat(transition).hasToString("node1 -> END"); } @@ -55,8 +61,8 @@ void should_build_transition_using_workflowState_and_node() { // when Transition transition = Transition.from(from, to); // then - assertThat(transition.getFrom()).isEqualTo(from); - assertThat(transition.getTo()).isEqualTo(to); + assertThat(transition.from()).isEqualTo(from); + assertThat(transition.to()).isEqualTo(to); assertThat(transition).hasToString("START -> node2"); } @@ -64,14 +70,14 @@ void should_build_transition_using_workflowState_and_node() { @Test void should_throw_illegalArgumentException_when_transition_from_END() { assertThatExceptionOfType(IllegalArgumentException.class) - .isThrownBy(() -> Transition.from(WorkflowStateName.END, "to")) + .isThrownBy(() -> Transition.from(WorkflowStateName.END, Node.from("to", s -> s + "1"))) .withMessage("Cannot transition from an END state"); } @Test void should_throw_illegalArgumentException_when_transition_to_START() { assertThatExceptionOfType(IllegalArgumentException.class) - .isThrownBy(() -> Transition.from("from", WorkflowStateName.START)) + .isThrownBy(() -> Transition.from(Node.from("from", s -> s + "2"), WorkflowStateName.START)) .withMessage("Cannot transition to a START state"); } diff --git a/langchain4j-workflow/pom.xml b/langchain4j-workflow/pom.xml index 3aaa043..760a761 100644 --- a/langchain4j-workflow/pom.xml +++ b/langchain4j-workflow/pom.xml @@ -4,12 +4,16 @@ com.github.czelabueno jai-workflow-parent - ${revision} + 0.2.0 langchain4j-workflow JavAI Workflow :: LangChain4j + + 1.0.0-alpha1 + + com.github.czelabueno @@ -20,6 +24,19 @@ dev.langchain4j langchain4j + ${langchain4j.version} + + + + dev.langchain4j + langchain4j-reactor + ${langchain4j.version} + + + + dev.langchain4j + langchain4j-document-transformer-jsoup + ${langchain4j.version} @@ -31,5 +48,40 @@ org.projectlombok lombok + + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test + + + + org.mockito + mockito-junit-jupiter + test + + + + io.projectreactor + reactor-test + 3.7.0 + test + + + + org.assertj + assertj-core + ${assertj.version} + test + + + + dev.langchain4j + langchain4j-mistral-ai + ${langchain4j.version} + test + diff --git a/langchain4j-workflow/src/main/java/com/github/czelabueno/jai/workflow/langchain4j/AbstractStatefulBean.java b/langchain4j-workflow/src/main/java/com/github/czelabueno/jai/workflow/langchain4j/AbstractStatefulBean.java index 917e6ee..ef6f4d0 100644 --- a/langchain4j-workflow/src/main/java/com/github/czelabueno/jai/workflow/langchain4j/AbstractStatefulBean.java +++ b/langchain4j-workflow/src/main/java/com/github/czelabueno/jai/workflow/langchain4j/AbstractStatefulBean.java @@ -1,19 +1,19 @@ package com.github.czelabueno.jai.workflow.langchain4j; import lombok.Data; +import reactor.core.publisher.Flux; /** * AbstractStatefulBean is an abstract class that represents a stateful bean which is responsible for holding the state of the workflow. * The state is a combination of a question, input data, output data and a response generation. * Every execution of the workflow initiates a state, which is then transferred among the nodes during their execution. *
- * Example: + * Here is the simplest example of a stateful bean: *
- * 
  * public class MyStatefulBean extends AbstractStatefulBean {
- *     // my input/output data fields
- *   }
- * 
+ *     private List documents;
+ *     // other additional input/output fields that you want to store
+ * }
  * 
*/ @Data @@ -21,4 +21,5 @@ public abstract class AbstractStatefulBean { private String question; private String generation; + private Flux generationStream; } diff --git a/langchain4j-workflow/src/main/java/com/github/czelabueno/jai/workflow/langchain4j/JAiWorkflow.java b/langchain4j-workflow/src/main/java/com/github/czelabueno/jai/workflow/langchain4j/JAiWorkflow.java index bd8f146..57a856f 100644 --- a/langchain4j-workflow/src/main/java/com/github/czelabueno/jai/workflow/langchain4j/JAiWorkflow.java +++ b/langchain4j-workflow/src/main/java/com/github/czelabueno/jai/workflow/langchain4j/JAiWorkflow.java @@ -2,23 +2,55 @@ import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; - -import java.util.List; +import reactor.core.publisher.Flux; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +/** + * The {@link JAiWorkflow} interface defines the entry-point contract for a workflow that processes user messages + * and generates AI responses. It provides basic methods for synchronous and asynchronous (streaming) responses. + */ public interface JAiWorkflow { + /** + * Generates an AI response to the given question. + * This method ensures that the question is not null before processing. + * + * @param question the question to be answered + * @return the AI response as a string + * @throws IllegalArgumentException if the question is null + */ default String answer(String question){ ensureNotNull(question, "question"); return answer(new UserMessage(question)).text(); } + + /** + * Generates an AI response to the given user message. + * + * @param question the UserMessage containing the question + * @return the AI response as an AiMessage + */ AiMessage answer(UserMessage question); - default List answerStream(String question){ + /** + * Generates a streaming AI response to the given question. + * This method ensures that the question is not null before processing. + * + * @param question the question to be answered + * @return a Flux stream of the AI response tokens + * @throws IllegalArgumentException if the question is null + */ + default Flux answerStream(String question){ ensureNotNull(question, "question"); return answerStream(new UserMessage(question)); } - List answerStream(UserMessage question); + /** + * Generates a streaming AI response to the given user message. + * + * @param question the UserMessage containing the question + * @return a Flux stream of the AI response tokens + */ + Flux answerStream(UserMessage question); } diff --git a/langchain4j-workflow/src/main/java/com/github/czelabueno/jai/workflow/langchain4j/internal/DefaultJAiWorkflow.java b/langchain4j-workflow/src/main/java/com/github/czelabueno/jai/workflow/langchain4j/internal/DefaultJAiWorkflow.java index 8dbc04a..dbce714 100644 --- a/langchain4j-workflow/src/main/java/com/github/czelabueno/jai/workflow/langchain4j/internal/DefaultJAiWorkflow.java +++ b/langchain4j-workflow/src/main/java/com/github/czelabueno/jai/workflow/langchain4j/internal/DefaultJAiWorkflow.java @@ -4,12 +4,14 @@ import com.github.czelabueno.jai.workflow.StateWorkflow; import com.github.czelabueno.jai.workflow.langchain4j.AbstractStatefulBean; import com.github.czelabueno.jai.workflow.langchain4j.JAiWorkflow; +import com.github.czelabueno.jai.workflow.langchain4j.node.StreamingNode; import com.github.czelabueno.jai.workflow.node.Node; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; import lombok.Builder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; import java.io.IOException; import java.nio.file.Path; @@ -18,7 +20,13 @@ import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; -class DefaultJAiWorkflow implements JAiWorkflow { +/** + * DefaultJAiWorkflow is a default implementation of the JAiWorkflow interface. + * It defines the workflow for processing user messages and generating AI responses. + * + * @param the type of the stateful bean, which extends AbstractStatefulBean + */ +public class DefaultJAiWorkflow implements JAiWorkflow { private static final Logger log = LoggerFactory.getLogger(DefaultJAiWorkflow.class); @@ -26,31 +34,46 @@ class DefaultJAiWorkflow implements JAiWorkflow private final Boolean generateWorkflowImage; private final Path workflowImageOutputPath; private final T statefulBean; - private final List> nodes; private DefaultStateWorkflow workflow; + /** + * Constructs a new DefaultJAiWorkflow with the specified parameters. + * + * @param statefulBean the stateful bean holding the state of the workflow + * @param nodes the list of nodes to be processed in the workflow + * @param runStream flag indicating whether to run the workflow in stream mode + * @param generateWorkflowImage flag indicating whether to generate a workflow image + * @param workflowImageOutputPath the output path for the workflow image + */ @Builder public DefaultJAiWorkflow(T statefulBean, - List> nodes, + List> nodes, Boolean runStream, Boolean generateWorkflowImage, Path workflowImageOutputPath) { + this.statefulBean = ensureNotNull(statefulBean, "%s cannot be null. jAI workflow cannot created without stateful bean definition", "statefulBean"); + ensureNotNull(nodes, "%s cannot be null. jAI workflow cannot created without nodes definition", "nodes"); + this.workflow = createWorkflow(statefulBean, nodes); this.runStream = getOrDefault(runStream, false); // check if workflowOutputPath is valid this.generateWorkflowImage = workflowImageOutputPath != null || getOrDefault(generateWorkflowImage, false); this.workflowImageOutputPath = workflowImageOutputPath; - this.statefulBean = ensureNotNull(statefulBean, "statefulBean"); - this.nodes = ensureNotNull(nodes, "%s cannot be null. jAI workflow cannot created without nodes definition", "nodes"); - this.workflow = createWorkflow(statefulBean); } + /** + * Returns the current workflow. + * + * @return the current workflow + */ public StateWorkflow workflow() { - if (workflow == null) { - workflow = createWorkflow(statefulBean); - } - return workflow; + return this.workflow; } + /** + * Sets the workflow to the specified workflow. + * + * @param workflow the workflow to be set + */ public void setWorkflow(DefaultStateWorkflow workflow) { this.workflow = workflow; } @@ -58,17 +81,53 @@ public void setWorkflow(DefaultStateWorkflow workflow) { @Override public AiMessage answer(UserMessage question) { // Define a stateful bean - this.statefulBean.setQuestion(question.text()); + this.statefulBean.setQuestion(question.singleText()); + // Run workflow in stream mode or not + if (this.runStream) { + workflow().runStream(node -> log.debug("Node processed: " + node.getName())); + } else { + workflow().run(); + } + generateWorkflowImageIfNeeded(); + return AiMessage.from(this.statefulBean.getGeneration()); + } + + @Override + public Flux answerStream(UserMessage question) { + if (!runStream || !isLastNodeAStreamingNode(workflow())) { + throw new IllegalStateException("The last node of the workflow must be a StreamingNode to run in stream mode"); + } + // Define a stateful bean + this.statefulBean.setQuestion(question.singleText()); // Run workflow in stream mode or not if (this.runStream) { workflow().runStream(node -> { + if (node instanceof StreamingNode) { + log.debug("StreamingNode processed: " + node.getName()); + } log.debug("Node processed: " + node.getName()); }); - } else { - workflow().run(); } + generateWorkflowImageIfNeeded(); + return this.statefulBean.getGenerationStream(); + } + + private DefaultStateWorkflow createWorkflow( + T statefulBean, + List> nodes) { + return DefaultStateWorkflow.builder() + .statefulBean(statefulBean) + .addNodes(nodes) + .build(); + } + + private Boolean isLastNodeAStreamingNode(StateWorkflow workflow) { + return workflow.getLastNode() instanceof StreamingNode; + } + + private void generateWorkflowImageIfNeeded() { // Generate workflow image if required - if (this.generateWorkflowImage) { + if (generateWorkflowImage) { try { if (workflowImageOutputPath != null) { workflow().generateWorkflowImage(workflowImageOutputPath.toAbsolutePath().toString()); @@ -79,18 +138,5 @@ public AiMessage answer(UserMessage question) { log.error("Error generating workflow image", e); } } - return AiMessage.from(this.statefulBean.getGeneration()); - } - - @Override - public List answerStream(UserMessage question) { - return null; // TODO: Implement streaming response and condition last node execution has a StreamingChatCompletion node - } - - private DefaultStateWorkflow createWorkflow(T statefulBean) { - return DefaultStateWorkflow.builder() - .statefulBean(statefulBean) - .addNodes(nodes) - .build(); } } diff --git a/langchain4j-workflow/src/main/java/com/github/czelabueno/jai/workflow/langchain4j/node/StreamingNode.java b/langchain4j-workflow/src/main/java/com/github/czelabueno/jai/workflow/langchain4j/node/StreamingNode.java new file mode 100644 index 0000000..babf701 --- /dev/null +++ b/langchain4j-workflow/src/main/java/com/github/czelabueno/jai/workflow/langchain4j/node/StreamingNode.java @@ -0,0 +1,154 @@ +package com.github.czelabueno.jai.workflow.langchain4j.node; + +import com.github.czelabueno.jai.workflow.langchain4j.AbstractStatefulBean; +import com.github.czelabueno.jai.workflow.node.Node; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.output.Response; +import lombok.NonNull; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Sinks; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; + +import static dev.langchain4j.internal.Utils.getOrDefault; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; + +/** + * StreamingNode is a specialized type of {@link Node} that handles streaming responses from a {@link StreamingChatLanguageModel}. + * It extends the generic Node class with specific types for stateful beans and reactive streams. + * + * @param the type of the stateful bean, which extends AbstractStatefulBean + */ +public class StreamingNode extends Node> { + + /** + * Constructs a new StreamingNode with the specified name, messages, and StreamingChatLanguageModel. + * + * @param name the name of the node + * @param messages the list of ChatMessage to be processed by the streamingChatLanguageModel + * @param doUserMessage a function to generate a user message from the stateful bean + * @param streamingChatLanguageModel the streaming chat language model to generate responses + */ + public StreamingNode(String name, + List messages, + Function doUserMessage, + @NonNull StreamingChatLanguageModel streamingChatLanguageModel) { + super(ensureNotBlank(name, "name"), (T statefulBean) -> streamingFunction(statefulBean, messages, doUserMessage, streamingChatLanguageModel)); + } + + /** + * Creates a new StreamingNode from the specified parameters. + * + * @param name the name of the node + * @param messages the list of ChatMessage to be processed by the streamingChatLanguageModel + * @param doUserMessage a function to generate a user message from the stateful bean + * @param streamingChatLanguageModel the streaming chat language model to generate responses + * @param the type of the stateful bean, which extends AbstractStatefulBean + * @return a new StreamingNode instance + */ + public static StreamingNode from(String name, + List messages, + Function doUserMessage, + @NonNull StreamingChatLanguageModel streamingChatLanguageModel) { + return new StreamingNode(name, messages, doUserMessage, streamingChatLanguageModel); + } + + /** + * Creates a new StreamingNode from the specified parameters. + * + * @param name the name of the node + * @param doUserMessage a function to generate a user message from the stateful bean + * @param streamingChatLanguageModel the streaming chat language model to generate responses + * @param the type of the stateful bean, which extends AbstractStatefulBean + * @return a new StreamingNode instance + */ + public static StreamingNode from(String name, + Function doUserMessage, + @NonNull StreamingChatLanguageModel streamingChatLanguageModel) { + return from(name, null, doUserMessage, streamingChatLanguageModel); + } + + /** + * Creates a new StreamingNode from the specified parameters. + * + * @param name the name of the node + * @param messages the list of ChatMessage to be processed by the streamingChatLanguageModel + * @param streamingChatLanguageModel the streaming chat language model to generate responses + * @param the type of the stateful bean, which extends AbstractStatefulBean + * @return a new StreamingNode instance + */ + public static StreamingNode from(String name, + List messages, + @NonNull StreamingChatLanguageModel streamingChatLanguageModel) { + return from(name, messages, null, streamingChatLanguageModel); + } + + /** + * Creates a new StreamingNode from the specified parameters. + * + * @param name the name of the node + * @param streamingChatLanguageModel the streaming chat language model to generate responses + * @param the type of the stateful bean, which extends AbstractStatefulBean + * @return a new StreamingNode instance + */ + public static StreamingNode from(String name, + @NonNull StreamingChatLanguageModel streamingChatLanguageModel) { + return from(name, null, null, streamingChatLanguageModel); + } + + /** + * A static function that handles the token of responses from the StreamingChatLanguageModel. + * It sets up a sink to collect the streamed tokens and completes the stateful bean with the final response. + * + * @param statefulBean the stateful bean holding the state of the workflow + * @param messages the list of ChatMessage to be processed by the streamingChatLanguageModel + * @param doUserMessage a function to generate a user message from the stateful bean + * @param streamingChatLanguageModel the streaming chat language model to generate responses + * @param the type of the stateful bean, which extends AbstractStatefulBean + * @return a Flux stream of the generated tokens + */ + private static Flux streamingFunction( + T statefulBean, + List messages, + Function doUserMessage, + StreamingChatLanguageModel streamingChatLanguageModel) { + Sinks.Many sink = Sinks.many().unicast().onBackpressureBuffer(); + CompletableFuture futureResponse = new CompletableFuture<>(); + if (messages == null || messages.isEmpty()) { + messages = doUserMessage != null ? + List.of(doUserMessage.apply(statefulBean)) : + List.of(UserMessage.from(getOrDefault(statefulBean.getQuestion(),"No question provided."))); + } + + streamingChatLanguageModel.generate( + messages, + new StreamingResponseHandler() { + @Override + public void onNext(String token) { + sink.tryEmitNext(token); + } + + @Override + public void onComplete(Response response) { + futureResponse.complete(response.content()); + sink.tryEmitComplete(); + } + + @Override + public void onError(Throwable throwable) { + sink.tryEmitError(throwable); + } + } + ); + statefulBean.setGenerationStream(sink.asFlux().cache()); + statefulBean.setGeneration(futureResponse.join().text()); + return statefulBean.getGenerationStream(); + } +} diff --git a/langchain4j-workflow/src/test/java/JAiWorkflowIT.java b/langchain4j-workflow/src/test/java/JAiWorkflowIT.java new file mode 100644 index 0000000..8e859f6 --- /dev/null +++ b/langchain4j-workflow/src/test/java/JAiWorkflowIT.java @@ -0,0 +1,146 @@ +import com.github.czelabueno.jai.workflow.StateWorkflow; +import com.github.czelabueno.jai.workflow.WorkflowStateName; +import com.github.czelabueno.jai.workflow.langchain4j.internal.DefaultJAiWorkflow; +import com.github.czelabueno.jai.workflow.langchain4j.node.StreamingNode; +import com.github.czelabueno.jai.workflow.node.Node; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.mistralai.MistralAiChatModel; +import dev.langchain4j.model.mistralai.MistralAiChatModelName; +import dev.langchain4j.model.mistralai.MistralAiStreamingChatModel; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; +import com.github.czelabueno.jai.workflow.langchain4j.workflow.NodeFunctionsMock; +import com.github.czelabueno.jai.workflow.langchain4j.workflow.StatefulBeanMock; + +import java.util.Arrays; +import java.util.List; + +import static java.util.stream.Collectors.joining; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +// JAiWorkflowIT is an integration test class that demonstrates how to use JAiWorkflow with LangChain4j to build agentic systems and orchestrated AI workflows. +// The workflow tested in this class is a simple example that retrieves documents, grades them, and generates a summary of the documents using the Mistral AI API. +// +// Workflow definition: +// START -> Retrieve Node -> Grade Documents Node -> Generate Node -> END +// +// The setUp method initializes the JAiWorkflow and JAiWorkflowStreaming objects with the MistralAiChatModel and MistralAiStreamingChatModel classes, respectively. +// These models are used to generate AI responses in both synchronous and streaming modes. +// +// The should_answer_question method tests the synchronous answer method of the JAiWorkflow class by providing a question and checking if the answer contains the expected text. +// The should_answer_stream_question method tests the streaming answerStream method of the JAiWorkflow class by providing a question and checking if the answer contains the expected tokens. +// +// This integration test class showcases how JAiWorkflow and LangChain4j can be combined to create complex AI-driven workflows that can process and generate information in a structured manner. +class JAiWorkflowIT { + + String[] documents = new String[]{ + "https://lilianweng.github.io/posts/2023-06-23-agent/", + "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/", + "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/" + }; + + ChatLanguageModel model = MistralAiChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .modelName(MistralAiChatModelName.MISTRAL_LARGE_LATEST) + .temperature(0.0) + .build(); + + StreamingChatLanguageModel streamingModel = MistralAiStreamingChatModel.builder() + .apiKey(System.getenv("MISTRAL_AI_API_KEY")) + .modelName(MistralAiChatModelName.MISTRAL_LARGE_LATEST) + .temperature(0.0) + .build(); + + DefaultJAiWorkflow jAiWorkflow; + DefaultJAiWorkflow jAiWorkflowStreaming; + + @BeforeEach() + void setUp() { + // Define a stateful bean to store the state of the workflow + StatefulBeanMock statefulBean = new StatefulBeanMock(); + + // Define nodes with your custom functions + Node retrieveNode = Node.from("Retrieve Node", obj -> NodeFunctionsMock.retrieve(obj, documents)); + Node gradeDocumentsNode = Node.from("Grade Documents Node", obj -> NodeFunctionsMock.gradeDocuments(obj)); + Node generateNode = Node.from("Generate Node", obj -> NodeFunctionsMock.generate(obj, model)); + StreamingNode generateStreamingNode = StreamingNode.from( + "Generate Node", + obj -> NodeFunctionsMock.generateUserMessageFromStatefulBean(obj), + streamingModel); + + // Build workflows of the synchronous and streaming ways + jAiWorkflow = buildWorkflow(statefulBean, false, retrieveNode, gradeDocumentsNode, generateNode); + jAiWorkflowStreaming = buildWorkflow(statefulBean, true, retrieveNode, gradeDocumentsNode, generateStreamingNode); + // Define your workflow transitions using edges and the entry point of the workflow + StateWorkflow workflow = jAiWorkflow.workflow(); + workflow.putEdge(retrieveNode, gradeDocumentsNode); + workflow.putEdge(gradeDocumentsNode, generateNode); + workflow.putEdge(generateNode, WorkflowStateName.END); + workflow.startNode(retrieveNode); + + StateWorkflow workflowStreaming = jAiWorkflowStreaming.workflow(); + workflowStreaming.putEdge(retrieveNode, gradeDocumentsNode); + workflowStreaming.putEdge(gradeDocumentsNode, generateStreamingNode); + workflowStreaming.putEdge(generateStreamingNode, WorkflowStateName.END); + workflowStreaming.startNode(retrieveNode); + } + + @Test + void should_answer_question() { + // given + String question = "Summarizes the importance of building agents with LLMs"; + + // when + String answer = jAiWorkflow.answer(question); + + // then + assertThat(answer).containsIgnoringWhitespaces("brain of an autonomous agent system"); + } + + @Test + void should_answer_stream_with_non_streamingNode_throw_IllegalStateException() { + // given + String question = "Summarizes the importance of building agents with LLMs"; + + // when + assertThatExceptionOfType(IllegalStateException.class) + .isThrownBy(() -> jAiWorkflow.answerStream(question)) + .withMessage("The last node of the workflow must be a StreamingNode to run in stream mode"); + } + + @Test + void should_answer_stream_question() { + // given + String question = "Summarizes the importance of building agents with LLMs"; + List expectedTokens = Arrays.asList("building", "agent", "system","general","problem", "solver"); + + // when + Flux tokens = jAiWorkflowStreaming.answerStream(question); + + // then + StepVerifier.create(tokens) + .expectNextMatches(token -> expectedTokens.stream().anyMatch(token.toLowerCase()::contains)) + .expectNextCount(1) + .thenCancel() + .verify(); + String answer = tokens.collectList().block().stream().collect(joining()); + assertThat(expectedTokens) + .anySatisfy(token -> assertThat(answer).containsIgnoringWhitespaces(token)); + } + + private DefaultJAiWorkflow buildWorkflow(StatefulBeanMock statefulBean, Boolean runStream, List> nodes) { + return DefaultJAiWorkflow.builder() + .statefulBean(statefulBean) + .runStream(runStream) + .nodes(nodes) + .build(); + } + + private DefaultJAiWorkflow buildWorkflow(StatefulBeanMock statefulBean, Boolean runStream, Node... nodes) { + return buildWorkflow(statefulBean, runStream, Arrays.asList(nodes)); + } +} diff --git a/langchain4j-workflow/src/test/java/JAiWorkflowTest.java b/langchain4j-workflow/src/test/java/JAiWorkflowTest.java new file mode 100644 index 0000000..53bc9d2 --- /dev/null +++ b/langchain4j-workflow/src/test/java/JAiWorkflowTest.java @@ -0,0 +1,84 @@ +import com.github.czelabueno.jai.workflow.langchain4j.JAiWorkflow; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.UserMessage; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.Mockito.*; + +class JAiWorkflowTest { + + private JAiWorkflow jAiWorkflow; + private UserMessage userMessage; + private AiMessage aiMessage; + + @BeforeEach + void setUp() { + jAiWorkflow = mock(JAiWorkflow.class); + userMessage = new UserMessage("What is the weather today?"); + aiMessage = new AiMessage("The weather is sunny today."); + } + + @Test + void should_answer_with_valid_question() { + // given + when(jAiWorkflow.answer(userMessage)).thenReturn(aiMessage); + // when + AiMessage response = jAiWorkflow.answer(userMessage); + // then + assertThat(response).isNotNull(); + assertThat(response.text()).isEqualTo("The weather is sunny today."); + verify(jAiWorkflow, times(1)).answer(userMessage); + } + + @Test + void should_throw_exception_with_null_question() { + // when + when(jAiWorkflow.answer((String) null)).thenThrow(new NullPointerException("question")); + // then + assertThatExceptionOfType(NullPointerException.class) + .isThrownBy(() -> jAiWorkflow.answer((String) null)) + .withMessage("question"); + verify(jAiWorkflow, times(1)).answer((String) null); + } + + @Test + void should_answer_stream_with_valid_question() { + // given + when(jAiWorkflow.answerStream(userMessage)).thenReturn(Flux.just("The", "weather", "is", "sunny", "today.")); + // when + Flux response = jAiWorkflow.answerStream(userMessage); + // then + assertThat(response).isNotNull(); + StepVerifier.create(response) + .expectNext("The", "weather", "is", "sunny", "today.") + .verifyComplete(); + verify(jAiWorkflow, times(1)).answerStream(userMessage); + } + + @Test + void should_throw_exception_with_null_question_stream() { + // when + when(jAiWorkflow.answerStream((String) null)).thenThrow(new NullPointerException("question")); + // then + assertThatExceptionOfType(NullPointerException.class) + .isThrownBy(() -> jAiWorkflow.answerStream((String) null)) + .withMessage("question"); + verify(jAiWorkflow, times(1)).answerStream((String) null); + } + + @Test + void should_answer_stream_with_non_streamingNode() { + // given + when(jAiWorkflow.answerStream(userMessage)).thenThrow(new IllegalStateException("The last node of the workflow must be a StreamingNode to run in stream mode")); + // when + assertThatExceptionOfType(IllegalStateException.class) + .isThrownBy(() -> jAiWorkflow.answerStream(userMessage)) + .withMessage("The last node of the workflow must be a StreamingNode to run in stream mode"); + verify(jAiWorkflow, times(1)).answerStream(userMessage); + } +} diff --git a/langchain4j-workflow/src/test/java/com/github/czelabueno/jai/workflow/langchain4j/node/StreamingNodeTest.java b/langchain4j-workflow/src/test/java/com/github/czelabueno/jai/workflow/langchain4j/node/StreamingNodeTest.java new file mode 100644 index 0000000..306af0c --- /dev/null +++ b/langchain4j-workflow/src/test/java/com/github/czelabueno/jai/workflow/langchain4j/node/StreamingNodeTest.java @@ -0,0 +1,95 @@ +package com.github.czelabueno.jai.workflow.langchain4j.node; + +import com.github.czelabueno.jai.workflow.langchain4j.AbstractStatefulBean; +import com.github.czelabueno.jai.workflow.langchain4j.node.StreamingNode; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.test.StepVerifier; + +import java.util.Arrays; +import java.util.List; + +import static org.mockito.ArgumentMatchers.any; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.Mockito.*; + +class StreamingNodeTest { + + private StreamingChatLanguageModel model; + private MyStatefulBean statefulBean; + private List messages; + + class MyStatefulBean extends AbstractStatefulBean{ + List documents; + + public MyStatefulBean(List documents) { + this.documents = documents; + } + + @Override + public String toString() { + return "MyStatefulBean{" + + "documents=" + documents + + '}'; + } + } + + @BeforeEach + void setUp() { + model = mock(StreamingChatLanguageModel.class); + statefulBean = new MyStatefulBean(List.of("document1", "document2")); + messages = List.of(new UserMessage("What is the weather today?")); + } + + @Test + void should_create_streaming_node_using_from() { + // given + StreamingNode node = StreamingNode.from("streamingNode1", messages, model); + // then + assertThat(node).isNotNull(); + assertThat(node.getName()).isEqualTo("streamingNode1"); + } + + @Test + void should_streaming_function_with_valid_inputs() { + // given + List tokens = Arrays.asList("The", "weather", "is", "sunny", "today."); + doAnswer(invocation -> { + StreamingResponseHandler handler = invocation.getArgument(1); + tokens.forEach(handler::onNext); + handler.onComplete(new Response<>(new AiMessage("The weather is sunny today."))); + return null; + }).when(model).generate(anyList(), any(StreamingResponseHandler.class)); + // when + StreamingNode node = StreamingNode.from("streamingNode1", messages, model); + node.execute(statefulBean); + // then + StepVerifier.create(statefulBean.getGenerationStream()) + .expectNext("The", "weather", "is", "sunny", "today.") + .verifyComplete(); + assertThat(statefulBean.getGeneration()).isEqualTo("The weather is sunny today."); + } + + @Test + void should_throw_null_pointer_exception_if_streamingChatLanguageModel_is_null() { + // then + assertThatExceptionOfType(NullPointerException.class) + .isThrownBy(() -> StreamingNode.from("streamingNode1", messages, null)) + .withMessage("streamingChatLanguageModel is marked non-null but is null"); + } + + @Test + void should_throw_illegal_argument_exception() { + // then + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> StreamingNode.from(null, messages, model)) + .withMessage("name cannot be null or blank"); + } +} diff --git a/langchain4j-workflow/src/test/java/com/github/czelabueno/jai/workflow/langchain4j/workflow/NodeFunctionsMock.java b/langchain4j-workflow/src/test/java/com/github/czelabueno/jai/workflow/langchain4j/workflow/NodeFunctionsMock.java new file mode 100644 index 0000000..4e7372f --- /dev/null +++ b/langchain4j-workflow/src/test/java/com/github/czelabueno/jai/workflow/langchain4j/workflow/NodeFunctionsMock.java @@ -0,0 +1,77 @@ +package com.github.czelabueno.jai.workflow.langchain4j.workflow; + +import dev.langchain4j.data.document.Document; +import dev.langchain4j.data.document.loader.UrlDocumentLoader; +import dev.langchain4j.data.document.parser.TextDocumentParser; +import dev.langchain4j.data.document.splitter.DocumentSplitters; +import dev.langchain4j.data.document.transformer.jsoup.HtmlToTextDocumentTransformer; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.input.Prompt; +import dev.langchain4j.model.input.structured.StructuredPromptProcessor; +import dev.langchain4j.rag.content.Content; +import com.github.czelabueno.jai.workflow.langchain4j.workflow.prompt.GenerateAnswerPrompt; + +import java.util.ArrayList; +import java.util.List; + +import static java.util.stream.Collectors.toList; + +public class NodeFunctionsMock { + + private NodeFunctionsMock() { + } + + public static StatefulBeanMock retrieve(StatefulBeanMock statefulBean, String... uris) { + // Parse uris to documents + List documents = new ArrayList<>(); + for (String uri : uris) { + Document document = UrlDocumentLoader.load(uri,new TextDocumentParser()); + HtmlToTextDocumentTransformer transformer = new HtmlToTextDocumentTransformer(null, null, false); + document = transformer.transform(document); + documents.add(document); + } + // Mock retrieval that only gets the first 7 segments as a relevant document + List segments = DocumentSplitters + .recursive(300,0) + .splitAll(documents); + List relevantDocuments = segments.stream() + .limit(7) + .map(segment -> new Content(segment.text())) + .collect(toList()); + statefulBean.setDocuments(relevantDocuments.stream() + .map(Content::textSegment) + .map(TextSegment::text) + .toList()); + return statefulBean; + } + + public static StatefulBeanMock gradeDocuments(StatefulBeanMock statefulBean) { + // Mock grading that return that doc is relevant + List docs = statefulBean.getDocuments(); + List filteredDocs = docs.stream() + .filter(doc -> doc.length() > 0) // feeble filter to return the first doc + .toList(); + statefulBean.setDocuments(filteredDocs); + statefulBean.setWebSearch("No"); // do not require go to web search because doc is relevant + return statefulBean; + } + + public static StatefulBeanMock generate(StatefulBeanMock statefulBean, ChatLanguageModel model) { + String generation = model.generate(generateUserMessageFromStatefulBean(statefulBean).singleText()); + statefulBean.setGeneration(generation); + return statefulBean; + } + + public static UserMessage generateUserMessageFromStatefulBean(StatefulBeanMock statefulBean) { + return UserMessage.from(answerPrompt(statefulBean).text()); + } + + private static Prompt answerPrompt(StatefulBeanMock statefulBean) { + String question = statefulBean.getQuestion(); + String context = String.join("\n\n", statefulBean.getDocuments()); + GenerateAnswerPrompt generateAnswerPrompt = new GenerateAnswerPrompt(question, context); + return StructuredPromptProcessor.toPrompt(generateAnswerPrompt); + } +} diff --git a/langchain4j-workflow/src/test/java/com/github/czelabueno/jai/workflow/langchain4j/workflow/StatefulBeanMock.java b/langchain4j-workflow/src/test/java/com/github/czelabueno/jai/workflow/langchain4j/workflow/StatefulBeanMock.java new file mode 100644 index 0000000..ec55ce4 --- /dev/null +++ b/langchain4j-workflow/src/test/java/com/github/czelabueno/jai/workflow/langchain4j/workflow/StatefulBeanMock.java @@ -0,0 +1,16 @@ +package com.github.czelabueno.jai.workflow.langchain4j.workflow; + +import com.github.czelabueno.jai.workflow.langchain4j.AbstractStatefulBean; +import lombok.Data; + +import java.util.List; + +@Data +public class StatefulBeanMock extends AbstractStatefulBean { + + private List documents; + private String webSearch; + + public StatefulBeanMock() { + } +} diff --git a/langchain4j-workflow/src/test/java/com/github/czelabueno/jai/workflow/langchain4j/workflow/prompt/GenerateAnswerPrompt.java b/langchain4j-workflow/src/test/java/com/github/czelabueno/jai/workflow/langchain4j/workflow/prompt/GenerateAnswerPrompt.java new file mode 100644 index 0000000..7c412f9 --- /dev/null +++ b/langchain4j-workflow/src/test/java/com/github/czelabueno/jai/workflow/langchain4j/workflow/prompt/GenerateAnswerPrompt.java @@ -0,0 +1,24 @@ +package com.github.czelabueno.jai.workflow.langchain4j.workflow.prompt; + +import dev.langchain4j.model.input.structured.StructuredPrompt; + +@StructuredPrompt({ + "You are an assistant for question-answering tasks. ", + "Use the following pieces of retrieved context to answer the question. ", + "If you don't know the answer, just say that you don't know. ", + "Use three sentences maximum and keep the answer concise.", + + "Question: {{question}} \n\n", + "Context: {{context}} \n\n", + "Answer:" +}) +public class GenerateAnswerPrompt { + + private String question; + private String context; + + public GenerateAnswerPrompt(String question, String context) { + this.question = question; + this.context = context; + } +} diff --git a/pom.xml b/pom.xml index ee07edd..87d29ef 100644 --- a/pom.xml +++ b/pom.xml @@ -6,27 +6,27 @@ com.github.czelabueno jai-workflow-parent - ${revision} + 0.2.0 JavAI Workflow JavAI Workflow: Flexible workflow engine to build agentic, enterprise and modular RAG applications for Java - https://github.com/czelabueno/langchain4j-workflow + https://github.com/czelabueno/jai-workflow pom GitHub - https://github.com/czelabueno/langchain4j-workflow/issues + https://github.com/czelabueno/jai-workflow/issues GitHub Actions - https://github.com/czelabueno/langchain4j-workflow/actions + https://github.com/czelabueno/jai-workflow/actions - https://github.com/czelabueno/langchain4j-workflow - scm:git:git://github.com/czelabueno/langchain4j-workflow.git - scm:git:ssh://git@github.com/czelabueno/langchain4j-workflow.git + https://github.com/czelabueno/jai-workflow/tree/main + scm:git:git://github.com/czelabueno/jai-workflow.git + scm:git:ssh://git@github.com/czelabueno/jai-workflow.git HEAD @@ -44,7 +44,7 @@ Carlos Zela c.zelabueno@gmail.com jAI Workflow - https://github.com/czelabueno/langchain4j-workflow + https://github.com/czelabueno/jai-workflow @@ -53,15 +53,14 @@ UTF-8 0.2.0 17 - 17 - 17 + ${java.version} + ${java.version} 1.18.30 1.5.3 2.0.7 0.18.1 21.3.0 - 0.32.0 5.14.2 3.25.3 5.10.0 @@ -70,11 +69,6 @@ - - dev.langchain4j - langchain4j - ${langchain4j.version} - org.projectlombok