Skip to content

Commit

Permalink
feat: add ros2 hri connector (#410)
Browse files Browse the repository at this point in the history
  • Loading branch information
rachwalk authored Feb 6, 2025
1 parent 6bae47f commit 47d7090
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 9 deletions.
6 changes: 4 additions & 2 deletions src/rai/rai/communication/hri_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import base64
from dataclasses import dataclass, field
from io import BytesIO
from typing import Generic, Literal, Sequence, TypeVar, get_args
from typing import Any, Dict, Generic, Literal, Optional, Sequence, TypeVar, get_args

from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage as LangchainBaseMessage
Expand Down Expand Up @@ -54,9 +54,11 @@ class HRIMessage(BaseMessage):
def __init__(
self,
payload: HRIPayload,
message_author: Literal["ai", "human"],
metadata: Optional[Dict[str, Any]] = None,
message_author: Literal["ai", "human"] = "ai",
**kwargs,
):
super().__init__(payload, metadata)
self.message_author = message_author
self.text = payload.text
self.images = payload.images
Expand Down
99 changes: 96 additions & 3 deletions src/rai/rai/communication/ros2/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
from typing import (
Annotated,
Expand Down Expand Up @@ -105,9 +106,14 @@ def adapt_requests_to_offers(publisher_info: List[TopicEndpointInfo]) -> QoSProf
return request_qos


def build_ros2_msg(msg_type: str, msg_args: Dict[str, Any]) -> object:
"""Build a ROS2 message instance from type string and content dictionary."""
msg_cls = import_message_from_str(msg_type)
def build_ros2_msg(
msg_type: str | type[rclpy.node.MsgType], msg_args: Dict[str, Any]
) -> object:
"""Build a ROS2 message instance from string or MsgType and content dictionary."""
if isinstance(msg_type, str):
msg_cls = import_message_from_str(msg_type)
else:
msg_cls = msg_type
msg = msg_cls()
rosidl_runtime_py.set_message.set_message_fields(msg, msg_args)
return msg
Expand Down Expand Up @@ -311,6 +317,93 @@ def shutdown(self) -> None:
publisher.destroy()


@dataclass
class TopicConfig:
name: str
msg_type: str
auto_qos_matching: bool = True
qos_profile: Optional[QoSProfile] = None
is_subscriber: bool = False
subscriber_callback: Optional[Callable[[Any], None]] = None

def __post_init__(self):
if not self.auto_qos_matching and self.qos_profile is None:
raise ValueError(
"Either 'auto_qos_matching' must be True or 'qos_profile' must be set."
)


class ConfigurableROS2TopicAPI(ROS2TopicAPI):

def __init__(self, node: rclpy.node.Node):
super().__init__(node)
self._subscribtions: dict[str, rclpy.node.Subscription] = {}

def configure_publisher(self, topic: str, config: TopicConfig):
if config.is_subscriber:
raise ValueError(
"Can't reconfigure publisher with subscriber config! Set config.is_subscriber to False"
)
qos_profile = self._resolve_qos_profile(
topic, config.auto_qos_matching, config.qos_profile, for_publisher=True
)
if topic in self._publishers:
flag = self._node.destroy_publisher(self._publishers[topic].handle)
if not flag:
raise ValueError(f"Failed to reconfigure existing publisher to {topic}")

self._publishers[topic] = self._node.create_publisher(
import_message_from_str(config.msg_type),
topic=topic,
qos_profile=qos_profile,
)

def configure_subscriber(
self,
topic: str,
config: TopicConfig,
):
if not config.is_subscriber:
raise ValueError(
"Can't reconfigure subscriber with publisher config! Set config.is_subscriber to True"
)
qos_profile = self._resolve_qos_profile(
topic, config.auto_qos_matching, config.qos_profile, for_publisher=False
)
if topic in self._subscribtions:
flag = self._node.destroy_subscription(self._subscribtions[topic])
if not flag:
raise ValueError(
f"Failed to reconfigure existing subscriber to {topic}"
)

assert config.subscriber_callback is not None
self._subscribtions[topic] = self._node.create_subscription(
msg_type=import_message_from_str(config.msg_type),
topic=topic,
callback=config.subscriber_callback,
qos_profile=qos_profile,
)

def publish_configured(self, topic: str, msg_content: dict[str, Any]) -> None:
"""Publish a message to a ROS2 topic.
Args:
topic: Name of the topic to publish to
msg_content: Dictionary containing the message content
Raises:
ValueError: If topic has not been configured for publishing
"""
try:
publisher = self._publishers[topic]
except Exception as e:
raise ValueError(f"{topic} has not been configured for publishing") from e
msg_type = publisher.msg_type
msg = build_ros2_msg(msg_type, msg_content) # type: ignore
publisher.publish(msg)


class ROS2ServiceAPI:
"""Handles ROS2 service operations including calling services."""

Expand Down
108 changes: 105 additions & 3 deletions src/rai/rai/communication/ros2/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import threading
import time
import uuid
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

import rclpy
import rclpy.executors
Expand All @@ -27,8 +27,20 @@
from rclpy.qos import QoSProfile
from tf2_ros import Buffer, LookupException, TransformListener, TransformStamped

from rai.communication.ari_connector import ARIConnector, ARIMessage
from rai.communication.ros2.api import ROS2ActionAPI, ROS2ServiceAPI, ROS2TopicAPI
from rai.communication import (
ARIConnector,
ARIMessage,
HRIConnector,
HRIMessage,
HRIPayload,
)
from rai.communication.ros2.api import (
ConfigurableROS2TopicAPI,
ROS2ActionAPI,
ROS2ServiceAPI,
ROS2TopicAPI,
TopicConfig,
)


class ROS2ARIMessage(ARIMessage):
Expand Down Expand Up @@ -183,3 +195,93 @@ def shutdown(self):
self._actions_api.shutdown()
self._topic_api.shutdown()
self._node.destroy_node()


class ROS2HRIMessage(HRIMessage):
def __init__(self, payload: HRIPayload, message_author: Literal["ai", "human"]):
super().__init__(payload, message_author)


class ROS2HRIConnector(HRIConnector[ROS2HRIMessage]):
def __init__(
self,
node_name: str = f"rai_ros2_hri_connector_{str(uuid.uuid4())[-12:]}",
targets: List[Tuple[str, TopicConfig]] = [],
sources: List[Tuple[str, TopicConfig]] = [],
):
configured_targets = [target[0] for target in targets]
configured_sources = [source[0] for source in sources]

self._configure_publishers(targets)
self._configure_subscribers(sources)

super().__init__(configured_targets, configured_sources)
self._node = Node(node_name)
self._topic_api = ConfigurableROS2TopicAPI(self._node)
self._service_api = ROS2ServiceAPI(self._node)
self._actions_api = ROS2ActionAPI(self._node)

self._executor = MultiThreadedExecutor()
self._executor.add_node(self._node)
self._thread = threading.Thread(target=self._executor.spin)
self._thread.start()

def _configure_publishers(self, targets: List[Tuple[str, TopicConfig]]):
for target in targets:
self._topic_api.configure_publisher(target[0], target[1])

def _configure_subscribers(self, sources: List[Tuple[str, TopicConfig]]):
for source in sources:
self._topic_api.configure_subscriber(source[0], source[1])

def send_message(self, message: ROS2HRIMessage, target: str, **kwargs):
self._topic_api.publish_configured(
topic=target,
msg_content=message.payload,
)

def receive_message(
self,
source: str,
timeout_sec: float = 1.0,
*,
message_author: Literal["human", "ai"],
msg_type: Optional[str] = None,
auto_topic_type: bool = True,
**kwargs: Any,
) -> ROS2HRIMessage:
if msg_type != "std_msgs/msg/String":
raise ValueError("ROS2HRIConnector only supports receiving sting messages")
msg = self._topic_api.receive(
topic=source,
timeout_sec=timeout_sec,
msg_type=msg_type,
auto_topic_type=auto_topic_type,
)
payload = HRIPayload(msg.data)
return ROS2HRIMessage(payload=payload, message_author=message_author)

def service_call(
self, message: ROS2HRIMessage, target: str, timeout_sec: float, **kwargs: Any
) -> ROS2HRIMessage:
raise NotImplementedError(
f"{self.__class__.__name__} doesn't support service calls"
)

def start_action(
self,
action_data: Optional[ROS2HRIMessage],
target: str,
on_feedback: Callable,
on_done: Callable,
timeout_sec: float,
**kwargs: Any,
) -> str:
raise NotImplementedError(
f"{self.__class__.__name__} doesn't support action calls"
)

def terminate_action(self, action_handle: str, **kwargs: Any):
raise NotImplementedError(
f"{self.__class__.__name__} doesn't support action calls"
)
89 changes: 88 additions & 1 deletion tests/communication/ros2/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
from rclpy.executors import MultiThreadedExecutor
from rclpy.node import Node

from rai.communication.ros2.api import ROS2ActionAPI, ROS2ServiceAPI, ROS2TopicAPI
from rai.communication.ros2.api import (
ConfigurableROS2TopicAPI,
ROS2ActionAPI,
ROS2ServiceAPI,
ROS2TopicAPI,
TopicConfig,
)

from .helpers import ActionServer_ as ActionServer
from .helpers import (
Expand Down Expand Up @@ -59,6 +65,87 @@ def test_ros2_single_message_publish(
shutdown_executors_and_threads(executors, threads)


def test_ros2_configure_publisher(ros_setup: None, request: pytest.FixtureRequest):
topic_name = f"{request.node.originalname}_topic" # type: ignore
node_name = f"{request.node.originalname}_node" # type: ignore
node = Node(node_name)
executors, threads = multi_threaded_spinner([node])
try:
topic_api = ConfigurableROS2TopicAPI(node)
cfg = TopicConfig(topic_name, "std_msgs/msg/String")
topic_api.configure_publisher(topic_name, cfg)
assert topic_api._publishers[topic_name] is not None
finally:
shutdown_executors_and_threads(executors, threads)


def test_ros2_configre_subscriber(ros_setup, request: pytest.FixtureRequest):
topic_name = f"{request.node.originalname}_topic" # type: ignore
node_name = f"{request.node.originalname}_node" # type: ignore
node = Node(node_name)
executors, threads = multi_threaded_spinner([node])
try:
topic_api = ConfigurableROS2TopicAPI(node)
cfg = TopicConfig(
topic_name,
"std_msgs/msg/String",
is_subscriber=True,
subscriber_callback=lambda _: None,
)
topic_api.configure_subscriber(topic_name, cfg)
assert topic_api._subscribtions[topic_name] is not None
finally:
shutdown_executors_and_threads(executors, threads)


def test_ros2_single_message_publish_configured(
ros_setup: None, request: pytest.FixtureRequest
) -> None:
topic_name = f"{request.node.originalname}_topic" # type: ignore
node_name = f"{request.node.originalname}_node" # type: ignore
message_receiver = MessageReceiver(topic_name)
node = Node(node_name)
executors, threads = multi_threaded_spinner([message_receiver, node])

try:
topic_api = ConfigurableROS2TopicAPI(node)
cfg = TopicConfig(
topic_name,
"std_msgs/msg/String",
is_subscriber=False,
)
topic_api.configure_publisher(topic_name, cfg)
topic_api.publish_configured(
topic_name,
{"data": "Hello, ROS2!"},
)
time.sleep(1)
assert len(message_receiver.received_messages) == 1
assert message_receiver.received_messages[0].data == "Hello, ROS2!"
finally:
shutdown_executors_and_threads(executors, threads)


def test_ros2_single_message_publish_configured_no_config(
ros_setup: None, request: pytest.FixtureRequest
) -> None:
topic_name = f"{request.node.originalname}_topic" # type: ignore
node_name = f"{request.node.originalname}_node" # type: ignore
message_receiver = MessageReceiver(topic_name)
node = Node(node_name)
executors, threads = multi_threaded_spinner([message_receiver, node])

try:
topic_api = ConfigurableROS2TopicAPI(node)
with pytest.raises(ValueError):
topic_api.publish_configured(
topic_name,
{"data": "Hello, ROS2!"},
)
finally:
shutdown_executors_and_threads(executors, threads)


def test_ros2_single_message_publish_wrong_msg_type(
ros_setup: None, request: pytest.FixtureRequest
) -> None:
Expand Down

0 comments on commit 47d7090

Please sign in to comment.