diff --git a/src/rai/rai/tools/ros/utils.py b/src/rai/rai/tools/ros/utils.py index 96daf3814..d8a0ae88a 100644 --- a/src/rai/rai/tools/ros/utils.py +++ b/src/rai/rai/tools/ros/utils.py @@ -35,8 +35,11 @@ def import_message_from_str(msg_type: str) -> Type[object]: return import_message_from_namespaced_type(msg_namespaced_type) -def convert_ros_img_to_ndarray(msg: sensor_msgs.msg.Image) -> np.ndarray: - encoding = msg.encoding.lower() +def convert_ros_img_to_ndarray( + msg: sensor_msgs.msg.Image, encoding: str = "" +) -> np.ndarray: + if encoding == "": + encoding = msg.encoding.lower() if encoding == "rgb8": image_data = np.frombuffer(msg.data, np.uint8) @@ -52,7 +55,7 @@ def convert_ros_img_to_ndarray(msg: sensor_msgs.msg.Image) -> np.ndarray: image_data = np.frombuffer(msg.data, np.uint16) image = image_data.reshape((msg.height, msg.width)) else: - raise ValueError(f"Unsupported encoding: {msg.encoding}") + raise ValueError(f"Unsupported encoding: {encoding}") return image diff --git a/src/rai_extensions/rai_grounding_dino/README.md b/src/rai_extensions/rai_grounding_dino/README.md index 899cab870..a1f48f1eb 100644 --- a/src/rai_extensions/rai_grounding_dino/README.md +++ b/src/rai_extensions/rai_grounding_dino/README.md @@ -58,6 +58,49 @@ ros2 launch rai_grounding_dino gdino_launch.xml [weights_path:=PATH/TO/WEIGHTS] > By default the weights will be downloaded to `$(ros2 pkg prefix rai_grounding_dino)/share/weights/`. > You can change this path if you downloaded the weights manually or moved them. +### RAI Tools + +This package provides the following tools: + +- `GetDetectionTool` + This tool calls the grounding dino service to use the model to see if the message from the provided camera topic contains objects from a comma separated prompt. + + **Example call** + + ``` + x = GetDetectionTool(node=RaiBaseNode(node_name="test_node"))._run( + camera_topic="/camera/camera/color/image_raw", + object_names=["chair", "human", "plushie", "box", "ball"], + ) + + ``` + + **Example output** + + ``` + I have detected the following items in the picture - chair, human + ``` + +- `GetDistanceToObjectsTool` + This tool calls the grounding dino service to use the model to see if the message from the provided camera topic contains objects from a comma separated prompt. Then it utilises messages from depth camera to create an estimation of distance to a detected object. + + **Example call** + + ``` + x = GetDistanceToObjectsTool(node=RaiBaseNode(node_name="test_node"))._run( + camera_topic="/camera/camera/color/image_raw", + depth_topic="/camera/camera/depth/image_rect_raw", + object_names=["chair", "human", "plushie", "box", "ball"], + ) + + ``` + + **Example output** + + ``` + I have detected the following items in the picture human: 1.68 m away, chair: 2.20 m away + ``` + ### Example An example client is provided with the package as `rai_grounding_dino/talker.py` diff --git a/src/rai_extensions/rai_grounding_dino/rai_grounding_dino/__init__.py b/src/rai_extensions/rai_grounding_dino/rai_grounding_dino/__init__.py index f138f42af..522b9d61b 100644 --- a/src/rai_extensions/rai_grounding_dino/rai_grounding_dino/__init__.py +++ b/src/rai_extensions/rai_grounding_dino/rai_grounding_dino/__init__.py @@ -12,3 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +from .grounding_dino import GDINO_NODE_NAME, GDINO_SERVICE_NAME +from .tools import GetDetectionTool, GetDistanceToObjectsTool + +__all__ = [ + "GetDistanceToObjectsTool", + "GetDetectionTool", + "GDINO_NODE_NAME", + "GDINO_SERVICE_NAME", +] diff --git a/src/rai_extensions/rai_grounding_dino/rai_grounding_dino/grounding_dino.py b/src/rai_extensions/rai_grounding_dino/rai_grounding_dino/grounding_dino.py index de32907a0..e610f6d0c 100644 --- a/src/rai_extensions/rai_grounding_dino/rai_grounding_dino/grounding_dino.py +++ b/src/rai_extensions/rai_grounding_dino/rai_grounding_dino/grounding_dino.py @@ -34,11 +34,15 @@ class GDRequest(TypedDict): source_img: Image +GDINO_NODE_NAME = "grounding_dino" +GDINO_SERVICE_NAME = "grounding_dino_classify" + + class GDinoService(Node): def __init__(self): - super().__init__(node_name="grounding_dino", parameter_overrides=[]) + super().__init__(node_name=GDINO_NODE_NAME, parameter_overrides=[]) self.srv = self.create_service( - RAIGroundingDino, "grounding_dino_classify", self.classify_callback + RAIGroundingDino, GDINO_SERVICE_NAME, self.classify_callback ) self.declare_parameter("weights_path", "") try: diff --git a/src/rai_extensions/rai_grounding_dino/rai_grounding_dino/tools.py b/src/rai_extensions/rai_grounding_dino/rai_grounding_dino/tools.py index 2cbf3cb7f..0f3697bed 100644 --- a/src/rai_extensions/rai_grounding_dino/rai_grounding_dino/tools.py +++ b/src/rai_extensions/rai_grounding_dino/rai_grounding_dino/tools.py @@ -19,6 +19,7 @@ import sensor_msgs.msg from langchain_core.pydantic_v1 import Field from pydantic import BaseModel +from rai_grounding_dino import GDINO_SERVICE_NAME from rclpy import Future from rclpy.exceptions import ( ParameterNotDeclaredException, @@ -98,7 +99,7 @@ def _spin(self, future: Future) -> Optional[RAIGroundingDino.Response]: def _call_gdino_node( self, camera_img_message: sensor_msgs.msg.Image, object_names: list[str] ) -> Future: - cli = self.node.create_client(RAIGroundingDino, "grounding_dino_classify") + cli = self.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME) while not cli.wait_for_service(timeout_sec=1.0): self.node.get_logger().info("service not available, waiting again...") req = RAIGroundingDino.Request() @@ -156,7 +157,7 @@ def _run( resolved = self._spin(future) if resolved is not None: detected = self._parse_detection_array(resolved) - names = [det.class_name for det in detected] + names = ", ".join([det.class_name for det in detected]) return f"I have detected the following items in the picture {names}" return "Failed to get detection" @@ -241,7 +242,7 @@ def _run( measurements = self._get_distance_from_detections( depth_img_msg, detected, threshold, conversion_ratio ) - measurement_string = " ".join( + measurement_string = ", ".join( [ f"{measurement[0]}: {measurement[1]:.2f}m away" for measurement in measurements