Skip to content

Commit

Permalink
feat: add ruff as formatting tool, remove unnecessary ones
Browse files Browse the repository at this point in the history
  • Loading branch information
rachwalk committed Feb 4, 2025
1 parent 6bae47f commit 95ec6fc
Show file tree
Hide file tree
Showing 41 changed files with 151 additions and 169 deletions.
29 changes: 8 additions & 21 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,12 @@ repos:
hooks:
- id: shellcheck

- repo: https://github.com/pycqa/autoflake
rev: v2.3.1
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.9.4
hooks:
- id: autoflake
args: ["--remove-all-unused-imports", "--in-place"]

- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]

- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black

- repo: https://github.com/pycqa/flake8
rev: 7.1.0
hooks:
- id: flake8
args: ["--ignore=E501,E731,W503,W504,E203"]
# Run the linter.
- id: ruff
args: [--extend-select, "I,RUF022", --fix, --ignore, E731]
# Run the formatter.
- id: ruff-format
2 changes: 1 addition & 1 deletion src/rai/rai/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

__all__ = [
"ToolRunner",
"VoiceRecognitionAgent",
"create_conversational_agent",
"create_state_based_agent",
"VoiceRecognitionAgent",
]
2 changes: 1 addition & 1 deletion src/rai/rai/agents/integrations/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def on_tool_end(self, output: Any, **kwargs: Any) -> Any:

# Decorator function to add the Streamlit execution context to a function
def add_streamlit_context(
fn: Callable[..., fn_return_type]
fn: Callable[..., fn_return_type],
) -> Callable[..., fn_return_type]:
"""
Decorator to ensure that the decorated function runs within the Streamlit execution context.
Expand Down
4 changes: 3 additions & 1 deletion src/rai/rai/agents/voice_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def should_record(
def transcription_thread(self, identifier: str):
self.logger.info(f"transcription thread {identifier} started")
audio_data = np.concatenate(self.transcription_buffers[identifier])
with self.transcription_lock: # this is only necessary for the local model... TODO: fix this somehow
with (
self.transcription_lock
): # this is only necessary for the local model... TODO: fix this somehow
transcription = self.transcription_model.transcribe(audio_data)
assert isinstance(self.connectors["ros2"], ROS2ARIConnector)
self.connectors["ros2"].send_message(
Expand Down
1 change: 0 additions & 1 deletion src/rai/rai/apps/state_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def robot_state_analyzer(
state: str,
state_analyzer_prompt: str = STATE_ANALYZER_PROMPT,
) -> State:

template = ChatPromptTemplate.from_messages(
[
("system", state_analyzer_prompt),
Expand Down
4 changes: 3 additions & 1 deletion src/rai/rai/apps/talk_to_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def talk_to_docs(documentation_root: str, llm: BaseChatModel):

agent = create_tool_calling_agent(llm, [query_docs], prompt) # type: ignore
agent_executor = AgentExecutor(
agent=agent, tools=[query_docs], return_intermediate_steps=True # type: ignore
agent=agent,
tools=[query_docs],
return_intermediate_steps=True, # type: ignore
)

def input_node(state: State) -> State:
Expand Down
2 changes: 1 addition & 1 deletion src/rai/rai/communication/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
__all__ = [
"ARIConnector",
"ARIMessage",
"BaseMessage",
"BaseConnector",
"BaseMessage",
"HRIConnector",
"HRIMessage",
"HRIPayload",
Expand Down
1 change: 0 additions & 1 deletion src/rai/rai/communication/base_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __init__(


class BaseConnector(Generic[T]):

def _generate_handle(self) -> str:
return str(uuid4())

Expand Down
4 changes: 1 addition & 3 deletions src/rai/rai/communication/hri_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
from io import BytesIO
from typing import Generic, Literal, Sequence, TypeVar, get_args

from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import BaseMessage as LangchainBaseMessage
from langchain_core.messages import HumanMessage
from PIL import Image
from PIL.Image import Image as ImageType
from pydub import AudioSegment
Expand Down Expand Up @@ -164,7 +163,6 @@ def _build_message(
self,
message: LangchainBaseMessage | RAIMultimodalMessage,
) -> T:

return self.T_class.from_langchain(message)

def send_all_targets(self, message: LangchainBaseMessage | RAIMultimodalMessage):
Expand Down
2 changes: 1 addition & 1 deletion src/rai/rai/communication/ros2/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def is_goal_done(self, handle: str) -> bool:
raise ValueError(f"Invalid action handle: {handle}")
if self.actions[handle]["result_future"] is None:
raise ValueError(
f"Result future is None for handle: {handle}. " "Was the goal accepted?"
f"Result future is None for handle: {handle}. Was the goal accepted?"
)
return self.actions[handle]["result_future"].done()

Expand Down
1 change: 0 additions & 1 deletion src/rai/rai/communication/ros2/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def send_message(
qos_profile: Optional[QoSProfile] = None,
**kwargs: Any,
):

self._topic_api.publish(
topic=target,
msg_content=message.payload,
Expand Down
2 changes: 1 addition & 1 deletion src/rai/rai/communication/sound_device/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@
__all__ = [
"SoundDeviceAPI",
"SoundDeviceConfig",
"SoundDeviceError",
"SoundDeviceConnector",
"SoundDeviceError",
]
1 change: 0 additions & 1 deletion src/rai/rai/communication/sound_device/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __post_init__(self):


class SoundDeviceAPI:

def __init__(self, config: SoundDeviceConfig):
self.device_name = ""

Expand Down
4 changes: 2 additions & 2 deletions src/rai/rai/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from .utils import preprocess_image

__all__ = [
"HumanMultimodalMessage",
"AiMultimodalMessage",
"HumanMultimodalMessage",
"MultimodalArtifact",
"SystemMultimodalMessage",
"ToolMultimodalMessage",
"MultimodalArtifact",
"preprocess_image",
]
2 changes: 1 addition & 1 deletion src/rai/rai/messages/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __repr_args__(self) -> Any:
v = [c for c in v if c["type"] != "image_url"]
elif k == "images":
imgs_summary = [image[0:10] + "..." for image in v]
v = f'{len(v)} base64 encoded images: [{", ".join(imgs_summary)}]'
v = f"{len(v)} base64 encoded images: [{', '.join(imgs_summary)}]"
new_args.append((k, v))
return new_args

Expand Down
1 change: 0 additions & 1 deletion src/rai/rai/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,6 @@ async def agent_loop(self, goal_handle: ServerGoalHandle):
"callbacks": get_tracing_callbacks(),
},
):

graph_node_name = list(state.keys())[0]
if graph_node_name == "reporter":
continue
Expand Down
12 changes: 6 additions & 6 deletions src/rai/rai/tools/ros/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@
)

__all__ = [
"AddDescribedWaypointToDatabaseTool",
"GetCurrentPositionTool",
"GetOccupancyGridTool",
"Ros2BaseInput",
"Ros2BaseTool",
"ros2_action",
"ros2_interface",
"ros2_node",
"ros2_topic",
"ros2_param",
"ros2_service",
"Ros2BaseTool",
"Ros2BaseInput",
"AddDescribedWaypointToDatabaseTool",
"GetOccupancyGridTool",
"GetCurrentPositionTool",
"ros2_topic",
]
4 changes: 3 additions & 1 deletion src/rai/rai/tools/ros/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def __init__(

def postprocess(self, msg: Image) -> str:
bridge = CvBridge()
cv_image = cast(cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")) # type: ignore
cv_image = cast(
cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")
) # type: ignore
if cv_image.shape[-1] == 4:
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGRA2RGB)
base64_image = base64.b64encode(
Expand Down
8 changes: 2 additions & 6 deletions src/rai/rai/tools/ros/native_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ class Ros2BaseActionTool(Ros2BaseTool):

class Ros2RunActionSync(Ros2BaseTool):
name: str = "Ros2RunAction"
description: str = (
"A tool for running a ros2 action. Make sure you know the action interface first!!! Actions might take some time to execute and are blocking - you will not be able to check their feedback, only will be informed about the result"
)
description: str = "A tool for running a ros2 action. Make sure you know the action interface first!!! Actions might take some time to execute and are blocking - you will not be able to check their feedback, only will be informed about the result"

args_schema: Type[Ros2ActionRunnerInput] = Ros2ActionRunnerInput

Expand Down Expand Up @@ -176,9 +174,7 @@ def _run(self) -> bool:

class Ros2GetLastActionFeedback(Ros2BaseActionTool):
name: str = "Ros2GetLastActionFeedback"
description: str = (
"Action feedback is an optional intermediate information from ros2 action. With this tool you can get the last feedback of running action."
)
description: str = "Action feedback is an optional intermediate information from ros2 action. With this tool you can get the last feedback of running action."

args_schema: Type[Ros2BaseInput] = Ros2BaseInput

Expand Down
4 changes: 1 addition & 3 deletions src/rai/rai/tools/ros/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ class GetOccupancyGridTool(BaseTool):
"""Get the current map as an image with the robot's position marked on it (red dot)."""

name: str = "GetOccupancyGridTool"
description: str = (
"A tool for getting the current map as an image with the robot's position marked on it."
)
description: str = "A tool for getting the current map as an image with the robot's position marked on it."

args_schema: Type[TopicInput] = TopicInput

Expand Down
10 changes: 5 additions & 5 deletions src/rai/rai/tools/ros2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
)

__all__ = [
"StartROS2ActionTool",
"GetROS2ImageTool",
"PublishROS2MessageTool",
"ReceiveROS2MessageTool",
"CallROS2ServiceTool",
"CancelROS2ActionTool",
"GetROS2TopicsNamesAndTypesTool",
"GetROS2ImageTool",
"GetROS2MessageInterfaceTool",
"GetROS2TopicsNamesAndTypesTool",
"GetROS2TransformTool",
"PublishROS2MessageTool",
"ReceiveROS2MessageTool",
"StartROS2ActionTool",
]
4 changes: 3 additions & 1 deletion src/rai/rai/tools/ros2/topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def _run(self, topic: str) -> Tuple[str, MultimodalArtifact]:
raise ValueError(
f"Unsupported message type: {message.metadata['msg_type']}"
)
return "Image received successfully", MultimodalArtifact(images=[preprocess_image(image)]) # type: ignore
return "Image received successfully", MultimodalArtifact(
images=[preprocess_image(image)]
) # type: ignore


class GetROS2TopicsNamesAndTypesTool(BaseTool):
Expand Down
4 changes: 3 additions & 1 deletion src/rai/rai/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def __init__(

def postprocess(self, msg: Image) -> str:
bridge = CvBridge()
cv_image = cast(cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")) # type: ignore
cv_image = cast(
cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")
) # type: ignore
if cv_image.shape[-1] == 4:
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGRA2RGB)
base64_image = base64.b64encode(
Expand Down
60 changes: 30 additions & 30 deletions src/rai/rai/utils/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,19 +316,19 @@ def on_model_vendor_change(model_type: str):
)

def on_langfuse_change():
st.session_state.config["tracing"]["langfuse"][
"use_langfuse"
] = st.session_state.langfuse_checkbox
st.session_state.config["tracing"]["langfuse"]["use_langfuse"] = (
st.session_state.langfuse_checkbox
)

def on_langfuse_host_change():
st.session_state.config["tracing"]["langfuse"][
"host"
] = st.session_state.langfuse_host_input
st.session_state.config["tracing"]["langfuse"]["host"] = (
st.session_state.langfuse_host_input
)

def on_langsmith_change():
st.session_state.config["tracing"]["langsmith"][
"use_langsmith"
] = st.session_state.langsmith_checkbox
st.session_state.config["tracing"]["langsmith"]["use_langsmith"] = (
st.session_state.langsmith_checkbox
)

# Ensure tracing config exists
if "tracing" not in st.session_state.config:
Expand Down Expand Up @@ -397,9 +397,9 @@ def on_langsmith_change():
elif st.session_state.current_step == 4:

def on_recording_device_change():
st.session_state.config["asr"][
"recording_device_name"
] = st.session_state.recording_device_select
st.session_state.config["asr"]["recording_device_name"] = (
st.session_state.recording_device_select
)

def on_asr_vendor_change():
vendor = (
Expand All @@ -413,29 +413,29 @@ def on_language_change():
st.session_state.config["asr"]["language"] = st.session_state.language_input

def on_silence_grace_change():
st.session_state.config["asr"][
"silence_grace_period"
] = st.session_state.silence_grace_input
st.session_state.config["asr"]["silence_grace_period"] = (
st.session_state.silence_grace_input
)

def on_vad_threshold_change():
st.session_state.config["asr"][
"vad_threshold"
] = st.session_state.vad_threshold_input
st.session_state.config["asr"]["vad_threshold"] = (
st.session_state.vad_threshold_input
)

def on_wake_word_change():
st.session_state.config["asr"][
"use_wake_word"
] = st.session_state.wake_word_checkbox
st.session_state.config["asr"]["use_wake_word"] = (
st.session_state.wake_word_checkbox
)

def on_wake_word_model_change():
st.session_state.config["asr"][
"wake_word_model"
] = st.session_state.wake_word_model_input
st.session_state.config["asr"]["wake_word_model"] = (
st.session_state.wake_word_model_input
)

def on_wake_word_threshold_change():
st.session_state.config["asr"][
"wake_word_threshold"
] = st.session_state.wake_word_threshold_input
st.session_state.config["asr"]["wake_word_threshold"] = (
st.session_state.wake_word_threshold_input
)

# Ensure asr config exists
if "asr" not in st.session_state.config:
Expand Down Expand Up @@ -588,9 +588,9 @@ def on_tts_vendor_change():
st.session_state.config["tts"]["vendor"] = vendor

def on_keep_speaker_busy_change():
st.session_state.config["tts"][
"keep_speaker_busy"
] = st.session_state.keep_speaker_busy_checkbox
st.session_state.config["tts"]["keep_speaker_busy"] = (
st.session_state.keep_speaker_busy_checkbox
)

# Ensure tts config exists
if "tts" not in st.session_state.config:
Expand Down
4 changes: 1 addition & 3 deletions src/rai_asr/rai_asr/asr_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,7 @@ def initialize_sounddevice_stream(self):
sd.default.latency = ("low", "low")
self.device_sample_rate = sd.query_devices(
device=self.recording_device_number, kind="input"
)[
"default_samplerate"
] # type: ignore
)["default_samplerate"] # type: ignore
self.window_size_samples = int(
DEFAULT_BLOCKSIZE * self.device_sample_rate / VAD_SAMPLING_RATE
)
Expand Down
Loading

0 comments on commit 95ec6fc

Please sign in to comment.