diff --git a/examples/rosbot-xl-generic-node-demo.py b/examples/rosbot-xl-generic-node-demo.py index 662f8fca5..b2bbbb614 100644 --- a/examples/rosbot-xl-generic-node-demo.py +++ b/examples/rosbot-xl-generic-node-demo.py @@ -20,10 +20,11 @@ import rclpy.qos import rclpy.subscription import rclpy.task +from langchain.tools.render import render_text_description_and_args from langchain_openai import ChatOpenAI from rai.agents.state_based import create_state_based_agent -from rai.node import RaiNode, describe_ros_image, wait_for_2s +from rai.node import RaiNode, describe_ros_image from rai.tools.ros.native import ( GetCameraImage, GetMsgFromTopic, @@ -31,6 +32,7 @@ ) from rai.tools.ros.native_actions import Ros2RunActionSync from rai.tools.ros.tools import GetOccupancyGridTool +from rai.tools.time import WaitForSecondsTool def main(): @@ -68,10 +70,9 @@ def main(): "/wait", ] - SYSTEM_PROMPT = "You are an autonomous robot connected to ros2 environment. Your main goal is to fulfill the user's requests. " - "Do not make assumptions about the environment you are currently in. " - "Use the tooling provided to gather information about the environment." - "You can use ros2 topics, services and actions to operate." + # TODO(boczekbartek): refactor system prompt + + SYSTEM_PROMPT = "" node = RaiNode( llm=ChatOpenAI( @@ -84,7 +85,7 @@ def main(): ) tools = [ - wait_for_2s, + WaitForSecondsTool(), GetMsgFromTopic(node=node), Ros2RunActionSync(node=node), GetCameraImage(node=node), @@ -94,6 +95,18 @@ def main(): state_retriever = node.get_robot_state + SYSTEM_PROMPT = f"""You are an autonomous robot connected to ros2 environment. Your main goal is to fulfill the user's requests. + Do not make assumptions about the environment you are currently in. + Use the tooling provided to gather information about the environment: + + {render_text_description_and_args(tools)} + + You can use ros2 topics, services and actions to operate. """ + + node.get_logger().info(f"{SYSTEM_PROMPT=}") + + node.system_prompt = node.initialize_system_prompt(SYSTEM_PROMPT) + app = create_state_based_agent( llm=llm, tools=tools, diff --git a/src/rai/rai/agents/state_based.py b/src/rai/rai/agents/state_based.py index f70496d2d..795bb1492 100644 --- a/src/rai/rai/agents/state_based.py +++ b/src/rai/rai/agents/state_based.py @@ -17,6 +17,7 @@ import pickle import time from functools import partial +from pathlib import Path from typing import ( Any, Callable, @@ -76,20 +77,39 @@ class Report(BaseModel): steps: List[str] = Field( ..., title="Steps", description="The steps taken to solve the problem" ) + success: bool = Field( + ..., title="Success", description="Whether the problem was solved" + ) response_to_user: str = Field( ..., title="Response", description="The response to the user" ) -def get_stored_artifacts(tool_call_id: str) -> List[Any]: - with open("artifact_database.pkl", "rb") as file: - artifact_database = pickle.load(file) +def get_stored_artifacts( + tool_call_id: str, db_path="artifact_database.pkl" +) -> List[Any]: + # TODO(boczekbartek): refactor + db_path = Path(db_path) + if not db_path.is_file(): + return [] + + with db_path.open("rb") as db: + artifact_database = pickle.load(db) if tool_call_id in artifact_database: return artifact_database[tool_call_id] + return [] -def store_artifacts(tool_call_id: str, artifacts: List[Any]): +def store_artifacts( + tool_call_id: str, artifacts: List[Any], db_path="artifact_database.pkl" +): + # TODO(boczekbartek): refactor + db_path = Path(db_path) + if not db_path.is_file(): + artifact_database = {} + with open("artifact_database.pkl", "wb") as file: + pickle.dump(artifact_database, file) with open("artifact_database.pkl", "rb") as file: artifact_database = pickle.load(file) if tool_call_id not in artifact_database: @@ -283,7 +303,7 @@ def retriever_wrapper( info = str_output(retrieved_info) state["messages"].append( HumanMultimodalMessage( - content="Retrieved state: {}".format(info), images=images, audios=audios + content=f"Retrieved state: {info}", images=images, audios=audios ) ) return state diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index 34c6ab2a0..92a60f197 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -15,10 +15,7 @@ import functools import time -from collections import deque -from dataclasses import dataclass, field -from pprint import pformat -from typing import Any, Callable, Deque, Dict, List, Literal, Optional, Tuple +from typing import Any, Callable, Dict, List, Literal, Optional import rcl_interfaces.msg import rclpy @@ -28,13 +25,11 @@ import rclpy.subscription import rclpy.task import sensor_msgs.msg -from langchain.tools import tool from langchain_core.messages import HumanMessage, SystemMessage -from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from langgraph.graph.graph import CompiledGraph from rclpy.action.graph import get_action_names_and_types -from rclpy.action.server import ActionServer +from rclpy.action.server import ActionServer, GoalResponse, ServerGoalHandle from rclpy.node import Node from rclpy.qos import ( DurabilityPolicy, @@ -45,92 +40,18 @@ ) from std_srvs.srv import Trigger -from rai.agents.state_based import State -from rai.messages.multimodal import HumanMultimodalMessage +from rai.agents.state_based import Report, State +from rai.messages.multimodal import HumanMultimodalMessage, MultimodalMessage from rai.tools.ros.utils import convert_ros_img_to_base64, import_message_from_str from rai.tools.utils import wait_for_message +from rai.utils.ros import NodeDiscovery, RosoutBuffer from rai_interfaces.action import Task as TaskAction -class RosoutBuffer: - def __init__(self, llm, bufsize: int = 100) -> None: - self.bufsize = bufsize - self._buffer: Deque[str] = deque() - self.template = ChatPromptTemplate.from_messages( - [ - ( - "system", - "Shorten the following log keeping its format - for example merge similar or repeating lines", - ), - ("human", "{rosout}"), - ] - ) - llm = llm - self.llm = self.template | llm - - def clear(self): - self._buffer.clear() - - def append(self, line: str): - self._buffer.append(line) - if len(self._buffer) > self.bufsize: - self._buffer.popleft() - - def get_raw_logs(self, last_n: int = 30) -> str: - return "\n".join(list(self._buffer)[-last_n:]) - - def summarize(self): - if len(self._buffer) == 0: - return "No logs" - buffer = self.get_raw_logs() - response = self.llm.invoke({"rosout": buffer}) - return str(response.content) - - -@tool -def wait_for_2s(): - """Wait for 2 seconds""" - time.sleep(2) - - -@dataclass -class NodeDiscovery: - topics_and_types: Dict[str, str] = field(default_factory=dict) - services_and_types: Dict[str, str] = field(default_factory=dict) - actions_and_types: Dict[str, str] = field(default_factory=dict) - whitelist: Optional[List[str]] = field(default_factory=list) - - def set(self, topics, services, actions): - def to_dict(info: List[Tuple[str, List[str]]]) -> Dict[str, str]: - return {k: v[0] for k, v in info} - - self.topics_and_types = to_dict(topics) - self.services_and_types = to_dict(services) - self.actions_and_types = to_dict(actions) - if self.whitelist is not None: - self.__filter(self.whitelist) - - def __filter(self, whitelist: List[str]): - for d in [ - self.topics_and_types, - self.services_and_types, - self.actions_and_types, - ]: - to_remove = [k for k in d if k not in whitelist] - for k in to_remove: - d.pop(k) - - def dict(self): - return { - "topics_and_types": self.topics_and_types, - "services_and_types": self.services_and_types, - "actions_and_types": self.actions_and_types, - } - - class RaiBaseNode(Node): def __init__( self, + whitelist: Optional[List[str]] = None, *args, **kwargs, ): @@ -144,7 +65,7 @@ def __init__( self.DISCOVERY_FREQ, self.discovery, ) - self.ros_discovery_info = NodeDiscovery(whitelist=None) + self.ros_discovery_info = NodeDiscovery(whitelist=whitelist) self.discovery() self.qos_profile = QoSProfile( history=HistoryPolicy.KEEP_LAST, @@ -215,10 +136,9 @@ def __init__( *args, **kwargs, ): - super().__init__(node_name, *args, **kwargs) + super().__init__(node_name=node_name, whitelist=whitelist, *args, **kwargs) self.llm = llm - self.whitelist = whitelist self.robot_state = dict() self.state_topics = observe_topics if observe_topics is not None else [] self.state_postprocessors = ( @@ -233,20 +153,8 @@ def __init__( Trigger, "rai_whoami_identity_service" ) - self.DISCOVERY_FREQ = 2.0 - self.DISCOVERY_DEPTH = 5 - self.callback_group = rclpy.callback_groups.MutuallyExclusiveCallbackGroup() - self.qos_profile = QoSProfile( - history=HistoryPolicy.KEEP_LAST, - depth=1, - reliability=ReliabilityPolicy.BEST_EFFORT, - durability=DurabilityPolicy.VOLATILE, - liveliness=LivelinessPolicy.AUTOMATIC, - ) - - self.state_subscribers = dict() self.initialize_robot_state_interfaces(self.state_topics) self.system_prompt = self.initialize_system_prompt(system_prompt) @@ -310,6 +218,14 @@ def initialize_robot_state_interfaces(self, topics): self.state_subscribers[topic] = subscriber +def parse_task_goal(ros_action_goal: TaskAction.Goal) -> Dict[str, Any]: + return dict( + task=ros_action_goal.task, + description=ros_action_goal.description, + priority=ros_action_goal.priority, + ) + + class RaiNode(RaiGenericBaseNode): def __init__( self, @@ -332,11 +248,7 @@ def __init__( **kwargs, ) - # ---------- ROS Parameters ---------- - self.task_topic = "/task_addition_requests" - # ---------- ROS configuration ---------- - self.rosout_sub = self.create_subscription( rcl_interfaces.msg.Log, "/rosout", @@ -347,8 +259,14 @@ def __init__( # ---------- Task Queue ---------- self.task_action_server = ActionServer( - self, TaskAction, "perform_task", self.agent_loop + self, + TaskAction, + "perform_task", + execute_callback=self.agent_loop, + goal_callback=self.goal_callback, ) + # Node is busy when task is executed. Only 1 task is allowed + self.busy = False # ---------- LLM Agents ---------- self.AGENT_RECURSION_LIMIT = 100 @@ -358,51 +276,93 @@ def __init__( # self.agent_loop_thread = Thread(target=self.agent_loop) # self.agent_loop_thread.start() - def agent_loop(self, goal_handle: TaskAction.Goal): - self.get_logger().info(f"Received goal handle: {goal_handle}") - action_request = goal_handle.request - task = dict( - task=action_request.task, - description=action_request.description, - priority=action_request.priority, - ) - self.get_logger().info(f"Received task: {task}") + def goal_callback(self, _) -> GoalResponse: + """Accept or reject a client request to begin an action.""" + response = GoalResponse.REJECT if self.busy else GoalResponse.ACCEPT + self.get_logger().info(f"Received goal request. Response: {response}") + return response + + async def agent_loop(self, goal_handle: ServerGoalHandle): + self.busy = True + try: + action_request: TaskAction.Goal = goal_handle.request + task: Dict[str, Any] = parse_task_goal( + action_request + ) # TODO(boczekbartek): base model and json + + self.get_logger().info(f"Received task: {task}") + + # ---- LLM Task Handling ---- + messages = [ + SystemMessage(content=self.system_prompt), + HumanMessage(content=f"Task: {task}"), + ] - # ---- LLM Task Handling ---- - messages = [ - SystemMessage(content=self.system_prompt), - HumanMessage(content=f"Task: {task}"), - ] + payload = State(messages=messages) - payload = State(messages=messages) + state = None + for state in self.llm_app.stream( + payload, {"recursion_limit": self.AGENT_RECURSION_LIMIT} + ): - state: State = self.llm_app.invoke( - payload, {"recursion_limit": self.AGENT_RECURSION_LIMIT} - ) # type: ignore + print(state.keys()) + graph_node_name = list(state.keys())[0] + if graph_node_name == "reporter": + continue - # ---- Share Action feedback ---- - # TODO(boczekbartek): add graph node to langgraph which will send ros2 action feedback to HMI + msg = state[graph_node_name]["messages"][-1] - # ---- Share Action Result ---- - report = state["messages"][-1] - report = pformat(report.json()) + if isinstance(msg, MultimodalMessage): + last_msg = msg.text + else: + last_msg = msg.content - result = TaskAction.Result() - result.success = ( - True # TODO(boczekbartek): ask llm if the action has been successful - ) - result.report = report + feedback_msg = TaskAction.Feedback() + feedback_msg.current_status = f"{graph_node_name}: {last_msg}" + + goal_handle.publish_feedback(feedback_msg) + + # ---- Share Action Result ---- + if state is None: + raise ValueError("No output from LLM") + print(state) - self.get_logger().info(f"Finished task:\n{report}") - self.clear_state() + graph_node_name = list(state.keys())[0] + if graph_node_name != "reporter": + raise ValueError(f"Unexpected output llm node: {graph_node_name}") - return report + report = state["reporter"]["messages"][ + -1 + ] # TODO define graph more strictly not as dict key + + if not isinstance(report, Report): + raise ValueError(f"Unexpected type of agent output: {type(report)}") + + if report.success: + goal_handle.succeed() + else: + goal_handle.abort() + + result = TaskAction.Result() + result.success = report.success + result.report = report.response_to_user + + self.get_logger().info(f"Finished task:\n{result}") + self.clear_state() + + return result + finally: + self.busy = False def set_app(self, app: CompiledGraph): self.llm_app = app def get_robot_state(self) -> Dict[str, str]: state_dict = dict() + + if self.robot_state is None: + return state_dict + for t in self.state_subscribers: if t not in self.robot_state: msg = "No message yet" diff --git a/src/rai/rai/tools/ros/native.py b/src/rai/rai/tools/ros/native.py index 2ae58b05f..9bcdd26ea 100644 --- a/src/rai/rai/tools/ros/native.py +++ b/src/rai/rai/tools/ros/native.py @@ -72,8 +72,6 @@ class Ros2BaseTool(BaseTool): node: rclpy.node.Node = Field(..., exclude=True, required=True) args_schema: Type[Ros2BaseInput] = Ros2BaseInput - handle_tool_error = True - handle_validation_error = True @property def logger(self) -> RcutilsLogger: diff --git a/src/rai/rai/tools/ros/native_actions.py b/src/rai/rai/tools/ros/native_actions.py index c464b6f90..740e6e22c 100644 --- a/src/rai/rai/tools/ros/native_actions.py +++ b/src/rai/rai/tools/ros/native_actions.py @@ -56,7 +56,7 @@ def _run(self): class Ros2RunActionSync(Ros2BaseTool): name: str = "Ros2RunAction" description: str = ( - "A tool for running a ros2 action. Actions might take some time to execute and are blocking - you will not be able to check their feedback, only will be informed about the result" + "A tool for running a ros2 action. Make sure you know the action interface first!!! Actions might take some time to execute and are blocking - you will not be able to check their feedback, only will be informed about the result" ) args_schema: Type[Ros2ActionRunnerInput] = Ros2ActionRunnerInput @@ -82,6 +82,10 @@ def _build_msg( def _run( self, action_name: str, action_type: str, action_goal_args: Dict[str, Any] ): + if action_name[0] != "/": + action_name = "/" + action_name + self.node.get_logger().info(f"Action name corrected to: {action_name}") + try: goal_msg, msg_cls = self._build_msg(action_type, action_goal_args) except Exception as e: @@ -89,7 +93,13 @@ def _run( client = ActionClient(self.node, msg_cls, action_name) + retries = 0 while not client.wait_for_server(timeout_sec=1.0): + retries += 1 + if retries > 5: + raise Exception( + f"Action server '{action_name}' is not available. Make sure `action_name` is correct..." + ) self.node.get_logger().info( f"'{action_name}' action server not available, waiting..." ) diff --git a/src/rai/rai/tools/time.py b/src/rai/rai/tools/time.py index 3ce638440..3b2bd44d8 100644 --- a/src/rai/rai/tools/time.py +++ b/src/rai/rai/tools/time.py @@ -17,19 +17,9 @@ from typing import Type from langchain.pydantic_v1 import BaseModel, Field -from langchain.tools import tool from langchain_core.tools import BaseTool -@tool -def sleep_max_5s(n: int): - """Wait n seconds, max 5s""" - if n > 5: - n = 5 - - time.sleep(n) - - class WaitForSecondsToolInput(BaseModel): """Input for the WaitForSecondsTool tool.""" @@ -44,11 +34,14 @@ class WaitForSecondsTool(BaseTool): "A tool for waiting. " "Useful for pausing execution for a specified number of seconds. " "Input should be the number of seconds to wait." + "Maximum allowed time is 5 seconds" ) args_schema: Type[WaitForSecondsToolInput] = WaitForSecondsToolInput def _run(self, seconds: int): """Waits for the specified number of seconds.""" + if seconds > 5: + seconds = 5 time.sleep(seconds) return f"Waited for {seconds} seconds." diff --git a/src/rai/rai/utils/ros.py b/src/rai/rai/utils/ros.py new file mode 100644 index 000000000..c281de74e --- /dev/null +++ b/src/rai/rai/utils/ros.py @@ -0,0 +1,91 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import deque +from dataclasses import dataclass, field +from typing import Deque, Dict, List, Optional, Tuple + +from langchain_core.language_models import BaseChatModel +from langchain_core.prompts import ChatPromptTemplate + + +class RosoutBuffer: + def __init__(self, llm: BaseChatModel, bufsize: int = 100) -> None: + self.bufsize = bufsize + self._buffer: Deque[str] = deque() + self.template = ChatPromptTemplate.from_messages( + [ + ( + "system", + "Shorten the following log keeping its format - for example merge simillar or repeating lines", + ), + ("human", "{rosout}"), + ] + ) + llm = llm + self.llm = self.template | llm + + def clear(self): + self._buffer.clear() + + def append(self, line: str): + self._buffer.append(line) + if len(self._buffer) > self.bufsize: + self._buffer.popleft() + + def get_raw_logs(self, last_n: int = 30) -> str: + return "\n".join(list(self._buffer)[-last_n:]) + + def summarize(self): + if len(self._buffer) == 0: + return "No logs" + buffer = self.get_raw_logs() + response = self.llm.invoke({"rosout": buffer}) + return str(response.content) + + +@dataclass +class NodeDiscovery: + topics_and_types: Dict[str, str] = field(default_factory=dict) + services_and_types: Dict[str, str] = field(default_factory=dict) + actions_and_types: Dict[str, str] = field(default_factory=dict) + whitelist: Optional[List[str]] = field(default_factory=list) + + def set(self, topics, services, actions): + def to_dict(info: List[Tuple[str, List[str]]]) -> Dict[str, str]: + return {k: v[0] for k, v in info} + + self.topics_and_types = to_dict(topics) + self.services_and_types = to_dict(services) + self.actions_and_types = to_dict(actions) + if self.whitelist is not None: + self.__filter(self.whitelist) + + def __filter(self, whitelist: List[str]): + for d in [ + self.topics_and_types, + self.services_and_types, + self.actions_and_types, + ]: + to_remove = [k for k in d if k not in whitelist] + for k in to_remove: + d.pop(k) + + def dict(self): + return { + "topics_and_types": self.topics_and_types, + "services_and_types": self.services_and_types, + "actions_and_types": self.actions_and_types, + } diff --git a/src/rai_hmi/rai_hmi/text_hmi.py b/src/rai_hmi/rai_hmi/text_hmi.py index 66ce71810..d9dd0876f 100644 --- a/src/rai_hmi/rai_hmi/text_hmi.py +++ b/src/rai_hmi/rai_hmi/text_hmi.py @@ -228,8 +228,9 @@ def display_agent_message( return # we do not handle system messages elif isinstance(message, MissionMessage): logger.info("Displaying mission message") - avatar, content = message.render_steamlit() - st.chat_message("bot", avatar=avatar).markdown(content) + with st.expander(label=message.STATUS): + avatar, content = message.render_steamlit() + st.chat_message("bot", avatar=avatar).markdown(content) else: raise ValueError("Unknown message type")