Skip to content

Commit

Permalink
add docs, define API
Browse files Browse the repository at this point in the history
  • Loading branch information
rachwalk committed Sep 24, 2024
1 parent a63b312 commit 37d715d
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 8 deletions.
9 changes: 6 additions & 3 deletions src/rai/rai/tools/ros/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ def import_message_from_str(msg_type: str) -> Type[object]:
return import_message_from_namespaced_type(msg_namespaced_type)


def convert_ros_img_to_ndarray(msg: sensor_msgs.msg.Image) -> np.ndarray:
encoding = msg.encoding.lower()
def convert_ros_img_to_ndarray(
msg: sensor_msgs.msg.Image, encoding: str = ""
) -> np.ndarray:
if encoding == "":
encoding = msg.encoding.lower()

if encoding == "rgb8":
image_data = np.frombuffer(msg.data, np.uint8)
Expand All @@ -52,7 +55,7 @@ def convert_ros_img_to_ndarray(msg: sensor_msgs.msg.Image) -> np.ndarray:
image_data = np.frombuffer(msg.data, np.uint16)
image = image_data.reshape((msg.height, msg.width))
else:
raise ValueError(f"Unsupported encoding: {msg.encoding}")
raise ValueError(f"Unsupported encoding: {encoding}")

return image

Expand Down
43 changes: 43 additions & 0 deletions src/rai_extensions/rai_grounding_dino/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,49 @@ ros2 launch rai_grounding_dino gdino_launch.xml [weights_path:=PATH/TO/WEIGHTS]
> By default the weights will be downloaded to `$(ros2 pkg prefix rai_grounding_dino)/share/weights/`.
> You can change this path if you downloaded the weights manually or moved them.
### RAI Tools

This package provides the following tools:

- `GetDetectionTool`
This tool calls the grounding dino service to use the model to see if the message from the provided camera topic contains objects from a comma separated prompt.

**Example call**

```
x = GetDetectionTool(node=RaiBaseNode(node_name="test_node"))._run(
camera_topic="/camera/camera/color/image_raw",
object_names=["chair", "human", "plushie", "box", "ball"],
)
```

**Example output**

```
I have detected the following items in the picture - chair, human
```

- `GetDistanceToObjectsTool`
This tool calls the grounding dino service to use the model to see if the message from the provided camera topic contains objects from a comma separated prompt. Then it utilises messages from depth camera to create an estimation of distance to a detected object.

**Example call**

```
x = GetDistanceToObjectsTool(node=RaiBaseNode(node_name="test_node"))._run(
camera_topic="/camera/camera/color/image_raw",
depth_topic="/camera/camera/depth/image_rect_raw",
object_names=["chair", "human", "plushie", "box", "ball"],
)
```

**Example output**

```
I have detected the following items in the picture human: 1.68 m away, chair: 2.20 m away
```

### Example

An example client is provided with the package as `rai_grounding_dino/talker.py`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from .grounding_dino import GDINO_NODE_NAME, GDINO_SERVICE_NAME
from .tools import GetDetectionTool, GetDistanceToObjectsTool

__all__ = [
"GetDistanceToObjectsTool",
"GetDetectionTool",
"GDINO_NODE_NAME",
"GDINO_SERVICE_NAME",
]
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@ class GDRequest(TypedDict):
source_img: Image


GDINO_NODE_NAME = "grounding_dino"
GDINO_SERVICE_NAME = "grounding_dino_classify"


class GDinoService(Node):
def __init__(self):
super().__init__(node_name="grounding_dino", parameter_overrides=[])
super().__init__(node_name=GDINO_NODE_NAME, parameter_overrides=[])
self.srv = self.create_service(
RAIGroundingDino, "grounding_dino_classify", self.classify_callback
RAIGroundingDino, GDINO_SERVICE_NAME, self.classify_callback
)
self.declare_parameter("weights_path", "")
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sensor_msgs.msg
from langchain_core.pydantic_v1 import Field
from pydantic import BaseModel
from rai_grounding_dino import GDINO_SERVICE_NAME
from rclpy import Future
from rclpy.exceptions import (
ParameterNotDeclaredException,
Expand Down Expand Up @@ -98,7 +99,7 @@ def _spin(self, future: Future) -> Optional[RAIGroundingDino.Response]:
def _call_gdino_node(
self, camera_img_message: sensor_msgs.msg.Image, object_names: list[str]
) -> Future:
cli = self.node.create_client(RAIGroundingDino, "grounding_dino_classify")
cli = self.node.create_client(RAIGroundingDino, GDINO_SERVICE_NAME)
while not cli.wait_for_service(timeout_sec=1.0):
self.node.get_logger().info("service not available, waiting again...")
req = RAIGroundingDino.Request()
Expand Down Expand Up @@ -156,7 +157,7 @@ def _run(
resolved = self._spin(future)
if resolved is not None:
detected = self._parse_detection_array(resolved)
names = [det.class_name for det in detected]
names = ", ".join([det.class_name for det in detected])
return f"I have detected the following items in the picture {names}"

return "Failed to get detection"
Expand Down Expand Up @@ -241,7 +242,7 @@ def _run(
measurements = self._get_distance_from_detections(
depth_img_msg, detected, threshold, conversion_ratio
)
measurement_string = " ".join(
measurement_string = ", ".join(
[
f"{measurement[0]}: {measurement[1]:.2f}m away"
for measurement in measurements
Expand Down

0 comments on commit 37d715d

Please sign in to comment.