diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 2caf3cb76..ca94af87b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -23,25 +23,9 @@ repos:
hooks:
- id: shellcheck
- - repo: https://github.com/pycqa/autoflake
- rev: v2.3.1
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.9.4
hooks:
- - id: autoflake
- args: ["--remove-all-unused-imports", "--in-place"]
-
- - repo: https://github.com/pycqa/isort
- rev: 5.13.2
- hooks:
- - id: isort
- args: ["--profile", "black", "--filter-files"]
-
- - repo: https://github.com/psf/black
- rev: 24.4.2
- hooks:
- - id: black
-
- - repo: https://github.com/pycqa/flake8
- rev: 7.1.0
- hooks:
- - id: flake8
- args: ["--ignore=E501,E731,W503,W504,E203"]
+ - id: ruff
+ args: [--extend-select, "I,RUF022", --fix, --ignore, E731]
+ - id: ruff-format
diff --git a/examples/agriculture-demo.py b/examples/agriculture-demo.py
index e80d7a474..ac49e0e98 100644
--- a/examples/agriculture-demo.py
+++ b/examples/agriculture-demo.py
@@ -15,12 +15,6 @@
import argparse
import rclpy
-from rclpy.action import ActionClient
-from rclpy.callback_groups import ReentrantCallbackGroup
-from rclpy.executors import MultiThreadedExecutor
-from rclpy.node import Node
-from std_srvs.srv import Trigger
-
from rai.node import RaiStateBasedLlmNode, describe_ros_image
from rai.tools.ros.native import (
GetCameraImage,
@@ -30,6 +24,12 @@
Ros2ShowMsgInterfaceTool,
)
from rai.tools.time import WaitForSecondsTool
+from rclpy.action import ActionClient
+from rclpy.callback_groups import ReentrantCallbackGroup
+from rclpy.executors import MultiThreadedExecutor
+from rclpy.node import Node
+from std_srvs.srv import Trigger
+
from rai_interfaces.action import Task
diff --git a/examples/manipulation-demo-streamlit.py b/examples/manipulation-demo-streamlit.py
index bec608118..e226c1f88 100644
--- a/examples/manipulation-demo-streamlit.py
+++ b/examples/manipulation-demo-streamlit.py
@@ -16,7 +16,6 @@
import streamlit as st
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
-
from rai.agents.integrations.streamlit import get_streamlit_cb, streamlit_invoke
from rai.messages import HumanMultimodalMessage
diff --git a/examples/manipulation-demo.py b/examples/manipulation-demo.py
index 0c8d0665b..6f73248c2 100644
--- a/examples/manipulation-demo.py
+++ b/examples/manipulation-demo.py
@@ -17,7 +17,6 @@
import rclpy
import rclpy.qos
from langchain_core.messages import HumanMessage
-
from rai.agents.conversational_agent import create_conversational_agent
from rai.node import RaiBaseNode
from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool
diff --git a/examples/rosbot-xl-demo.py b/examples/rosbot-xl-demo.py
index 90600b6ae..b6c1ab666 100644
--- a/examples/rosbot-xl-demo.py
+++ b/examples/rosbot-xl-demo.py
@@ -18,8 +18,6 @@
import rclpy
import rclpy.executors
import rclpy.logging
-from rai_open_set_vision.tools import GetDetectionTool, GetDistanceToObjectsTool
-
from rai.node import RaiStateBasedLlmNode
from rai.tools.ros.native import (
GetMsgFromTopic,
@@ -35,6 +33,7 @@
Ros2RunActionAsync,
)
from rai.tools.time import WaitForSecondsTool
+from rai_open_set_vision.tools import GetDetectionTool, GetDistanceToObjectsTool
p = argparse.ArgumentParser()
p.add_argument("--allowlist", type=Path, required=False, default=None)
diff --git a/examples/taxi-demo.py b/examples/taxi-demo.py
index 6f5d24a48..18d498a68 100644
--- a/examples/taxi-demo.py
+++ b/examples/taxi-demo.py
@@ -20,12 +20,12 @@
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.tools import tool
-from std_msgs.msg import String
-
from rai.agents.conversational_agent import create_conversational_agent
from rai.tools.ros.cli import Ros2ServiceTool
from rai.tools.ros.native import Ros2PubMessageTool
from rai.utils.model_initialization import get_llm_model, get_tracing_callbacks
+from std_msgs.msg import String
+
from rai_hmi.api import GenericVoiceNode, split_message
system_prompt = """
diff --git a/poetry.lock b/poetry.lock
index d6da53572..93059b345 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -5653,6 +5653,70 @@ files = [
[package.dependencies]
cffi = {version = "*", markers = "implementation_name == \"pypy\""}
+[[package]]
+name = "rai"
+version = "1.0.0"
+description = "Core functionality for RAI framework"
+optional = false
+python-versions = "^3.10, <3.13"
+files = []
+develop = true
+
+[package.dependencies]
+coloredlogs = "^15.0.1"
+deprecated = "^1.2.14"
+langchain = "*"
+langchain-core = "^0.3"
+langgraph = "*"
+markdown = "^3.6"
+requests = "^2.32.2"
+rich = "^13.7.1"
+tomli = "^2.0.1"
+tomli-w = "^1.1.0"
+tqdm = "^4.66.4"
+
+[package.source]
+type = "directory"
+url = "src/rai_core"
+
+[[package]]
+name = "rai-asr"
+version = "1.0.0"
+description = "Automatic Speech Recognition module for RAI framework"
+optional = false
+python-versions = "^3.10, <3.13"
+files = []
+develop = true
+
+[package.dependencies]
+faster-whisper = "^1.1.1"
+openai-whisper = "^20231117"
+pydub = "^0.25.1"
+scipy = "^1.14.0"
+sounddevice = "^0.4.7"
+torchaudio = "^2.3.1"
+
+[package.source]
+type = "directory"
+url = "src/rai_asr"
+
+[[package]]
+name = "rai-tts"
+version = "1.0.0"
+description = "Text-to-Speech module for RAI framework"
+optional = false
+python-versions = "^3.10, <3.13"
+files = []
+develop = true
+
+[package.dependencies]
+elevenlabs = "^1.4.1"
+sounddevice = "^0.4.7"
+
+[package.source]
+type = "directory"
+url = "src/rai_tts"
+
[[package]]
name = "ray"
version = "2.41.0"
@@ -8238,4 +8302,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10, <3.13"
-content-hash = "158877f96c27f3c9beb75aff8a9608a41b83740afe17ceda2a471378c4bb3545"
+content-hash = "242e440c4ce4b31fa629d198a3d79b0854e84284d013d819a4b7a24e633a1706"
diff --git a/pyproject.toml b/pyproject.toml
index 3224ea498..b313dd299 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,5 +1,5 @@
[tool.poetry]
-name = "rai"
+name = "rai_framework"
version = "1.0.0"
description = "RAI is a framework for building general multi-agent systems, bringing Gen AI features to ROS enabled robots."
readme = "README.md"
@@ -10,8 +10,14 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
]
package-mode = false
+
[tool.poetry.dependencies]
python = "^3.10, <3.13"
+
+rai = {path = "src/rai_core", develop = true}
+rai_asr = {path = "src/rai_asr", develop = true}
+rai_tts = {path = "src/rai_tts", develop = true}
+
langchain-core = "^0.3"
langchain = "*"
langgraph = "*"
@@ -78,11 +84,9 @@ visualnav_transformer = { git = "https://github.com/RobotecAI/visualnav-transfor
gdown = "^5.2.0"
[build-system]
-requires = ["setuptools>=42", "wheel"]
-build-backend = "setuptools.build_meta"
+requires = ["poetry-core>=1.0.0"]
+build-backend = "poetry.core.masonry.api"
-[tool.isort]
-profile = "black"
[tool.pytest.ini_options]
markers = [
diff --git a/src/examples/turtlebot4/turtlebot_demo.py b/src/examples/turtlebot4/turtlebot_demo.py
index 393606374..c3ea985b8 100644
--- a/src/examples/turtlebot4/turtlebot_demo.py
+++ b/src/examples/turtlebot4/turtlebot_demo.py
@@ -23,7 +23,6 @@
import rclpy.qos
import rclpy.subscription
import rclpy.task
-
from rai.node import RaiStateBasedLlmNode
from rai.tools.ros.native import (
GetCameraImage,
diff --git a/src/rai/LICENSE b/src/rai/LICENSE
deleted file mode 100644
index d0356d5ef..000000000
--- a/src/rai/LICENSE
+++ /dev/null
@@ -1,202 +0,0 @@
-
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright 2024-present Robotec.ai
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/src/rai/package.xml b/src/rai/package.xml
deleted file mode 100644
index 58cec45b7..000000000
--- a/src/rai/package.xml
+++ /dev/null
@@ -1,23 +0,0 @@
-
-
-
- rai
- 1.0.0
- RAI core modules
- Bartłomiej Boczek
- Maciej Majek
- Apache-2.0
-
- ament_copyright
- ament_flake8
- ament_pep257
- python3-pytest
-
- nav2_msgs
- nav2_simple_commander
- tf_transformations
-
-
- ament_python
-
-
diff --git a/src/rai/setup.cfg b/src/rai/setup.cfg
deleted file mode 100644
index c398a856f..000000000
--- a/src/rai/setup.cfg
+++ /dev/null
@@ -1,4 +0,0 @@
-[develop]
-script_dir=$base/lib/rai
-[install]
-install_scripts=$base/lib/rai
diff --git a/src/rai/setup.py b/src/rai/setup.py
deleted file mode 100644
index 9ed4bf5df..000000000
--- a/src/rai/setup.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from setuptools import find_packages, setup
-
-package_name = "rai"
-
-setup(
- name=package_name,
- version="1.0.0",
- packages=find_packages(exclude=["test"]),
- data_files=[
- ("share/ament_index/resource_index/packages", ["resource/" + package_name]),
- ("share/" + package_name, ["package.xml"]),
- ],
- install_requires=["setuptools"],
- zip_safe=True,
- maintainer="Bartłomiej Boczek",
- maintainer_email="bartlomiej.boczek@robotec.ai",
- description="TODO: Package description",
- license="Apache-2.0",
- tests_require=["pytest"],
- entry_points={},
-)
diff --git a/src/rai_asr/launch/local.launch.py b/src/rai_asr/launch/local.launch.py
deleted file mode 100644
index 7bfa69ca0..000000000
--- a/src/rai_asr/launch/local.launch.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from launch import LaunchDescription
-from launch.actions import DeclareLaunchArgument
-from launch.substitutions import LaunchConfiguration
-from launch_ros.actions import Node
-
-
-def generate_launch_description():
- return LaunchDescription(
- [
- DeclareLaunchArgument(
- "recording_device",
- default_value="0",
- description="Microphone device number. See available by running python -c 'import sounddevice as sd; print(sd.query_devices())'",
- ),
- DeclareLaunchArgument(
- "language",
- default_value="en",
- description="Language code for the ASR model",
- ),
- DeclareLaunchArgument(
- "model_name",
- default_value="base",
- description="Model name for the ASR model",
- ),
- DeclareLaunchArgument(
- "model_vendor",
- default_value="whisper",
- description="Model vendor of the ASR",
- ),
- DeclareLaunchArgument(
- "silence_grace_period",
- default_value="1.0",
- description="Grace period in seconds after silence to stop recording",
- ),
- DeclareLaunchArgument(
- "use_wake_word",
- default_value="False",
- description="Whether to use wake word detection",
- ),
- DeclareLaunchArgument(
- "wake_word_model",
- default_value="",
- description="Wake word model to use",
- ),
- DeclareLaunchArgument(
- "wake_word_threshold",
- default_value="0.5",
- description="Threshold for wake word detection",
- ),
- DeclareLaunchArgument(
- "vad_threshold",
- default_value="0.5",
- description="Threshold for voice activity detection",
- ),
- Node(
- package="rai_asr",
- executable="asr_node",
- name="rai_asr",
- output="screen",
- emulate_tty=True,
- parameters=[
- {
- "recording_device": LaunchConfiguration("recording_device"),
- "language": LaunchConfiguration("language"),
- "model_name": LaunchConfiguration("model_name"),
- "model_vendor": LaunchConfiguration("model_vendor"),
- "silence_grace_period": LaunchConfiguration(
- "silence_grace_period"
- ),
- "use_wake_word": LaunchConfiguration("use_wake_word"),
- "wake_word_model": LaunchConfiguration("wake_word_model"),
- "wake_word_threshold": LaunchConfiguration(
- "wake_word_threshold"
- ),
- "vad_threshold": LaunchConfiguration("vad_threshold"),
- }
- ],
- ),
- ]
- )
diff --git a/src/rai_asr/launch/openai.launch.py b/src/rai_asr/launch/openai.launch.py
deleted file mode 100644
index d7b4524ea..000000000
--- a/src/rai_asr/launch/openai.launch.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from launch import LaunchDescription
-from launch.actions import DeclareLaunchArgument
-from launch.substitutions import LaunchConfiguration
-from launch_ros.actions import Node
-
-
-def generate_launch_description():
- return LaunchDescription(
- [
- DeclareLaunchArgument(
- "recording_device",
- default_value="0",
- description="Microphone device number. See available by running python -c 'import sounddevice as sd; print(sd.query_devices())'",
- ),
- DeclareLaunchArgument(
- "language",
- default_value="en",
- description="Language code for the ASR model",
- ),
- DeclareLaunchArgument(
- "model_name",
- default_value="whisper-1",
- description="Model name for the ASR model",
- ),
- DeclareLaunchArgument(
- "model_vendor",
- default_value="openai",
- description="Model vendor of the ASR",
- ),
- DeclareLaunchArgument(
- "silence_grace_period",
- default_value="1.0",
- description="Grace period in seconds after silence to stop recording",
- ),
- DeclareLaunchArgument(
- "use_wake_word",
- default_value="False",
- description="Whether to use wake word detection",
- ),
- DeclareLaunchArgument(
- "wake_word_model",
- default_value="",
- description="Wake word model to use",
- ),
- DeclareLaunchArgument(
- "wake_word_threshold",
- default_value="0.5",
- description="Threshold for wake word detection",
- ),
- DeclareLaunchArgument(
- "vad_threshold",
- default_value="0.5",
- description="Threshold for voice activity detection",
- ),
- Node(
- package="rai_asr",
- executable="asr_node",
- name="rai_asr",
- output="screen",
- emulate_tty=True,
- parameters=[
- {
- "recording_device": LaunchConfiguration("recording_device"),
- "language": LaunchConfiguration("language"),
- "model_name": LaunchConfiguration("model_name"),
- "model_vendor": LaunchConfiguration("model_vendor"),
- "silence_grace_period": LaunchConfiguration(
- "silence_grace_period"
- ),
- "use_wake_word": LaunchConfiguration("use_wake_word"),
- "wake_word_model": LaunchConfiguration("wake_word_model"),
- "wake_word_threshold": LaunchConfiguration(
- "wake_word_threshold"
- ),
- "vad_threshold": LaunchConfiguration("vad_threshold"),
- }
- ],
- ),
- ]
- )
diff --git a/src/rai_asr/package.xml b/src/rai_asr/package.xml
deleted file mode 100644
index 56547ae81..000000000
--- a/src/rai_asr/package.xml
+++ /dev/null
@@ -1,16 +0,0 @@
-
-
-
- rai_asr
- 0.1.0
- An Automatic Speech Recognition package, leveraging Whisper for transcription and
- Silero VAD for voice activity detection. This node captures audio, detects speech, and
- transcribes the spoken content, publishing the transcription to a topic.
- mkotynia
- Apache-2.0
-
- portaudio19-dev
-
- ament_python
-
-
diff --git a/src/rai_asr/pyproject.toml b/src/rai_asr/pyproject.toml
new file mode 100644
index 000000000..925525951
--- /dev/null
+++ b/src/rai_asr/pyproject.toml
@@ -0,0 +1,27 @@
+[tool.poetry]
+name = "rai_asr"
+version = "1.0.0"
+description = "Automatic Speech Recognition module for RAI framework"
+authors = ["Maciej Majek ", "Bartłomiej Boczek ", "Kajetan Rachwał "]
+readme = "README.md"
+classifiers = [
+ "Programming Language :: Python :: 3",
+ "Development Status :: 4 - Beta",
+ "License :: OSI Approved :: Apache Software License",
+]
+packages = [
+ { include = "rai_asr", from = "." },
+]
+
+[build-system]
+requires = ["poetry-core>=1.0.0"]
+build-backend = "poetry.core.masonry.api"
+
+[tool.poetry.dependencies]
+python = "^3.10, <3.13"
+sounddevice = "^0.4.7"
+openai-whisper = "^20231117"
+scipy = "^1.14.0"
+torchaudio = "^2.3.1"
+faster-whisper = "^1.1.1"
+pydub = "^0.25.1"
diff --git a/src/rai_asr/rai_asr/__init__.py b/src/rai_asr/rai_asr/__init__.py
index ef74fc891..499007aa9 100644
--- a/src/rai_asr/rai_asr/__init__.py
+++ b/src/rai_asr/rai_asr/__init__.py
@@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+"""RAI ASR package."""
+
+__version__ = "0.1.0"
diff --git a/src/rai_asr/rai_asr/asr_node.py b/src/rai_asr/rai_asr/asr_node.py
deleted file mode 100755
index 7dac9e268..000000000
--- a/src/rai_asr/rai_asr/asr_node.py
+++ /dev/null
@@ -1,402 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import os
-import time
-from typing import Literal, Optional, cast
-
-import numpy as np
-import rclpy
-import sounddevice as sd
-import torch
-from numpy.typing import NDArray
-from openwakeword.model import Model as OWWModel
-from openwakeword.utils import download_models
-from rcl_interfaces.msg import ParameterDescriptor, ParameterType
-from rclpy.callback_groups import ReentrantCallbackGroup
-from rclpy.executors import SingleThreadedExecutor
-from rclpy.node import Node
-from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile, ReliabilityPolicy
-from scipy.signal import resample
-from std_msgs.msg import String
-
-VAD_SAMPLING_RATE = 16000 # default value used by silero vad
-DEFAULT_BLOCKSIZE = 1280
-
-
-class ASRNode(Node):
- def __init__(self):
- super().__init__("rai_asr") # type: ignore
- self._declare_parameters()
- self._initialize_parameters()
- self._setup_node_components()
- self._setup_publishers_and_subscribers()
-
- self.asr_model = self._initialize_asr_model()
- self.vad_model = self._initialize_vad_model()
- self.oww_model = self._initialize_open_wake_word()
-
- self.initialize_sounddevice_stream()
-
- self.is_recording = False
- self.audio_buffer = []
- self.silence_start_time: Optional[float] = None
- self.last_transcription_time = 0
- self.hmi_lock = False
- self.tts_lock = False
-
- self.current_chunk: Optional[NDArray[np.int16]] = None
-
- self.transcription_recording_timeout = 1
- self.get_logger().info("ASR Node has been initialized") # type: ignore
-
- def _declare_parameters(self):
- self.declare_parameter(
- "use_wake_word",
- False,
- descriptor=ParameterDescriptor(
- type=ParameterType.PARAMETER_BOOL,
- description=("Whether to use wake word for starting conversation"),
- ),
- )
- self.declare_parameter(
- "wake_word_model",
- "",
- descriptor=ParameterDescriptor(
- type=ParameterType.PARAMETER_STRING,
- description=("Wake word model onnx file"),
- ),
- )
- self.declare_parameter(
- "wake_word_threshold",
- 0.1,
- descriptor=ParameterDescriptor(
- type=ParameterType.PARAMETER_DOUBLE,
- description=("Wake word threshold"),
- ),
- )
- self.declare_parameter(
- "vad_threshold",
- 0.5,
- descriptor=ParameterDescriptor(
- type=ParameterType.PARAMETER_DOUBLE,
- description=("VAD threshold"),
- ),
- )
- self.declare_parameter(
- "recording_device",
- 0,
- descriptor=ParameterDescriptor(
- type=ParameterType.PARAMETER_INTEGER,
- description=(
- "Recording device number. See available by running"
- "python -c 'import sounddevice as sd; print(sd.query_devices())'"
- ),
- ),
- )
- self.declare_parameter(
- "model_vendor",
- "whisper", # openai, whisper
- ParameterDescriptor(
- type=ParameterType.PARAMETER_STRING,
- description="Vendor of the ASR model",
- ),
- )
- self.declare_parameter(
- "language",
- "en",
- ParameterDescriptor(
- type=ParameterType.PARAMETER_STRING,
- description="Language code for the ASR model",
- ),
- )
- self.declare_parameter(
- "model_name",
- "base",
- ParameterDescriptor(
- type=ParameterType.PARAMETER_STRING,
- description="Model type for the ASR model",
- ),
- )
- self.declare_parameter(
- "silence_grace_period",
- 1.0,
- ParameterDescriptor(
- type=ParameterType.PARAMETER_DOUBLE,
- description="Grace period in seconds after silence to stop recording",
- ),
- )
-
- def _initialize_open_wake_word(self) -> Optional[OWWModel]:
- if self.use_wake_word:
- download_models()
- oww_model = OWWModel(
- wakeword_models=[
- self.wake_word_model,
- ],
- inference_framework="onnx",
- )
- self.get_logger().info("Wake word model has been initialized") # type: ignore
- return oww_model
- return None
-
- def _initialize_vad_model(self):
- model, _ = torch.hub.load(
- repo_or_dir="snakers4/silero-vad",
- model="silero_vad",
- )
- return model
-
- def _setup_node_components(self):
- self.callback_group = ReentrantCallbackGroup()
-
- def _initialize_parameters(self):
- self.silence_grace_period = cast(
- float,
- self.get_parameter("silence_grace_period")
- .get_parameter_value()
- .double_value,
- )
- self.vad_threshold = cast(
- float,
- self.get_parameter("vad_threshold").get_parameter_value().double_value,
- ) # type: ignore
- self.model_name = (
- self.get_parameter("model_name").get_parameter_value().string_value
- ) # type: ignore
- self.model_vendor = (
- self.get_parameter("model_vendor").get_parameter_value().string_value
- ) # type: ignore
- self.language = (
- self.get_parameter("language").get_parameter_value().string_value
- ) # type: ignore
-
- self.use_wake_word = cast(
- bool,
- self.get_parameter("use_wake_word").get_parameter_value().bool_value,
- )
- self.wake_word_model = cast(
- str,
- self.get_parameter("wake_word_model").get_parameter_value().string_value,
- )
- self.wake_word_threshold = cast(
- float,
- self.get_parameter("wake_word_threshold")
- .get_parameter_value()
- .double_value,
- )
- self.recording_device_number = cast(
- int,
- self.get_parameter("recording_device").get_parameter_value().integer_value,
- )
-
- if self.use_wake_word:
- if not os.path.exists(self.wake_word_model):
- raise FileNotFoundError(f"Model file {self.wake_word_model} not found")
-
- self.get_logger().info("Parameters have been initialized") # type: ignore
-
- def _setup_publishers_and_subscribers(self):
- reliable_qos = QoSProfile(
- reliability=ReliabilityPolicy.RELIABLE,
- durability=DurabilityPolicy.TRANSIENT_LOCAL,
- history=HistoryPolicy.KEEP_ALL,
- )
- self.transcription_publisher = self.create_publisher(
- String, "/from_human", qos_profile=reliable_qos
- )
- self.status_publisher = self.create_publisher(String, "/asr_status", 10)
- self.tts_status_subscriber = self.create_subscription(
- String,
- "/tts_status",
- self.tts_status_callback,
- 10,
- callback_group=self.callback_group,
- )
- self.hmi_status_subscriber = self.create_subscription(
- String,
- "/hmi_status",
- self.hmi_status_callback,
- 10,
- callback_group=self.callback_group,
- )
-
- def _initialize_asr_model(self):
- if self.model_vendor == "openai":
- from rai_asr.asr_clients import OpenAIWhisper
-
- self.model = OpenAIWhisper(
- self.model_name, VAD_SAMPLING_RATE, self.language
- )
- elif self.model_vendor == "whisper":
- from rai_asr.asr_clients import LocalWhisper
-
- self.model = LocalWhisper(self.model_name, VAD_SAMPLING_RATE, self.language)
- else:
- raise ValueError(f"Unknown model vendor: {self.model_vendor}")
-
- def tts_status_callback(self, msg: String):
- if msg.data == "processing":
- self.tts_lock = True
- elif msg.data == "waiting":
- self.tts_lock = False
-
- def hmi_status_callback(self, msg: String):
- if msg.data == "processing":
- self.hmi_lock = True
- elif msg.data == "waiting":
- self.hmi_lock = False
-
- def should_listen(self, audio_data: NDArray[np.int16]) -> bool:
- def int2float(sound: NDArray[np.int16]):
- abs_max = np.abs(sound).max()
- sound = sound.astype("float32")
- if abs_max > 0:
- sound *= 1 / 32768
- sound = sound.squeeze()
- return sound
-
- vad_confidence = self.vad_model(
- torch.tensor(int2float(audio_data[-512:])), VAD_SAMPLING_RATE
- ).item()
-
- if self.oww_model:
- if self.is_recording:
- self.get_logger().debug(f"VAD confidence: {vad_confidence}") # type: ignore
- return vad_confidence > self.vad_threshold
- else:
- predictions = self.oww_model.predict(audio_data)
- for key, value in predictions.items():
- if value > self.wake_word_threshold:
- self.get_logger().debug(f"Detected wake word: {key}") # type: ignore
- self.oww_model.reset()
- return True
- else:
- return vad_confidence > self.vad_threshold
-
- return False
-
- def sd_callback(self, indata, frames, _, status):
- if status:
- self.get_logger().warning(f"Stream status: {status}") # type: ignore
- indata = indata.flatten()
- sample_time_length = len(indata) / self.device_sample_rate
- if self.device_sample_rate != VAD_SAMPLING_RATE:
- indata = resample(indata, int(sample_time_length * VAD_SAMPLING_RATE))
-
- asr_lock = (
- time.time()
- < self.last_transcription_time + self.transcription_recording_timeout
- )
- if asr_lock or self.hmi_lock or self.tts_lock:
- return
-
- if not self.is_recording: # keep last 5 indata of audio ~ 400ms
- self.audio_buffer.append(indata)
- if len(self.audio_buffer) > 5:
- self.audio_buffer.pop(0)
-
- if self.should_listen(indata):
- self.silence_start_time = time.time()
- if not self.is_recording:
- self.start_recording()
- self.audio_buffer.append(indata)
- elif self.is_recording:
- self.audio_buffer.append(indata)
- if not isinstance(self.silence_start_time, float):
- raise ValueError(
- "Silence start time is not set, this should not happen"
- )
- if time.time() - self.silence_start_time > self.silence_grace_period:
- self.stop_recording_and_transcribe()
-
- def initialize_sounddevice_stream(self):
- sd.default.latency = ("low", "low")
- self.device_sample_rate = sd.query_devices(
- device=self.recording_device_number, kind="input"
- )[
- "default_samplerate"
- ] # type: ignore
- self.window_size_samples = int(
- DEFAULT_BLOCKSIZE * self.device_sample_rate / VAD_SAMPLING_RATE
- )
- self.stream = sd.InputStream(
- samplerate=self.device_sample_rate,
- channels=1,
- device=self.recording_device_number,
- dtype="int16",
- blocksize=self.window_size_samples,
- callback=self.sd_callback,
- )
- self.stream.start()
-
- def reset_buffer(self):
- self.audio_buffer.clear()
-
- def start_recording(self):
- self.get_logger().info("Recording...") # type: ignore
- self.publish_status("recording")
- self.is_recording = True
-
- def stop_recording_and_transcribe(self):
- self.get_logger().info("Stopped recording. Transcribing...") # type: ignore
- self.is_recording = False
- self.publish_status("transcribing")
- self.transcribe_audio()
- self.publish_status("waiting")
- self.get_logger().info("Done transcribing.") # type: ignore
-
- def transcribe_audio(self):
- combined_audio = np.concatenate(self.audio_buffer)
- self.reset_buffer() # consume the buffer, so we don't transcribe the same audio twice
-
- transcription = self.model(data=combined_audio)
-
- if transcription.lower() in ["you", ""]:
- self.get_logger().info(f"Dropping transcription: '{transcription}'")
- self.publish_status("dropping")
- else:
- self.get_logger().info(f"Transcription: {transcription}")
- self.publish_transcription(transcription)
-
- self.last_transcription_time = time.time()
-
- def publish_transcription(self, transcription: str):
- msg = String()
- msg.data = transcription
- self.transcription_publisher.publish(msg)
-
- def publish_status(
- self, status: Literal["recording", "transcribing", "dropping", "waiting"]
- ):
- msg = String()
- msg.data = status
- self.status_publisher.publish(msg)
-
-
-def main(args=None):
- rclpy.init(args=args)
- node = ASRNode()
- executor = SingleThreadedExecutor()
- executor.add_node(node)
-
- try:
- executor.spin()
- except KeyboardInterrupt:
- pass
- finally:
- executor.shutdown()
- node.destroy_node()
- rclpy.shutdown()
diff --git a/src/rai_asr/rai_asr/models/__init__.py b/src/rai_asr/rai_asr/models/__init__.py
index 1d1a7e9de..daaa95a15 100644
--- a/src/rai_asr/rai_asr/models/__init__.py
+++ b/src/rai_asr/rai_asr/models/__init__.py
@@ -19,10 +19,10 @@
from rai_asr.models.silero_vad import SileroVAD
__all__ = [
- "BaseVoiceDetectionModel",
- "SileroVAD",
- "OpenWakeWord",
"BaseTranscriptionModel",
+ "BaseVoiceDetectionModel",
"LocalWhisper",
"OpenAIWhisper",
+ "OpenWakeWord",
+ "SileroVAD",
]
diff --git a/src/rai_asr/rai_asr/models/base.py b/src/rai_asr/rai_asr/models/base.py
index 13142df87..1cf62a14c 100644
--- a/src/rai_asr/rai_asr/models/base.py
+++ b/src/rai_asr/rai_asr/models/base.py
@@ -21,7 +21,6 @@
class BaseVoiceDetectionModel(ABC):
-
def __call__(
self, audio_data: NDArray, input_parameters: dict[str, Any]
) -> Tuple[bool, dict[str, Any]]:
diff --git a/src/rai_asr/setup.cfg b/src/rai_asr/setup.cfg
deleted file mode 100644
index 8499f73be..000000000
--- a/src/rai_asr/setup.cfg
+++ /dev/null
@@ -1,4 +0,0 @@
-[develop]
-script_dir=$base/lib/rai_asr
-[install]
-install_scripts=$base/lib/rai_asr
diff --git a/src/rai_asr/setup.py b/src/rai_asr/setup.py
deleted file mode 100644
index c85aa620a..000000000
--- a/src/rai_asr/setup.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import os
-from glob import glob
-
-from setuptools import find_packages, setup
-
-package_name = "rai_asr"
-
-setup(
- name=package_name,
- version="0.0.0",
- packages=find_packages(exclude=["test"]),
- data_files=[
- ("share/ament_index/resource_index/packages", ["resource/" + package_name]),
- ("share/" + package_name, ["package.xml"]),
- (os.path.join("share", package_name, "launch"), glob("launch/*.launch.py")),
- ],
- install_requires=["setuptools"],
- zip_safe=True,
- maintainer="mkotynia",
- maintainer_email="magdalena.kotynia@robotec.ai",
- description="An Automatic Speech Recognition package, leveraging Whisper for transcription and \
- Silero VAD for voice activity detection. This node captures audio, detects speech, and \
- transcribes the spoken content, publishing the transcription to a topic. ",
- license="Apache-2.0",
- tests_require=["pytest"],
- entry_points={
- "console_scripts": ["asr_node = rai_asr.asr_node:main"],
- },
-)
diff --git a/src/rai/resource/rai b/src/rai_core/README.md
similarity index 100%
rename from src/rai/resource/rai
rename to src/rai_core/README.md
diff --git a/src/rai_core/pyproject.toml b/src/rai_core/pyproject.toml
new file mode 100644
index 000000000..48809e4a8
--- /dev/null
+++ b/src/rai_core/pyproject.toml
@@ -0,0 +1,32 @@
+[build-system]
+requires = ["poetry-core>=1.0.0"]
+build-backend = "poetry.core.masonry.api"
+
+[tool.poetry]
+name = "rai"
+version = "1.0.0"
+description = "Core functionality for RAI framework"
+authors = ["Maciej Majek ", "Bartłomiej Boczek ", "Kajetan Rachwał "]
+readme = "README.md"
+classifiers = [
+ "Programming Language :: Python :: 3",
+ "Development Status :: 4 - Beta",
+ "License :: OSI Approved :: Apache Software License",
+]
+packages = [
+ { include = "rai", from = "." },
+]
+
+[tool.poetry.dependencies]
+python = "^3.10, <3.13"
+langchain-core = "^0.3"
+langchain = "*"
+langgraph = "*"
+requests = "^2.32.2"
+coloredlogs = "^15.0.1"
+markdown = "^3.6"
+tqdm = "^4.66.4"
+rich = "^13.7.1"
+deprecated = "^1.2.14"
+tomli = "^2.0.1"
+tomli-w = "^1.1.0"
diff --git a/src/rai/rai/__init__.py b/src/rai_core/rai/__init__.py
similarity index 100%
rename from src/rai/rai/__init__.py
rename to src/rai_core/rai/__init__.py
diff --git a/src/rai/rai/agents/__init__.py b/src/rai_core/rai/agents/__init__.py
similarity index 100%
rename from src/rai/rai/agents/__init__.py
rename to src/rai_core/rai/agents/__init__.py
index 2b7d4461a..e28822100 100644
--- a/src/rai/rai/agents/__init__.py
+++ b/src/rai_core/rai/agents/__init__.py
@@ -19,7 +19,7 @@
__all__ = [
"ToolRunner",
+ "VoiceRecognitionAgent",
"create_conversational_agent",
"create_state_based_agent",
- "VoiceRecognitionAgent",
]
diff --git a/src/rai/rai/agents/base.py b/src/rai_core/rai/agents/base.py
similarity index 100%
rename from src/rai/rai/agents/base.py
rename to src/rai_core/rai/agents/base.py
diff --git a/src/rai/rai/agents/conversational_agent.py b/src/rai_core/rai/agents/conversational_agent.py
similarity index 100%
rename from src/rai/rai/agents/conversational_agent.py
rename to src/rai_core/rai/agents/conversational_agent.py
diff --git a/src/rai/rai/agents/integrations/__init__.py b/src/rai_core/rai/agents/integrations/__init__.py
similarity index 100%
rename from src/rai/rai/agents/integrations/__init__.py
rename to src/rai_core/rai/agents/integrations/__init__.py
diff --git a/src/rai/rai/agents/integrations/streamlit.py b/src/rai_core/rai/agents/integrations/streamlit.py
similarity index 99%
rename from src/rai/rai/agents/integrations/streamlit.py
rename to src/rai_core/rai/agents/integrations/streamlit.py
index 18ca98683..73360893a 100644
--- a/src/rai/rai/agents/integrations/streamlit.py
+++ b/src/rai_core/rai/agents/integrations/streamlit.py
@@ -125,7 +125,7 @@ def on_tool_end(self, output: Any, **kwargs: Any) -> Any:
# Decorator function to add the Streamlit execution context to a function
def add_streamlit_context(
- fn: Callable[..., fn_return_type]
+ fn: Callable[..., fn_return_type],
) -> Callable[..., fn_return_type]:
"""
Decorator to ensure that the decorated function runs within the Streamlit execution context.
diff --git a/src/rai/rai/agents/state_based.py b/src/rai_core/rai/agents/state_based.py
similarity index 100%
rename from src/rai/rai/agents/state_based.py
rename to src/rai_core/rai/agents/state_based.py
diff --git a/src/rai/rai/agents/tool_runner.py b/src/rai_core/rai/agents/tool_runner.py
similarity index 100%
rename from src/rai/rai/agents/tool_runner.py
rename to src/rai_core/rai/agents/tool_runner.py
diff --git a/src/rai/rai/agents/voice_agent.py b/src/rai_core/rai/agents/voice_agent.py
similarity index 98%
rename from src/rai/rai/agents/voice_agent.py
rename to src/rai_core/rai/agents/voice_agent.py
index c012837f6..339db49da 100644
--- a/src/rai/rai/agents/voice_agent.py
+++ b/src/rai_core/rai/agents/voice_agent.py
@@ -187,7 +187,9 @@ def should_record(
def transcription_thread(self, identifier: str):
self.logger.info(f"transcription thread {identifier} started")
audio_data = np.concatenate(self.transcription_buffers[identifier])
- with self.transcription_lock: # this is only necessary for the local model... TODO: fix this somehow
+ with (
+ self.transcription_lock
+ ): # this is only necessary for the local model... TODO: fix this somehow
transcription = self.transcription_model.transcribe(audio_data)
assert isinstance(self.connectors["ros2"], ROS2ARIConnector)
self.connectors["ros2"].send_message(
diff --git a/src/rai/rai/apps/__init__.py b/src/rai_core/rai/apps/__init__.py
similarity index 100%
rename from src/rai/rai/apps/__init__.py
rename to src/rai_core/rai/apps/__init__.py
diff --git a/src/rai/rai/apps/document_loader.py b/src/rai_core/rai/apps/document_loader.py
similarity index 100%
rename from src/rai/rai/apps/document_loader.py
rename to src/rai_core/rai/apps/document_loader.py
diff --git a/src/rai/rai/apps/high_level_api.py b/src/rai_core/rai/apps/high_level_api.py
similarity index 100%
rename from src/rai/rai/apps/high_level_api.py
rename to src/rai_core/rai/apps/high_level_api.py
diff --git a/src/rai/rai/apps/state_analyzer.py b/src/rai_core/rai/apps/state_analyzer.py
similarity index 99%
rename from src/rai/rai/apps/state_analyzer.py
rename to src/rai_core/rai/apps/state_analyzer.py
index 0a736dadb..6ebe2d0db 100644
--- a/src/rai/rai/apps/state_analyzer.py
+++ b/src/rai_core/rai/apps/state_analyzer.py
@@ -39,7 +39,6 @@ def robot_state_analyzer(
state: str,
state_analyzer_prompt: str = STATE_ANALYZER_PROMPT,
) -> State:
-
template = ChatPromptTemplate.from_messages(
[
("system", state_analyzer_prompt),
diff --git a/src/rai/rai/apps/talk_to_docs.py b/src/rai_core/rai/apps/talk_to_docs.py
similarity index 97%
rename from src/rai/rai/apps/talk_to_docs.py
rename to src/rai_core/rai/apps/talk_to_docs.py
index 44811898c..8cccbf8f3 100644
--- a/src/rai/rai/apps/talk_to_docs.py
+++ b/src/rai_core/rai/apps/talk_to_docs.py
@@ -78,7 +78,9 @@ def talk_to_docs(documentation_root: str, llm: BaseChatModel):
agent = create_tool_calling_agent(llm, [query_docs], prompt) # type: ignore
agent_executor = AgentExecutor(
- agent=agent, tools=[query_docs], return_intermediate_steps=True # type: ignore
+ agent=agent,
+ tools=[query_docs],
+ return_intermediate_steps=True, # type: ignore
)
def input_node(state: State) -> State:
diff --git a/src/rai/rai/apps/task_executor.py b/src/rai_core/rai/apps/task_executor.py
similarity index 100%
rename from src/rai/rai/apps/task_executor.py
rename to src/rai_core/rai/apps/task_executor.py
diff --git a/src/rai/rai/apps/task_planner.py b/src/rai_core/rai/apps/task_planner.py
similarity index 100%
rename from src/rai/rai/apps/task_planner.py
rename to src/rai_core/rai/apps/task_planner.py
diff --git a/src/rai/rai/cli/__init__.py b/src/rai_core/rai/cli/__init__.py
similarity index 100%
rename from src/rai/rai/cli/__init__.py
rename to src/rai_core/rai/cli/__init__.py
diff --git a/src/rai/rai/cli/rai_cli.py b/src/rai_core/rai/cli/rai_cli.py
similarity index 99%
rename from src/rai/rai/cli/rai_cli.py
rename to src/rai_core/rai/cli/rai_cli.py
index 4df87cd17..70afb99cf 100644
--- a/src/rai/rai/cli/rai_cli.py
+++ b/src/rai_core/rai/cli/rai_cli.py
@@ -194,7 +194,7 @@ def create_rai_ws():
(package_path / "generated" / "robot_constitution.txt").touch()
default_constitution_path = (
- "src/rai/rai/cli/resources/default_robot_constitution.txt"
+ "src/rai_core/rai/cli/resources/default_robot_constitution.txt"
)
with open(default_constitution_path, "r") as file:
default_constitution = file.read()
diff --git a/src/rai/rai/cli/resources/default_robot_constitution.txt b/src/rai_core/rai/cli/resources/default_robot_constitution.txt
similarity index 100%
rename from src/rai/rai/cli/resources/default_robot_constitution.txt
rename to src/rai_core/rai/cli/resources/default_robot_constitution.txt
diff --git a/src/rai/rai/communication/__init__.py b/src/rai_core/rai/communication/__init__.py
similarity index 100%
rename from src/rai/rai/communication/__init__.py
rename to src/rai_core/rai/communication/__init__.py
index 394fdbb61..f18324d79 100644
--- a/src/rai/rai/communication/__init__.py
+++ b/src/rai_core/rai/communication/__init__.py
@@ -25,8 +25,8 @@
__all__ = [
"ARIConnector",
"ARIMessage",
- "BaseMessage",
"BaseConnector",
+ "BaseMessage",
"HRIConnector",
"HRIMessage",
"HRIPayload",
diff --git a/src/rai/rai/communication/ari_connector.py b/src/rai_core/rai/communication/ari_connector.py
similarity index 100%
rename from src/rai/rai/communication/ari_connector.py
rename to src/rai_core/rai/communication/ari_connector.py
diff --git a/src/rai/rai/communication/base_connector.py b/src/rai_core/rai/communication/base_connector.py
similarity index 99%
rename from src/rai/rai/communication/base_connector.py
rename to src/rai_core/rai/communication/base_connector.py
index 21d461b62..901256ddf 100644
--- a/src/rai/rai/communication/base_connector.py
+++ b/src/rai_core/rai/communication/base_connector.py
@@ -39,7 +39,6 @@ def __init__(
class BaseConnector(Generic[T]):
-
def _generate_handle(self) -> str:
return str(uuid4())
diff --git a/src/rai/rai/communication/hri_connector.py b/src/rai_core/rai/communication/hri_connector.py
similarity index 98%
rename from src/rai/rai/communication/hri_connector.py
rename to src/rai_core/rai/communication/hri_connector.py
index ba578afdf..a71b496e2 100644
--- a/src/rai/rai/communication/hri_connector.py
+++ b/src/rai_core/rai/communication/hri_connector.py
@@ -17,9 +17,8 @@
from io import BytesIO
from typing import Any, Dict, Generic, Literal, Optional, Sequence, TypeVar, get_args
-from langchain_core.messages import AIMessage
+from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import BaseMessage as LangchainBaseMessage
-from langchain_core.messages import HumanMessage
from PIL import Image
from PIL.Image import Image as ImageType
from pydub import AudioSegment
@@ -166,7 +165,6 @@ def _build_message(
self,
message: LangchainBaseMessage | RAIMultimodalMessage,
) -> T:
-
return self.T_class.from_langchain(message)
def send_all_targets(self, message: LangchainBaseMessage | RAIMultimodalMessage):
diff --git a/src/rai/rai/communication/ros2/__init__.py b/src/rai_core/rai/communication/ros2/__init__.py
similarity index 100%
rename from src/rai/rai/communication/ros2/__init__.py
rename to src/rai_core/rai/communication/ros2/__init__.py
diff --git a/src/rai/rai/communication/ros2/api.py b/src/rai_core/rai/communication/ros2/api.py
similarity index 99%
rename from src/rai/rai/communication/ros2/api.py
rename to src/rai_core/rai/communication/ros2/api.py
index a853f0932..44e03aafc 100644
--- a/src/rai/rai/communication/ros2/api.py
+++ b/src/rai_core/rai/communication/ros2/api.py
@@ -334,7 +334,6 @@ def __post_init__(self):
class ConfigurableROS2TopicAPI(ROS2TopicAPI):
-
def __init__(self, node: rclpy.node.Node):
super().__init__(node)
self._subscribtions: dict[str, rclpy.node.Subscription] = {}
@@ -562,7 +561,7 @@ def is_goal_done(self, handle: str) -> bool:
raise ValueError(f"Invalid action handle: {handle}")
if self.actions[handle]["result_future"] is None:
raise ValueError(
- f"Result future is None for handle: {handle}. " "Was the goal accepted?"
+ f"Result future is None for handle: {handle}. Was the goal accepted?"
)
return self.actions[handle]["result_future"].done()
diff --git a/src/rai/rai/communication/ros2/connectors.py b/src/rai_core/rai/communication/ros2/connectors.py
similarity index 99%
rename from src/rai/rai/communication/ros2/connectors.py
rename to src/rai_core/rai/communication/ros2/connectors.py
index f01a08258..2b4c94097 100644
--- a/src/rai/rai/communication/ros2/connectors.py
+++ b/src/rai_core/rai/communication/ros2/connectors.py
@@ -76,7 +76,6 @@ def send_message(
qos_profile: Optional[QoSProfile] = None,
**kwargs: Any,
):
-
self._topic_api.publish(
topic=target,
msg_content=message.payload,
diff --git a/src/rai/rai/communication/sound_device/__init__.py b/src/rai_core/rai/communication/sound_device/__init__.py
similarity index 100%
rename from src/rai/rai/communication/sound_device/__init__.py
rename to src/rai_core/rai/communication/sound_device/__init__.py
index 450926768..503c274d9 100644
--- a/src/rai/rai/communication/sound_device/__init__.py
+++ b/src/rai_core/rai/communication/sound_device/__init__.py
@@ -18,6 +18,6 @@
__all__ = [
"SoundDeviceAPI",
"SoundDeviceConfig",
- "SoundDeviceError",
"SoundDeviceConnector",
+ "SoundDeviceError",
]
diff --git a/src/rai/rai/communication/sound_device/api.py b/src/rai_core/rai/communication/sound_device/api.py
similarity index 99%
rename from src/rai/rai/communication/sound_device/api.py
rename to src/rai_core/rai/communication/sound_device/api.py
index e8aa9ae88..98d554e4b 100644
--- a/src/rai/rai/communication/sound_device/api.py
+++ b/src/rai_core/rai/communication/sound_device/api.py
@@ -49,7 +49,6 @@ def __post_init__(self):
class SoundDeviceAPI:
-
def __init__(self, config: SoundDeviceConfig):
self.device_name = ""
diff --git a/src/rai/rai/communication/sound_device/connector.py b/src/rai_core/rai/communication/sound_device/connector.py
similarity index 100%
rename from src/rai/rai/communication/sound_device/connector.py
rename to src/rai_core/rai/communication/sound_device/connector.py
diff --git a/src/rai/rai/config/__init__.py b/src/rai_core/rai/config/__init__.py
similarity index 100%
rename from src/rai/rai/config/__init__.py
rename to src/rai_core/rai/config/__init__.py
diff --git a/src/rai/rai/config/models.py b/src/rai_core/rai/config/models.py
similarity index 100%
rename from src/rai/rai/config/models.py
rename to src/rai_core/rai/config/models.py
diff --git a/src/rai/rai/extensions/__init__.py b/src/rai_core/rai/extensions/__init__.py
similarity index 100%
rename from src/rai/rai/extensions/__init__.py
rename to src/rai_core/rai/extensions/__init__.py
diff --git a/src/rai/rai/messages/__init__.py b/src/rai_core/rai/messages/__init__.py
similarity index 100%
rename from src/rai/rai/messages/__init__.py
rename to src/rai_core/rai/messages/__init__.py
index 929e04c5c..f5af4f43d 100644
--- a/src/rai/rai/messages/__init__.py
+++ b/src/rai_core/rai/messages/__init__.py
@@ -23,10 +23,10 @@
from .utils import preprocess_image
__all__ = [
- "HumanMultimodalMessage",
"AiMultimodalMessage",
+ "HumanMultimodalMessage",
+ "MultimodalArtifact",
"SystemMultimodalMessage",
"ToolMultimodalMessage",
- "MultimodalArtifact",
"preprocess_image",
]
diff --git a/src/rai/rai/messages/multimodal.py b/src/rai_core/rai/messages/multimodal.py
similarity index 98%
rename from src/rai/rai/messages/multimodal.py
rename to src/rai_core/rai/messages/multimodal.py
index 8db862bee..33e0ff35c 100644
--- a/src/rai/rai/messages/multimodal.py
+++ b/src/rai_core/rai/messages/multimodal.py
@@ -72,7 +72,7 @@ def __repr_args__(self) -> Any:
v = [c for c in v if c["type"] != "image_url"]
elif k == "images":
imgs_summary = [image[0:10] + "..." for image in v]
- v = f'{len(v)} base64 encoded images: [{", ".join(imgs_summary)}]'
+ v = f"{len(v)} base64 encoded images: [{', '.join(imgs_summary)}]"
new_args.append((k, v))
return new_args
diff --git a/src/rai/rai/messages/utils.py b/src/rai_core/rai/messages/utils.py
similarity index 100%
rename from src/rai/rai/messages/utils.py
rename to src/rai_core/rai/messages/utils.py
diff --git a/src/rai/rai/node.py b/src/rai_core/rai/node.py
similarity index 99%
rename from src/rai/rai/node.py
rename to src/rai_core/rai/node.py
index fcc2d3ae1..e39fe2300 100644
--- a/src/rai/rai/node.py
+++ b/src/rai_core/rai/node.py
@@ -303,7 +303,6 @@ async def agent_loop(self, goal_handle: ServerGoalHandle):
"callbacks": get_tracing_callbacks(),
},
):
-
graph_node_name = list(state.keys())[0]
if graph_node_name == "reporter":
continue
diff --git a/src/rai/rai/ros2_apis.py b/src/rai_core/rai/ros2_apis.py
similarity index 100%
rename from src/rai/rai/ros2_apis.py
rename to src/rai_core/rai/ros2_apis.py
diff --git a/src/rai/rai/tools/__init__.py b/src/rai_core/rai/tools/__init__.py
similarity index 100%
rename from src/rai/rai/tools/__init__.py
rename to src/rai_core/rai/tools/__init__.py
diff --git a/src/rai/rai/tools/debugging_assistant.py b/src/rai_core/rai/tools/debugging_assistant.py
similarity index 100%
rename from src/rai/rai/tools/debugging_assistant.py
rename to src/rai_core/rai/tools/debugging_assistant.py
diff --git a/src/rai/rai/tools/ros/__init__.py b/src/rai_core/rai/tools/ros/__init__.py
similarity index 100%
rename from src/rai/rai/tools/ros/__init__.py
rename to src/rai_core/rai/tools/ros/__init__.py
index 71e488752..c3be30b81 100644
--- a/src/rai/rai/tools/ros/__init__.py
+++ b/src/rai_core/rai/tools/ros/__init__.py
@@ -29,15 +29,15 @@
)
__all__ = [
+ "AddDescribedWaypointToDatabaseTool",
+ "GetCurrentPositionTool",
+ "GetOccupancyGridTool",
+ "Ros2BaseInput",
+ "Ros2BaseTool",
"ros2_action",
"ros2_interface",
"ros2_node",
- "ros2_topic",
"ros2_param",
"ros2_service",
- "Ros2BaseTool",
- "Ros2BaseInput",
- "AddDescribedWaypointToDatabaseTool",
- "GetOccupancyGridTool",
- "GetCurrentPositionTool",
+ "ros2_topic",
]
diff --git a/src/rai/rai/tools/ros/cli.py b/src/rai_core/rai/tools/ros/cli.py
similarity index 100%
rename from src/rai/rai/tools/ros/cli.py
rename to src/rai_core/rai/tools/ros/cli.py
diff --git a/src/rai/rai/tools/ros/deprecated.py b/src/rai_core/rai/tools/ros/deprecated.py
similarity index 96%
rename from src/rai/rai/tools/ros/deprecated.py
rename to src/rai_core/rai/tools/ros/deprecated.py
index 98c0e8ccf..ec8e3ba8e 100644
--- a/src/rai/rai/tools/ros/deprecated.py
+++ b/src/rai_core/rai/tools/ros/deprecated.py
@@ -104,7 +104,9 @@ def __init__(
def postprocess(self, msg: Image) -> str:
bridge = CvBridge()
- cv_image = cast(cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")) # type: ignore
+ cv_image = cast(
+ cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")
+ ) # type: ignore
if cv_image.shape[-1] == 4:
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGRA2RGB)
base64_image = base64.b64encode(
diff --git a/src/rai/rai/tools/ros/manipulation.py b/src/rai_core/rai/tools/ros/manipulation.py
similarity index 100%
rename from src/rai/rai/tools/ros/manipulation.py
rename to src/rai_core/rai/tools/ros/manipulation.py
diff --git a/src/rai/rai/tools/ros/native.py b/src/rai_core/rai/tools/ros/native.py
similarity index 100%
rename from src/rai/rai/tools/ros/native.py
rename to src/rai_core/rai/tools/ros/native.py
diff --git a/src/rai/rai/tools/ros/native_actions.py b/src/rai_core/rai/tools/ros/native_actions.py
similarity index 93%
rename from src/rai/rai/tools/ros/native_actions.py
rename to src/rai_core/rai/tools/ros/native_actions.py
index ba70a5364..c3833a3d7 100644
--- a/src/rai/rai/tools/ros/native_actions.py
+++ b/src/rai_core/rai/tools/ros/native_actions.py
@@ -60,9 +60,7 @@ class Ros2BaseActionTool(Ros2BaseTool):
class Ros2RunActionSync(Ros2BaseTool):
name: str = "Ros2RunAction"
- description: str = (
- "A tool for running a ros2 action. Make sure you know the action interface first!!! Actions might take some time to execute and are blocking - you will not be able to check their feedback, only will be informed about the result"
- )
+ description: str = "A tool for running a ros2 action. Make sure you know the action interface first!!! Actions might take some time to execute and are blocking - you will not be able to check their feedback, only will be informed about the result"
args_schema: Type[Ros2ActionRunnerInput] = Ros2ActionRunnerInput
@@ -176,9 +174,7 @@ def _run(self) -> bool:
class Ros2GetLastActionFeedback(Ros2BaseActionTool):
name: str = "Ros2GetLastActionFeedback"
- description: str = (
- "Action feedback is an optional intermediate information from ros2 action. With this tool you can get the last feedback of running action."
- )
+ description: str = "Action feedback is an optional intermediate information from ros2 action. With this tool you can get the last feedback of running action."
args_schema: Type[Ros2BaseInput] = Ros2BaseInput
diff --git a/src/rai/rai/tools/ros/nav2/__init__.py b/src/rai_core/rai/tools/ros/nav2/__init__.py
similarity index 100%
rename from src/rai/rai/tools/ros/nav2/__init__.py
rename to src/rai_core/rai/tools/ros/nav2/__init__.py
diff --git a/src/rai/rai/tools/ros/nav2/basic_navigator.py b/src/rai_core/rai/tools/ros/nav2/basic_navigator.py
similarity index 100%
rename from src/rai/rai/tools/ros/nav2/basic_navigator.py
rename to src/rai_core/rai/tools/ros/nav2/basic_navigator.py
diff --git a/src/rai/rai/tools/ros/nav2/navigator.py b/src/rai_core/rai/tools/ros/nav2/navigator.py
similarity index 100%
rename from src/rai/rai/tools/ros/nav2/navigator.py
rename to src/rai_core/rai/tools/ros/nav2/navigator.py
diff --git a/src/rai/rai/tools/ros/tools.py b/src/rai_core/rai/tools/ros/tools.py
similarity index 98%
rename from src/rai/rai/tools/ros/tools.py
rename to src/rai_core/rai/tools/ros/tools.py
index 7148e8eef..508c174a3 100644
--- a/src/rai/rai/tools/ros/tools.py
+++ b/src/rai_core/rai/tools/ros/tools.py
@@ -98,9 +98,7 @@ class GetOccupancyGridTool(BaseTool):
"""Get the current map as an image with the robot's position marked on it (red dot)."""
name: str = "GetOccupancyGridTool"
- description: str = (
- "A tool for getting the current map as an image with the robot's position marked on it."
- )
+ description: str = "A tool for getting the current map as an image with the robot's position marked on it."
args_schema: Type[TopicInput] = TopicInput
diff --git a/src/rai/rai/tools/ros/utils.py b/src/rai_core/rai/tools/ros/utils.py
similarity index 100%
rename from src/rai/rai/tools/ros/utils.py
rename to src/rai_core/rai/tools/ros/utils.py
diff --git a/src/rai/rai/tools/ros2/__init__.py b/src/rai_core/rai/tools/ros2/__init__.py
similarity index 100%
rename from src/rai/rai/tools/ros2/__init__.py
rename to src/rai_core/rai/tools/ros2/__init__.py
index ea7d939f5..460f4d100 100644
--- a/src/rai/rai/tools/ros2/__init__.py
+++ b/src/rai_core/rai/tools/ros2/__init__.py
@@ -24,13 +24,13 @@
)
__all__ = [
- "StartROS2ActionTool",
- "GetROS2ImageTool",
- "PublishROS2MessageTool",
- "ReceiveROS2MessageTool",
"CallROS2ServiceTool",
"CancelROS2ActionTool",
- "GetROS2TopicsNamesAndTypesTool",
+ "GetROS2ImageTool",
"GetROS2MessageInterfaceTool",
+ "GetROS2TopicsNamesAndTypesTool",
"GetROS2TransformTool",
+ "PublishROS2MessageTool",
+ "ReceiveROS2MessageTool",
+ "StartROS2ActionTool",
]
diff --git a/src/rai/rai/tools/ros2/actions.py b/src/rai_core/rai/tools/ros2/actions.py
similarity index 100%
rename from src/rai/rai/tools/ros2/actions.py
rename to src/rai_core/rai/tools/ros2/actions.py
diff --git a/src/rai/rai/tools/ros2/services.py b/src/rai_core/rai/tools/ros2/services.py
similarity index 100%
rename from src/rai/rai/tools/ros2/services.py
rename to src/rai_core/rai/tools/ros2/services.py
diff --git a/src/rai/rai/tools/ros2/topics.py b/src/rai_core/rai/tools/ros2/topics.py
similarity index 98%
rename from src/rai/rai/tools/ros2/topics.py
rename to src/rai_core/rai/tools/ros2/topics.py
index fce3a8bae..82dbfb347 100644
--- a/src/rai/rai/tools/ros2/topics.py
+++ b/src/rai_core/rai/tools/ros2/topics.py
@@ -98,7 +98,9 @@ def _run(self, topic: str) -> Tuple[str, MultimodalArtifact]:
raise ValueError(
f"Unsupported message type: {message.metadata['msg_type']}"
)
- return "Image received successfully", MultimodalArtifact(images=[preprocess_image(image)]) # type: ignore
+ return "Image received successfully", MultimodalArtifact(
+ images=[preprocess_image(image)]
+ ) # type: ignore
class GetROS2TopicsNamesAndTypesTool(BaseTool):
diff --git a/src/rai/rai/tools/ros2/utils.py b/src/rai_core/rai/tools/ros2/utils.py
similarity index 100%
rename from src/rai/rai/tools/ros2/utils.py
rename to src/rai_core/rai/tools/ros2/utils.py
diff --git a/src/rai/rai/tools/time.py b/src/rai_core/rai/tools/time.py
similarity index 100%
rename from src/rai/rai/tools/time.py
rename to src/rai_core/rai/tools/time.py
diff --git a/src/rai/rai/tools/utils.py b/src/rai_core/rai/tools/utils.py
similarity index 98%
rename from src/rai/rai/tools/utils.py
rename to src/rai_core/rai/tools/utils.py
index 58da02bc6..37c3c8fc3 100644
--- a/src/rai/rai/tools/utils.py
+++ b/src/rai_core/rai/tools/utils.py
@@ -249,7 +249,9 @@ def __init__(
def postprocess(self, msg: Image) -> str:
bridge = CvBridge()
- cv_image = cast(cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")) # type: ignore
+ cv_image = cast(
+ cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")
+ ) # type: ignore
if cv_image.shape[-1] == 4:
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGRA2RGB)
base64_image = base64.b64encode(
diff --git a/src/rai/rai/utils/__init__.py b/src/rai_core/rai/utils/__init__.py
similarity index 100%
rename from src/rai/rai/utils/__init__.py
rename to src/rai_core/rai/utils/__init__.py
diff --git a/src/rai/rai/utils/artifacts.py b/src/rai_core/rai/utils/artifacts.py
similarity index 100%
rename from src/rai/rai/utils/artifacts.py
rename to src/rai_core/rai/utils/artifacts.py
diff --git a/src/rai/rai/utils/configurator.py b/src/rai_core/rai/utils/configurator.py
similarity index 96%
rename from src/rai/rai/utils/configurator.py
rename to src/rai_core/rai/utils/configurator.py
index 87be321fb..522b98224 100644
--- a/src/rai/rai/utils/configurator.py
+++ b/src/rai_core/rai/utils/configurator.py
@@ -316,19 +316,19 @@ def on_model_vendor_change(model_type: str):
)
def on_langfuse_change():
- st.session_state.config["tracing"]["langfuse"][
- "use_langfuse"
- ] = st.session_state.langfuse_checkbox
+ st.session_state.config["tracing"]["langfuse"]["use_langfuse"] = (
+ st.session_state.langfuse_checkbox
+ )
def on_langfuse_host_change():
- st.session_state.config["tracing"]["langfuse"][
- "host"
- ] = st.session_state.langfuse_host_input
+ st.session_state.config["tracing"]["langfuse"]["host"] = (
+ st.session_state.langfuse_host_input
+ )
def on_langsmith_change():
- st.session_state.config["tracing"]["langsmith"][
- "use_langsmith"
- ] = st.session_state.langsmith_checkbox
+ st.session_state.config["tracing"]["langsmith"]["use_langsmith"] = (
+ st.session_state.langsmith_checkbox
+ )
# Ensure tracing config exists
if "tracing" not in st.session_state.config:
@@ -397,9 +397,9 @@ def on_langsmith_change():
elif st.session_state.current_step == 4:
def on_recording_device_change():
- st.session_state.config["asr"][
- "recording_device_name"
- ] = st.session_state.recording_device_select
+ st.session_state.config["asr"]["recording_device_name"] = (
+ st.session_state.recording_device_select
+ )
def on_asr_vendor_change():
vendor = (
@@ -413,29 +413,29 @@ def on_language_change():
st.session_state.config["asr"]["language"] = st.session_state.language_input
def on_silence_grace_change():
- st.session_state.config["asr"][
- "silence_grace_period"
- ] = st.session_state.silence_grace_input
+ st.session_state.config["asr"]["silence_grace_period"] = (
+ st.session_state.silence_grace_input
+ )
def on_vad_threshold_change():
- st.session_state.config["asr"][
- "vad_threshold"
- ] = st.session_state.vad_threshold_input
+ st.session_state.config["asr"]["vad_threshold"] = (
+ st.session_state.vad_threshold_input
+ )
def on_wake_word_change():
- st.session_state.config["asr"][
- "use_wake_word"
- ] = st.session_state.wake_word_checkbox
+ st.session_state.config["asr"]["use_wake_word"] = (
+ st.session_state.wake_word_checkbox
+ )
def on_wake_word_model_change():
- st.session_state.config["asr"][
- "wake_word_model"
- ] = st.session_state.wake_word_model_input
+ st.session_state.config["asr"]["wake_word_model"] = (
+ st.session_state.wake_word_model_input
+ )
def on_wake_word_threshold_change():
- st.session_state.config["asr"][
- "wake_word_threshold"
- ] = st.session_state.wake_word_threshold_input
+ st.session_state.config["asr"]["wake_word_threshold"] = (
+ st.session_state.wake_word_threshold_input
+ )
# Ensure asr config exists
if "asr" not in st.session_state.config:
@@ -588,9 +588,9 @@ def on_tts_vendor_change():
st.session_state.config["tts"]["vendor"] = vendor
def on_keep_speaker_busy_change():
- st.session_state.config["tts"][
- "keep_speaker_busy"
- ] = st.session_state.keep_speaker_busy_checkbox
+ st.session_state.config["tts"]["keep_speaker_busy"] = (
+ st.session_state.keep_speaker_busy_checkbox
+ )
# Ensure tts config exists
if "tts" not in st.session_state.config:
diff --git a/src/rai/rai/utils/model_initialization.py b/src/rai_core/rai/utils/model_initialization.py
similarity index 100%
rename from src/rai/rai/utils/model_initialization.py
rename to src/rai_core/rai/utils/model_initialization.py
diff --git a/src/rai/rai/utils/ros.py b/src/rai_core/rai/utils/ros.py
similarity index 100%
rename from src/rai/rai/utils/ros.py
rename to src/rai_core/rai/utils/ros.py
diff --git a/src/rai/rai/utils/ros_async.py b/src/rai_core/rai/utils/ros_async.py
similarity index 100%
rename from src/rai/rai/utils/ros_async.py
rename to src/rai_core/rai/utils/ros_async.py
diff --git a/src/rai/rai/utils/ros_executors.py b/src/rai_core/rai/utils/ros_executors.py
similarity index 100%
rename from src/rai/rai/utils/ros_executors.py
rename to src/rai_core/rai/utils/ros_executors.py
diff --git a/src/rai/rai/utils/ros_logs.py b/src/rai_core/rai/utils/ros_logs.py
similarity index 100%
rename from src/rai/rai/utils/ros_logs.py
rename to src/rai_core/rai/utils/ros_logs.py
diff --git a/src/rai_extensions/rai_nomad/package.xml b/src/rai_extensions/rai_nomad/package.xml
index 0ab802cad..c07d612ef 100644
--- a/src/rai_extensions/rai_nomad/package.xml
+++ b/src/rai_extensions/rai_nomad/package.xml
@@ -8,7 +8,6 @@
Apache-2.0
ament_copyright
- ament_flake8
ament_pep257
python3-pytest
diff --git a/src/rai_extensions/rai_nomad/rai_nomad/nomad.py b/src/rai_extensions/rai_nomad/rai_nomad/nomad.py
index 5095d51ca..6e87a0be7 100644
--- a/src/rai_extensions/rai_nomad/rai_nomad/nomad.py
+++ b/src/rai_extensions/rai_nomad/rai_nomad/nomad.py
@@ -330,9 +330,9 @@ def pd_controller(self, waypoint: np.ndarray) -> Tuple[float]:
angular_vel = (
self.get_parameter("angular_vel").get_parameter_value().double_value
)
- assert (
- len(waypoint) == 2 or len(waypoint) == 4
- ), "waypoint must be a 2D or 4D vector"
+ assert len(waypoint) == 2 or len(waypoint) == 4, (
+ "waypoint must be a 2D or 4D vector"
+ )
if len(waypoint) == 2:
dx, dy = waypoint
else:
diff --git a/src/rai_extensions/rai_open_set_vision/package.xml b/src/rai_extensions/rai_open_set_vision/package.xml
index f348db3d6..9b22f49ee 100644
--- a/src/rai_extensions/rai_open_set_vision/package.xml
+++ b/src/rai_extensions/rai_open_set_vision/package.xml
@@ -8,7 +8,6 @@
Apache-2.0
ament_copyright
- ament_flake8
ament_pep257
python3-pytest
diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py
index fa3562537..3b7a71e20 100644
--- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py
+++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py
@@ -17,8 +17,8 @@
from .tools import GetDetectionTool, GetDistanceToObjectsTool
__all__ = [
- "GetDistanceToObjectsTool",
- "GetDetectionTool",
"GDINO_NODE_NAME",
"GDINO_SERVICE_NAME",
+ "GetDetectionTool",
+ "GetDistanceToObjectsTool",
]
diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py
index 90b75351d..45fe9eb52 100644
--- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py
+++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py
@@ -21,10 +21,10 @@
import rclpy
from ament_index_python.packages import get_package_share_directory
from cv_bridge import CvBridge
-from rai_open_set_vision.vision_markup.segmenter import GDSegmenter
from rclpy.node import Node
from rai_interfaces.srv import RAIGroundedSam
+from rai_open_set_vision.vision_markup.segmenter import GDSegmenter
GSAM_NODE_NAME = "grounded_sam"
GSAM_SERVICE_NAME = "grounded_sam_segment"
diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py
index 4fe1ae8c9..7927d7cf8 100644
--- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py
+++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py
@@ -20,12 +20,12 @@
import rclpy
from ament_index_python.packages import get_package_share_directory
-from rai_open_set_vision.vision_markup.boxer import GDBoxer
from rclpy.node import Node
from sensor_msgs.msg import Image
from rai_interfaces.msg import RAIDetectionArray
from rai_interfaces.srv import RAIGroundingDino
+from rai_open_set_vision.vision_markup.boxer import GDBoxer
class GDRequest(TypedDict):
diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py
index 4091ab0d0..52330070b 100644
--- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py
+++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py
@@ -16,8 +16,8 @@
from .segmentation_tools import GetGrabbingPointTool, GetSegmentationTool
__all__ = [
- "GetDistanceToObjectsTool",
"GetDetectionTool",
- "GetSegmentationTool",
+ "GetDistanceToObjectsTool",
"GetGrabbingPointTool",
+ "GetSegmentationTool",
]
diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py
index 2fa76fba1..336ec4bc7 100644
--- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py
+++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py
@@ -19,19 +19,19 @@
import rclpy.qos
import sensor_msgs.msg
from pydantic import BaseModel, Field
-from rai_open_set_vision import GDINO_SERVICE_NAME
+from rai.node import RaiBaseNode
+from rai.tools.ros import Ros2BaseInput, Ros2BaseTool
+from rai.tools.ros.utils import convert_ros_img_to_ndarray
+from rai.tools.utils import wait_for_message
+from rai.utils.ros_async import get_future_result
from rclpy.exceptions import (
ParameterNotDeclaredException,
ParameterUninitializedException,
)
from rclpy.task import Future
-from rai.node import RaiBaseNode
-from rai.tools.ros import Ros2BaseInput, Ros2BaseTool
-from rai.tools.ros.utils import convert_ros_img_to_ndarray
-from rai.tools.utils import wait_for_message
-from rai.utils.ros_async import get_future_result
from rai_interfaces.srv import RAIGroundingDino
+from rai_open_set_vision import GDINO_SERVICE_NAME
# --------------------- Inputs ---------------------
@@ -147,9 +147,7 @@ def _parse_detection_array(
class GetDetectionTool(GroundingDinoBaseTool):
name: str = "GetDetectionTool"
- description: str = (
- "A tool for detecting specified objects using a ros2 action. The tool call might take some time to execute and is blocking - you will not be able to check their feedback, only will be informed about the result."
- )
+ description: str = "A tool for detecting specified objects using a ros2 action. The tool call might take some time to execute and is blocking - you will not be able to check their feedback, only will be informed about the result."
args_schema: Type[Ros2GetDetectionInput] = Ros2GetDetectionInput
@@ -177,9 +175,7 @@ def _run(
class GetDistanceToObjectsTool(GroundingDinoBaseTool):
name: str = "GetDistanceToObjectsTool"
- description: str = (
- "A tool for calculating distance to specified objects using a ros2 action. The tool call might take some time to execute and is blocking - you will not be able to check their feedback, only will be informed about the result."
- )
+ description: str = "A tool for calculating distance to specified objects using a ros2 action. The tool call might take some time to execute and is blocking - you will not be able to check their feedback, only will be informed about the result."
args_schema: Type[GetDistanceToObjectsInput] = GetDistanceToObjectsInput
diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py
index c264bd885..29ff0fe18 100644
--- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py
+++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py
@@ -19,18 +19,18 @@
import rclpy
import sensor_msgs.msg
from pydantic import Field
-from rai_open_set_vision import GDINO_SERVICE_NAME
+from rai.node import RaiBaseNode
+from rai.tools.ros import Ros2BaseInput, Ros2BaseTool
+from rai.tools.ros.utils import convert_ros_img_to_base64, convert_ros_img_to_ndarray
+from rai.utils.ros_async import get_future_result
from rclpy import Future
from rclpy.exceptions import (
ParameterNotDeclaredException,
ParameterUninitializedException,
)
-from rai.node import RaiBaseNode
-from rai.tools.ros import Ros2BaseInput, Ros2BaseTool
-from rai.tools.ros.utils import convert_ros_img_to_base64, convert_ros_img_to_ndarray
-from rai.utils.ros_async import get_future_result
from rai_interfaces.srv import RAIGroundedSam, RAIGroundingDino
+from rai_open_set_vision import GDINO_SERVICE_NAME
# --------------------- Inputs ---------------------
@@ -186,7 +186,6 @@ def depth_to_point_cloud(
class GetGrabbingPointTool(GetSegmentationTool):
-
name: str = "GetGrabbingPointTool"
description: str = "Get the grabbing point of an object"
pcd: List[Any] = []
diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/vision_markup/segmenter.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/vision_markup/segmenter.py
index e9c44c087..e91563a83 100644
--- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/vision_markup/segmenter.py
+++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/vision_markup/segmenter.py
@@ -20,13 +20,12 @@
import numpy as np
import torch
from cv_bridge import CvBridge
+from rai.tools.ros.utils import convert_ros_img_to_ndarray
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sensor_msgs.msg import Image
from vision_msgs.msg import BoundingBox2D
-from rai.tools.ros.utils import convert_ros_img_to_ndarray
-
class GDSegmenter:
def __init__(
diff --git a/src/rai_hmi/rai_hmi/agent.py b/src/rai_hmi/rai_hmi/agent.py
index 7243b699f..c54ead3da 100644
--- a/src/rai_hmi/rai_hmi/agent.py
+++ b/src/rai_hmi/rai_hmi/agent.py
@@ -17,11 +17,11 @@
from typing import List
from langchain.tools import tool
-
from rai.agents.conversational_agent import create_conversational_agent
from rai.node import RaiBaseNode
from rai.tools.ros.native import GetCameraImage, Ros2GetRobotInterfaces
from rai.utils.model_initialization import get_llm_model
+
from rai_hmi.base import BaseHMINode
from rai_hmi.chat_msgs import MissionMessage
from rai_hmi.task import Task, TaskInput
diff --git a/src/rai_hmi/rai_hmi/base.py b/src/rai_hmi/rai_hmi/base.py
index 9502a2255..41bf656d6 100644
--- a/src/rai_hmi/rai_hmi/base.py
+++ b/src/rai_hmi/rai_hmi/base.py
@@ -22,13 +22,13 @@
from langchain_core.documents import Document
from langchain_core.tools import BaseTool
from pydantic import UUID4
+from rai.node import append_whoami_info_to_prompt
+from rai.utils.model_initialization import get_embeddings_model
from rclpy.action import ActionClient
from rclpy.node import Node
from std_msgs.msg import String
from std_srvs.srv import Trigger
-from rai.node import append_whoami_info_to_prompt
-from rai.utils.model_initialization import get_embeddings_model
from rai_hmi.chat_msgs import (
MissionAcceptanceMessage,
MissionDoneMessage,
diff --git a/src/rai_hmi/rai_hmi/ros.py b/src/rai_hmi/rai_hmi/ros.py
index 181132f3e..eed05149e 100644
--- a/src/rai_hmi/rai_hmi/ros.py
+++ b/src/rai_hmi/rai_hmi/ros.py
@@ -19,9 +19,9 @@
from typing import Optional, Tuple
import rclpy
+from rai.node import RaiBaseNode
from rclpy.executors import MultiThreadedExecutor
-from rai.node import RaiBaseNode
from rai_hmi.base import BaseHMINode
diff --git a/src/rai_hmi/rai_hmi/text_hmi.py b/src/rai_hmi/rai_hmi/text_hmi.py
index 955888062..8f89b5bf7 100644
--- a/src/rai_hmi/rai_hmi/text_hmi.py
+++ b/src/rai_hmi/rai_hmi/text_hmi.py
@@ -33,12 +33,12 @@
)
from PIL import Image
from pydantic import BaseModel
-from rclpy.node import Node
-from streamlit.delta_generator import DeltaGenerator
-
from rai.messages import HumanMultimodalMessage
from rai.node import RaiBaseNode
from rai.utils.artifacts import get_stored_artifacts
+from rclpy.node import Node
+from streamlit.delta_generator import DeltaGenerator
+
from rai_hmi.agent import initialize_agent
from rai_hmi.base import BaseHMINode
from rai_hmi.chat_msgs import EMOJIS, MissionMessage
diff --git a/src/rai_hmi/rai_hmi/voice_hmi.py b/src/rai_hmi/rai_hmi/voice_hmi.py
index e380020f2..71b9e04a3 100644
--- a/src/rai_hmi/rai_hmi/voice_hmi.py
+++ b/src/rai_hmi/rai_hmi/voice_hmi.py
@@ -22,12 +22,12 @@
import rclpy
from langchain_core.messages import HumanMessage
+from rai.node import RaiBaseNode
from rclpy.callback_groups import ReentrantCallbackGroup
from rclpy.executors import MultiThreadedExecutor
from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile, ReliabilityPolicy
from std_msgs.msg import String
-from rai.node import RaiBaseNode
from rai_hmi.agent import initialize_agent
from rai_hmi.base import BaseHMINode
from rai_hmi.text_hmi_utils import Memory
diff --git a/src/rai_interfaces/package.xml b/src/rai_interfaces/package.xml
index 6627000a0..89f79d86c 100644
--- a/src/rai_interfaces/package.xml
+++ b/src/rai_interfaces/package.xml
@@ -18,6 +18,10 @@
rosidl_default_generators
ament_cmake
+ portaudio19-dev
+ nav2_msgs
+ nav2_simple_commander
+ tf_transformations
rosidl_default_runtime
rosidl_interface_packages
diff --git a/src/rai_asr/resource/rai_asr b/src/rai_tts/README.md
similarity index 100%
rename from src/rai_asr/resource/rai_asr
rename to src/rai_tts/README.md
diff --git a/src/rai_tts/config/elevenlabs.yaml b/src/rai_tts/config/elevenlabs.yaml
deleted file mode 100644
index 755f75617..000000000
--- a/src/rai_tts/config/elevenlabs.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-tts_node:
- ros__parameters:
- tts_client: elevenlabs
- voice: Jessica
- base_url: ""
- topic: to_human
diff --git a/src/rai_tts/config/opentts.yaml b/src/rai_tts/config/opentts.yaml
deleted file mode 100644
index 00c8a1cea..000000000
--- a/src/rai_tts/config/opentts.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-tts_node:
- ros__parameters:
- tts_client: opentts
- voice: larynx:blizzard_lessac-glow_tts
- base_url: http://localhost:5500/api/tts
- topic: to_human
diff --git a/src/rai_tts/launch/elevenlabs.launch.py b/src/rai_tts/launch/elevenlabs.launch.py
deleted file mode 100644
index c1ec695f2..000000000
--- a/src/rai_tts/launch/elevenlabs.launch.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import os
-
-from ament_index_python.packages import get_package_share_directory
-from launch import LaunchDescription
-from launch.actions import DeclareLaunchArgument, ExecuteProcess
-from launch.conditions import IfCondition
-from launch.substitutions import LaunchConfiguration
-from launch_ros.actions import Node
-
-
-def generate_launch_description():
- config = os.path.join(
- get_package_share_directory("rai_tts"), "config", "elevenlabs.yaml"
- )
- launch_configuration = [
- DeclareLaunchArgument(
- "config_file",
- default_value=config,
- description="Path to the config file",
- ),
- Node(
- package="rai_tts",
- executable="tts_node",
- name="tts_node",
- parameters=[LaunchConfiguration("config_file")],
- ),
- ExecuteProcess(
- cmd=[
- "ffplay",
- "-f",
- "lavfi",
- "-i",
- "sine=frequency=432",
- "-af",
- "volume=0.01",
- "-nodisp",
- "-v",
- "0",
- ],
- name="ffplay_sine_wave",
- output="screen",
- condition=IfCondition(
- LaunchConfiguration("keep_speaker_busy", default=False)
- ),
- ),
- ]
-
- return LaunchDescription(launch_configuration)
diff --git a/src/rai_tts/launch/opentts.launch.py b/src/rai_tts/launch/opentts.launch.py
deleted file mode 100644
index a2149bf66..000000000
--- a/src/rai_tts/launch/opentts.launch.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import os
-
-from ament_index_python.packages import get_package_share_directory
-from launch import LaunchDescription
-from launch.actions import DeclareLaunchArgument, ExecuteProcess
-from launch.conditions import IfCondition
-from launch.substitutions import LaunchConfiguration
-from launch_ros.actions import Node
-
-
-def generate_launch_description():
- config = os.path.join(
- get_package_share_directory("rai_tts"), "config", "opentts.yaml"
- )
-
- return LaunchDescription(
- [
- DeclareLaunchArgument(
- "config_file",
- default_value=config,
- description="Path to the config file",
- ),
- Node(
- package="rai_tts",
- executable="tts_node",
- name="tts_node",
- parameters=[LaunchConfiguration("config_file")],
- ),
- ExecuteProcess(
- cmd=[
- "ffplay",
- "-f",
- "lavfi",
- "-i",
- "sine=frequency=432",
- "-af",
- "volume=0.01",
- "-nodisp",
- "-v",
- "0",
- ],
- name="ffplay_sine_wave",
- output="screen",
- condition=IfCondition(
- LaunchConfiguration("keep_speaker_busy", default=False)
- ),
- ),
- ]
- )
diff --git a/src/rai_tts/package.xml b/src/rai_tts/package.xml
deleted file mode 100644
index aa2c599c8..000000000
--- a/src/rai_tts/package.xml
+++ /dev/null
@@ -1,16 +0,0 @@
-
-
-
- rai_tts
- 0.1.0
- A Text To Speech package with streaming capabilities.
- maciejmajek
- Apache-2.0
-
- ffmpeg
- python3-pytest
-
-
- ament_python
-
-
diff --git a/src/rai_tts/pyproject.toml b/src/rai_tts/pyproject.toml
new file mode 100644
index 000000000..0c4bface2
--- /dev/null
+++ b/src/rai_tts/pyproject.toml
@@ -0,0 +1,23 @@
+[tool.poetry]
+name = "rai_tts"
+version = "1.0.0"
+description = "Text-to-Speech module for RAI framework"
+authors = ["Maciej Majek ", "Bartłomiej Boczek ", "Kajetan Rachwał "]
+readme = "README.md"
+classifiers = [
+ "Programming Language :: Python :: 3",
+ "Development Status :: 4 - Beta",
+ "License :: OSI Approved :: Apache Software License",
+]
+packages = [
+ { include = "rai_tts", from = "." },
+]
+
+[build-system]
+requires = ["poetry-core>=1.0.0"]
+build-backend = "poetry.core.masonry.api"
+
+[tool.poetry.dependencies]
+python = "^3.10, <3.13"
+elevenlabs = "^1.4.1"
+sounddevice = "^0.4.7"
diff --git a/src/rai_tts/rai_tts/__init__.py b/src/rai_tts/rai_tts/__init__.py
index ef74fc891..2fac6f9fe 100644
--- a/src/rai_tts/rai_tts/__init__.py
+++ b/src/rai_tts/rai_tts/__init__.py
@@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from .tts_clients import ElevenLabsClient, OpenTTSClient
+
+__all__ = ["ElevenLabsClient", "OpenTTSClient"]
diff --git a/src/rai_tts/rai_tts/tts_node.py b/src/rai_tts/rai_tts/tts_node.py
deleted file mode 100644
index a819e87d8..000000000
--- a/src/rai_tts/rai_tts/tts_node.py
+++ /dev/null
@@ -1,186 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import re
-import subprocess
-import threading
-import time
-from queue import PriorityQueue
-from typing import NamedTuple, cast
-
-import rclpy
-from rclpy.node import Node
-from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile, ReliabilityPolicy
-from std_msgs.msg import String
-
-from .tts_clients import ElevenLabsClient, OpenTTSClient, TTSClient
-
-
-class TTSJob(NamedTuple):
- id: int
- file_path: str
-
-
-class TTSNode(Node):
- def __init__(self):
- super().__init__("rai_tts_node")
-
- self.declare_parameter("tts_client", "opentts")
- self.declare_parameter("voice", "larynx:blizzard_lessac-glow_tts")
- self.declare_parameter("base_url", "http://localhost:5500/api/tts")
- self.declare_parameter("topic", "to_human")
-
- topic_param = self.get_parameter("topic").get_parameter_value().string_value # type: ignore
- reliable_qos = QoSProfile(
- reliability=ReliabilityPolicy.RELIABLE,
- durability=DurabilityPolicy.TRANSIENT_LOCAL,
- history=HistoryPolicy.KEEP_ALL,
- )
- self.subscription = self.create_subscription( # type: ignore
- String, topic_param, self.listener_callback, qos_profile=reliable_qos # type: ignore
- )
- self.playing = False
- self.status_publisher = self.create_publisher(String, "tts_status", 10) # type: ignore
- self.queue: PriorityQueue[TTSJob] = PriorityQueue()
- self.it: int = 0
- self.job_id: int = 0
- self.queued_job_id = 0
- self.tts_client = self._initialize_client()
-
- status_publisher_timer_period_sec = 0.25
- self.create_timer(status_publisher_timer_period_sec, self.status_callback)
- threading.Thread(target=self._process_queue).start()
- self.get_logger().info("TTS Node has been started") # type: ignore
- self.threads_number = 0
- self.threads_max = 5
- self.thread_lock = threading.Lock()
-
- def status_callback(self):
- if self.threads_number == 0 and self.playing is False and self.queue.empty():
- self.status_publisher.publish(String(data="waiting"))
- else:
- self.status_publisher.publish(String(data="processing"))
-
- def listener_callback(self, msg: String):
- self.playing = True
- self.get_logger().info( # type: ignore
- f"Registering new TTS job: {self.job_id} length: {len(msg.data)} chars." # type: ignore
- )
- self.get_logger().debug(f"The job: {msg.data}") # type: ignore
-
- threading.Thread(
- target=self.start_synthesize_thread, args=(msg, self.job_id) # type: ignore
- ).start()
- self.job_id += 1
-
- def start_synthesize_thread(self, msg: String, job_id: int):
- while True:
- with self.thread_lock:
- if (
- self.threads_number < self.threads_max
- and self.queued_job_id == job_id
- ):
- threading.Thread(
- target=self.synthesize_speech, args=(job_id, msg.data) # type: ignore
- ).start()
- self.threads_number += 1
- self.queued_job_id += 1
- return
-
- def synthesize_speech(
- self,
- id: int,
- text: str,
- ) -> str:
- text = self._preprocess_text(text)
- if id > 0:
- time.sleep(0.5)
- temp_file_path = self.tts_client.synthesize_speech_to_file(text)
- self.get_logger().info(f"Job {id} completed.") # type: ignore
- tts_job = TTSJob(id, temp_file_path)
- self.queue.put(tts_job)
- with self.thread_lock:
- self.threads_number -= 1
-
- return temp_file_path
-
- def _process_queue(self):
- while rclpy.ok():
- time.sleep(0.01)
- if not self.queue.empty():
- if self.queue.queue[0][0] == self.it:
- self.it += 1
- tts_job = self.queue.get()
- self.get_logger().info( # type: ignore
- f"Playing audio for job {tts_job.id}. {tts_job.file_path}"
- )
- self._play_audio(tts_job.file_path)
-
- def _play_audio(self, filepath: str):
- self.playing = True
- self.status_publisher.publish(String(data="playing"))
- subprocess.run(
- ["ffplay", "-v", "0", "-nodisp", "-autoexit", filepath],
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
- self.playing = False
-
- def _initialize_client(self) -> TTSClient:
- tts_client_param = cast(str, self.get_parameter("tts_client").get_parameter_value().string_value) # type: ignore
- voice_param = cast(str, self.get_parameter("voice").get_parameter_value().string_value) # type: ignore
- base_url_param = cast(str, self.get_parameter("base_url").get_parameter_value().string_value) # type: ignore
-
- if tts_client_param == "opentts":
- return OpenTTSClient(
- base_url=base_url_param,
- voice=voice_param,
- )
- elif tts_client_param == "elevenlabs":
- return ElevenLabsClient(
- voice=voice_param,
- base_url=base_url_param,
- )
- else:
- raise ValueError(f"Unknown TTS client: {tts_client_param}")
-
- def _preprocess_text(self, text: str) -> str:
- """Remove emojis from text."""
- emoji_pattern = re.compile(
- "["
- "\U0001F600-\U0001F64F" # emoticons
- "\U0001F300-\U0001F5FF" # symbols & pictographs
- "\U0001F680-\U0001F6FF" # transport & map symbols
- "\U0001F1E0-\U0001F1FF" # flags (iOS)
- "]+",
- flags=re.UNICODE,
- )
- text = emoji_pattern.sub(r"", text)
- return text
-
-
-def main():
- rclpy.init()
-
- tts_node = TTSNode()
-
- rclpy.spin(tts_node)
-
- tts_node.destroy_node()
- rclpy.shutdown()
-
-
-if __name__ == "__main__":
- main()
diff --git a/src/rai_tts/resource/rai_tts b/src/rai_tts/resource/rai_tts
deleted file mode 100644
index e69de29bb..000000000
diff --git a/src/rai_tts/setup.cfg b/src/rai_tts/setup.cfg
deleted file mode 100644
index 10e09e3c1..000000000
--- a/src/rai_tts/setup.cfg
+++ /dev/null
@@ -1,4 +0,0 @@
-[develop]
-script_dir=$base/lib/rai_tts
-[install]
-install_scripts=$base/lib/rai_tts
diff --git a/src/rai_tts/setup.py b/src/rai_tts/setup.py
deleted file mode 100644
index 4fc9f0dc3..000000000
--- a/src/rai_tts/setup.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Copyright (C) 2024 Robotec.AI
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import os
-from glob import glob
-
-from setuptools import find_packages, setup
-
-package_name = "rai_tts"
-
-setup(
- name=package_name,
- version="0.1.0",
- packages=find_packages(exclude=["test"]),
- data_files=[
- ("share/ament_index/resource_index/packages", ["resource/" + package_name]),
- ("share/" + package_name, ["package.xml"]),
- (os.path.join("share", package_name, "launch"), glob("launch/*.launch.py")),
- (os.path.join("share", package_name, "config"), glob("config/*.yaml")),
- ],
- install_requires=["setuptools"],
- zip_safe=True,
- maintainer="maciejmajek",
- maintainer_email="maciej.majek@robotec.ai",
- description="A Text To Speech package with streaming capabilities.",
- license="Apache-2.0",
- tests_require=["pytest"],
- entry_points={
- "console_scripts": [
- "tts_node = rai_tts.tts_node:main",
- ],
- },
-)
diff --git a/src/rai_whoami/package.xml b/src/rai_whoami/package.xml
index fbbc8118f..3508fcc6f 100644
--- a/src/rai_whoami/package.xml
+++ b/src/rai_whoami/package.xml
@@ -8,7 +8,6 @@
Apache-2.0
ament_copyright
- ament_flake8
ament_pep257
python3-pytest
rai_interfaces
diff --git a/src/rai_whoami/rai_whoami/rai_whoami_node.py b/src/rai_whoami/rai_whoami/rai_whoami_node.py
index c98f17b46..d592fb54d 100644
--- a/src/rai_whoami/rai_whoami/rai_whoami_node.py
+++ b/src/rai_whoami/rai_whoami/rai_whoami_node.py
@@ -18,21 +18,20 @@
import rclpy
from ament_index_python.packages import get_package_share_directory
from langchain_community.vectorstores import FAISS
-from rai_interfaces.srv._vector_store_retrieval import (
- VectorStoreRetrieval_Request,
- VectorStoreRetrieval_Response,
-)
+from rai.utils.model_initialization import get_embeddings_model
from rclpy.node import Node
from rclpy.parameter import Parameter
from std_srvs.srv import Trigger
from std_srvs.srv._trigger import Trigger_Request, Trigger_Response
-from rai.utils.model_initialization import get_embeddings_model
from rai_interfaces.srv import VectorStoreRetrieval
+from rai_interfaces.srv._vector_store_retrieval import (
+ VectorStoreRetrieval_Request,
+ VectorStoreRetrieval_Response,
+)
class WhoAmI(Node):
-
def __init__(self):
super().__init__("rai_whoami_node")
self.declare_parameter("robot_description_package", Parameter.Type.STRING)
diff --git a/tests/communication/ros2/__init__.py b/tests/communication/ros2/__init__.py
index 4764660ff..efe9277e1 100644
--- a/tests/communication/ros2/__init__.py
+++ b/tests/communication/ros2/__init__.py
@@ -21,9 +21,9 @@
)
__all__ = [
- "shutdown_executors_and_threads",
- "multi_threaded_spinner",
- "MessagePublisher",
"ActionServer",
+ "MessagePublisher",
"MessageReceiver",
+ "multi_threaded_spinner",
+ "shutdown_executors_and_threads",
]
diff --git a/tests/communication/ros2/test_api.py b/tests/communication/ros2/test_api.py
index d0ba3ec6b..cdae3f98f 100644
--- a/tests/communication/ros2/test_api.py
+++ b/tests/communication/ros2/test_api.py
@@ -18,9 +18,6 @@
import pytest
from action_msgs.msg import GoalStatus
from action_msgs.srv import CancelGoal
-from rclpy.executors import MultiThreadedExecutor
-from rclpy.node import Node
-
from rai.communication.ros2.api import (
ConfigurableROS2TopicAPI,
ROS2ActionAPI,
@@ -28,6 +25,8 @@
ROS2TopicAPI,
TopicConfig,
)
+from rclpy.executors import MultiThreadedExecutor
+from rclpy.node import Node
from .helpers import ActionServer_ as ActionServer
from .helpers import (
diff --git a/tests/communication/ros2/test_connectors.py b/tests/communication/ros2/test_connectors.py
index f8a3211dd..3045d655a 100644
--- a/tests/communication/ros2/test_connectors.py
+++ b/tests/communication/ros2/test_connectors.py
@@ -16,11 +16,10 @@
from typing import Any, List
import pytest
+from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage
from std_msgs.msg import String
from std_srvs.srv import SetBool
-from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage
-
from .helpers import ActionServer_ as ActionServer
from .helpers import (
MessagePublisher,
diff --git a/tests/communication/sounds_device/test_api.py b/tests/communication/sounds_device/test_api.py
index 1df6f464d..b5a62a54a 100644
--- a/tests/communication/sounds_device/test_api.py
+++ b/tests/communication/sounds_device/test_api.py
@@ -16,7 +16,6 @@
import numpy as np
import pytest
import sounddevice
-
from rai.communication.sound_device import (
SoundDeviceAPI,
SoundDeviceConfig,
@@ -34,14 +33,13 @@ def mock_sd():
mock_stop = MagicMock()
mock_wait = MagicMock()
- with patch.object(sounddevice, "play", mock_play), patch.object(
- sounddevice, "rec", mock_rec
- ), patch.object(sounddevice, "open", mock_open), patch.object(
- sounddevice, "stop", mock_stop
- ), patch.object(
- sounddevice, "wait", mock_wait
+ with (
+ patch.object(sounddevice, "play", mock_play),
+ patch.object(sounddevice, "rec", mock_rec),
+ patch.object(sounddevice, "open", mock_open),
+ patch.object(sounddevice, "stop", mock_stop),
+ patch.object(sounddevice, "wait", mock_wait),
):
-
yield {
"play": mock_play,
"rec": mock_rec,
diff --git a/tests/communication/sounds_device/test_connector.py b/tests/communication/sounds_device/test_connector.py
index ba0daf9c4..a7f86caa9 100644
--- a/tests/communication/sounds_device/test_connector.py
+++ b/tests/communication/sounds_device/test_connector.py
@@ -18,14 +18,13 @@
import numpy as np
import pytest
import sounddevice
-from scipy.io import wavfile
-
from rai.communication import HRIPayload
from rai.communication.sound_device import SoundDeviceConfig, SoundDeviceError
from rai.communication.sound_device.connector import ( # Replace with actual module name
SoundDeviceConnector,
SoundDeviceMessage,
)
+from scipy.io import wavfile
@pytest.fixture
diff --git a/tests/communication/test_hri_message.py b/tests/communication/test_hri_message.py
index bba976da6..87a4385fb 100644
--- a/tests/communication/test_hri_message.py
+++ b/tests/communication/test_hri_message.py
@@ -18,7 +18,6 @@
from langchain_core.messages import HumanMessage
from PIL import Image
from pydub import AudioSegment
-
from rai.communication import HRIMessage, HRIPayload
from rai.messages.multimodal import MultimodalMessage as RAIMultimodalMessage
diff --git a/tests/conftest.py b/tests/conftest.py
index ce9dfd6df..74089aea0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -18,9 +18,8 @@
import pytest
from _pytest.terminal import TerminalReporter
-from tabulate import tabulate
-
from rai.config.models import BEDROCK_CLAUDE_HAIKU, OPENAI_MINI
+from tabulate import tabulate
@pytest.fixture
diff --git a/tests/core/test_rai_cli.py b/tests/core/test_rai_cli.py
index 19f16f97a..fa7aa3b24 100644
--- a/tests/core/test_rai_cli.py
+++ b/tests/core/test_rai_cli.py
@@ -17,7 +17,6 @@
from unittest.mock import MagicMock, patch
import pytest
-
from rai.cli.rai_cli import create_rai_ws
diff --git a/tests/core/test_ros2_tools.py b/tests/core/test_ros2_tools.py
index 10f7c5a65..6fc48ddda 100644
--- a/tests/core/test_ros2_tools.py
+++ b/tests/core/test_ros2_tools.py
@@ -19,13 +19,12 @@
import rclpy
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai.chat_models import ChatOpenAI
+from rai.agents.state_based import create_state_based_agent
+from rai.tools.ros.native import Ros2PubMessageTool
from rclpy.executors import MultiThreadedExecutor
from rclpy.node import Node
from std_msgs.msg import String
-from rai.agents.state_based import create_state_based_agent
-from rai.tools.ros.native import Ros2PubMessageTool
-
class Subscriber(Node):
def __init__(self) -> None:
diff --git a/tests/core/test_tool_runner.py b/tests/core/test_tool_runner.py
index 793eb8998..0c6971d26 100644
--- a/tests/core/test_tool_runner.py
+++ b/tests/core/test_tool_runner.py
@@ -16,7 +16,6 @@
from langchain_core.messages import AIMessage, ToolCall, ToolMessage
from langchain_core.tools import tool
-
from rai.agents.tool_runner import ToolRunner
from rai.messages import HumanMultimodalMessage, ToolMultimodalMessage
from rai.messages.utils import preprocess_image
@@ -36,12 +35,12 @@ def test_tool_runner_invalid_call():
tool_call = ToolCall(name="bad_fn", args={"command": "list"}, id="12345")
state = {"messages": [AIMessage(content="", tool_calls=[tool_call])]}
output = runner.invoke(state)
- assert isinstance(
- output["messages"][0], AIMessage
- ), "First message is not an AIMessage"
- assert isinstance(
- output["messages"][1], ToolMessage
- ), "Tool output is not a tool message"
+ assert isinstance(output["messages"][0], AIMessage), (
+ "First message is not an AIMessage"
+ )
+ assert isinstance(output["messages"][1], ToolMessage), (
+ "Tool output is not a tool message"
+ )
assert output["messages"][1].status == "error"
@@ -51,15 +50,15 @@ def test_tool_runner():
tool_call = ToolCall(name="ros2_topic", args={"command": "list"}, id="12345")
state = {"messages": [AIMessage(content="", tool_calls=[tool_call])]}
output = runner.invoke(state)
- assert isinstance(
- output["messages"][0], AIMessage
- ), "First message is not an AIMessage"
- assert isinstance(
- output["messages"][1], ToolMessage
- ), "Tool output is not a tool message"
- assert (
- len(output["messages"][-1].content) > 0
- ), "Tool output is empty. At least rosout should be visible."
+ assert isinstance(output["messages"][0], AIMessage), (
+ "First message is not an AIMessage"
+ )
+ assert isinstance(output["messages"][1], ToolMessage), (
+ "Tool output is not a tool message"
+ )
+ assert len(output["messages"][-1].content) > 0, (
+ "Tool output is empty. At least rosout should be visible."
+ )
def test_tool_runner_multimodal():
@@ -71,12 +70,12 @@ def test_tool_runner_multimodal():
state = {"messages": [AIMessage(content="", tool_calls=[tool_call])]}
output = runner.invoke(state)
- assert isinstance(
- output["messages"][0], AIMessage
- ), "First message is not an AIMessage"
- assert isinstance(
- output["messages"][1], ToolMultimodalMessage
- ), "Tool output is not a multimodal message"
- assert isinstance(
- output["messages"][2], HumanMultimodalMessage
- ), "Human output is not a multimodal message"
+ assert isinstance(output["messages"][0], AIMessage), (
+ "First message is not an AIMessage"
+ )
+ assert isinstance(output["messages"][1], ToolMultimodalMessage), (
+ "Tool output is not a multimodal message"
+ )
+ assert isinstance(output["messages"][2], HumanMultimodalMessage), (
+ "Human output is not a multimodal message"
+ )
diff --git a/tests/messages/test_multimodal.py b/tests/messages/test_multimodal.py
index 3ada5ec40..667d506fa 100644
--- a/tests/messages/test_multimodal.py
+++ b/tests/messages/test_multimodal.py
@@ -29,7 +29,6 @@
from langfuse.callback import CallbackHandler
from pydantic import BaseModel, Field
from pytest import FixtureRequest
-
from rai.messages import HumanMultimodalMessage
from rai.tools.utils import run_requested_tools
@@ -39,7 +38,6 @@ class GetImageToolInput(BaseModel):
class GetImageTool(BaseTool):
-
name: str = "GetImageTool"
description: str = "Get an image from the user"
diff --git a/tests/messages/test_transport.py b/tests/messages/test_transport.py
index 93a31e5e6..42b20869c 100644
--- a/tests/messages/test_transport.py
+++ b/tests/messages/test_transport.py
@@ -20,14 +20,13 @@
import numpy as np
import pytest
import rclpy
+from rai.node import RaiBaseNode
from rclpy.executors import SingleThreadedExecutor
from rclpy.node import Node
from rclpy.qos import QoSPresetProfiles, QoSProfile
from sensor_msgs.msg import Image
from std_msgs.msg import String
-from rai.node import RaiBaseNode
-
def get_qos_profiles() -> List[str]:
ros_distro = os.environ.get("ROS_DISTRO")
diff --git a/tests/messages/test_utils.py b/tests/messages/test_utils.py
index 268aa4c7a..baa727cba 100644
--- a/tests/messages/test_utils.py
+++ b/tests/messages/test_utils.py
@@ -18,7 +18,6 @@
import numpy as np
import pytest
from PIL import Image
-
from rai.messages.utils import preprocess_image
diff --git a/tests/smoke/import_test.py b/tests/smoke/import_test.py
index 6b7ccf202..9ea1e42e2 100644
--- a/tests/smoke/import_test.py
+++ b/tests/smoke/import_test.py
@@ -35,9 +35,7 @@ def rai_python_modules():
@pytest.mark.parametrize("module", rai_python_modules())
def test_can_import_all_modules_pathlib(module: ModuleType) -> None:
-
def import_submodules(package: ModuleType) -> None:
-
package_path = pathlib.Path(package.__file__).parent # type: ignore
importables = set()
diff --git a/tests/tools/ros2/test_action_tools.py b/tests/tools/ros2/test_action_tools.py
index 3d46c69b6..e8aa6945c 100644
--- a/tests/tools/ros2/test_action_tools.py
+++ b/tests/tools/ros2/test_action_tools.py
@@ -24,6 +24,7 @@
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros2 import StartROS2ActionTool
+
from tests.communication.ros2.helpers import ActionServer_ as ActionServer
from tests.communication.ros2.helpers import (
multi_threaded_spinner,
diff --git a/tests/tools/ros2/test_service_tools.py b/tests/tools/ros2/test_service_tools.py
index ee813b327..ca60515ad 100644
--- a/tests/tools/ros2/test_service_tools.py
+++ b/tests/tools/ros2/test_service_tools.py
@@ -24,6 +24,7 @@
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros2 import CallROS2ServiceTool
+
from tests.communication.ros2.helpers import (
ServiceServer,
multi_threaded_spinner,
diff --git a/tests/tools/ros2/test_topic_tools.py b/tests/tools/ros2/test_topic_tools.py
index 4d615597e..5e320bf2e 100644
--- a/tests/tools/ros2/test_topic_tools.py
+++ b/tests/tools/ros2/test_topic_tools.py
@@ -26,7 +26,6 @@
import time
from PIL import Image
-
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros2 import (
GetROS2ImageTool,
@@ -36,6 +35,7 @@
PublishROS2MessageTool,
ReceiveROS2MessageTool,
)
+
from tests.communication.ros2.helpers import (
ImagePublisher,
MessagePublisher,
diff --git a/tests/tools/test_tool_input_args_compatibility.py b/tests/tools/test_tool_input_args_compatibility.py
index 6ab845037..9845dc595 100644
--- a/tests/tools/test_tool_input_args_compatibility.py
+++ b/tests/tools/test_tool_input_args_compatibility.py
@@ -23,7 +23,7 @@
def get_all_tool_classes() -> set[BaseTool]:
"""Recursively find all classes that inherit from pydantic.BaseModel in src/rai/rai/tools"""
tools = []
- tools_path = Path("src/rai/rai/tools")
+ tools_path = Path("src/rai_core/rai/tools")
# Recursively find all .py files
for py_file in tools_path.rglob("*.py"):
diff --git a/tests/tools/test_tool_utils.py b/tests/tools/test_tool_utils.py
index fb62fb21a..1efec5425 100644
--- a/tests/tools/test_tool_utils.py
+++ b/tests/tools/test_tool_utils.py
@@ -16,11 +16,10 @@
import pytest
from geometry_msgs.msg import Point, TransformStamped
from nav2_msgs.action import NavigateToPose
+from rai.tools.ros2.utils import ros2_message_to_dict
from sensor_msgs.msg import Image
from tf2_msgs.msg import TFMessage
-from rai.tools.ros2.utils import ros2_message_to_dict
-
# TODO(`maciejmajek`): Add custom RAI messages?
@pytest.mark.parametrize(