From 5edc2b9a589dba1c45f7df29f65e48307b9083ef Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Thu, 6 Feb 2025 19:27:49 +0100 Subject: [PATCH 1/8] refactor: rai as a python package --- pyproject.toml | 7 +- src/rai/LICENSE | 202 ------------------ src/rai/package.xml | 23 -- src/rai/resource/rai | 0 src/rai/setup.cfg | 4 - src/rai/setup.py | 36 ---- src/rai_core/pyproject.toml | 35 +++ src/{rai => rai_core}/rai/__init__.py | 0 src/{rai => rai_core}/rai/agents/__init__.py | 0 src/{rai => rai_core}/rai/agents/base.py | 0 .../rai/agents/conversational_agent.py | 0 .../rai/agents/integrations/__init__.py | 0 .../rai/agents/integrations/streamlit.py | 0 .../rai/agents/state_based.py | 0 .../rai/agents/tool_runner.py | 0 .../rai/agents/voice_agent.py | 0 src/{rai => rai_core}/rai/apps/__init__.py | 0 .../rai/apps/document_loader.py | 0 .../rai/apps/high_level_api.py | 0 .../rai/apps/state_analyzer.py | 0 .../rai/apps/talk_to_docs.py | 0 .../rai/apps/task_executor.py | 0 .../rai/apps/task_planner.py | 0 src/{rai => rai_core}/rai/cli/__init__.py | 0 src/{rai => rai_core}/rai/cli/rai_cli.py | 0 .../resources/default_robot_constitution.txt | 0 .../rai/communication/__init__.py | 0 .../rai/communication/ari_connector.py | 0 .../rai/communication/base_connector.py | 0 .../rai/communication/hri_connector.py | 0 .../rai/communication/ros2/__init__.py | 0 .../rai/communication/ros2/api.py | 0 .../rai/communication/ros2/connectors.py | 0 .../communication/sound_device/__init__.py | 0 .../rai/communication/sound_device/api.py | 0 .../communication/sound_device/connector.py | 0 src/{rai => rai_core}/rai/config/__init__.py | 0 src/{rai => rai_core}/rai/config/models.py | 0 .../rai/extensions/__init__.py | 0 .../rai/messages/__init__.py | 0 .../rai/messages/multimodal.py | 0 src/{rai => rai_core}/rai/messages/utils.py | 0 src/{rai => rai_core}/rai/node.py | 0 src/{rai => rai_core}/rai/ros2_apis.py | 0 src/{rai => rai_core}/rai/tools/__init__.py | 0 .../rai/tools/debugging_assistant.py | 0 .../rai/tools/ros/__init__.py | 0 src/{rai => rai_core}/rai/tools/ros/cli.py | 0 .../rai/tools/ros/deprecated.py | 0 .../rai/tools/ros/manipulation.py | 0 src/{rai => rai_core}/rai/tools/ros/native.py | 0 .../rai/tools/ros/native_actions.py | 0 .../rai/tools/ros/nav2/__init__.py | 0 .../rai/tools/ros/nav2/basic_navigator.py | 0 .../rai/tools/ros/nav2/navigator.py | 0 src/{rai => rai_core}/rai/tools/ros/tools.py | 0 src/{rai => rai_core}/rai/tools/ros/utils.py | 0 .../rai/tools/ros2/__init__.py | 0 .../rai/tools/ros2/actions.py | 0 .../rai/tools/ros2/services.py | 0 .../rai/tools/ros2/topics.py | 0 src/{rai => rai_core}/rai/tools/ros2/utils.py | 0 src/{rai => rai_core}/rai/tools/time.py | 0 src/{rai => rai_core}/rai/tools/utils.py | 0 src/{rai => rai_core}/rai/utils/__init__.py | 0 src/{rai => rai_core}/rai/utils/artifacts.py | 0 .../rai/utils/configurator.py | 0 .../rai/utils/model_initialization.py | 0 src/{rai => rai_core}/rai/utils/ros.py | 0 src/{rai => rai_core}/rai/utils/ros_async.py | 0 .../rai/utils/ros_executors.py | 0 src/{rai => rai_core}/rai/utils/ros_logs.py | 0 src/rai_interfaces/package.xml | 3 + 73 files changed, 43 insertions(+), 267 deletions(-) delete mode 100644 src/rai/LICENSE delete mode 100644 src/rai/package.xml delete mode 100644 src/rai/resource/rai delete mode 100644 src/rai/setup.cfg delete mode 100644 src/rai/setup.py create mode 100644 src/rai_core/pyproject.toml rename src/{rai => rai_core}/rai/__init__.py (100%) rename src/{rai => rai_core}/rai/agents/__init__.py (100%) rename src/{rai => rai_core}/rai/agents/base.py (100%) rename src/{rai => rai_core}/rai/agents/conversational_agent.py (100%) rename src/{rai => rai_core}/rai/agents/integrations/__init__.py (100%) rename src/{rai => rai_core}/rai/agents/integrations/streamlit.py (100%) rename src/{rai => rai_core}/rai/agents/state_based.py (100%) rename src/{rai => rai_core}/rai/agents/tool_runner.py (100%) rename src/{rai => rai_core}/rai/agents/voice_agent.py (100%) rename src/{rai => rai_core}/rai/apps/__init__.py (100%) rename src/{rai => rai_core}/rai/apps/document_loader.py (100%) rename src/{rai => rai_core}/rai/apps/high_level_api.py (100%) rename src/{rai => rai_core}/rai/apps/state_analyzer.py (100%) rename src/{rai => rai_core}/rai/apps/talk_to_docs.py (100%) rename src/{rai => rai_core}/rai/apps/task_executor.py (100%) rename src/{rai => rai_core}/rai/apps/task_planner.py (100%) rename src/{rai => rai_core}/rai/cli/__init__.py (100%) rename src/{rai => rai_core}/rai/cli/rai_cli.py (100%) rename src/{rai => rai_core}/rai/cli/resources/default_robot_constitution.txt (100%) rename src/{rai => rai_core}/rai/communication/__init__.py (100%) rename src/{rai => rai_core}/rai/communication/ari_connector.py (100%) rename src/{rai => rai_core}/rai/communication/base_connector.py (100%) rename src/{rai => rai_core}/rai/communication/hri_connector.py (100%) rename src/{rai => rai_core}/rai/communication/ros2/__init__.py (100%) rename src/{rai => rai_core}/rai/communication/ros2/api.py (100%) rename src/{rai => rai_core}/rai/communication/ros2/connectors.py (100%) rename src/{rai => rai_core}/rai/communication/sound_device/__init__.py (100%) rename src/{rai => rai_core}/rai/communication/sound_device/api.py (100%) rename src/{rai => rai_core}/rai/communication/sound_device/connector.py (100%) rename src/{rai => rai_core}/rai/config/__init__.py (100%) rename src/{rai => rai_core}/rai/config/models.py (100%) rename src/{rai => rai_core}/rai/extensions/__init__.py (100%) rename src/{rai => rai_core}/rai/messages/__init__.py (100%) rename src/{rai => rai_core}/rai/messages/multimodal.py (100%) rename src/{rai => rai_core}/rai/messages/utils.py (100%) rename src/{rai => rai_core}/rai/node.py (100%) rename src/{rai => rai_core}/rai/ros2_apis.py (100%) rename src/{rai => rai_core}/rai/tools/__init__.py (100%) rename src/{rai => rai_core}/rai/tools/debugging_assistant.py (100%) rename src/{rai => rai_core}/rai/tools/ros/__init__.py (100%) rename src/{rai => rai_core}/rai/tools/ros/cli.py (100%) rename src/{rai => rai_core}/rai/tools/ros/deprecated.py (100%) rename src/{rai => rai_core}/rai/tools/ros/manipulation.py (100%) rename src/{rai => rai_core}/rai/tools/ros/native.py (100%) rename src/{rai => rai_core}/rai/tools/ros/native_actions.py (100%) rename src/{rai => rai_core}/rai/tools/ros/nav2/__init__.py (100%) rename src/{rai => rai_core}/rai/tools/ros/nav2/basic_navigator.py (100%) rename src/{rai => rai_core}/rai/tools/ros/nav2/navigator.py (100%) rename src/{rai => rai_core}/rai/tools/ros/tools.py (100%) rename src/{rai => rai_core}/rai/tools/ros/utils.py (100%) rename src/{rai => rai_core}/rai/tools/ros2/__init__.py (100%) rename src/{rai => rai_core}/rai/tools/ros2/actions.py (100%) rename src/{rai => rai_core}/rai/tools/ros2/services.py (100%) rename src/{rai => rai_core}/rai/tools/ros2/topics.py (100%) rename src/{rai => rai_core}/rai/tools/ros2/utils.py (100%) rename src/{rai => rai_core}/rai/tools/time.py (100%) rename src/{rai => rai_core}/rai/tools/utils.py (100%) rename src/{rai => rai_core}/rai/utils/__init__.py (100%) rename src/{rai => rai_core}/rai/utils/artifacts.py (100%) rename src/{rai => rai_core}/rai/utils/configurator.py (100%) rename src/{rai => rai_core}/rai/utils/model_initialization.py (100%) rename src/{rai => rai_core}/rai/utils/ros.py (100%) rename src/{rai => rai_core}/rai/utils/ros_async.py (100%) rename src/{rai => rai_core}/rai/utils/ros_executors.py (100%) rename src/{rai => rai_core}/rai/utils/ros_logs.py (100%) diff --git a/pyproject.toml b/pyproject.toml index 3224ea498..5d8fb1355 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.poetry] -name = "rai" +name = "rai_framework" version = "1.0.0" description = "RAI is a framework for building general multi-agent systems, bringing Gen AI features to ROS enabled robots." readme = "README.md" @@ -9,7 +9,10 @@ classifiers = [ "Development Status :: 4 - Beta", "License :: OSI Approved :: Apache Software License", ] -package-mode = false + +packages = [ + { include = "rai_core", from = "src" }, +] [tool.poetry.dependencies] python = "^3.10, <3.13" langchain-core = "^0.3" diff --git a/src/rai/LICENSE b/src/rai/LICENSE deleted file mode 100644 index d0356d5ef..000000000 --- a/src/rai/LICENSE +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright 2024-present Robotec.ai - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/src/rai/package.xml b/src/rai/package.xml deleted file mode 100644 index 58cec45b7..000000000 --- a/src/rai/package.xml +++ /dev/null @@ -1,23 +0,0 @@ - - - - rai - 1.0.0 - RAI core modules - Bartłomiej Boczek - Maciej Majek - Apache-2.0 - - ament_copyright - ament_flake8 - ament_pep257 - python3-pytest - - nav2_msgs - nav2_simple_commander - tf_transformations - - - ament_python - - diff --git a/src/rai/resource/rai b/src/rai/resource/rai deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/rai/setup.cfg b/src/rai/setup.cfg deleted file mode 100644 index c398a856f..000000000 --- a/src/rai/setup.cfg +++ /dev/null @@ -1,4 +0,0 @@ -[develop] -script_dir=$base/lib/rai -[install] -install_scripts=$base/lib/rai diff --git a/src/rai/setup.py b/src/rai/setup.py deleted file mode 100644 index 9ed4bf5df..000000000 --- a/src/rai/setup.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from setuptools import find_packages, setup - -package_name = "rai" - -setup( - name=package_name, - version="1.0.0", - packages=find_packages(exclude=["test"]), - data_files=[ - ("share/ament_index/resource_index/packages", ["resource/" + package_name]), - ("share/" + package_name, ["package.xml"]), - ], - install_requires=["setuptools"], - zip_safe=True, - maintainer="Bartłomiej Boczek", - maintainer_email="bartlomiej.boczek@robotec.ai", - description="TODO: Package description", - license="Apache-2.0", - tests_require=["pytest"], - entry_points={}, -) diff --git a/src/rai_core/pyproject.toml b/src/rai_core/pyproject.toml new file mode 100644 index 000000000..abec6c940 --- /dev/null +++ b/src/rai_core/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "rai" +version = "1.0.0" +description = "Core functionality for RAI framework" +authors = ["Maciej Majek ", "Bartłomiej Boczek ", "Kajetan Rachwał "] +readme = "README.md" +classifiers = [ + "Programming Language :: Python :: 3", + "Development Status :: 4 - Beta", + "License :: OSI Approved :: Apache Software License", +] +packages = [ + { include = "rai", from = "." }, +] + +[tool.poetry.dependencies] +python = "^3.10, <3.13" +langchain-core = "^0.3" +langchain = "*" +langgraph = "*" +requests = "^2.32.2" +coloredlogs = "^15.0.1" +markdown = "^3.6" +tqdm = "^4.66.4" +rich = "^13.7.1" +deprecated = "^1.2.14" +tomli = "^2.0.1" +tomli-w = "^1.1.0" + +[tool.isort] +profile = "black" \ No newline at end of file diff --git a/src/rai/rai/__init__.py b/src/rai_core/rai/__init__.py similarity index 100% rename from src/rai/rai/__init__.py rename to src/rai_core/rai/__init__.py diff --git a/src/rai/rai/agents/__init__.py b/src/rai_core/rai/agents/__init__.py similarity index 100% rename from src/rai/rai/agents/__init__.py rename to src/rai_core/rai/agents/__init__.py diff --git a/src/rai/rai/agents/base.py b/src/rai_core/rai/agents/base.py similarity index 100% rename from src/rai/rai/agents/base.py rename to src/rai_core/rai/agents/base.py diff --git a/src/rai/rai/agents/conversational_agent.py b/src/rai_core/rai/agents/conversational_agent.py similarity index 100% rename from src/rai/rai/agents/conversational_agent.py rename to src/rai_core/rai/agents/conversational_agent.py diff --git a/src/rai/rai/agents/integrations/__init__.py b/src/rai_core/rai/agents/integrations/__init__.py similarity index 100% rename from src/rai/rai/agents/integrations/__init__.py rename to src/rai_core/rai/agents/integrations/__init__.py diff --git a/src/rai/rai/agents/integrations/streamlit.py b/src/rai_core/rai/agents/integrations/streamlit.py similarity index 100% rename from src/rai/rai/agents/integrations/streamlit.py rename to src/rai_core/rai/agents/integrations/streamlit.py diff --git a/src/rai/rai/agents/state_based.py b/src/rai_core/rai/agents/state_based.py similarity index 100% rename from src/rai/rai/agents/state_based.py rename to src/rai_core/rai/agents/state_based.py diff --git a/src/rai/rai/agents/tool_runner.py b/src/rai_core/rai/agents/tool_runner.py similarity index 100% rename from src/rai/rai/agents/tool_runner.py rename to src/rai_core/rai/agents/tool_runner.py diff --git a/src/rai/rai/agents/voice_agent.py b/src/rai_core/rai/agents/voice_agent.py similarity index 100% rename from src/rai/rai/agents/voice_agent.py rename to src/rai_core/rai/agents/voice_agent.py diff --git a/src/rai/rai/apps/__init__.py b/src/rai_core/rai/apps/__init__.py similarity index 100% rename from src/rai/rai/apps/__init__.py rename to src/rai_core/rai/apps/__init__.py diff --git a/src/rai/rai/apps/document_loader.py b/src/rai_core/rai/apps/document_loader.py similarity index 100% rename from src/rai/rai/apps/document_loader.py rename to src/rai_core/rai/apps/document_loader.py diff --git a/src/rai/rai/apps/high_level_api.py b/src/rai_core/rai/apps/high_level_api.py similarity index 100% rename from src/rai/rai/apps/high_level_api.py rename to src/rai_core/rai/apps/high_level_api.py diff --git a/src/rai/rai/apps/state_analyzer.py b/src/rai_core/rai/apps/state_analyzer.py similarity index 100% rename from src/rai/rai/apps/state_analyzer.py rename to src/rai_core/rai/apps/state_analyzer.py diff --git a/src/rai/rai/apps/talk_to_docs.py b/src/rai_core/rai/apps/talk_to_docs.py similarity index 100% rename from src/rai/rai/apps/talk_to_docs.py rename to src/rai_core/rai/apps/talk_to_docs.py diff --git a/src/rai/rai/apps/task_executor.py b/src/rai_core/rai/apps/task_executor.py similarity index 100% rename from src/rai/rai/apps/task_executor.py rename to src/rai_core/rai/apps/task_executor.py diff --git a/src/rai/rai/apps/task_planner.py b/src/rai_core/rai/apps/task_planner.py similarity index 100% rename from src/rai/rai/apps/task_planner.py rename to src/rai_core/rai/apps/task_planner.py diff --git a/src/rai/rai/cli/__init__.py b/src/rai_core/rai/cli/__init__.py similarity index 100% rename from src/rai/rai/cli/__init__.py rename to src/rai_core/rai/cli/__init__.py diff --git a/src/rai/rai/cli/rai_cli.py b/src/rai_core/rai/cli/rai_cli.py similarity index 100% rename from src/rai/rai/cli/rai_cli.py rename to src/rai_core/rai/cli/rai_cli.py diff --git a/src/rai/rai/cli/resources/default_robot_constitution.txt b/src/rai_core/rai/cli/resources/default_robot_constitution.txt similarity index 100% rename from src/rai/rai/cli/resources/default_robot_constitution.txt rename to src/rai_core/rai/cli/resources/default_robot_constitution.txt diff --git a/src/rai/rai/communication/__init__.py b/src/rai_core/rai/communication/__init__.py similarity index 100% rename from src/rai/rai/communication/__init__.py rename to src/rai_core/rai/communication/__init__.py diff --git a/src/rai/rai/communication/ari_connector.py b/src/rai_core/rai/communication/ari_connector.py similarity index 100% rename from src/rai/rai/communication/ari_connector.py rename to src/rai_core/rai/communication/ari_connector.py diff --git a/src/rai/rai/communication/base_connector.py b/src/rai_core/rai/communication/base_connector.py similarity index 100% rename from src/rai/rai/communication/base_connector.py rename to src/rai_core/rai/communication/base_connector.py diff --git a/src/rai/rai/communication/hri_connector.py b/src/rai_core/rai/communication/hri_connector.py similarity index 100% rename from src/rai/rai/communication/hri_connector.py rename to src/rai_core/rai/communication/hri_connector.py diff --git a/src/rai/rai/communication/ros2/__init__.py b/src/rai_core/rai/communication/ros2/__init__.py similarity index 100% rename from src/rai/rai/communication/ros2/__init__.py rename to src/rai_core/rai/communication/ros2/__init__.py diff --git a/src/rai/rai/communication/ros2/api.py b/src/rai_core/rai/communication/ros2/api.py similarity index 100% rename from src/rai/rai/communication/ros2/api.py rename to src/rai_core/rai/communication/ros2/api.py diff --git a/src/rai/rai/communication/ros2/connectors.py b/src/rai_core/rai/communication/ros2/connectors.py similarity index 100% rename from src/rai/rai/communication/ros2/connectors.py rename to src/rai_core/rai/communication/ros2/connectors.py diff --git a/src/rai/rai/communication/sound_device/__init__.py b/src/rai_core/rai/communication/sound_device/__init__.py similarity index 100% rename from src/rai/rai/communication/sound_device/__init__.py rename to src/rai_core/rai/communication/sound_device/__init__.py diff --git a/src/rai/rai/communication/sound_device/api.py b/src/rai_core/rai/communication/sound_device/api.py similarity index 100% rename from src/rai/rai/communication/sound_device/api.py rename to src/rai_core/rai/communication/sound_device/api.py diff --git a/src/rai/rai/communication/sound_device/connector.py b/src/rai_core/rai/communication/sound_device/connector.py similarity index 100% rename from src/rai/rai/communication/sound_device/connector.py rename to src/rai_core/rai/communication/sound_device/connector.py diff --git a/src/rai/rai/config/__init__.py b/src/rai_core/rai/config/__init__.py similarity index 100% rename from src/rai/rai/config/__init__.py rename to src/rai_core/rai/config/__init__.py diff --git a/src/rai/rai/config/models.py b/src/rai_core/rai/config/models.py similarity index 100% rename from src/rai/rai/config/models.py rename to src/rai_core/rai/config/models.py diff --git a/src/rai/rai/extensions/__init__.py b/src/rai_core/rai/extensions/__init__.py similarity index 100% rename from src/rai/rai/extensions/__init__.py rename to src/rai_core/rai/extensions/__init__.py diff --git a/src/rai/rai/messages/__init__.py b/src/rai_core/rai/messages/__init__.py similarity index 100% rename from src/rai/rai/messages/__init__.py rename to src/rai_core/rai/messages/__init__.py diff --git a/src/rai/rai/messages/multimodal.py b/src/rai_core/rai/messages/multimodal.py similarity index 100% rename from src/rai/rai/messages/multimodal.py rename to src/rai_core/rai/messages/multimodal.py diff --git a/src/rai/rai/messages/utils.py b/src/rai_core/rai/messages/utils.py similarity index 100% rename from src/rai/rai/messages/utils.py rename to src/rai_core/rai/messages/utils.py diff --git a/src/rai/rai/node.py b/src/rai_core/rai/node.py similarity index 100% rename from src/rai/rai/node.py rename to src/rai_core/rai/node.py diff --git a/src/rai/rai/ros2_apis.py b/src/rai_core/rai/ros2_apis.py similarity index 100% rename from src/rai/rai/ros2_apis.py rename to src/rai_core/rai/ros2_apis.py diff --git a/src/rai/rai/tools/__init__.py b/src/rai_core/rai/tools/__init__.py similarity index 100% rename from src/rai/rai/tools/__init__.py rename to src/rai_core/rai/tools/__init__.py diff --git a/src/rai/rai/tools/debugging_assistant.py b/src/rai_core/rai/tools/debugging_assistant.py similarity index 100% rename from src/rai/rai/tools/debugging_assistant.py rename to src/rai_core/rai/tools/debugging_assistant.py diff --git a/src/rai/rai/tools/ros/__init__.py b/src/rai_core/rai/tools/ros/__init__.py similarity index 100% rename from src/rai/rai/tools/ros/__init__.py rename to src/rai_core/rai/tools/ros/__init__.py diff --git a/src/rai/rai/tools/ros/cli.py b/src/rai_core/rai/tools/ros/cli.py similarity index 100% rename from src/rai/rai/tools/ros/cli.py rename to src/rai_core/rai/tools/ros/cli.py diff --git a/src/rai/rai/tools/ros/deprecated.py b/src/rai_core/rai/tools/ros/deprecated.py similarity index 100% rename from src/rai/rai/tools/ros/deprecated.py rename to src/rai_core/rai/tools/ros/deprecated.py diff --git a/src/rai/rai/tools/ros/manipulation.py b/src/rai_core/rai/tools/ros/manipulation.py similarity index 100% rename from src/rai/rai/tools/ros/manipulation.py rename to src/rai_core/rai/tools/ros/manipulation.py diff --git a/src/rai/rai/tools/ros/native.py b/src/rai_core/rai/tools/ros/native.py similarity index 100% rename from src/rai/rai/tools/ros/native.py rename to src/rai_core/rai/tools/ros/native.py diff --git a/src/rai/rai/tools/ros/native_actions.py b/src/rai_core/rai/tools/ros/native_actions.py similarity index 100% rename from src/rai/rai/tools/ros/native_actions.py rename to src/rai_core/rai/tools/ros/native_actions.py diff --git a/src/rai/rai/tools/ros/nav2/__init__.py b/src/rai_core/rai/tools/ros/nav2/__init__.py similarity index 100% rename from src/rai/rai/tools/ros/nav2/__init__.py rename to src/rai_core/rai/tools/ros/nav2/__init__.py diff --git a/src/rai/rai/tools/ros/nav2/basic_navigator.py b/src/rai_core/rai/tools/ros/nav2/basic_navigator.py similarity index 100% rename from src/rai/rai/tools/ros/nav2/basic_navigator.py rename to src/rai_core/rai/tools/ros/nav2/basic_navigator.py diff --git a/src/rai/rai/tools/ros/nav2/navigator.py b/src/rai_core/rai/tools/ros/nav2/navigator.py similarity index 100% rename from src/rai/rai/tools/ros/nav2/navigator.py rename to src/rai_core/rai/tools/ros/nav2/navigator.py diff --git a/src/rai/rai/tools/ros/tools.py b/src/rai_core/rai/tools/ros/tools.py similarity index 100% rename from src/rai/rai/tools/ros/tools.py rename to src/rai_core/rai/tools/ros/tools.py diff --git a/src/rai/rai/tools/ros/utils.py b/src/rai_core/rai/tools/ros/utils.py similarity index 100% rename from src/rai/rai/tools/ros/utils.py rename to src/rai_core/rai/tools/ros/utils.py diff --git a/src/rai/rai/tools/ros2/__init__.py b/src/rai_core/rai/tools/ros2/__init__.py similarity index 100% rename from src/rai/rai/tools/ros2/__init__.py rename to src/rai_core/rai/tools/ros2/__init__.py diff --git a/src/rai/rai/tools/ros2/actions.py b/src/rai_core/rai/tools/ros2/actions.py similarity index 100% rename from src/rai/rai/tools/ros2/actions.py rename to src/rai_core/rai/tools/ros2/actions.py diff --git a/src/rai/rai/tools/ros2/services.py b/src/rai_core/rai/tools/ros2/services.py similarity index 100% rename from src/rai/rai/tools/ros2/services.py rename to src/rai_core/rai/tools/ros2/services.py diff --git a/src/rai/rai/tools/ros2/topics.py b/src/rai_core/rai/tools/ros2/topics.py similarity index 100% rename from src/rai/rai/tools/ros2/topics.py rename to src/rai_core/rai/tools/ros2/topics.py diff --git a/src/rai/rai/tools/ros2/utils.py b/src/rai_core/rai/tools/ros2/utils.py similarity index 100% rename from src/rai/rai/tools/ros2/utils.py rename to src/rai_core/rai/tools/ros2/utils.py diff --git a/src/rai/rai/tools/time.py b/src/rai_core/rai/tools/time.py similarity index 100% rename from src/rai/rai/tools/time.py rename to src/rai_core/rai/tools/time.py diff --git a/src/rai/rai/tools/utils.py b/src/rai_core/rai/tools/utils.py similarity index 100% rename from src/rai/rai/tools/utils.py rename to src/rai_core/rai/tools/utils.py diff --git a/src/rai/rai/utils/__init__.py b/src/rai_core/rai/utils/__init__.py similarity index 100% rename from src/rai/rai/utils/__init__.py rename to src/rai_core/rai/utils/__init__.py diff --git a/src/rai/rai/utils/artifacts.py b/src/rai_core/rai/utils/artifacts.py similarity index 100% rename from src/rai/rai/utils/artifacts.py rename to src/rai_core/rai/utils/artifacts.py diff --git a/src/rai/rai/utils/configurator.py b/src/rai_core/rai/utils/configurator.py similarity index 100% rename from src/rai/rai/utils/configurator.py rename to src/rai_core/rai/utils/configurator.py diff --git a/src/rai/rai/utils/model_initialization.py b/src/rai_core/rai/utils/model_initialization.py similarity index 100% rename from src/rai/rai/utils/model_initialization.py rename to src/rai_core/rai/utils/model_initialization.py diff --git a/src/rai/rai/utils/ros.py b/src/rai_core/rai/utils/ros.py similarity index 100% rename from src/rai/rai/utils/ros.py rename to src/rai_core/rai/utils/ros.py diff --git a/src/rai/rai/utils/ros_async.py b/src/rai_core/rai/utils/ros_async.py similarity index 100% rename from src/rai/rai/utils/ros_async.py rename to src/rai_core/rai/utils/ros_async.py diff --git a/src/rai/rai/utils/ros_executors.py b/src/rai_core/rai/utils/ros_executors.py similarity index 100% rename from src/rai/rai/utils/ros_executors.py rename to src/rai_core/rai/utils/ros_executors.py diff --git a/src/rai/rai/utils/ros_logs.py b/src/rai_core/rai/utils/ros_logs.py similarity index 100% rename from src/rai/rai/utils/ros_logs.py rename to src/rai_core/rai/utils/ros_logs.py diff --git a/src/rai_interfaces/package.xml b/src/rai_interfaces/package.xml index 6627000a0..20e53abf8 100644 --- a/src/rai_interfaces/package.xml +++ b/src/rai_interfaces/package.xml @@ -18,6 +18,9 @@ rosidl_default_generators ament_cmake + nav2_msgs + nav2_simple_commander + tf_transformations rosidl_default_runtime rosidl_interface_packages From 413474e3a8e60e736012678a05a1fa1515b90eb6 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Thu, 6 Feb 2025 19:34:15 +0100 Subject: [PATCH 2/8] refactor: rai_asr as a python package --- pyproject.toml | 2 + src/rai_asr/launch/local.launch.py | 95 ------- src/rai_asr/launch/openai.launch.py | 95 ------- src/rai_asr/package.xml | 16 -- src/rai_asr/pyproject.toml | 30 +++ src/rai_asr/rai_asr/asr_node.py | 402 ---------------------------- src/rai_asr/resource/rai_asr | 0 src/rai_asr/setup.cfg | 4 - src/rai_asr/setup.py | 44 --- src/rai_interfaces/package.xml | 1 + 10 files changed, 33 insertions(+), 656 deletions(-) delete mode 100644 src/rai_asr/launch/local.launch.py delete mode 100644 src/rai_asr/launch/openai.launch.py delete mode 100644 src/rai_asr/package.xml create mode 100644 src/rai_asr/pyproject.toml delete mode 100755 src/rai_asr/rai_asr/asr_node.py delete mode 100644 src/rai_asr/resource/rai_asr delete mode 100644 src/rai_asr/setup.cfg delete mode 100644 src/rai_asr/setup.py diff --git a/pyproject.toml b/pyproject.toml index 5d8fb1355..88654cd63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,9 @@ classifiers = [ packages = [ { include = "rai_core", from = "src" }, + { include = "rai_asr", from = "src" }, ] + [tool.poetry.dependencies] python = "^3.10, <3.13" langchain-core = "^0.3" diff --git a/src/rai_asr/launch/local.launch.py b/src/rai_asr/launch/local.launch.py deleted file mode 100644 index 7bfa69ca0..000000000 --- a/src/rai_asr/launch/local.launch.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from launch import LaunchDescription -from launch.actions import DeclareLaunchArgument -from launch.substitutions import LaunchConfiguration -from launch_ros.actions import Node - - -def generate_launch_description(): - return LaunchDescription( - [ - DeclareLaunchArgument( - "recording_device", - default_value="0", - description="Microphone device number. See available by running python -c 'import sounddevice as sd; print(sd.query_devices())'", - ), - DeclareLaunchArgument( - "language", - default_value="en", - description="Language code for the ASR model", - ), - DeclareLaunchArgument( - "model_name", - default_value="base", - description="Model name for the ASR model", - ), - DeclareLaunchArgument( - "model_vendor", - default_value="whisper", - description="Model vendor of the ASR", - ), - DeclareLaunchArgument( - "silence_grace_period", - default_value="1.0", - description="Grace period in seconds after silence to stop recording", - ), - DeclareLaunchArgument( - "use_wake_word", - default_value="False", - description="Whether to use wake word detection", - ), - DeclareLaunchArgument( - "wake_word_model", - default_value="", - description="Wake word model to use", - ), - DeclareLaunchArgument( - "wake_word_threshold", - default_value="0.5", - description="Threshold for wake word detection", - ), - DeclareLaunchArgument( - "vad_threshold", - default_value="0.5", - description="Threshold for voice activity detection", - ), - Node( - package="rai_asr", - executable="asr_node", - name="rai_asr", - output="screen", - emulate_tty=True, - parameters=[ - { - "recording_device": LaunchConfiguration("recording_device"), - "language": LaunchConfiguration("language"), - "model_name": LaunchConfiguration("model_name"), - "model_vendor": LaunchConfiguration("model_vendor"), - "silence_grace_period": LaunchConfiguration( - "silence_grace_period" - ), - "use_wake_word": LaunchConfiguration("use_wake_word"), - "wake_word_model": LaunchConfiguration("wake_word_model"), - "wake_word_threshold": LaunchConfiguration( - "wake_word_threshold" - ), - "vad_threshold": LaunchConfiguration("vad_threshold"), - } - ], - ), - ] - ) diff --git a/src/rai_asr/launch/openai.launch.py b/src/rai_asr/launch/openai.launch.py deleted file mode 100644 index d7b4524ea..000000000 --- a/src/rai_asr/launch/openai.launch.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from launch import LaunchDescription -from launch.actions import DeclareLaunchArgument -from launch.substitutions import LaunchConfiguration -from launch_ros.actions import Node - - -def generate_launch_description(): - return LaunchDescription( - [ - DeclareLaunchArgument( - "recording_device", - default_value="0", - description="Microphone device number. See available by running python -c 'import sounddevice as sd; print(sd.query_devices())'", - ), - DeclareLaunchArgument( - "language", - default_value="en", - description="Language code for the ASR model", - ), - DeclareLaunchArgument( - "model_name", - default_value="whisper-1", - description="Model name for the ASR model", - ), - DeclareLaunchArgument( - "model_vendor", - default_value="openai", - description="Model vendor of the ASR", - ), - DeclareLaunchArgument( - "silence_grace_period", - default_value="1.0", - description="Grace period in seconds after silence to stop recording", - ), - DeclareLaunchArgument( - "use_wake_word", - default_value="False", - description="Whether to use wake word detection", - ), - DeclareLaunchArgument( - "wake_word_model", - default_value="", - description="Wake word model to use", - ), - DeclareLaunchArgument( - "wake_word_threshold", - default_value="0.5", - description="Threshold for wake word detection", - ), - DeclareLaunchArgument( - "vad_threshold", - default_value="0.5", - description="Threshold for voice activity detection", - ), - Node( - package="rai_asr", - executable="asr_node", - name="rai_asr", - output="screen", - emulate_tty=True, - parameters=[ - { - "recording_device": LaunchConfiguration("recording_device"), - "language": LaunchConfiguration("language"), - "model_name": LaunchConfiguration("model_name"), - "model_vendor": LaunchConfiguration("model_vendor"), - "silence_grace_period": LaunchConfiguration( - "silence_grace_period" - ), - "use_wake_word": LaunchConfiguration("use_wake_word"), - "wake_word_model": LaunchConfiguration("wake_word_model"), - "wake_word_threshold": LaunchConfiguration( - "wake_word_threshold" - ), - "vad_threshold": LaunchConfiguration("vad_threshold"), - } - ], - ), - ] - ) diff --git a/src/rai_asr/package.xml b/src/rai_asr/package.xml deleted file mode 100644 index 56547ae81..000000000 --- a/src/rai_asr/package.xml +++ /dev/null @@ -1,16 +0,0 @@ - - - - rai_asr - 0.1.0 - An Automatic Speech Recognition package, leveraging Whisper for transcription and - Silero VAD for voice activity detection. This node captures audio, detects speech, and - transcribes the spoken content, publishing the transcription to a topic. - mkotynia - Apache-2.0 - - portaudio19-dev - - ament_python - - diff --git a/src/rai_asr/pyproject.toml b/src/rai_asr/pyproject.toml new file mode 100644 index 000000000..ca4524d6b --- /dev/null +++ b/src/rai_asr/pyproject.toml @@ -0,0 +1,30 @@ +[tool.poetry] +name = "rai_asr" +version = "1.0.0" +description = "Automatic Speech Recognition module for RAI framework" +authors = ["Maciej Majek ", "Bartłomiej Boczek ", "Kajetan Rachwał "] +readme = "README.md" +classifiers = [ + "Programming Language :: Python :: 3", + "Development Status :: 4 - Beta", + "License :: OSI Approved :: Apache Software License", +] +packages = [ + { include = "rai_asr", from = "." }, +] + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry.dependencies] +python = "^3.10, <3.13" +sounddevice = "^0.4.7" +openai-whisper = "^20231117" +scipy = "^1.14.0" +torchaudio = "^2.3.1" +faster-whisper = "^1.1.1" +pydub = "^0.25.1" + +[tool.isort] +profile = "black" \ No newline at end of file diff --git a/src/rai_asr/rai_asr/asr_node.py b/src/rai_asr/rai_asr/asr_node.py deleted file mode 100755 index 7dac9e268..000000000 --- a/src/rai_asr/rai_asr/asr_node.py +++ /dev/null @@ -1,402 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os -import time -from typing import Literal, Optional, cast - -import numpy as np -import rclpy -import sounddevice as sd -import torch -from numpy.typing import NDArray -from openwakeword.model import Model as OWWModel -from openwakeword.utils import download_models -from rcl_interfaces.msg import ParameterDescriptor, ParameterType -from rclpy.callback_groups import ReentrantCallbackGroup -from rclpy.executors import SingleThreadedExecutor -from rclpy.node import Node -from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile, ReliabilityPolicy -from scipy.signal import resample -from std_msgs.msg import String - -VAD_SAMPLING_RATE = 16000 # default value used by silero vad -DEFAULT_BLOCKSIZE = 1280 - - -class ASRNode(Node): - def __init__(self): - super().__init__("rai_asr") # type: ignore - self._declare_parameters() - self._initialize_parameters() - self._setup_node_components() - self._setup_publishers_and_subscribers() - - self.asr_model = self._initialize_asr_model() - self.vad_model = self._initialize_vad_model() - self.oww_model = self._initialize_open_wake_word() - - self.initialize_sounddevice_stream() - - self.is_recording = False - self.audio_buffer = [] - self.silence_start_time: Optional[float] = None - self.last_transcription_time = 0 - self.hmi_lock = False - self.tts_lock = False - - self.current_chunk: Optional[NDArray[np.int16]] = None - - self.transcription_recording_timeout = 1 - self.get_logger().info("ASR Node has been initialized") # type: ignore - - def _declare_parameters(self): - self.declare_parameter( - "use_wake_word", - False, - descriptor=ParameterDescriptor( - type=ParameterType.PARAMETER_BOOL, - description=("Whether to use wake word for starting conversation"), - ), - ) - self.declare_parameter( - "wake_word_model", - "", - descriptor=ParameterDescriptor( - type=ParameterType.PARAMETER_STRING, - description=("Wake word model onnx file"), - ), - ) - self.declare_parameter( - "wake_word_threshold", - 0.1, - descriptor=ParameterDescriptor( - type=ParameterType.PARAMETER_DOUBLE, - description=("Wake word threshold"), - ), - ) - self.declare_parameter( - "vad_threshold", - 0.5, - descriptor=ParameterDescriptor( - type=ParameterType.PARAMETER_DOUBLE, - description=("VAD threshold"), - ), - ) - self.declare_parameter( - "recording_device", - 0, - descriptor=ParameterDescriptor( - type=ParameterType.PARAMETER_INTEGER, - description=( - "Recording device number. See available by running" - "python -c 'import sounddevice as sd; print(sd.query_devices())'" - ), - ), - ) - self.declare_parameter( - "model_vendor", - "whisper", # openai, whisper - ParameterDescriptor( - type=ParameterType.PARAMETER_STRING, - description="Vendor of the ASR model", - ), - ) - self.declare_parameter( - "language", - "en", - ParameterDescriptor( - type=ParameterType.PARAMETER_STRING, - description="Language code for the ASR model", - ), - ) - self.declare_parameter( - "model_name", - "base", - ParameterDescriptor( - type=ParameterType.PARAMETER_STRING, - description="Model type for the ASR model", - ), - ) - self.declare_parameter( - "silence_grace_period", - 1.0, - ParameterDescriptor( - type=ParameterType.PARAMETER_DOUBLE, - description="Grace period in seconds after silence to stop recording", - ), - ) - - def _initialize_open_wake_word(self) -> Optional[OWWModel]: - if self.use_wake_word: - download_models() - oww_model = OWWModel( - wakeword_models=[ - self.wake_word_model, - ], - inference_framework="onnx", - ) - self.get_logger().info("Wake word model has been initialized") # type: ignore - return oww_model - return None - - def _initialize_vad_model(self): - model, _ = torch.hub.load( - repo_or_dir="snakers4/silero-vad", - model="silero_vad", - ) - return model - - def _setup_node_components(self): - self.callback_group = ReentrantCallbackGroup() - - def _initialize_parameters(self): - self.silence_grace_period = cast( - float, - self.get_parameter("silence_grace_period") - .get_parameter_value() - .double_value, - ) - self.vad_threshold = cast( - float, - self.get_parameter("vad_threshold").get_parameter_value().double_value, - ) # type: ignore - self.model_name = ( - self.get_parameter("model_name").get_parameter_value().string_value - ) # type: ignore - self.model_vendor = ( - self.get_parameter("model_vendor").get_parameter_value().string_value - ) # type: ignore - self.language = ( - self.get_parameter("language").get_parameter_value().string_value - ) # type: ignore - - self.use_wake_word = cast( - bool, - self.get_parameter("use_wake_word").get_parameter_value().bool_value, - ) - self.wake_word_model = cast( - str, - self.get_parameter("wake_word_model").get_parameter_value().string_value, - ) - self.wake_word_threshold = cast( - float, - self.get_parameter("wake_word_threshold") - .get_parameter_value() - .double_value, - ) - self.recording_device_number = cast( - int, - self.get_parameter("recording_device").get_parameter_value().integer_value, - ) - - if self.use_wake_word: - if not os.path.exists(self.wake_word_model): - raise FileNotFoundError(f"Model file {self.wake_word_model} not found") - - self.get_logger().info("Parameters have been initialized") # type: ignore - - def _setup_publishers_and_subscribers(self): - reliable_qos = QoSProfile( - reliability=ReliabilityPolicy.RELIABLE, - durability=DurabilityPolicy.TRANSIENT_LOCAL, - history=HistoryPolicy.KEEP_ALL, - ) - self.transcription_publisher = self.create_publisher( - String, "/from_human", qos_profile=reliable_qos - ) - self.status_publisher = self.create_publisher(String, "/asr_status", 10) - self.tts_status_subscriber = self.create_subscription( - String, - "/tts_status", - self.tts_status_callback, - 10, - callback_group=self.callback_group, - ) - self.hmi_status_subscriber = self.create_subscription( - String, - "/hmi_status", - self.hmi_status_callback, - 10, - callback_group=self.callback_group, - ) - - def _initialize_asr_model(self): - if self.model_vendor == "openai": - from rai_asr.asr_clients import OpenAIWhisper - - self.model = OpenAIWhisper( - self.model_name, VAD_SAMPLING_RATE, self.language - ) - elif self.model_vendor == "whisper": - from rai_asr.asr_clients import LocalWhisper - - self.model = LocalWhisper(self.model_name, VAD_SAMPLING_RATE, self.language) - else: - raise ValueError(f"Unknown model vendor: {self.model_vendor}") - - def tts_status_callback(self, msg: String): - if msg.data == "processing": - self.tts_lock = True - elif msg.data == "waiting": - self.tts_lock = False - - def hmi_status_callback(self, msg: String): - if msg.data == "processing": - self.hmi_lock = True - elif msg.data == "waiting": - self.hmi_lock = False - - def should_listen(self, audio_data: NDArray[np.int16]) -> bool: - def int2float(sound: NDArray[np.int16]): - abs_max = np.abs(sound).max() - sound = sound.astype("float32") - if abs_max > 0: - sound *= 1 / 32768 - sound = sound.squeeze() - return sound - - vad_confidence = self.vad_model( - torch.tensor(int2float(audio_data[-512:])), VAD_SAMPLING_RATE - ).item() - - if self.oww_model: - if self.is_recording: - self.get_logger().debug(f"VAD confidence: {vad_confidence}") # type: ignore - return vad_confidence > self.vad_threshold - else: - predictions = self.oww_model.predict(audio_data) - for key, value in predictions.items(): - if value > self.wake_word_threshold: - self.get_logger().debug(f"Detected wake word: {key}") # type: ignore - self.oww_model.reset() - return True - else: - return vad_confidence > self.vad_threshold - - return False - - def sd_callback(self, indata, frames, _, status): - if status: - self.get_logger().warning(f"Stream status: {status}") # type: ignore - indata = indata.flatten() - sample_time_length = len(indata) / self.device_sample_rate - if self.device_sample_rate != VAD_SAMPLING_RATE: - indata = resample(indata, int(sample_time_length * VAD_SAMPLING_RATE)) - - asr_lock = ( - time.time() - < self.last_transcription_time + self.transcription_recording_timeout - ) - if asr_lock or self.hmi_lock or self.tts_lock: - return - - if not self.is_recording: # keep last 5 indata of audio ~ 400ms - self.audio_buffer.append(indata) - if len(self.audio_buffer) > 5: - self.audio_buffer.pop(0) - - if self.should_listen(indata): - self.silence_start_time = time.time() - if not self.is_recording: - self.start_recording() - self.audio_buffer.append(indata) - elif self.is_recording: - self.audio_buffer.append(indata) - if not isinstance(self.silence_start_time, float): - raise ValueError( - "Silence start time is not set, this should not happen" - ) - if time.time() - self.silence_start_time > self.silence_grace_period: - self.stop_recording_and_transcribe() - - def initialize_sounddevice_stream(self): - sd.default.latency = ("low", "low") - self.device_sample_rate = sd.query_devices( - device=self.recording_device_number, kind="input" - )[ - "default_samplerate" - ] # type: ignore - self.window_size_samples = int( - DEFAULT_BLOCKSIZE * self.device_sample_rate / VAD_SAMPLING_RATE - ) - self.stream = sd.InputStream( - samplerate=self.device_sample_rate, - channels=1, - device=self.recording_device_number, - dtype="int16", - blocksize=self.window_size_samples, - callback=self.sd_callback, - ) - self.stream.start() - - def reset_buffer(self): - self.audio_buffer.clear() - - def start_recording(self): - self.get_logger().info("Recording...") # type: ignore - self.publish_status("recording") - self.is_recording = True - - def stop_recording_and_transcribe(self): - self.get_logger().info("Stopped recording. Transcribing...") # type: ignore - self.is_recording = False - self.publish_status("transcribing") - self.transcribe_audio() - self.publish_status("waiting") - self.get_logger().info("Done transcribing.") # type: ignore - - def transcribe_audio(self): - combined_audio = np.concatenate(self.audio_buffer) - self.reset_buffer() # consume the buffer, so we don't transcribe the same audio twice - - transcription = self.model(data=combined_audio) - - if transcription.lower() in ["you", ""]: - self.get_logger().info(f"Dropping transcription: '{transcription}'") - self.publish_status("dropping") - else: - self.get_logger().info(f"Transcription: {transcription}") - self.publish_transcription(transcription) - - self.last_transcription_time = time.time() - - def publish_transcription(self, transcription: str): - msg = String() - msg.data = transcription - self.transcription_publisher.publish(msg) - - def publish_status( - self, status: Literal["recording", "transcribing", "dropping", "waiting"] - ): - msg = String() - msg.data = status - self.status_publisher.publish(msg) - - -def main(args=None): - rclpy.init(args=args) - node = ASRNode() - executor = SingleThreadedExecutor() - executor.add_node(node) - - try: - executor.spin() - except KeyboardInterrupt: - pass - finally: - executor.shutdown() - node.destroy_node() - rclpy.shutdown() diff --git a/src/rai_asr/resource/rai_asr b/src/rai_asr/resource/rai_asr deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/rai_asr/setup.cfg b/src/rai_asr/setup.cfg deleted file mode 100644 index 8499f73be..000000000 --- a/src/rai_asr/setup.cfg +++ /dev/null @@ -1,4 +0,0 @@ -[develop] -script_dir=$base/lib/rai_asr -[install] -install_scripts=$base/lib/rai_asr diff --git a/src/rai_asr/setup.py b/src/rai_asr/setup.py deleted file mode 100644 index c85aa620a..000000000 --- a/src/rai_asr/setup.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os -from glob import glob - -from setuptools import find_packages, setup - -package_name = "rai_asr" - -setup( - name=package_name, - version="0.0.0", - packages=find_packages(exclude=["test"]), - data_files=[ - ("share/ament_index/resource_index/packages", ["resource/" + package_name]), - ("share/" + package_name, ["package.xml"]), - (os.path.join("share", package_name, "launch"), glob("launch/*.launch.py")), - ], - install_requires=["setuptools"], - zip_safe=True, - maintainer="mkotynia", - maintainer_email="magdalena.kotynia@robotec.ai", - description="An Automatic Speech Recognition package, leveraging Whisper for transcription and \ - Silero VAD for voice activity detection. This node captures audio, detects speech, and \ - transcribes the spoken content, publishing the transcription to a topic. ", - license="Apache-2.0", - tests_require=["pytest"], - entry_points={ - "console_scripts": ["asr_node = rai_asr.asr_node:main"], - }, -) diff --git a/src/rai_interfaces/package.xml b/src/rai_interfaces/package.xml index 20e53abf8..89f79d86c 100644 --- a/src/rai_interfaces/package.xml +++ b/src/rai_interfaces/package.xml @@ -18,6 +18,7 @@ rosidl_default_generators ament_cmake + portaudio19-dev nav2_msgs nav2_simple_commander tf_transformations From 638a3c86962d9ba90c514b2e2e05800bc6ba84c2 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Thu, 6 Feb 2025 19:42:41 +0100 Subject: [PATCH 3/8] refactor: rai_tts as a python package --- pyproject.toml | 1 + src/rai_tts/config/elevenlabs.yaml | 6 - src/rai_tts/config/opentts.yaml | 6 - src/rai_tts/launch/elevenlabs.launch.py | 63 -------- src/rai_tts/launch/opentts.launch.py | 64 -------- src/rai_tts/package.xml | 16 -- src/rai_tts/pyproject.toml | 26 ++++ src/rai_tts/rai_tts/__init__.py | 4 + src/rai_tts/rai_tts/tts_node.py | 186 ------------------------ src/rai_tts/resource/rai_tts | 0 src/rai_tts/setup.cfg | 4 - src/rai_tts/setup.py | 45 ------ 12 files changed, 31 insertions(+), 390 deletions(-) delete mode 100644 src/rai_tts/config/elevenlabs.yaml delete mode 100644 src/rai_tts/config/opentts.yaml delete mode 100644 src/rai_tts/launch/elevenlabs.launch.py delete mode 100644 src/rai_tts/launch/opentts.launch.py delete mode 100644 src/rai_tts/package.xml create mode 100644 src/rai_tts/pyproject.toml delete mode 100644 src/rai_tts/rai_tts/tts_node.py delete mode 100644 src/rai_tts/resource/rai_tts delete mode 100644 src/rai_tts/setup.cfg delete mode 100644 src/rai_tts/setup.py diff --git a/pyproject.toml b/pyproject.toml index 88654cd63..0705512b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ classifiers = [ packages = [ { include = "rai_core", from = "src" }, { include = "rai_asr", from = "src" }, + { include = "rai_tts", from = "src" }, ] [tool.poetry.dependencies] diff --git a/src/rai_tts/config/elevenlabs.yaml b/src/rai_tts/config/elevenlabs.yaml deleted file mode 100644 index 755f75617..000000000 --- a/src/rai_tts/config/elevenlabs.yaml +++ /dev/null @@ -1,6 +0,0 @@ -tts_node: - ros__parameters: - tts_client: elevenlabs - voice: Jessica - base_url: "" - topic: to_human diff --git a/src/rai_tts/config/opentts.yaml b/src/rai_tts/config/opentts.yaml deleted file mode 100644 index 00c8a1cea..000000000 --- a/src/rai_tts/config/opentts.yaml +++ /dev/null @@ -1,6 +0,0 @@ -tts_node: - ros__parameters: - tts_client: opentts - voice: larynx:blizzard_lessac-glow_tts - base_url: http://localhost:5500/api/tts - topic: to_human diff --git a/src/rai_tts/launch/elevenlabs.launch.py b/src/rai_tts/launch/elevenlabs.launch.py deleted file mode 100644 index c1ec695f2..000000000 --- a/src/rai_tts/launch/elevenlabs.launch.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os - -from ament_index_python.packages import get_package_share_directory -from launch import LaunchDescription -from launch.actions import DeclareLaunchArgument, ExecuteProcess -from launch.conditions import IfCondition -from launch.substitutions import LaunchConfiguration -from launch_ros.actions import Node - - -def generate_launch_description(): - config = os.path.join( - get_package_share_directory("rai_tts"), "config", "elevenlabs.yaml" - ) - launch_configuration = [ - DeclareLaunchArgument( - "config_file", - default_value=config, - description="Path to the config file", - ), - Node( - package="rai_tts", - executable="tts_node", - name="tts_node", - parameters=[LaunchConfiguration("config_file")], - ), - ExecuteProcess( - cmd=[ - "ffplay", - "-f", - "lavfi", - "-i", - "sine=frequency=432", - "-af", - "volume=0.01", - "-nodisp", - "-v", - "0", - ], - name="ffplay_sine_wave", - output="screen", - condition=IfCondition( - LaunchConfiguration("keep_speaker_busy", default=False) - ), - ), - ] - - return LaunchDescription(launch_configuration) diff --git a/src/rai_tts/launch/opentts.launch.py b/src/rai_tts/launch/opentts.launch.py deleted file mode 100644 index a2149bf66..000000000 --- a/src/rai_tts/launch/opentts.launch.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os - -from ament_index_python.packages import get_package_share_directory -from launch import LaunchDescription -from launch.actions import DeclareLaunchArgument, ExecuteProcess -from launch.conditions import IfCondition -from launch.substitutions import LaunchConfiguration -from launch_ros.actions import Node - - -def generate_launch_description(): - config = os.path.join( - get_package_share_directory("rai_tts"), "config", "opentts.yaml" - ) - - return LaunchDescription( - [ - DeclareLaunchArgument( - "config_file", - default_value=config, - description="Path to the config file", - ), - Node( - package="rai_tts", - executable="tts_node", - name="tts_node", - parameters=[LaunchConfiguration("config_file")], - ), - ExecuteProcess( - cmd=[ - "ffplay", - "-f", - "lavfi", - "-i", - "sine=frequency=432", - "-af", - "volume=0.01", - "-nodisp", - "-v", - "0", - ], - name="ffplay_sine_wave", - output="screen", - condition=IfCondition( - LaunchConfiguration("keep_speaker_busy", default=False) - ), - ), - ] - ) diff --git a/src/rai_tts/package.xml b/src/rai_tts/package.xml deleted file mode 100644 index aa2c599c8..000000000 --- a/src/rai_tts/package.xml +++ /dev/null @@ -1,16 +0,0 @@ - - - - rai_tts - 0.1.0 - A Text To Speech package with streaming capabilities. - maciejmajek - Apache-2.0 - - ffmpeg - python3-pytest - - - ament_python - - diff --git a/src/rai_tts/pyproject.toml b/src/rai_tts/pyproject.toml new file mode 100644 index 000000000..ec6ce5d28 --- /dev/null +++ b/src/rai_tts/pyproject.toml @@ -0,0 +1,26 @@ +[tool.poetry] +name = "rai_tts" +version = "1.0.0" +description = "Text-to-Speech module for RAI framework" +authors = ["Maciej Majek ", "Bartłomiej Boczek ", "Kajetan Rachwał "] +readme = "README.md" +classifiers = [ + "Programming Language :: Python :: 3", + "Development Status :: 4 - Beta", + "License :: OSI Approved :: Apache Software License", +] +packages = [ + { include = "rai_tts", from = "." }, +] + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry.dependencies] +python = "^3.10, <3.13" +elevenlabs = "^1.4.1" +sounddevice = "^0.4.7" + +[tool.isort] +profile = "black" \ No newline at end of file diff --git a/src/rai_tts/rai_tts/__init__.py b/src/rai_tts/rai_tts/__init__.py index ef74fc891..2fac6f9fe 100644 --- a/src/rai_tts/rai_tts/__init__.py +++ b/src/rai_tts/rai_tts/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .tts_clients import ElevenLabsClient, OpenTTSClient + +__all__ = ["ElevenLabsClient", "OpenTTSClient"] diff --git a/src/rai_tts/rai_tts/tts_node.py b/src/rai_tts/rai_tts/tts_node.py deleted file mode 100644 index a819e87d8..000000000 --- a/src/rai_tts/rai_tts/tts_node.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import re -import subprocess -import threading -import time -from queue import PriorityQueue -from typing import NamedTuple, cast - -import rclpy -from rclpy.node import Node -from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile, ReliabilityPolicy -from std_msgs.msg import String - -from .tts_clients import ElevenLabsClient, OpenTTSClient, TTSClient - - -class TTSJob(NamedTuple): - id: int - file_path: str - - -class TTSNode(Node): - def __init__(self): - super().__init__("rai_tts_node") - - self.declare_parameter("tts_client", "opentts") - self.declare_parameter("voice", "larynx:blizzard_lessac-glow_tts") - self.declare_parameter("base_url", "http://localhost:5500/api/tts") - self.declare_parameter("topic", "to_human") - - topic_param = self.get_parameter("topic").get_parameter_value().string_value # type: ignore - reliable_qos = QoSProfile( - reliability=ReliabilityPolicy.RELIABLE, - durability=DurabilityPolicy.TRANSIENT_LOCAL, - history=HistoryPolicy.KEEP_ALL, - ) - self.subscription = self.create_subscription( # type: ignore - String, topic_param, self.listener_callback, qos_profile=reliable_qos # type: ignore - ) - self.playing = False - self.status_publisher = self.create_publisher(String, "tts_status", 10) # type: ignore - self.queue: PriorityQueue[TTSJob] = PriorityQueue() - self.it: int = 0 - self.job_id: int = 0 - self.queued_job_id = 0 - self.tts_client = self._initialize_client() - - status_publisher_timer_period_sec = 0.25 - self.create_timer(status_publisher_timer_period_sec, self.status_callback) - threading.Thread(target=self._process_queue).start() - self.get_logger().info("TTS Node has been started") # type: ignore - self.threads_number = 0 - self.threads_max = 5 - self.thread_lock = threading.Lock() - - def status_callback(self): - if self.threads_number == 0 and self.playing is False and self.queue.empty(): - self.status_publisher.publish(String(data="waiting")) - else: - self.status_publisher.publish(String(data="processing")) - - def listener_callback(self, msg: String): - self.playing = True - self.get_logger().info( # type: ignore - f"Registering new TTS job: {self.job_id} length: {len(msg.data)} chars." # type: ignore - ) - self.get_logger().debug(f"The job: {msg.data}") # type: ignore - - threading.Thread( - target=self.start_synthesize_thread, args=(msg, self.job_id) # type: ignore - ).start() - self.job_id += 1 - - def start_synthesize_thread(self, msg: String, job_id: int): - while True: - with self.thread_lock: - if ( - self.threads_number < self.threads_max - and self.queued_job_id == job_id - ): - threading.Thread( - target=self.synthesize_speech, args=(job_id, msg.data) # type: ignore - ).start() - self.threads_number += 1 - self.queued_job_id += 1 - return - - def synthesize_speech( - self, - id: int, - text: str, - ) -> str: - text = self._preprocess_text(text) - if id > 0: - time.sleep(0.5) - temp_file_path = self.tts_client.synthesize_speech_to_file(text) - self.get_logger().info(f"Job {id} completed.") # type: ignore - tts_job = TTSJob(id, temp_file_path) - self.queue.put(tts_job) - with self.thread_lock: - self.threads_number -= 1 - - return temp_file_path - - def _process_queue(self): - while rclpy.ok(): - time.sleep(0.01) - if not self.queue.empty(): - if self.queue.queue[0][0] == self.it: - self.it += 1 - tts_job = self.queue.get() - self.get_logger().info( # type: ignore - f"Playing audio for job {tts_job.id}. {tts_job.file_path}" - ) - self._play_audio(tts_job.file_path) - - def _play_audio(self, filepath: str): - self.playing = True - self.status_publisher.publish(String(data="playing")) - subprocess.run( - ["ffplay", "-v", "0", "-nodisp", "-autoexit", filepath], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - self.playing = False - - def _initialize_client(self) -> TTSClient: - tts_client_param = cast(str, self.get_parameter("tts_client").get_parameter_value().string_value) # type: ignore - voice_param = cast(str, self.get_parameter("voice").get_parameter_value().string_value) # type: ignore - base_url_param = cast(str, self.get_parameter("base_url").get_parameter_value().string_value) # type: ignore - - if tts_client_param == "opentts": - return OpenTTSClient( - base_url=base_url_param, - voice=voice_param, - ) - elif tts_client_param == "elevenlabs": - return ElevenLabsClient( - voice=voice_param, - base_url=base_url_param, - ) - else: - raise ValueError(f"Unknown TTS client: {tts_client_param}") - - def _preprocess_text(self, text: str) -> str: - """Remove emojis from text.""" - emoji_pattern = re.compile( - "[" - "\U0001F600-\U0001F64F" # emoticons - "\U0001F300-\U0001F5FF" # symbols & pictographs - "\U0001F680-\U0001F6FF" # transport & map symbols - "\U0001F1E0-\U0001F1FF" # flags (iOS) - "]+", - flags=re.UNICODE, - ) - text = emoji_pattern.sub(r"", text) - return text - - -def main(): - rclpy.init() - - tts_node = TTSNode() - - rclpy.spin(tts_node) - - tts_node.destroy_node() - rclpy.shutdown() - - -if __name__ == "__main__": - main() diff --git a/src/rai_tts/resource/rai_tts b/src/rai_tts/resource/rai_tts deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/rai_tts/setup.cfg b/src/rai_tts/setup.cfg deleted file mode 100644 index 10e09e3c1..000000000 --- a/src/rai_tts/setup.cfg +++ /dev/null @@ -1,4 +0,0 @@ -[develop] -script_dir=$base/lib/rai_tts -[install] -install_scripts=$base/lib/rai_tts diff --git a/src/rai_tts/setup.py b/src/rai_tts/setup.py deleted file mode 100644 index 4fc9f0dc3..000000000 --- a/src/rai_tts/setup.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os -from glob import glob - -from setuptools import find_packages, setup - -package_name = "rai_tts" - -setup( - name=package_name, - version="0.1.0", - packages=find_packages(exclude=["test"]), - data_files=[ - ("share/ament_index/resource_index/packages", ["resource/" + package_name]), - ("share/" + package_name, ["package.xml"]), - (os.path.join("share", package_name, "launch"), glob("launch/*.launch.py")), - (os.path.join("share", package_name, "config"), glob("config/*.yaml")), - ], - install_requires=["setuptools"], - zip_safe=True, - maintainer="maciejmajek", - maintainer_email="maciej.majek@robotec.ai", - description="A Text To Speech package with streaming capabilities.", - license="Apache-2.0", - tests_require=["pytest"], - entry_points={ - "console_scripts": [ - "tts_node = rai_tts.tts_node:main", - ], - }, -) From 4d803da526ec2e87921328558ff4b6ced4696c3c Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Thu, 6 Feb 2025 19:51:35 +0100 Subject: [PATCH 4/8] chore: use ruff instead of black, isort and flake8 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Kajetan Rachwał --- .pre-commit-config.yaml | 26 +++++--------------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2caf3cb76..ca94af87b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,25 +23,9 @@ repos: hooks: - id: shellcheck - - repo: https://github.com/pycqa/autoflake - rev: v2.3.1 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.4 hooks: - - id: autoflake - args: ["--remove-all-unused-imports", "--in-place"] - - - repo: https://github.com/pycqa/isort - rev: 5.13.2 - hooks: - - id: isort - args: ["--profile", "black", "--filter-files"] - - - repo: https://github.com/psf/black - rev: 24.4.2 - hooks: - - id: black - - - repo: https://github.com/pycqa/flake8 - rev: 7.1.0 - hooks: - - id: flake8 - args: ["--ignore=E501,E731,W503,W504,E203"] + - id: ruff + args: [--extend-select, "I,RUF022", --fix, --ignore, E731] + - id: ruff-format From 843fd9f1f52c70d44100eaf83bef2646ea83b24a Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Thu, 6 Feb 2025 19:53:36 +0100 Subject: [PATCH 5/8] chore: pre-commit --- examples/agriculture-demo.py | 12 ++-- examples/manipulation-demo-streamlit.py | 1 - examples/manipulation-demo.py | 1 - examples/rosbot-xl-demo.py | 3 +- examples/taxi-demo.py | 4 +- src/examples/turtlebot4/turtlebot_demo.py | 1 - src/rai_asr/pyproject.toml | 2 +- src/rai_asr/rai_asr/models/__init__.py | 6 +- src/rai_asr/rai_asr/models/base.py | 1 - src/rai_core/pyproject.toml | 2 +- src/rai_core/rai/agents/__init__.py | 2 +- .../rai/agents/integrations/streamlit.py | 2 +- src/rai_core/rai/agents/voice_agent.py | 4 +- src/rai_core/rai/apps/state_analyzer.py | 1 - src/rai_core/rai/apps/talk_to_docs.py | 4 +- src/rai_core/rai/communication/__init__.py | 2 +- .../rai/communication/base_connector.py | 1 - .../rai/communication/hri_connector.py | 4 +- src/rai_core/rai/communication/ros2/api.py | 3 +- .../rai/communication/ros2/connectors.py | 1 - .../communication/sound_device/__init__.py | 2 +- .../rai/communication/sound_device/api.py | 1 - src/rai_core/rai/messages/__init__.py | 4 +- src/rai_core/rai/messages/multimodal.py | 2 +- src/rai_core/rai/node.py | 1 - src/rai_core/rai/tools/ros/__init__.py | 12 ++-- src/rai_core/rai/tools/ros/deprecated.py | 4 +- src/rai_core/rai/tools/ros/native_actions.py | 8 +-- src/rai_core/rai/tools/ros/tools.py | 4 +- src/rai_core/rai/tools/ros2/__init__.py | 10 ++-- src/rai_core/rai/tools/ros2/topics.py | 4 +- src/rai_core/rai/tools/utils.py | 4 +- src/rai_core/rai/utils/configurator.py | 60 +++++++++---------- .../rai_nomad/rai_nomad/nomad.py | 6 +- .../rai_open_set_vision/__init__.py | 4 +- .../services/grounded_sam.py | 2 +- .../services/grounding_dino.py | 2 +- .../rai_open_set_vision/tools/__init__.py | 4 +- .../rai_open_set_vision/tools/gdino_tools.py | 20 +++---- .../tools/segmentation_tools.py | 11 ++-- .../vision_markup/segmenter.py | 3 +- src/rai_hmi/rai_hmi/agent.py | 2 +- src/rai_hmi/rai_hmi/base.py | 4 +- src/rai_hmi/rai_hmi/ros.py | 2 +- src/rai_hmi/rai_hmi/text_hmi.py | 6 +- src/rai_hmi/rai_hmi/voice_hmi.py | 2 +- src/rai_tts/pyproject.toml | 2 +- src/rai_whoami/rai_whoami/rai_whoami_node.py | 11 ++-- tests/communication/ros2/__init__.py | 6 +- tests/communication/ros2/test_api.py | 5 +- tests/communication/ros2/test_connectors.py | 3 +- tests/communication/sounds_device/test_api.py | 14 ++--- .../sounds_device/test_connector.py | 3 +- tests/communication/test_hri_message.py | 1 - tests/conftest.py | 3 +- tests/core/test_rai_cli.py | 1 - tests/core/test_ros2_tools.py | 5 +- tests/core/test_tool_runner.py | 49 ++++++++------- tests/messages/test_multimodal.py | 2 - tests/messages/test_transport.py | 3 +- tests/messages/test_utils.py | 1 - tests/smoke/import_test.py | 2 - tests/tools/ros2/test_action_tools.py | 1 + tests/tools/ros2/test_service_tools.py | 1 + tests/tools/ros2/test_topic_tools.py | 2 +- tests/tools/test_tool_utils.py | 3 +- 66 files changed, 164 insertions(+), 195 deletions(-) diff --git a/examples/agriculture-demo.py b/examples/agriculture-demo.py index e80d7a474..ac49e0e98 100644 --- a/examples/agriculture-demo.py +++ b/examples/agriculture-demo.py @@ -15,12 +15,6 @@ import argparse import rclpy -from rclpy.action import ActionClient -from rclpy.callback_groups import ReentrantCallbackGroup -from rclpy.executors import MultiThreadedExecutor -from rclpy.node import Node -from std_srvs.srv import Trigger - from rai.node import RaiStateBasedLlmNode, describe_ros_image from rai.tools.ros.native import ( GetCameraImage, @@ -30,6 +24,12 @@ Ros2ShowMsgInterfaceTool, ) from rai.tools.time import WaitForSecondsTool +from rclpy.action import ActionClient +from rclpy.callback_groups import ReentrantCallbackGroup +from rclpy.executors import MultiThreadedExecutor +from rclpy.node import Node +from std_srvs.srv import Trigger + from rai_interfaces.action import Task diff --git a/examples/manipulation-demo-streamlit.py b/examples/manipulation-demo-streamlit.py index bec608118..e226c1f88 100644 --- a/examples/manipulation-demo-streamlit.py +++ b/examples/manipulation-demo-streamlit.py @@ -16,7 +16,6 @@ import streamlit as st from langchain_core.messages import AIMessage, HumanMessage, ToolMessage - from rai.agents.integrations.streamlit import get_streamlit_cb, streamlit_invoke from rai.messages import HumanMultimodalMessage diff --git a/examples/manipulation-demo.py b/examples/manipulation-demo.py index 0c8d0665b..6f73248c2 100644 --- a/examples/manipulation-demo.py +++ b/examples/manipulation-demo.py @@ -17,7 +17,6 @@ import rclpy import rclpy.qos from langchain_core.messages import HumanMessage - from rai.agents.conversational_agent import create_conversational_agent from rai.node import RaiBaseNode from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool diff --git a/examples/rosbot-xl-demo.py b/examples/rosbot-xl-demo.py index 90600b6ae..b6c1ab666 100644 --- a/examples/rosbot-xl-demo.py +++ b/examples/rosbot-xl-demo.py @@ -18,8 +18,6 @@ import rclpy import rclpy.executors import rclpy.logging -from rai_open_set_vision.tools import GetDetectionTool, GetDistanceToObjectsTool - from rai.node import RaiStateBasedLlmNode from rai.tools.ros.native import ( GetMsgFromTopic, @@ -35,6 +33,7 @@ Ros2RunActionAsync, ) from rai.tools.time import WaitForSecondsTool +from rai_open_set_vision.tools import GetDetectionTool, GetDistanceToObjectsTool p = argparse.ArgumentParser() p.add_argument("--allowlist", type=Path, required=False, default=None) diff --git a/examples/taxi-demo.py b/examples/taxi-demo.py index 6f5d24a48..18d498a68 100644 --- a/examples/taxi-demo.py +++ b/examples/taxi-demo.py @@ -20,12 +20,12 @@ from langchain_community.tools.tavily_search import TavilySearchResults from langchain_core.messages import BaseMessage, HumanMessage from langchain_core.tools import tool -from std_msgs.msg import String - from rai.agents.conversational_agent import create_conversational_agent from rai.tools.ros.cli import Ros2ServiceTool from rai.tools.ros.native import Ros2PubMessageTool from rai.utils.model_initialization import get_llm_model, get_tracing_callbacks +from std_msgs.msg import String + from rai_hmi.api import GenericVoiceNode, split_message system_prompt = """ diff --git a/src/examples/turtlebot4/turtlebot_demo.py b/src/examples/turtlebot4/turtlebot_demo.py index 393606374..c3ea985b8 100644 --- a/src/examples/turtlebot4/turtlebot_demo.py +++ b/src/examples/turtlebot4/turtlebot_demo.py @@ -23,7 +23,6 @@ import rclpy.qos import rclpy.subscription import rclpy.task - from rai.node import RaiStateBasedLlmNode from rai.tools.ros.native import ( GetCameraImage, diff --git a/src/rai_asr/pyproject.toml b/src/rai_asr/pyproject.toml index ca4524d6b..dfe16c043 100644 --- a/src/rai_asr/pyproject.toml +++ b/src/rai_asr/pyproject.toml @@ -27,4 +27,4 @@ faster-whisper = "^1.1.1" pydub = "^0.25.1" [tool.isort] -profile = "black" \ No newline at end of file +profile = "black" diff --git a/src/rai_asr/rai_asr/models/__init__.py b/src/rai_asr/rai_asr/models/__init__.py index 1d1a7e9de..daaa95a15 100644 --- a/src/rai_asr/rai_asr/models/__init__.py +++ b/src/rai_asr/rai_asr/models/__init__.py @@ -19,10 +19,10 @@ from rai_asr.models.silero_vad import SileroVAD __all__ = [ - "BaseVoiceDetectionModel", - "SileroVAD", - "OpenWakeWord", "BaseTranscriptionModel", + "BaseVoiceDetectionModel", "LocalWhisper", "OpenAIWhisper", + "OpenWakeWord", + "SileroVAD", ] diff --git a/src/rai_asr/rai_asr/models/base.py b/src/rai_asr/rai_asr/models/base.py index 13142df87..1cf62a14c 100644 --- a/src/rai_asr/rai_asr/models/base.py +++ b/src/rai_asr/rai_asr/models/base.py @@ -21,7 +21,6 @@ class BaseVoiceDetectionModel(ABC): - def __call__( self, audio_data: NDArray, input_parameters: dict[str, Any] ) -> Tuple[bool, dict[str, Any]]: diff --git a/src/rai_core/pyproject.toml b/src/rai_core/pyproject.toml index abec6c940..48dbc1771 100644 --- a/src/rai_core/pyproject.toml +++ b/src/rai_core/pyproject.toml @@ -32,4 +32,4 @@ tomli = "^2.0.1" tomli-w = "^1.1.0" [tool.isort] -profile = "black" \ No newline at end of file +profile = "black" diff --git a/src/rai_core/rai/agents/__init__.py b/src/rai_core/rai/agents/__init__.py index 2b7d4461a..e28822100 100644 --- a/src/rai_core/rai/agents/__init__.py +++ b/src/rai_core/rai/agents/__init__.py @@ -19,7 +19,7 @@ __all__ = [ "ToolRunner", + "VoiceRecognitionAgent", "create_conversational_agent", "create_state_based_agent", - "VoiceRecognitionAgent", ] diff --git a/src/rai_core/rai/agents/integrations/streamlit.py b/src/rai_core/rai/agents/integrations/streamlit.py index 18ca98683..73360893a 100644 --- a/src/rai_core/rai/agents/integrations/streamlit.py +++ b/src/rai_core/rai/agents/integrations/streamlit.py @@ -125,7 +125,7 @@ def on_tool_end(self, output: Any, **kwargs: Any) -> Any: # Decorator function to add the Streamlit execution context to a function def add_streamlit_context( - fn: Callable[..., fn_return_type] + fn: Callable[..., fn_return_type], ) -> Callable[..., fn_return_type]: """ Decorator to ensure that the decorated function runs within the Streamlit execution context. diff --git a/src/rai_core/rai/agents/voice_agent.py b/src/rai_core/rai/agents/voice_agent.py index c012837f6..339db49da 100644 --- a/src/rai_core/rai/agents/voice_agent.py +++ b/src/rai_core/rai/agents/voice_agent.py @@ -187,7 +187,9 @@ def should_record( def transcription_thread(self, identifier: str): self.logger.info(f"transcription thread {identifier} started") audio_data = np.concatenate(self.transcription_buffers[identifier]) - with self.transcription_lock: # this is only necessary for the local model... TODO: fix this somehow + with ( + self.transcription_lock + ): # this is only necessary for the local model... TODO: fix this somehow transcription = self.transcription_model.transcribe(audio_data) assert isinstance(self.connectors["ros2"], ROS2ARIConnector) self.connectors["ros2"].send_message( diff --git a/src/rai_core/rai/apps/state_analyzer.py b/src/rai_core/rai/apps/state_analyzer.py index 0a736dadb..6ebe2d0db 100644 --- a/src/rai_core/rai/apps/state_analyzer.py +++ b/src/rai_core/rai/apps/state_analyzer.py @@ -39,7 +39,6 @@ def robot_state_analyzer( state: str, state_analyzer_prompt: str = STATE_ANALYZER_PROMPT, ) -> State: - template = ChatPromptTemplate.from_messages( [ ("system", state_analyzer_prompt), diff --git a/src/rai_core/rai/apps/talk_to_docs.py b/src/rai_core/rai/apps/talk_to_docs.py index 44811898c..8cccbf8f3 100644 --- a/src/rai_core/rai/apps/talk_to_docs.py +++ b/src/rai_core/rai/apps/talk_to_docs.py @@ -78,7 +78,9 @@ def talk_to_docs(documentation_root: str, llm: BaseChatModel): agent = create_tool_calling_agent(llm, [query_docs], prompt) # type: ignore agent_executor = AgentExecutor( - agent=agent, tools=[query_docs], return_intermediate_steps=True # type: ignore + agent=agent, + tools=[query_docs], + return_intermediate_steps=True, # type: ignore ) def input_node(state: State) -> State: diff --git a/src/rai_core/rai/communication/__init__.py b/src/rai_core/rai/communication/__init__.py index 394fdbb61..f18324d79 100644 --- a/src/rai_core/rai/communication/__init__.py +++ b/src/rai_core/rai/communication/__init__.py @@ -25,8 +25,8 @@ __all__ = [ "ARIConnector", "ARIMessage", - "BaseMessage", "BaseConnector", + "BaseMessage", "HRIConnector", "HRIMessage", "HRIPayload", diff --git a/src/rai_core/rai/communication/base_connector.py b/src/rai_core/rai/communication/base_connector.py index 21d461b62..901256ddf 100644 --- a/src/rai_core/rai/communication/base_connector.py +++ b/src/rai_core/rai/communication/base_connector.py @@ -39,7 +39,6 @@ def __init__( class BaseConnector(Generic[T]): - def _generate_handle(self) -> str: return str(uuid4()) diff --git a/src/rai_core/rai/communication/hri_connector.py b/src/rai_core/rai/communication/hri_connector.py index ba578afdf..a71b496e2 100644 --- a/src/rai_core/rai/communication/hri_connector.py +++ b/src/rai_core/rai/communication/hri_connector.py @@ -17,9 +17,8 @@ from io import BytesIO from typing import Any, Dict, Generic, Literal, Optional, Sequence, TypeVar, get_args -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import BaseMessage as LangchainBaseMessage -from langchain_core.messages import HumanMessage from PIL import Image from PIL.Image import Image as ImageType from pydub import AudioSegment @@ -166,7 +165,6 @@ def _build_message( self, message: LangchainBaseMessage | RAIMultimodalMessage, ) -> T: - return self.T_class.from_langchain(message) def send_all_targets(self, message: LangchainBaseMessage | RAIMultimodalMessage): diff --git a/src/rai_core/rai/communication/ros2/api.py b/src/rai_core/rai/communication/ros2/api.py index a853f0932..44e03aafc 100644 --- a/src/rai_core/rai/communication/ros2/api.py +++ b/src/rai_core/rai/communication/ros2/api.py @@ -334,7 +334,6 @@ def __post_init__(self): class ConfigurableROS2TopicAPI(ROS2TopicAPI): - def __init__(self, node: rclpy.node.Node): super().__init__(node) self._subscribtions: dict[str, rclpy.node.Subscription] = {} @@ -562,7 +561,7 @@ def is_goal_done(self, handle: str) -> bool: raise ValueError(f"Invalid action handle: {handle}") if self.actions[handle]["result_future"] is None: raise ValueError( - f"Result future is None for handle: {handle}. " "Was the goal accepted?" + f"Result future is None for handle: {handle}. Was the goal accepted?" ) return self.actions[handle]["result_future"].done() diff --git a/src/rai_core/rai/communication/ros2/connectors.py b/src/rai_core/rai/communication/ros2/connectors.py index f01a08258..2b4c94097 100644 --- a/src/rai_core/rai/communication/ros2/connectors.py +++ b/src/rai_core/rai/communication/ros2/connectors.py @@ -76,7 +76,6 @@ def send_message( qos_profile: Optional[QoSProfile] = None, **kwargs: Any, ): - self._topic_api.publish( topic=target, msg_content=message.payload, diff --git a/src/rai_core/rai/communication/sound_device/__init__.py b/src/rai_core/rai/communication/sound_device/__init__.py index 450926768..503c274d9 100644 --- a/src/rai_core/rai/communication/sound_device/__init__.py +++ b/src/rai_core/rai/communication/sound_device/__init__.py @@ -18,6 +18,6 @@ __all__ = [ "SoundDeviceAPI", "SoundDeviceConfig", - "SoundDeviceError", "SoundDeviceConnector", + "SoundDeviceError", ] diff --git a/src/rai_core/rai/communication/sound_device/api.py b/src/rai_core/rai/communication/sound_device/api.py index e8aa9ae88..98d554e4b 100644 --- a/src/rai_core/rai/communication/sound_device/api.py +++ b/src/rai_core/rai/communication/sound_device/api.py @@ -49,7 +49,6 @@ def __post_init__(self): class SoundDeviceAPI: - def __init__(self, config: SoundDeviceConfig): self.device_name = "" diff --git a/src/rai_core/rai/messages/__init__.py b/src/rai_core/rai/messages/__init__.py index 929e04c5c..f5af4f43d 100644 --- a/src/rai_core/rai/messages/__init__.py +++ b/src/rai_core/rai/messages/__init__.py @@ -23,10 +23,10 @@ from .utils import preprocess_image __all__ = [ - "HumanMultimodalMessage", "AiMultimodalMessage", + "HumanMultimodalMessage", + "MultimodalArtifact", "SystemMultimodalMessage", "ToolMultimodalMessage", - "MultimodalArtifact", "preprocess_image", ] diff --git a/src/rai_core/rai/messages/multimodal.py b/src/rai_core/rai/messages/multimodal.py index 8db862bee..33e0ff35c 100644 --- a/src/rai_core/rai/messages/multimodal.py +++ b/src/rai_core/rai/messages/multimodal.py @@ -72,7 +72,7 @@ def __repr_args__(self) -> Any: v = [c for c in v if c["type"] != "image_url"] elif k == "images": imgs_summary = [image[0:10] + "..." for image in v] - v = f'{len(v)} base64 encoded images: [{", ".join(imgs_summary)}]' + v = f"{len(v)} base64 encoded images: [{', '.join(imgs_summary)}]" new_args.append((k, v)) return new_args diff --git a/src/rai_core/rai/node.py b/src/rai_core/rai/node.py index fcc2d3ae1..e39fe2300 100644 --- a/src/rai_core/rai/node.py +++ b/src/rai_core/rai/node.py @@ -303,7 +303,6 @@ async def agent_loop(self, goal_handle: ServerGoalHandle): "callbacks": get_tracing_callbacks(), }, ): - graph_node_name = list(state.keys())[0] if graph_node_name == "reporter": continue diff --git a/src/rai_core/rai/tools/ros/__init__.py b/src/rai_core/rai/tools/ros/__init__.py index 71e488752..c3be30b81 100644 --- a/src/rai_core/rai/tools/ros/__init__.py +++ b/src/rai_core/rai/tools/ros/__init__.py @@ -29,15 +29,15 @@ ) __all__ = [ + "AddDescribedWaypointToDatabaseTool", + "GetCurrentPositionTool", + "GetOccupancyGridTool", + "Ros2BaseInput", + "Ros2BaseTool", "ros2_action", "ros2_interface", "ros2_node", - "ros2_topic", "ros2_param", "ros2_service", - "Ros2BaseTool", - "Ros2BaseInput", - "AddDescribedWaypointToDatabaseTool", - "GetOccupancyGridTool", - "GetCurrentPositionTool", + "ros2_topic", ] diff --git a/src/rai_core/rai/tools/ros/deprecated.py b/src/rai_core/rai/tools/ros/deprecated.py index 98c0e8ccf..ec8e3ba8e 100644 --- a/src/rai_core/rai/tools/ros/deprecated.py +++ b/src/rai_core/rai/tools/ros/deprecated.py @@ -104,7 +104,9 @@ def __init__( def postprocess(self, msg: Image) -> str: bridge = CvBridge() - cv_image = cast(cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")) # type: ignore + cv_image = cast( + cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough") + ) # type: ignore if cv_image.shape[-1] == 4: cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGRA2RGB) base64_image = base64.b64encode( diff --git a/src/rai_core/rai/tools/ros/native_actions.py b/src/rai_core/rai/tools/ros/native_actions.py index ba70a5364..c3833a3d7 100644 --- a/src/rai_core/rai/tools/ros/native_actions.py +++ b/src/rai_core/rai/tools/ros/native_actions.py @@ -60,9 +60,7 @@ class Ros2BaseActionTool(Ros2BaseTool): class Ros2RunActionSync(Ros2BaseTool): name: str = "Ros2RunAction" - description: str = ( - "A tool for running a ros2 action. Make sure you know the action interface first!!! Actions might take some time to execute and are blocking - you will not be able to check their feedback, only will be informed about the result" - ) + description: str = "A tool for running a ros2 action. Make sure you know the action interface first!!! Actions might take some time to execute and are blocking - you will not be able to check their feedback, only will be informed about the result" args_schema: Type[Ros2ActionRunnerInput] = Ros2ActionRunnerInput @@ -176,9 +174,7 @@ def _run(self) -> bool: class Ros2GetLastActionFeedback(Ros2BaseActionTool): name: str = "Ros2GetLastActionFeedback" - description: str = ( - "Action feedback is an optional intermediate information from ros2 action. With this tool you can get the last feedback of running action." - ) + description: str = "Action feedback is an optional intermediate information from ros2 action. With this tool you can get the last feedback of running action." args_schema: Type[Ros2BaseInput] = Ros2BaseInput diff --git a/src/rai_core/rai/tools/ros/tools.py b/src/rai_core/rai/tools/ros/tools.py index 7148e8eef..508c174a3 100644 --- a/src/rai_core/rai/tools/ros/tools.py +++ b/src/rai_core/rai/tools/ros/tools.py @@ -98,9 +98,7 @@ class GetOccupancyGridTool(BaseTool): """Get the current map as an image with the robot's position marked on it (red dot).""" name: str = "GetOccupancyGridTool" - description: str = ( - "A tool for getting the current map as an image with the robot's position marked on it." - ) + description: str = "A tool for getting the current map as an image with the robot's position marked on it." args_schema: Type[TopicInput] = TopicInput diff --git a/src/rai_core/rai/tools/ros2/__init__.py b/src/rai_core/rai/tools/ros2/__init__.py index ea7d939f5..460f4d100 100644 --- a/src/rai_core/rai/tools/ros2/__init__.py +++ b/src/rai_core/rai/tools/ros2/__init__.py @@ -24,13 +24,13 @@ ) __all__ = [ - "StartROS2ActionTool", - "GetROS2ImageTool", - "PublishROS2MessageTool", - "ReceiveROS2MessageTool", "CallROS2ServiceTool", "CancelROS2ActionTool", - "GetROS2TopicsNamesAndTypesTool", + "GetROS2ImageTool", "GetROS2MessageInterfaceTool", + "GetROS2TopicsNamesAndTypesTool", "GetROS2TransformTool", + "PublishROS2MessageTool", + "ReceiveROS2MessageTool", + "StartROS2ActionTool", ] diff --git a/src/rai_core/rai/tools/ros2/topics.py b/src/rai_core/rai/tools/ros2/topics.py index fce3a8bae..82dbfb347 100644 --- a/src/rai_core/rai/tools/ros2/topics.py +++ b/src/rai_core/rai/tools/ros2/topics.py @@ -98,7 +98,9 @@ def _run(self, topic: str) -> Tuple[str, MultimodalArtifact]: raise ValueError( f"Unsupported message type: {message.metadata['msg_type']}" ) - return "Image received successfully", MultimodalArtifact(images=[preprocess_image(image)]) # type: ignore + return "Image received successfully", MultimodalArtifact( + images=[preprocess_image(image)] + ) # type: ignore class GetROS2TopicsNamesAndTypesTool(BaseTool): diff --git a/src/rai_core/rai/tools/utils.py b/src/rai_core/rai/tools/utils.py index 58da02bc6..37c3c8fc3 100644 --- a/src/rai_core/rai/tools/utils.py +++ b/src/rai_core/rai/tools/utils.py @@ -249,7 +249,9 @@ def __init__( def postprocess(self, msg: Image) -> str: bridge = CvBridge() - cv_image = cast(cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")) # type: ignore + cv_image = cast( + cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough") + ) # type: ignore if cv_image.shape[-1] == 4: cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGRA2RGB) base64_image = base64.b64encode( diff --git a/src/rai_core/rai/utils/configurator.py b/src/rai_core/rai/utils/configurator.py index 87be321fb..522b98224 100644 --- a/src/rai_core/rai/utils/configurator.py +++ b/src/rai_core/rai/utils/configurator.py @@ -316,19 +316,19 @@ def on_model_vendor_change(model_type: str): ) def on_langfuse_change(): - st.session_state.config["tracing"]["langfuse"][ - "use_langfuse" - ] = st.session_state.langfuse_checkbox + st.session_state.config["tracing"]["langfuse"]["use_langfuse"] = ( + st.session_state.langfuse_checkbox + ) def on_langfuse_host_change(): - st.session_state.config["tracing"]["langfuse"][ - "host" - ] = st.session_state.langfuse_host_input + st.session_state.config["tracing"]["langfuse"]["host"] = ( + st.session_state.langfuse_host_input + ) def on_langsmith_change(): - st.session_state.config["tracing"]["langsmith"][ - "use_langsmith" - ] = st.session_state.langsmith_checkbox + st.session_state.config["tracing"]["langsmith"]["use_langsmith"] = ( + st.session_state.langsmith_checkbox + ) # Ensure tracing config exists if "tracing" not in st.session_state.config: @@ -397,9 +397,9 @@ def on_langsmith_change(): elif st.session_state.current_step == 4: def on_recording_device_change(): - st.session_state.config["asr"][ - "recording_device_name" - ] = st.session_state.recording_device_select + st.session_state.config["asr"]["recording_device_name"] = ( + st.session_state.recording_device_select + ) def on_asr_vendor_change(): vendor = ( @@ -413,29 +413,29 @@ def on_language_change(): st.session_state.config["asr"]["language"] = st.session_state.language_input def on_silence_grace_change(): - st.session_state.config["asr"][ - "silence_grace_period" - ] = st.session_state.silence_grace_input + st.session_state.config["asr"]["silence_grace_period"] = ( + st.session_state.silence_grace_input + ) def on_vad_threshold_change(): - st.session_state.config["asr"][ - "vad_threshold" - ] = st.session_state.vad_threshold_input + st.session_state.config["asr"]["vad_threshold"] = ( + st.session_state.vad_threshold_input + ) def on_wake_word_change(): - st.session_state.config["asr"][ - "use_wake_word" - ] = st.session_state.wake_word_checkbox + st.session_state.config["asr"]["use_wake_word"] = ( + st.session_state.wake_word_checkbox + ) def on_wake_word_model_change(): - st.session_state.config["asr"][ - "wake_word_model" - ] = st.session_state.wake_word_model_input + st.session_state.config["asr"]["wake_word_model"] = ( + st.session_state.wake_word_model_input + ) def on_wake_word_threshold_change(): - st.session_state.config["asr"][ - "wake_word_threshold" - ] = st.session_state.wake_word_threshold_input + st.session_state.config["asr"]["wake_word_threshold"] = ( + st.session_state.wake_word_threshold_input + ) # Ensure asr config exists if "asr" not in st.session_state.config: @@ -588,9 +588,9 @@ def on_tts_vendor_change(): st.session_state.config["tts"]["vendor"] = vendor def on_keep_speaker_busy_change(): - st.session_state.config["tts"][ - "keep_speaker_busy" - ] = st.session_state.keep_speaker_busy_checkbox + st.session_state.config["tts"]["keep_speaker_busy"] = ( + st.session_state.keep_speaker_busy_checkbox + ) # Ensure tts config exists if "tts" not in st.session_state.config: diff --git a/src/rai_extensions/rai_nomad/rai_nomad/nomad.py b/src/rai_extensions/rai_nomad/rai_nomad/nomad.py index 5095d51ca..6e87a0be7 100644 --- a/src/rai_extensions/rai_nomad/rai_nomad/nomad.py +++ b/src/rai_extensions/rai_nomad/rai_nomad/nomad.py @@ -330,9 +330,9 @@ def pd_controller(self, waypoint: np.ndarray) -> Tuple[float]: angular_vel = ( self.get_parameter("angular_vel").get_parameter_value().double_value ) - assert ( - len(waypoint) == 2 or len(waypoint) == 4 - ), "waypoint must be a 2D or 4D vector" + assert len(waypoint) == 2 or len(waypoint) == 4, ( + "waypoint must be a 2D or 4D vector" + ) if len(waypoint) == 2: dx, dy = waypoint else: diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py index fa3562537..3b7a71e20 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/__init__.py @@ -17,8 +17,8 @@ from .tools import GetDetectionTool, GetDistanceToObjectsTool __all__ = [ - "GetDistanceToObjectsTool", - "GetDetectionTool", "GDINO_NODE_NAME", "GDINO_SERVICE_NAME", + "GetDetectionTool", + "GetDistanceToObjectsTool", ] diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py index 90b75351d..45fe9eb52 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounded_sam.py @@ -21,10 +21,10 @@ import rclpy from ament_index_python.packages import get_package_share_directory from cv_bridge import CvBridge -from rai_open_set_vision.vision_markup.segmenter import GDSegmenter from rclpy.node import Node from rai_interfaces.srv import RAIGroundedSam +from rai_open_set_vision.vision_markup.segmenter import GDSegmenter GSAM_NODE_NAME = "grounded_sam" GSAM_SERVICE_NAME = "grounded_sam_segment" diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py index 4fe1ae8c9..7927d7cf8 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/services/grounding_dino.py @@ -20,12 +20,12 @@ import rclpy from ament_index_python.packages import get_package_share_directory -from rai_open_set_vision.vision_markup.boxer import GDBoxer from rclpy.node import Node from sensor_msgs.msg import Image from rai_interfaces.msg import RAIDetectionArray from rai_interfaces.srv import RAIGroundingDino +from rai_open_set_vision.vision_markup.boxer import GDBoxer class GDRequest(TypedDict): diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py index 4091ab0d0..52330070b 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/__init__.py @@ -16,8 +16,8 @@ from .segmentation_tools import GetGrabbingPointTool, GetSegmentationTool __all__ = [ - "GetDistanceToObjectsTool", "GetDetectionTool", - "GetSegmentationTool", + "GetDistanceToObjectsTool", "GetGrabbingPointTool", + "GetSegmentationTool", ] diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py index 2fa76fba1..336ec4bc7 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py @@ -19,19 +19,19 @@ import rclpy.qos import sensor_msgs.msg from pydantic import BaseModel, Field -from rai_open_set_vision import GDINO_SERVICE_NAME +from rai.node import RaiBaseNode +from rai.tools.ros import Ros2BaseInput, Ros2BaseTool +from rai.tools.ros.utils import convert_ros_img_to_ndarray +from rai.tools.utils import wait_for_message +from rai.utils.ros_async import get_future_result from rclpy.exceptions import ( ParameterNotDeclaredException, ParameterUninitializedException, ) from rclpy.task import Future -from rai.node import RaiBaseNode -from rai.tools.ros import Ros2BaseInput, Ros2BaseTool -from rai.tools.ros.utils import convert_ros_img_to_ndarray -from rai.tools.utils import wait_for_message -from rai.utils.ros_async import get_future_result from rai_interfaces.srv import RAIGroundingDino +from rai_open_set_vision import GDINO_SERVICE_NAME # --------------------- Inputs --------------------- @@ -147,9 +147,7 @@ def _parse_detection_array( class GetDetectionTool(GroundingDinoBaseTool): name: str = "GetDetectionTool" - description: str = ( - "A tool for detecting specified objects using a ros2 action. The tool call might take some time to execute and is blocking - you will not be able to check their feedback, only will be informed about the result." - ) + description: str = "A tool for detecting specified objects using a ros2 action. The tool call might take some time to execute and is blocking - you will not be able to check their feedback, only will be informed about the result." args_schema: Type[Ros2GetDetectionInput] = Ros2GetDetectionInput @@ -177,9 +175,7 @@ def _run( class GetDistanceToObjectsTool(GroundingDinoBaseTool): name: str = "GetDistanceToObjectsTool" - description: str = ( - "A tool for calculating distance to specified objects using a ros2 action. The tool call might take some time to execute and is blocking - you will not be able to check their feedback, only will be informed about the result." - ) + description: str = "A tool for calculating distance to specified objects using a ros2 action. The tool call might take some time to execute and is blocking - you will not be able to check their feedback, only will be informed about the result." args_schema: Type[GetDistanceToObjectsInput] = GetDistanceToObjectsInput diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py index c264bd885..29ff0fe18 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py @@ -19,18 +19,18 @@ import rclpy import sensor_msgs.msg from pydantic import Field -from rai_open_set_vision import GDINO_SERVICE_NAME +from rai.node import RaiBaseNode +from rai.tools.ros import Ros2BaseInput, Ros2BaseTool +from rai.tools.ros.utils import convert_ros_img_to_base64, convert_ros_img_to_ndarray +from rai.utils.ros_async import get_future_result from rclpy import Future from rclpy.exceptions import ( ParameterNotDeclaredException, ParameterUninitializedException, ) -from rai.node import RaiBaseNode -from rai.tools.ros import Ros2BaseInput, Ros2BaseTool -from rai.tools.ros.utils import convert_ros_img_to_base64, convert_ros_img_to_ndarray -from rai.utils.ros_async import get_future_result from rai_interfaces.srv import RAIGroundedSam, RAIGroundingDino +from rai_open_set_vision import GDINO_SERVICE_NAME # --------------------- Inputs --------------------- @@ -186,7 +186,6 @@ def depth_to_point_cloud( class GetGrabbingPointTool(GetSegmentationTool): - name: str = "GetGrabbingPointTool" description: str = "Get the grabbing point of an object" pcd: List[Any] = [] diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/vision_markup/segmenter.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/vision_markup/segmenter.py index e9c44c087..e91563a83 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/vision_markup/segmenter.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/vision_markup/segmenter.py @@ -20,13 +20,12 @@ import numpy as np import torch from cv_bridge import CvBridge +from rai.tools.ros.utils import convert_ros_img_to_ndarray from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor from sensor_msgs.msg import Image from vision_msgs.msg import BoundingBox2D -from rai.tools.ros.utils import convert_ros_img_to_ndarray - class GDSegmenter: def __init__( diff --git a/src/rai_hmi/rai_hmi/agent.py b/src/rai_hmi/rai_hmi/agent.py index 7243b699f..c54ead3da 100644 --- a/src/rai_hmi/rai_hmi/agent.py +++ b/src/rai_hmi/rai_hmi/agent.py @@ -17,11 +17,11 @@ from typing import List from langchain.tools import tool - from rai.agents.conversational_agent import create_conversational_agent from rai.node import RaiBaseNode from rai.tools.ros.native import GetCameraImage, Ros2GetRobotInterfaces from rai.utils.model_initialization import get_llm_model + from rai_hmi.base import BaseHMINode from rai_hmi.chat_msgs import MissionMessage from rai_hmi.task import Task, TaskInput diff --git a/src/rai_hmi/rai_hmi/base.py b/src/rai_hmi/rai_hmi/base.py index 9502a2255..41bf656d6 100644 --- a/src/rai_hmi/rai_hmi/base.py +++ b/src/rai_hmi/rai_hmi/base.py @@ -22,13 +22,13 @@ from langchain_core.documents import Document from langchain_core.tools import BaseTool from pydantic import UUID4 +from rai.node import append_whoami_info_to_prompt +from rai.utils.model_initialization import get_embeddings_model from rclpy.action import ActionClient from rclpy.node import Node from std_msgs.msg import String from std_srvs.srv import Trigger -from rai.node import append_whoami_info_to_prompt -from rai.utils.model_initialization import get_embeddings_model from rai_hmi.chat_msgs import ( MissionAcceptanceMessage, MissionDoneMessage, diff --git a/src/rai_hmi/rai_hmi/ros.py b/src/rai_hmi/rai_hmi/ros.py index 181132f3e..eed05149e 100644 --- a/src/rai_hmi/rai_hmi/ros.py +++ b/src/rai_hmi/rai_hmi/ros.py @@ -19,9 +19,9 @@ from typing import Optional, Tuple import rclpy +from rai.node import RaiBaseNode from rclpy.executors import MultiThreadedExecutor -from rai.node import RaiBaseNode from rai_hmi.base import BaseHMINode diff --git a/src/rai_hmi/rai_hmi/text_hmi.py b/src/rai_hmi/rai_hmi/text_hmi.py index 955888062..8f89b5bf7 100644 --- a/src/rai_hmi/rai_hmi/text_hmi.py +++ b/src/rai_hmi/rai_hmi/text_hmi.py @@ -33,12 +33,12 @@ ) from PIL import Image from pydantic import BaseModel -from rclpy.node import Node -from streamlit.delta_generator import DeltaGenerator - from rai.messages import HumanMultimodalMessage from rai.node import RaiBaseNode from rai.utils.artifacts import get_stored_artifacts +from rclpy.node import Node +from streamlit.delta_generator import DeltaGenerator + from rai_hmi.agent import initialize_agent from rai_hmi.base import BaseHMINode from rai_hmi.chat_msgs import EMOJIS, MissionMessage diff --git a/src/rai_hmi/rai_hmi/voice_hmi.py b/src/rai_hmi/rai_hmi/voice_hmi.py index e380020f2..71b9e04a3 100644 --- a/src/rai_hmi/rai_hmi/voice_hmi.py +++ b/src/rai_hmi/rai_hmi/voice_hmi.py @@ -22,12 +22,12 @@ import rclpy from langchain_core.messages import HumanMessage +from rai.node import RaiBaseNode from rclpy.callback_groups import ReentrantCallbackGroup from rclpy.executors import MultiThreadedExecutor from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile, ReliabilityPolicy from std_msgs.msg import String -from rai.node import RaiBaseNode from rai_hmi.agent import initialize_agent from rai_hmi.base import BaseHMINode from rai_hmi.text_hmi_utils import Memory diff --git a/src/rai_tts/pyproject.toml b/src/rai_tts/pyproject.toml index ec6ce5d28..a80efc23d 100644 --- a/src/rai_tts/pyproject.toml +++ b/src/rai_tts/pyproject.toml @@ -23,4 +23,4 @@ elevenlabs = "^1.4.1" sounddevice = "^0.4.7" [tool.isort] -profile = "black" \ No newline at end of file +profile = "black" diff --git a/src/rai_whoami/rai_whoami/rai_whoami_node.py b/src/rai_whoami/rai_whoami/rai_whoami_node.py index c98f17b46..d592fb54d 100644 --- a/src/rai_whoami/rai_whoami/rai_whoami_node.py +++ b/src/rai_whoami/rai_whoami/rai_whoami_node.py @@ -18,21 +18,20 @@ import rclpy from ament_index_python.packages import get_package_share_directory from langchain_community.vectorstores import FAISS -from rai_interfaces.srv._vector_store_retrieval import ( - VectorStoreRetrieval_Request, - VectorStoreRetrieval_Response, -) +from rai.utils.model_initialization import get_embeddings_model from rclpy.node import Node from rclpy.parameter import Parameter from std_srvs.srv import Trigger from std_srvs.srv._trigger import Trigger_Request, Trigger_Response -from rai.utils.model_initialization import get_embeddings_model from rai_interfaces.srv import VectorStoreRetrieval +from rai_interfaces.srv._vector_store_retrieval import ( + VectorStoreRetrieval_Request, + VectorStoreRetrieval_Response, +) class WhoAmI(Node): - def __init__(self): super().__init__("rai_whoami_node") self.declare_parameter("robot_description_package", Parameter.Type.STRING) diff --git a/tests/communication/ros2/__init__.py b/tests/communication/ros2/__init__.py index 4764660ff..efe9277e1 100644 --- a/tests/communication/ros2/__init__.py +++ b/tests/communication/ros2/__init__.py @@ -21,9 +21,9 @@ ) __all__ = [ - "shutdown_executors_and_threads", - "multi_threaded_spinner", - "MessagePublisher", "ActionServer", + "MessagePublisher", "MessageReceiver", + "multi_threaded_spinner", + "shutdown_executors_and_threads", ] diff --git a/tests/communication/ros2/test_api.py b/tests/communication/ros2/test_api.py index d0ba3ec6b..cdae3f98f 100644 --- a/tests/communication/ros2/test_api.py +++ b/tests/communication/ros2/test_api.py @@ -18,9 +18,6 @@ import pytest from action_msgs.msg import GoalStatus from action_msgs.srv import CancelGoal -from rclpy.executors import MultiThreadedExecutor -from rclpy.node import Node - from rai.communication.ros2.api import ( ConfigurableROS2TopicAPI, ROS2ActionAPI, @@ -28,6 +25,8 @@ ROS2TopicAPI, TopicConfig, ) +from rclpy.executors import MultiThreadedExecutor +from rclpy.node import Node from .helpers import ActionServer_ as ActionServer from .helpers import ( diff --git a/tests/communication/ros2/test_connectors.py b/tests/communication/ros2/test_connectors.py index f8a3211dd..3045d655a 100644 --- a/tests/communication/ros2/test_connectors.py +++ b/tests/communication/ros2/test_connectors.py @@ -16,11 +16,10 @@ from typing import Any, List import pytest +from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage from std_msgs.msg import String from std_srvs.srv import SetBool -from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage - from .helpers import ActionServer_ as ActionServer from .helpers import ( MessagePublisher, diff --git a/tests/communication/sounds_device/test_api.py b/tests/communication/sounds_device/test_api.py index 1df6f464d..b5a62a54a 100644 --- a/tests/communication/sounds_device/test_api.py +++ b/tests/communication/sounds_device/test_api.py @@ -16,7 +16,6 @@ import numpy as np import pytest import sounddevice - from rai.communication.sound_device import ( SoundDeviceAPI, SoundDeviceConfig, @@ -34,14 +33,13 @@ def mock_sd(): mock_stop = MagicMock() mock_wait = MagicMock() - with patch.object(sounddevice, "play", mock_play), patch.object( - sounddevice, "rec", mock_rec - ), patch.object(sounddevice, "open", mock_open), patch.object( - sounddevice, "stop", mock_stop - ), patch.object( - sounddevice, "wait", mock_wait + with ( + patch.object(sounddevice, "play", mock_play), + patch.object(sounddevice, "rec", mock_rec), + patch.object(sounddevice, "open", mock_open), + patch.object(sounddevice, "stop", mock_stop), + patch.object(sounddevice, "wait", mock_wait), ): - yield { "play": mock_play, "rec": mock_rec, diff --git a/tests/communication/sounds_device/test_connector.py b/tests/communication/sounds_device/test_connector.py index ba0daf9c4..a7f86caa9 100644 --- a/tests/communication/sounds_device/test_connector.py +++ b/tests/communication/sounds_device/test_connector.py @@ -18,14 +18,13 @@ import numpy as np import pytest import sounddevice -from scipy.io import wavfile - from rai.communication import HRIPayload from rai.communication.sound_device import SoundDeviceConfig, SoundDeviceError from rai.communication.sound_device.connector import ( # Replace with actual module name SoundDeviceConnector, SoundDeviceMessage, ) +from scipy.io import wavfile @pytest.fixture diff --git a/tests/communication/test_hri_message.py b/tests/communication/test_hri_message.py index bba976da6..87a4385fb 100644 --- a/tests/communication/test_hri_message.py +++ b/tests/communication/test_hri_message.py @@ -18,7 +18,6 @@ from langchain_core.messages import HumanMessage from PIL import Image from pydub import AudioSegment - from rai.communication import HRIMessage, HRIPayload from rai.messages.multimodal import MultimodalMessage as RAIMultimodalMessage diff --git a/tests/conftest.py b/tests/conftest.py index ce9dfd6df..74089aea0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,9 +18,8 @@ import pytest from _pytest.terminal import TerminalReporter -from tabulate import tabulate - from rai.config.models import BEDROCK_CLAUDE_HAIKU, OPENAI_MINI +from tabulate import tabulate @pytest.fixture diff --git a/tests/core/test_rai_cli.py b/tests/core/test_rai_cli.py index 19f16f97a..fa7aa3b24 100644 --- a/tests/core/test_rai_cli.py +++ b/tests/core/test_rai_cli.py @@ -17,7 +17,6 @@ from unittest.mock import MagicMock, patch import pytest - from rai.cli.rai_cli import create_rai_ws diff --git a/tests/core/test_ros2_tools.py b/tests/core/test_ros2_tools.py index 10f7c5a65..6fc48ddda 100644 --- a/tests/core/test_ros2_tools.py +++ b/tests/core/test_ros2_tools.py @@ -19,13 +19,12 @@ import rclpy from langchain_core.messages import HumanMessage, SystemMessage from langchain_openai.chat_models import ChatOpenAI +from rai.agents.state_based import create_state_based_agent +from rai.tools.ros.native import Ros2PubMessageTool from rclpy.executors import MultiThreadedExecutor from rclpy.node import Node from std_msgs.msg import String -from rai.agents.state_based import create_state_based_agent -from rai.tools.ros.native import Ros2PubMessageTool - class Subscriber(Node): def __init__(self) -> None: diff --git a/tests/core/test_tool_runner.py b/tests/core/test_tool_runner.py index 793eb8998..0c6971d26 100644 --- a/tests/core/test_tool_runner.py +++ b/tests/core/test_tool_runner.py @@ -16,7 +16,6 @@ from langchain_core.messages import AIMessage, ToolCall, ToolMessage from langchain_core.tools import tool - from rai.agents.tool_runner import ToolRunner from rai.messages import HumanMultimodalMessage, ToolMultimodalMessage from rai.messages.utils import preprocess_image @@ -36,12 +35,12 @@ def test_tool_runner_invalid_call(): tool_call = ToolCall(name="bad_fn", args={"command": "list"}, id="12345") state = {"messages": [AIMessage(content="", tool_calls=[tool_call])]} output = runner.invoke(state) - assert isinstance( - output["messages"][0], AIMessage - ), "First message is not an AIMessage" - assert isinstance( - output["messages"][1], ToolMessage - ), "Tool output is not a tool message" + assert isinstance(output["messages"][0], AIMessage), ( + "First message is not an AIMessage" + ) + assert isinstance(output["messages"][1], ToolMessage), ( + "Tool output is not a tool message" + ) assert output["messages"][1].status == "error" @@ -51,15 +50,15 @@ def test_tool_runner(): tool_call = ToolCall(name="ros2_topic", args={"command": "list"}, id="12345") state = {"messages": [AIMessage(content="", tool_calls=[tool_call])]} output = runner.invoke(state) - assert isinstance( - output["messages"][0], AIMessage - ), "First message is not an AIMessage" - assert isinstance( - output["messages"][1], ToolMessage - ), "Tool output is not a tool message" - assert ( - len(output["messages"][-1].content) > 0 - ), "Tool output is empty. At least rosout should be visible." + assert isinstance(output["messages"][0], AIMessage), ( + "First message is not an AIMessage" + ) + assert isinstance(output["messages"][1], ToolMessage), ( + "Tool output is not a tool message" + ) + assert len(output["messages"][-1].content) > 0, ( + "Tool output is empty. At least rosout should be visible." + ) def test_tool_runner_multimodal(): @@ -71,12 +70,12 @@ def test_tool_runner_multimodal(): state = {"messages": [AIMessage(content="", tool_calls=[tool_call])]} output = runner.invoke(state) - assert isinstance( - output["messages"][0], AIMessage - ), "First message is not an AIMessage" - assert isinstance( - output["messages"][1], ToolMultimodalMessage - ), "Tool output is not a multimodal message" - assert isinstance( - output["messages"][2], HumanMultimodalMessage - ), "Human output is not a multimodal message" + assert isinstance(output["messages"][0], AIMessage), ( + "First message is not an AIMessage" + ) + assert isinstance(output["messages"][1], ToolMultimodalMessage), ( + "Tool output is not a multimodal message" + ) + assert isinstance(output["messages"][2], HumanMultimodalMessage), ( + "Human output is not a multimodal message" + ) diff --git a/tests/messages/test_multimodal.py b/tests/messages/test_multimodal.py index 3ada5ec40..667d506fa 100644 --- a/tests/messages/test_multimodal.py +++ b/tests/messages/test_multimodal.py @@ -29,7 +29,6 @@ from langfuse.callback import CallbackHandler from pydantic import BaseModel, Field from pytest import FixtureRequest - from rai.messages import HumanMultimodalMessage from rai.tools.utils import run_requested_tools @@ -39,7 +38,6 @@ class GetImageToolInput(BaseModel): class GetImageTool(BaseTool): - name: str = "GetImageTool" description: str = "Get an image from the user" diff --git a/tests/messages/test_transport.py b/tests/messages/test_transport.py index 93a31e5e6..42b20869c 100644 --- a/tests/messages/test_transport.py +++ b/tests/messages/test_transport.py @@ -20,14 +20,13 @@ import numpy as np import pytest import rclpy +from rai.node import RaiBaseNode from rclpy.executors import SingleThreadedExecutor from rclpy.node import Node from rclpy.qos import QoSPresetProfiles, QoSProfile from sensor_msgs.msg import Image from std_msgs.msg import String -from rai.node import RaiBaseNode - def get_qos_profiles() -> List[str]: ros_distro = os.environ.get("ROS_DISTRO") diff --git a/tests/messages/test_utils.py b/tests/messages/test_utils.py index 268aa4c7a..baa727cba 100644 --- a/tests/messages/test_utils.py +++ b/tests/messages/test_utils.py @@ -18,7 +18,6 @@ import numpy as np import pytest from PIL import Image - from rai.messages.utils import preprocess_image diff --git a/tests/smoke/import_test.py b/tests/smoke/import_test.py index 6b7ccf202..9ea1e42e2 100644 --- a/tests/smoke/import_test.py +++ b/tests/smoke/import_test.py @@ -35,9 +35,7 @@ def rai_python_modules(): @pytest.mark.parametrize("module", rai_python_modules()) def test_can_import_all_modules_pathlib(module: ModuleType) -> None: - def import_submodules(package: ModuleType) -> None: - package_path = pathlib.Path(package.__file__).parent # type: ignore importables = set() diff --git a/tests/tools/ros2/test_action_tools.py b/tests/tools/ros2/test_action_tools.py index 3d46c69b6..e8aa6945c 100644 --- a/tests/tools/ros2/test_action_tools.py +++ b/tests/tools/ros2/test_action_tools.py @@ -24,6 +24,7 @@ from rai.communication.ros2.connectors import ROS2ARIConnector from rai.tools.ros2 import StartROS2ActionTool + from tests.communication.ros2.helpers import ActionServer_ as ActionServer from tests.communication.ros2.helpers import ( multi_threaded_spinner, diff --git a/tests/tools/ros2/test_service_tools.py b/tests/tools/ros2/test_service_tools.py index ee813b327..ca60515ad 100644 --- a/tests/tools/ros2/test_service_tools.py +++ b/tests/tools/ros2/test_service_tools.py @@ -24,6 +24,7 @@ from rai.communication.ros2.connectors import ROS2ARIConnector from rai.tools.ros2 import CallROS2ServiceTool + from tests.communication.ros2.helpers import ( ServiceServer, multi_threaded_spinner, diff --git a/tests/tools/ros2/test_topic_tools.py b/tests/tools/ros2/test_topic_tools.py index 4d615597e..5e320bf2e 100644 --- a/tests/tools/ros2/test_topic_tools.py +++ b/tests/tools/ros2/test_topic_tools.py @@ -26,7 +26,6 @@ import time from PIL import Image - from rai.communication.ros2.connectors import ROS2ARIConnector from rai.tools.ros2 import ( GetROS2ImageTool, @@ -36,6 +35,7 @@ PublishROS2MessageTool, ReceiveROS2MessageTool, ) + from tests.communication.ros2.helpers import ( ImagePublisher, MessagePublisher, diff --git a/tests/tools/test_tool_utils.py b/tests/tools/test_tool_utils.py index fb62fb21a..1efec5425 100644 --- a/tests/tools/test_tool_utils.py +++ b/tests/tools/test_tool_utils.py @@ -16,11 +16,10 @@ import pytest from geometry_msgs.msg import Point, TransformStamped from nav2_msgs.action import NavigateToPose +from rai.tools.ros2.utils import ros2_message_to_dict from sensor_msgs.msg import Image from tf2_msgs.msg import TFMessage -from rai.tools.ros2.utils import ros2_message_to_dict - # TODO(`maciejmajek`): Add custom RAI messages? @pytest.mark.parametrize( From 008252cddee67c17e2ed0b278880c65b294f32c7 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Thu, 6 Feb 2025 20:25:26 +0100 Subject: [PATCH 6/8] fix: tests --- poetry.lock | 66 ++++++++++++++++++++++++++++++++- pyproject.toml | 16 ++++---- src/rai_asr/rai_asr/__init__.py | 4 ++ src/rai_core/README.md | 0 src/rai_core/rai/cli/rai_cli.py | 2 +- src/rai_tts/README.md | 0 6 files changed, 78 insertions(+), 10 deletions(-) create mode 100644 src/rai_core/README.md create mode 100644 src/rai_tts/README.md diff --git a/poetry.lock b/poetry.lock index d6da53572..93059b345 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5653,6 +5653,70 @@ files = [ [package.dependencies] cffi = {version = "*", markers = "implementation_name == \"pypy\""} +[[package]] +name = "rai" +version = "1.0.0" +description = "Core functionality for RAI framework" +optional = false +python-versions = "^3.10, <3.13" +files = [] +develop = true + +[package.dependencies] +coloredlogs = "^15.0.1" +deprecated = "^1.2.14" +langchain = "*" +langchain-core = "^0.3" +langgraph = "*" +markdown = "^3.6" +requests = "^2.32.2" +rich = "^13.7.1" +tomli = "^2.0.1" +tomli-w = "^1.1.0" +tqdm = "^4.66.4" + +[package.source] +type = "directory" +url = "src/rai_core" + +[[package]] +name = "rai-asr" +version = "1.0.0" +description = "Automatic Speech Recognition module for RAI framework" +optional = false +python-versions = "^3.10, <3.13" +files = [] +develop = true + +[package.dependencies] +faster-whisper = "^1.1.1" +openai-whisper = "^20231117" +pydub = "^0.25.1" +scipy = "^1.14.0" +sounddevice = "^0.4.7" +torchaudio = "^2.3.1" + +[package.source] +type = "directory" +url = "src/rai_asr" + +[[package]] +name = "rai-tts" +version = "1.0.0" +description = "Text-to-Speech module for RAI framework" +optional = false +python-versions = "^3.10, <3.13" +files = [] +develop = true + +[package.dependencies] +elevenlabs = "^1.4.1" +sounddevice = "^0.4.7" + +[package.source] +type = "directory" +url = "src/rai_tts" + [[package]] name = "ray" version = "2.41.0" @@ -8238,4 +8302,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10, <3.13" -content-hash = "158877f96c27f3c9beb75aff8a9608a41b83740afe17ceda2a471378c4bb3545" +content-hash = "242e440c4ce4b31fa629d198a3d79b0854e84284d013d819a4b7a24e633a1706" diff --git a/pyproject.toml b/pyproject.toml index 0705512b8..0d3627c83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,15 +9,15 @@ classifiers = [ "Development Status :: 4 - Beta", "License :: OSI Approved :: Apache Software License", ] - -packages = [ - { include = "rai_core", from = "src" }, - { include = "rai_asr", from = "src" }, - { include = "rai_tts", from = "src" }, -] +package-mode = false [tool.poetry.dependencies] python = "^3.10, <3.13" + +rai = {path = "src/rai_core", develop = true} +rai_asr = {path = "src/rai_asr", develop = true} +rai_tts = {path = "src/rai_tts", develop = true} + langchain-core = "^0.3" langchain = "*" langgraph = "*" @@ -84,8 +84,8 @@ visualnav_transformer = { git = "https://github.com/RobotecAI/visualnav-transfor gdown = "^5.2.0" [build-system] -requires = ["setuptools>=42", "wheel"] -build-backend = "setuptools.build_meta" +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" [tool.isort] profile = "black" diff --git a/src/rai_asr/rai_asr/__init__.py b/src/rai_asr/rai_asr/__init__.py index ef74fc891..499007aa9 100644 --- a/src/rai_asr/rai_asr/__init__.py +++ b/src/rai_asr/rai_asr/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""RAI ASR package.""" + +__version__ = "0.1.0" diff --git a/src/rai_core/README.md b/src/rai_core/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/src/rai_core/rai/cli/rai_cli.py b/src/rai_core/rai/cli/rai_cli.py index 4df87cd17..70afb99cf 100644 --- a/src/rai_core/rai/cli/rai_cli.py +++ b/src/rai_core/rai/cli/rai_cli.py @@ -194,7 +194,7 @@ def create_rai_ws(): (package_path / "generated" / "robot_constitution.txt").touch() default_constitution_path = ( - "src/rai/rai/cli/resources/default_robot_constitution.txt" + "src/rai_core/rai/cli/resources/default_robot_constitution.txt" ) with open(default_constitution_path, "r") as file: default_constitution = file.read() diff --git a/src/rai_tts/README.md b/src/rai_tts/README.md new file mode 100644 index 000000000..e69de29bb From e29ec7e041addb31bda8ffee2f417c83741fb432 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Fri, 7 Feb 2025 11:37:20 +0100 Subject: [PATCH 7/8] chore: remove flake8 & isort references in source code --- pyproject.toml | 2 -- src/rai_asr/pyproject.toml | 3 --- src/rai_core/pyproject.toml | 3 --- src/rai_extensions/rai_nomad/package.xml | 1 - src/rai_extensions/rai_open_set_vision/package.xml | 1 - src/rai_tts/pyproject.toml | 3 --- src/rai_whoami/package.xml | 1 - 7 files changed, 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0d3627c83..b313dd299 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,8 +87,6 @@ gdown = "^5.2.0" requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" -[tool.isort] -profile = "black" [tool.pytest.ini_options] markers = [ diff --git a/src/rai_asr/pyproject.toml b/src/rai_asr/pyproject.toml index dfe16c043..925525951 100644 --- a/src/rai_asr/pyproject.toml +++ b/src/rai_asr/pyproject.toml @@ -25,6 +25,3 @@ scipy = "^1.14.0" torchaudio = "^2.3.1" faster-whisper = "^1.1.1" pydub = "^0.25.1" - -[tool.isort] -profile = "black" diff --git a/src/rai_core/pyproject.toml b/src/rai_core/pyproject.toml index 48dbc1771..48809e4a8 100644 --- a/src/rai_core/pyproject.toml +++ b/src/rai_core/pyproject.toml @@ -30,6 +30,3 @@ rich = "^13.7.1" deprecated = "^1.2.14" tomli = "^2.0.1" tomli-w = "^1.1.0" - -[tool.isort] -profile = "black" diff --git a/src/rai_extensions/rai_nomad/package.xml b/src/rai_extensions/rai_nomad/package.xml index 0ab802cad..c07d612ef 100644 --- a/src/rai_extensions/rai_nomad/package.xml +++ b/src/rai_extensions/rai_nomad/package.xml @@ -8,7 +8,6 @@ Apache-2.0 ament_copyright - ament_flake8 ament_pep257 python3-pytest diff --git a/src/rai_extensions/rai_open_set_vision/package.xml b/src/rai_extensions/rai_open_set_vision/package.xml index f348db3d6..9b22f49ee 100644 --- a/src/rai_extensions/rai_open_set_vision/package.xml +++ b/src/rai_extensions/rai_open_set_vision/package.xml @@ -8,7 +8,6 @@ Apache-2.0 ament_copyright - ament_flake8 ament_pep257 python3-pytest diff --git a/src/rai_tts/pyproject.toml b/src/rai_tts/pyproject.toml index a80efc23d..0c4bface2 100644 --- a/src/rai_tts/pyproject.toml +++ b/src/rai_tts/pyproject.toml @@ -21,6 +21,3 @@ build-backend = "poetry.core.masonry.api" python = "^3.10, <3.13" elevenlabs = "^1.4.1" sounddevice = "^0.4.7" - -[tool.isort] -profile = "black" diff --git a/src/rai_whoami/package.xml b/src/rai_whoami/package.xml index fbbc8118f..3508fcc6f 100644 --- a/src/rai_whoami/package.xml +++ b/src/rai_whoami/package.xml @@ -8,7 +8,6 @@ Apache-2.0 ament_copyright - ament_flake8 ament_pep257 python3-pytest rai_interfaces From 6bf65976e00941cb5af6767ee7cb00ddc2071854 Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Fri, 7 Feb 2025 15:48:54 +0100 Subject: [PATCH 8/8] fix: tool tests discovery --- tests/tools/test_tool_input_args_compatibility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tools/test_tool_input_args_compatibility.py b/tests/tools/test_tool_input_args_compatibility.py index 6ab845037..9845dc595 100644 --- a/tests/tools/test_tool_input_args_compatibility.py +++ b/tests/tools/test_tool_input_args_compatibility.py @@ -23,7 +23,7 @@ def get_all_tool_classes() -> set[BaseTool]: """Recursively find all classes that inherit from pydantic.BaseModel in src/rai/rai/tools""" tools = [] - tools_path = Path("src/rai/rai/tools") + tools_path = Path("src/rai_core/rai/tools") # Recursively find all .py files for py_file in tools_path.rglob("*.py"):