Skip to content

Commit

Permalink
chore: pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejmajek committed Feb 6, 2025
1 parent 4d803da commit 843fd9f
Show file tree
Hide file tree
Showing 66 changed files with 164 additions and 195 deletions.
12 changes: 6 additions & 6 deletions examples/agriculture-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@
import argparse

import rclpy
from rclpy.action import ActionClient
from rclpy.callback_groups import ReentrantCallbackGroup
from rclpy.executors import MultiThreadedExecutor
from rclpy.node import Node
from std_srvs.srv import Trigger

from rai.node import RaiStateBasedLlmNode, describe_ros_image
from rai.tools.ros.native import (
GetCameraImage,
Expand All @@ -30,6 +24,12 @@
Ros2ShowMsgInterfaceTool,
)
from rai.tools.time import WaitForSecondsTool
from rclpy.action import ActionClient
from rclpy.callback_groups import ReentrantCallbackGroup
from rclpy.executors import MultiThreadedExecutor
from rclpy.node import Node
from std_srvs.srv import Trigger

from rai_interfaces.action import Task


Expand Down
1 change: 0 additions & 1 deletion examples/manipulation-demo-streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import streamlit as st
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage

from rai.agents.integrations.streamlit import get_streamlit_cb, streamlit_invoke
from rai.messages import HumanMultimodalMessage

Expand Down
1 change: 0 additions & 1 deletion examples/manipulation-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import rclpy
import rclpy.qos
from langchain_core.messages import HumanMessage

from rai.agents.conversational_agent import create_conversational_agent
from rai.node import RaiBaseNode
from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool
Expand Down
3 changes: 1 addition & 2 deletions examples/rosbot-xl-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import rclpy
import rclpy.executors
import rclpy.logging
from rai_open_set_vision.tools import GetDetectionTool, GetDistanceToObjectsTool

from rai.node import RaiStateBasedLlmNode
from rai.tools.ros.native import (
GetMsgFromTopic,
Expand All @@ -35,6 +33,7 @@
Ros2RunActionAsync,
)
from rai.tools.time import WaitForSecondsTool
from rai_open_set_vision.tools import GetDetectionTool, GetDistanceToObjectsTool

p = argparse.ArgumentParser()
p.add_argument("--allowlist", type=Path, required=False, default=None)
Expand Down
4 changes: 2 additions & 2 deletions examples/taxi-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.tools import tool
from std_msgs.msg import String

from rai.agents.conversational_agent import create_conversational_agent
from rai.tools.ros.cli import Ros2ServiceTool
from rai.tools.ros.native import Ros2PubMessageTool
from rai.utils.model_initialization import get_llm_model, get_tracing_callbacks
from std_msgs.msg import String

from rai_hmi.api import GenericVoiceNode, split_message

system_prompt = """
Expand Down
1 change: 0 additions & 1 deletion src/examples/turtlebot4/turtlebot_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import rclpy.qos
import rclpy.subscription
import rclpy.task

from rai.node import RaiStateBasedLlmNode
from rai.tools.ros.native import (
GetCameraImage,
Expand Down
2 changes: 1 addition & 1 deletion src/rai_asr/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ faster-whisper = "^1.1.1"
pydub = "^0.25.1"

[tool.isort]
profile = "black"
profile = "black"
6 changes: 3 additions & 3 deletions src/rai_asr/rai_asr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from rai_asr.models.silero_vad import SileroVAD

__all__ = [
"BaseVoiceDetectionModel",
"SileroVAD",
"OpenWakeWord",
"BaseTranscriptionModel",
"BaseVoiceDetectionModel",
"LocalWhisper",
"OpenAIWhisper",
"OpenWakeWord",
"SileroVAD",
]
1 change: 0 additions & 1 deletion src/rai_asr/rai_asr/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@


class BaseVoiceDetectionModel(ABC):

def __call__(
self, audio_data: NDArray, input_parameters: dict[str, Any]
) -> Tuple[bool, dict[str, Any]]:
Expand Down
2 changes: 1 addition & 1 deletion src/rai_core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ tomli = "^2.0.1"
tomli-w = "^1.1.0"

[tool.isort]
profile = "black"
profile = "black"
2 changes: 1 addition & 1 deletion src/rai_core/rai/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

__all__ = [
"ToolRunner",
"VoiceRecognitionAgent",
"create_conversational_agent",
"create_state_based_agent",
"VoiceRecognitionAgent",
]
2 changes: 1 addition & 1 deletion src/rai_core/rai/agents/integrations/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def on_tool_end(self, output: Any, **kwargs: Any) -> Any:

# Decorator function to add the Streamlit execution context to a function
def add_streamlit_context(
fn: Callable[..., fn_return_type]
fn: Callable[..., fn_return_type],
) -> Callable[..., fn_return_type]:
"""
Decorator to ensure that the decorated function runs within the Streamlit execution context.
Expand Down
4 changes: 3 additions & 1 deletion src/rai_core/rai/agents/voice_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def should_record(
def transcription_thread(self, identifier: str):
self.logger.info(f"transcription thread {identifier} started")
audio_data = np.concatenate(self.transcription_buffers[identifier])
with self.transcription_lock: # this is only necessary for the local model... TODO: fix this somehow
with (
self.transcription_lock
): # this is only necessary for the local model... TODO: fix this somehow
transcription = self.transcription_model.transcribe(audio_data)
assert isinstance(self.connectors["ros2"], ROS2ARIConnector)
self.connectors["ros2"].send_message(
Expand Down
1 change: 0 additions & 1 deletion src/rai_core/rai/apps/state_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def robot_state_analyzer(
state: str,
state_analyzer_prompt: str = STATE_ANALYZER_PROMPT,
) -> State:

template = ChatPromptTemplate.from_messages(
[
("system", state_analyzer_prompt),
Expand Down
4 changes: 3 additions & 1 deletion src/rai_core/rai/apps/talk_to_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def talk_to_docs(documentation_root: str, llm: BaseChatModel):

agent = create_tool_calling_agent(llm, [query_docs], prompt) # type: ignore
agent_executor = AgentExecutor(
agent=agent, tools=[query_docs], return_intermediate_steps=True # type: ignore
agent=agent,
tools=[query_docs],
return_intermediate_steps=True, # type: ignore
)

def input_node(state: State) -> State:
Expand Down
2 changes: 1 addition & 1 deletion src/rai_core/rai/communication/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
__all__ = [
"ARIConnector",
"ARIMessage",
"BaseMessage",
"BaseConnector",
"BaseMessage",
"HRIConnector",
"HRIMessage",
"HRIPayload",
Expand Down
1 change: 0 additions & 1 deletion src/rai_core/rai/communication/base_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(


class BaseConnector(Generic[T]):

def _generate_handle(self) -> str:
return str(uuid4())

Expand Down
4 changes: 1 addition & 3 deletions src/rai_core/rai/communication/hri_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
from io import BytesIO
from typing import Any, Dict, Generic, Literal, Optional, Sequence, TypeVar, get_args

from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import BaseMessage as LangchainBaseMessage
from langchain_core.messages import HumanMessage
from PIL import Image
from PIL.Image import Image as ImageType
from pydub import AudioSegment
Expand Down Expand Up @@ -166,7 +165,6 @@ def _build_message(
self,
message: LangchainBaseMessage | RAIMultimodalMessage,
) -> T:

return self.T_class.from_langchain(message)

def send_all_targets(self, message: LangchainBaseMessage | RAIMultimodalMessage):
Expand Down
3 changes: 1 addition & 2 deletions src/rai_core/rai/communication/ros2/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,6 @@ def __post_init__(self):


class ConfigurableROS2TopicAPI(ROS2TopicAPI):

def __init__(self, node: rclpy.node.Node):
super().__init__(node)
self._subscribtions: dict[str, rclpy.node.Subscription] = {}
Expand Down Expand Up @@ -562,7 +561,7 @@ def is_goal_done(self, handle: str) -> bool:
raise ValueError(f"Invalid action handle: {handle}")
if self.actions[handle]["result_future"] is None:
raise ValueError(
f"Result future is None for handle: {handle}. " "Was the goal accepted?"
f"Result future is None for handle: {handle}. Was the goal accepted?"
)
return self.actions[handle]["result_future"].done()

Expand Down
1 change: 0 additions & 1 deletion src/rai_core/rai/communication/ros2/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def send_message(
qos_profile: Optional[QoSProfile] = None,
**kwargs: Any,
):

self._topic_api.publish(
topic=target,
msg_content=message.payload,
Expand Down
2 changes: 1 addition & 1 deletion src/rai_core/rai/communication/sound_device/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@
__all__ = [
"SoundDeviceAPI",
"SoundDeviceConfig",
"SoundDeviceError",
"SoundDeviceConnector",
"SoundDeviceError",
]
1 change: 0 additions & 1 deletion src/rai_core/rai/communication/sound_device/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __post_init__(self):


class SoundDeviceAPI:

def __init__(self, config: SoundDeviceConfig):
self.device_name = ""

Expand Down
4 changes: 2 additions & 2 deletions src/rai_core/rai/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from .utils import preprocess_image

__all__ = [
"HumanMultimodalMessage",
"AiMultimodalMessage",
"HumanMultimodalMessage",
"MultimodalArtifact",
"SystemMultimodalMessage",
"ToolMultimodalMessage",
"MultimodalArtifact",
"preprocess_image",
]
2 changes: 1 addition & 1 deletion src/rai_core/rai/messages/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __repr_args__(self) -> Any:
v = [c for c in v if c["type"] != "image_url"]
elif k == "images":
imgs_summary = [image[0:10] + "..." for image in v]
v = f'{len(v)} base64 encoded images: [{", ".join(imgs_summary)}]'
v = f"{len(v)} base64 encoded images: [{', '.join(imgs_summary)}]"
new_args.append((k, v))
return new_args

Expand Down
1 change: 0 additions & 1 deletion src/rai_core/rai/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,6 @@ async def agent_loop(self, goal_handle: ServerGoalHandle):
"callbacks": get_tracing_callbacks(),
},
):

graph_node_name = list(state.keys())[0]
if graph_node_name == "reporter":
continue
Expand Down
12 changes: 6 additions & 6 deletions src/rai_core/rai/tools/ros/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@
)

__all__ = [
"AddDescribedWaypointToDatabaseTool",
"GetCurrentPositionTool",
"GetOccupancyGridTool",
"Ros2BaseInput",
"Ros2BaseTool",
"ros2_action",
"ros2_interface",
"ros2_node",
"ros2_topic",
"ros2_param",
"ros2_service",
"Ros2BaseTool",
"Ros2BaseInput",
"AddDescribedWaypointToDatabaseTool",
"GetOccupancyGridTool",
"GetCurrentPositionTool",
"ros2_topic",
]
4 changes: 3 additions & 1 deletion src/rai_core/rai/tools/ros/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def __init__(

def postprocess(self, msg: Image) -> str:
bridge = CvBridge()
cv_image = cast(cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")) # type: ignore
cv_image = cast(
cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")
) # type: ignore
if cv_image.shape[-1] == 4:
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGRA2RGB)
base64_image = base64.b64encode(
Expand Down
8 changes: 2 additions & 6 deletions src/rai_core/rai/tools/ros/native_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ class Ros2BaseActionTool(Ros2BaseTool):

class Ros2RunActionSync(Ros2BaseTool):
name: str = "Ros2RunAction"
description: str = (
"A tool for running a ros2 action. Make sure you know the action interface first!!! Actions might take some time to execute and are blocking - you will not be able to check their feedback, only will be informed about the result"
)
description: str = "A tool for running a ros2 action. Make sure you know the action interface first!!! Actions might take some time to execute and are blocking - you will not be able to check their feedback, only will be informed about the result"

args_schema: Type[Ros2ActionRunnerInput] = Ros2ActionRunnerInput

Expand Down Expand Up @@ -176,9 +174,7 @@ def _run(self) -> bool:

class Ros2GetLastActionFeedback(Ros2BaseActionTool):
name: str = "Ros2GetLastActionFeedback"
description: str = (
"Action feedback is an optional intermediate information from ros2 action. With this tool you can get the last feedback of running action."
)
description: str = "Action feedback is an optional intermediate information from ros2 action. With this tool you can get the last feedback of running action."

args_schema: Type[Ros2BaseInput] = Ros2BaseInput

Expand Down
4 changes: 1 addition & 3 deletions src/rai_core/rai/tools/ros/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ class GetOccupancyGridTool(BaseTool):
"""Get the current map as an image with the robot's position marked on it (red dot)."""

name: str = "GetOccupancyGridTool"
description: str = (
"A tool for getting the current map as an image with the robot's position marked on it."
)
description: str = "A tool for getting the current map as an image with the robot's position marked on it."

args_schema: Type[TopicInput] = TopicInput

Expand Down
10 changes: 5 additions & 5 deletions src/rai_core/rai/tools/ros2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
)

__all__ = [
"StartROS2ActionTool",
"GetROS2ImageTool",
"PublishROS2MessageTool",
"ReceiveROS2MessageTool",
"CallROS2ServiceTool",
"CancelROS2ActionTool",
"GetROS2TopicsNamesAndTypesTool",
"GetROS2ImageTool",
"GetROS2MessageInterfaceTool",
"GetROS2TopicsNamesAndTypesTool",
"GetROS2TransformTool",
"PublishROS2MessageTool",
"ReceiveROS2MessageTool",
"StartROS2ActionTool",
]
4 changes: 3 additions & 1 deletion src/rai_core/rai/tools/ros2/topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def _run(self, topic: str) -> Tuple[str, MultimodalArtifact]:
raise ValueError(
f"Unsupported message type: {message.metadata['msg_type']}"
)
return "Image received successfully", MultimodalArtifact(images=[preprocess_image(image)]) # type: ignore
return "Image received successfully", MultimodalArtifact(
images=[preprocess_image(image)]
) # type: ignore


class GetROS2TopicsNamesAndTypesTool(BaseTool):
Expand Down
4 changes: 3 additions & 1 deletion src/rai_core/rai/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def __init__(

def postprocess(self, msg: Image) -> str:
bridge = CvBridge()
cv_image = cast(cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")) # type: ignore
cv_image = cast(
cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")
) # type: ignore
if cv_image.shape[-1] == 4:
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGRA2RGB)
base64_image = base64.b64encode(
Expand Down
Loading

0 comments on commit 843fd9f

Please sign in to comment.