diff --git a/src/rai_hmi/rai_hmi/voice_hmi.py b/src/rai_hmi/rai_hmi/voice_hmi.py index 36418598..c5a6ba5c 100644 --- a/src/rai_hmi/rai_hmi/voice_hmi.py +++ b/src/rai_hmi/rai_hmi/voice_hmi.py @@ -13,6 +13,13 @@ # limitations under the License. # +from typing import List + +import rclpy +from langchain.agents import AgentExecutor, create_tool_calling_agent +from langchain.prompts import ChatPromptTemplate +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_openai import ChatOpenAI from rclpy.callback_groups import ReentrantCallbackGroup from std_msgs.msg import String @@ -20,8 +27,8 @@ class VoiceHMINode(BaseHMINode): - def __init__(self, node_name: str, robot_description_package: str): - super().__init__(node_name, robot_description_package) + def __init__(self, node_name: str): + super().__init__(node_name) self.callback_group = ReentrantCallbackGroup() self.hmi_subscription = self.create_subscription( @@ -36,20 +43,63 @@ def __init__(self, node_name: str, robot_description_package: str): String, "to_human", 10, callback_group=self.callback_group ) + self.history: List[BaseMessage] = [] + self.agent = self.initialize_agent() + + self.get_logger().info("Voice HMI node initialized") + + def initialize_agent(self): + prompt = ChatPromptTemplate.from_messages( + [ + ("system", self.system_prompt), + ("placeholder", "{chat_history}"), + ("human", "{user_input}"), + ("placeholder", "{agent_scratchpad}"), + ] + ) + llm = ChatOpenAI(model="gpt-4o") + agent = create_tool_calling_agent(llm=llm, tools=self.tools, prompt=prompt) + agent_executor = AgentExecutor(agent=agent, tools=self.tools) + return agent_executor + def handle_human_message(self, msg: String): self.processing = True # handle human message - output = "" # self.agent(msg.data, config=config) + response = self.agent.invoke( + {"user_input": msg.data, "chat_history": self.history} + ) + output = response["output"] + self.history.append(HumanMessage(msg.data)) + self.history.append(AIMessage(output)) - self.processing = False self.hmi_publisher.publish(String(data=output)) + self.processing = False def handle_feedback_request(self, feedback_query: str) -> str: self.processing = True # handle feedback request - output = "" # self.agent(feedback_query, config=config) + feedback_prompt = ( + "The task executioner is asking for feedback on the following:" + f"```\n{feedback_query}\n```" + "Please provide needed information based on the following chat history:" + ) + local_history: List[BaseMessage] = [ + SystemMessage(content=self.system_prompt), + HumanMessage(content=feedback_prompt), + ] + local_history.extend(self.history) + response = self.agent.invoke({"user_input": "", "chat_history": local_history}) + output = response["output"] self.processing = False return output + + +def main(args=None): + rclpy.init(args=args) + voice_hmi_node = VoiceHMINode("voice_hmi_node") + rclpy.spin(voice_hmi_node) + voice_hmi_node.destroy_node() + rclpy.shutdown() diff --git a/src/rai_hmi/setup.py b/src/rai_hmi/setup.py index e0b56574..757f8d35 100644 --- a/src/rai_hmi/setup.py +++ b/src/rai_hmi/setup.py @@ -35,6 +35,7 @@ entry_points={ "console_scripts": [ "hmi_node = rai_hmi.hmi_node:main", + "voice_hmi_node = rai_hmi.voice_hmi:main", ], }, )