diff --git a/examples/taxi-demo.py b/examples/taxi-demo.py new file mode 100644 index 000000000..1d78dfc7e --- /dev/null +++ b/examples/taxi-demo.py @@ -0,0 +1,92 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from queue import Queue +from typing import List + +import rclpy +from langchain_community.tools import GooglePlacesTool +from langchain_community.tools.tavily_search import TavilySearchResults +from langchain_core.messages import BaseMessage, HumanMessage +from langchain_core.tools import tool +from std_msgs.msg import String + +from rai.agents.conversational_agent import create_conversational_agent +from rai.utils.model_initialization import get_llm_model, get_tracing_callbacks +from rai_hmi.api import GenericVoiceNode, split_message + + +@tool +def navigate(street: str, number: int, city: str, country: str) -> str: + """ + Send the destination to the navigation system. + """ + return f"Navigating to {street} {number}, {city}, {country}" + + +system_prompt = """ +**System Role: Taxi Driver in Warsaw** + +- **User Instructions**: You will be provided with a destination by the user, which may either be a specific place or an address. Sometimes, the user might describe the destination in a way that isn't clearly a place or address. + +- **Clarifying the Destination**: If the destination isn't immediately clear, your task is to ask clarifying questions to determine where the user wants to go. Once confirmed, ensure you obtain the exact address (including street name, number, etc.) to send to the navigation system. + +- **Location Context**: You are based in Warsaw, Poland, and your communication with the user must always be in English. + +- **Tools**: + - **tavily_search_results_json**: Use this tool to find an address when the user provides a non-specific description of a destination. + - **navigate**: Once the exact address is confirmed, use this to send the destination to the navigation system. + - **google_places**: Use this tool to search for specific places, businesses, or landmarks based on user descriptions. It can help if the user mentions popular destinations or well-known places in Warsaw. + +- **Communication Style**: Be friendly, helpful, and concise. While you may receive greetings or unrelated questions, keep the conversation focused on resolving the user's destination. + +- **Key Directives**: + - Do not guess or assume information; rely on tools to obtain any needed details. + - Your primary goal is to successfully navigate to the destination provided by the user. + - If you are sure about the destination, please try to resolve without additional interaction with the client. +""" + + +class TaxiDemo(GenericVoiceNode): + def __init__(self): + super().__init__("taxi_demo_node", Queue(), "") + + self.agent = create_conversational_agent( + get_llm_model("complex_model"), + [navigate, GooglePlacesTool(), TavilySearchResults()], + system_prompt, + logger=self.get_logger(), + ) + + self.history: List[BaseMessage] = [] + + def _handle_human_message(self, msg: String): + self.history.append(HumanMessage(content=msg.data)) + response = self.agent.invoke( + {"messages": self.history}, config={"callbacks": get_tracing_callbacks()} + ) + last_message = response["messages"][-1].content + for sentence in split_message(last_message): + self.hmi_publisher.publish(String(data=sentence)) + + +def main(): + rclpy.init() + node = TaxiDemo() + rclpy.spin(node) + rclpy.shutdown() + + +if __name__ == "__main__": + main() diff --git a/poetry.lock b/poetry.lock index 57ea6ad09..6d6078b33 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -55,13 +55,13 @@ files = [ [[package]] name = "aiohappyeyeballs" -version = "2.4.2" +version = "2.4.3" description = "Happy Eyeballs for asyncio" optional = false python-versions = ">=3.8" files = [ - {file = "aiohappyeyeballs-2.4.2-py3-none-any.whl", hash = "sha256:8522691d9a154ba1145b157d6d5c15e5c692527ce6a53c5e5f9876977f6dab2f"}, - {file = "aiohappyeyeballs-2.4.2.tar.gz", hash = "sha256:4ca893e6c5c1f5bf3888b04cb5a3bee24995398efef6e0b9f747b5e89d84fd74"}, + {file = "aiohappyeyeballs-2.4.3-py3-none-any.whl", hash = "sha256:8a7a83727b2756f394ab2895ea0765a0a8c475e3c71e98d43d76f22b4b435572"}, + {file = "aiohappyeyeballs-2.4.3.tar.gz", hash = "sha256:75cf88a15106a5002a8eb1dab212525c00d1f4c0fa96e551c9fbe6f09a621586"}, ] [[package]] @@ -455,17 +455,17 @@ files = [ [[package]] name = "boto3" -version = "1.35.29" +version = "1.35.30" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.35.29-py3-none-any.whl", hash = "sha256:2244044cdfa8ac345d7400536dc15a4824835e7ec5c55bc267e118af66bb27db"}, - {file = "boto3-1.35.29.tar.gz", hash = "sha256:7bbb1ee649e09e956952285782cfdebd7e81fc78384f48dfab3d66c6eaf3f63f"}, + {file = "boto3-1.35.30-py3-none-any.whl", hash = "sha256:d89c3459db89c5408e83219ab849ffd0146bc4285e75cdc67c6e45d390a12df2"}, + {file = "boto3-1.35.30.tar.gz", hash = "sha256:d2851aec8e9dc6937977acbe9a5124ecc31b3ad5f50a10cd9ae52636da3f52fa"}, ] [package.dependencies] -botocore = ">=1.35.29,<1.36.0" +botocore = ">=1.35.30,<1.36.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -474,13 +474,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.35.29" +version = "1.35.30" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.35.29-py3-none-any.whl", hash = "sha256:f8e3ae0d84214eff3fb69cb4dc51cea6c43d3bde82027a94d00c52b941d6c3d5"}, - {file = "botocore-1.35.29.tar.gz", hash = "sha256:4ed28ab03675bb008a290c452c5ddd7aaa5d4e3fa1912aadbdf93057ee84362b"}, + {file = "botocore-1.35.30-py3-none-any.whl", hash = "sha256:3bb9f9dde001608671ea74681ac3cec06bbbb10cba8cb8c1387a25e843075ce0"}, + {file = "botocore-1.35.30.tar.gz", hash = "sha256:ab5350e8a50e48d371fa2d517d65c29a40c43788cb9a15387f93eac5a23df0fd"}, ] [package.dependencies] @@ -1766,6 +1766,19 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4 [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] +[[package]] +name = "googlemaps" +version = "4.10.0" +description = "Python client library for Google Maps Platform" +optional = false +python-versions = ">=3.5" +files = [ + {file = "googlemaps-4.10.0.tar.gz", hash = "sha256:3055fcbb1aa262a9159b589b5e6af762b10e80634ae11c59495bd44867e47d88"}, +] + +[package.dependencies] +requests = ">=2.20.0,<3.0" + [[package]] name = "greenlet" version = "3.1.1" @@ -2754,8 +2767,8 @@ langchain-core = ">=0.3.6,<0.4.0" langchain-text-splitters = ">=0.3.0,<0.4.0" langsmith = ">=0.1.17,<0.2.0" numpy = [ - {version = ">=1,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""}, + {version = ">=1,<2", markers = "python_version < \"3.12\""}, ] pydantic = ">=2.7.4,<3.0.0" PyYAML = ">=5.3" @@ -2778,8 +2791,8 @@ files = [ boto3 = ">=1.34.131" langchain-core = ">=0.3.2,<0.4" numpy = [ - {version = ">=1,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""}, + {version = ">=1,<2", markers = "python_version < \"3.12\""}, ] pydantic = ">=2,<3" @@ -2801,8 +2814,8 @@ langchain = ">=0.3.1,<0.4.0" langchain-core = ">=0.3.6,<0.4.0" langsmith = ">=0.1.125,<0.2.0" numpy = [ - {version = ">=1,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""}, + {version = ">=1,<2", markers = "python_version < \"3.12\""}, ] pydantic-settings = ">=2.4.0,<3.0.0" PyYAML = ">=5.3" @@ -2812,13 +2825,13 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" [[package]] name = "langchain-core" -version = "0.3.6" +version = "0.3.7" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_core-0.3.6-py3-none-any.whl", hash = "sha256:7bb3df0117bdc628b18b6c8748de72c6f537d745d47566053ce6650d5712281c"}, - {file = "langchain_core-0.3.6.tar.gz", hash = "sha256:eb190494a5483f1965f693bb2085edb523370b20fc52dc294d3bd425773cd076"}, + {file = "langchain_core-0.3.7-py3-none-any.whl", hash = "sha256:a789875358001ca9293875c12f0b6238855325621ab66775109497b9b1648157"}, + {file = "langchain_core-0.3.7.tar.gz", hash = "sha256:9f877c00fec7fe1dca929dd3bed3999ee4c2e5c14c6744ed82cc66ddfcd15fdf"}, ] [package.dependencies] @@ -2905,28 +2918,28 @@ openai = ["openai (>=0.27.8)"] [[package]] name = "langgraph" -version = "0.2.28" +version = "0.2.31" description = "Building stateful, multi-actor applications with LLMs" optional = false python-versions = "<4.0,>=3.9.0" files = [ - {file = "langgraph-0.2.28-py3-none-any.whl", hash = "sha256:23390763c025139f71dc1f1576b31b6755fecff8dcc51a84505e24e63ec1218b"}, - {file = "langgraph-0.2.28.tar.gz", hash = "sha256:c968a1ed85025e0651d9390a7ba978447ab80d676f81dd0a049a7456754b3bce"}, + {file = "langgraph-0.2.31-py3-none-any.whl", hash = "sha256:9e5b4138aae95bfbd928b6f0f2869431060c80d7a62fc831370cf2aed3a488e8"}, + {file = "langgraph-0.2.31.tar.gz", hash = "sha256:78759ebd8abcabb1894cf64e07d221a11b970e77553a4f89e1134c3602958341"}, ] [package.dependencies] langchain-core = ">=0.2.39,<0.4" -langgraph-checkpoint = ">=1.0.2,<2.0.0" +langgraph-checkpoint = ">=1.0.14,<2.0.0" [[package]] name = "langgraph-checkpoint" -version = "1.0.12" +version = "1.0.14" description = "Library with base interfaces for LangGraph checkpoint savers." optional = false python-versions = "<4.0.0,>=3.9.0" files = [ - {file = "langgraph_checkpoint-1.0.12-py3-none-any.whl", hash = "sha256:44fc464c82ecb643a69b1c394080c54c63969798e0c538b763bbab67911b6e21"}, - {file = "langgraph_checkpoint-1.0.12.tar.gz", hash = "sha256:a8bdcf3a39a45193f009dd2a6ebaf637dbaeb50f1b88a66b151d9ab8c5b41d21"}, + {file = "langgraph_checkpoint-1.0.14-py3-none-any.whl", hash = "sha256:a60cbf06011a5f9c9bfcde971684732acd5df39632c58ff45f02f814519e9d8c"}, + {file = "langgraph_checkpoint-1.0.14.tar.gz", hash = "sha256:5c51f8d8cca4c0ed3e75c264a7bf66a2efa60ff521ed46f05facf606df424eb1"}, ] [package.dependencies] @@ -4131,10 +4144,10 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -4155,10 +4168,10 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -4344,9 +4357,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -8089,4 +8102,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10, <3.13" -content-hash = "a893846f73b481b6386f2ef9950a72b546a77d37b7418997202d746ede8d5932" +content-hash = "7c4152d060193b6cc8006aeb74ed228de410aaa1fcd960859fb39a110917ffc0" diff --git a/pyproject.toml b/pyproject.toml index dc57f3bea..6ab229ebc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ faiss-cpu = "^1.8.0.post1" rich = "^13.7.1" docx2txt = "^0.8" pypdf = "^4.2.0" +googlemaps = "^4.0" streamlit = "^1.37.1" deprecated = "^1.2.14"