From 5fb55ad2fbabebc5d9dd376e24347144c0cb5f64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Mon, 3 Feb 2025 16:49:23 +0100 Subject: [PATCH 1/2] refactor: make HRIMessage use PIL.Image and pydub.AudioSegment --- poetry.lock | 29 ++++++--- pyproject.toml | 1 + src/rai/rai/communication/hri_connector.py | 73 ++++++++++++++++------ tests/communication/test_hri_message.py | 36 +++++++---- 4 files changed, 99 insertions(+), 40 deletions(-) diff --git a/poetry.lock b/poetry.lock index e654d08d2..d6da53572 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2805,8 +2805,8 @@ langchain-core = ">=0.3.33,<0.4.0" langchain-text-splitters = ">=0.3.3,<0.4.0" langsmith = ">=0.1.17,<0.4" numpy = [ - {version = ">=1.22.4,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.2,<3", markers = "python_version >= \"3.12\""}, + {version = ">=1.22.4,<2", markers = "python_version < \"3.12\""}, ] pydantic = ">=2.7.4,<3.0.0" PyYAML = ">=5.3" @@ -2829,8 +2829,8 @@ files = [ boto3 = ">=1.35.74" langchain-core = ">=0.3.27,<0.4.0" numpy = [ - {version = ">=1,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.0,<3", markers = "python_version >= \"3.12\""}, + {version = ">=1,<2", markers = "python_version < \"3.12\""}, ] pydantic = ">=2,<3" @@ -2853,8 +2853,8 @@ langchain = ">=0.3.16,<0.4.0" langchain-core = ">=0.3.32,<0.4.0" langsmith = ">=0.1.125,<0.4" numpy = [ - {version = ">=1.22.4,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.2,<3", markers = "python_version >= \"3.12\""}, + {version = ">=1.22.4,<2", markers = "python_version < \"3.12\""}, ] pydantic-settings = ">=2.4.0,<3.0.0" PyYAML = ">=5.3" @@ -4106,10 +4106,10 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -4130,10 +4130,10 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -4341,9 +4341,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -5183,6 +5183,17 @@ numpy = ">=1.16.4" carto = ["pydeck-carto"] jupyter = ["ipykernel (>=5.1.2)", "ipython (>=5.8.0)", "ipywidgets (>=7,<8)", "traitlets (>=4.3.2)"] +[[package]] +name = "pydub" +version = "0.25.1" +description = "Manipulate audio with an simple and easy high level interface" +optional = false +python-versions = "*" +files = [ + {file = "pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6"}, + {file = "pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f"}, +] + [[package]] name = "pygments" version = "2.19.1" @@ -5689,8 +5700,8 @@ pandas = {version = "*", optional = true, markers = "extra == \"tune\""} prometheus-client = {version = ">=0.7.1", optional = true, markers = "extra == \"default\""} protobuf = ">=3.15.3,<3.19.5 || >3.19.5" py-spy = [ - {version = ">=0.2.0", optional = true, markers = "python_version < \"3.12\" and extra == \"default\""}, {version = ">=0.4.0", optional = true, markers = "python_version >= \"3.12\" and extra == \"default\""}, + {version = ">=0.2.0", optional = true, markers = "python_version < \"3.12\" and extra == \"default\""}, ] pyarrow = [ {version = ">=9.0.0,<18", optional = true, markers = "sys_platform == \"darwin\" and platform_machine == \"x86_64\" and extra == \"tune\""}, @@ -6801,8 +6812,8 @@ files = [ contourpy = {version = ">=1.0.7", markers = "python_version >= \"3.8\" and python_version < \"3.13\""} defusedxml = ">=0.7.1,<0.8.0" matplotlib = [ - {version = ">=3.6.0", markers = "python_version >= \"3.9\" and python_version < \"3.12\""}, {version = ">=3.7.3", markers = "python_version >= \"3.12\""}, + {version = ">=3.6.0", markers = "python_version >= \"3.9\" and python_version < \"3.12\""}, ] numpy = {version = ">=1.21.2", markers = "python_version < \"3.13\""} opencv-python = ">=4.5.5.64" @@ -8227,4 +8238,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10, <3.13" -content-hash = "da1a7720082bf43b4efc7cd972b63f39882fc7e0d69340bbc436f18e889e55b2" +content-hash = "158877f96c27f3c9beb75aff8a9608a41b83740afe17ceda2a471378c4bb3545" diff --git a/pyproject.toml b/pyproject.toml index 3534c688d..3224ea498 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ openwakeword = { git = "https://github.com/maciejmajek/openWakeWord.git", branch pytest-timeout = "^2.3.1" tomli-w = "^1.1.0" faster-whisper = "^1.1.1" +pydub = "^0.25.1" [tool.poetry.group.dev.dependencies] ipykernel = "^6.29.4" diff --git a/src/rai/rai/communication/hri_connector.py b/src/rai/rai/communication/hri_connector.py index 2f834c777..3cbcf1dab 100644 --- a/src/rai/rai/communication/hri_connector.py +++ b/src/rai/rai/communication/hri_connector.py @@ -12,21 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import ( - Annotated, - Generic, - List, - Literal, - Optional, - Sequence, - TypeVar, - get_args, -) +import base64 +from dataclasses import dataclass, field +from io import BytesIO +from typing import Generic, Literal, Sequence, TypeVar, get_args from langchain_core.messages import AIMessage from langchain_core.messages import BaseMessage as LangchainBaseMessage from langchain_core.messages import HumanMessage -from pydantic import BaseModel +from PIL import Image +from PIL.Image import Image as ImageType +from pydub import AudioSegment from rai.messages import AiMultimodalMessage, HumanMultimodalMessage from rai.messages.multimodal import MultimodalMessage as RAIMultimodalMessage @@ -39,10 +35,19 @@ def __init__(self, msg): super().__init__(msg) -class HRIPayload(BaseModel): +@dataclass +class HRIPayload: text: str - images: Optional[Annotated[List[str], "base64 encoded png images"]] = None - audios: Optional[Annotated[List[str], "base64 encoded wav audio"]] = None + images: list[ImageType] = field(default_factory=list) + audios: list[AudioSegment] = field(default_factory=list) + + def __post_init__(self): + if not isinstance(self.text, str): + raise TypeError(f"Text should be of type str, got {type(self.text)}") + if not isinstance(self.images, list): + raise TypeError(f"Images should be of type list, got {type(self.images)}") + if not isinstance(self.audios, list): + raise TypeError(f"Audios should be of type list, got {type(self.audios)}") class HRIMessage(BaseMessage): @@ -60,19 +65,43 @@ def __init__( def __repr__(self): return f"HRIMessage(type={self.message_author}, text={self.text}, images={self.images}, audios={self.audios})" + def _image_to_base64(self, image: ImageType) -> str: + buffered = BytesIO() + img_format = image.format if image.format else "PNG" + image.save(buffered, img_format) + return base64.b64encode(buffered.getvalue()).decode("utf-8") + + def _audio_to_base64(self, audio: AudioSegment) -> str: + buffered = BytesIO() + audio.export(buffered, format="wav") + return base64.b64encode(buffered.getvalue()).decode("utf-8") + + @classmethod + def _base64_to_image(cls, base64_str: str) -> ImageType: + img_data = base64.b64decode(base64_str) + return Image.open(BytesIO(img_data)) + + @classmethod + def _base64_to_audio(cls, base64_str: str) -> AudioSegment: + audio_data = base64.b64decode(base64_str) + return AudioSegment.from_file(BytesIO(audio_data), format="wav") + def to_langchain(self) -> LangchainBaseMessage: + base64_images = [self._image_to_base64(image) for image in self.images] + base64_audios = [self._audio_to_base64(audio) for audio in self.audios] match self.message_author: case "human": - if self.images is None and self.audios is None: + if self.images == [] and self.audios == []: return HumanMessage(content=self.text) + return HumanMultimodalMessage( - content=self.text, images=self.images, audios=self.audios + content=self.text, images=base64_images, audios=base64_audios ) case "ai": - if self.images is None and self.audios is None: + if self.images == [] and self.audios == []: return AIMessage(content=self.text) return AiMultimodalMessage( - content=self.text, images=self.images, audios=self.audios + content=self.text, images=base64_images, audios=base64_images ) case _: raise ValueError( @@ -97,8 +126,12 @@ def from_langchain( return cls( payload=HRIPayload( text=text, - images=images, - audios=audios, + images=( + [cls._base64_to_image(image) for image in images] if images else [] + ), + audios=( + [cls._base64_to_audio(audio) for audio in audios] if audios else [] + ), ), message_author=message.type, # type: ignore ) diff --git a/tests/communication/test_hri_message.py b/tests/communication/test_hri_message.py index f37b741e9..bba976da6 100644 --- a/tests/communication/test_hri_message.py +++ b/tests/communication/test_hri_message.py @@ -16,30 +16,44 @@ import pytest from langchain_core.messages import BaseMessage as LangchainBaseMessage from langchain_core.messages import HumanMessage +from PIL import Image +from pydub import AudioSegment from rai.communication import HRIMessage, HRIPayload from rai.messages.multimodal import MultimodalMessage as RAIMultimodalMessage -def test_initialization(): - payload = HRIPayload(text="Hello", images=["image1"], audios=["audio1"]) +@pytest.fixture +def image(): + img = Image.new("RGB", (100, 100), color="red") + return img + + +@pytest.fixture +def audio(): + audio = AudioSegment.silent(duration=1000) + return audio + + +def test_initialization(image, audio): + payload = HRIPayload(text="Hello", images=[image], audios=[audio]) message = HRIMessage(payload=payload, message_author="human") assert message.text == "Hello" - assert message.images == ["image1"] - assert message.audios == ["audio1"] + assert message.images == [image] + assert message.audios == [audio] assert message.message_author == "human" def test_repr(): - payload = HRIPayload(text="Hello", images=None, audios=None) + payload = HRIPayload(text="Hello") message = HRIMessage(payload=payload, message_author="ai") - assert repr(message) == "HRIMessage(type=ai, text=Hello, images=None, audios=None)" + assert repr(message) == "HRIMessage(type=ai, text=Hello, images=[], audios=[])" def test_to_langchain_human(): - payload = HRIPayload(text="Hi there", images=None, audios=None) + payload = HRIPayload(text="Hi there", images=[], audios=[]) message = HRIMessage(payload=payload, message_author="human") langchain_message = message.to_langchain() @@ -47,8 +61,8 @@ def test_to_langchain_human(): assert langchain_message.content == "Hi there" -def test_to_langchain_ai_multimodal(): - payload = HRIPayload(text="Response", images=["img"], audios=["audio"]) +def test_to_langchain_ai_multimodal(image, audio): + payload = HRIPayload(text="Response", images=[image], audios=[audio]) message = HRIMessage(payload=payload, message_author="ai") with pytest.raises( @@ -67,8 +81,8 @@ def test_from_langchain_human(): hri_message = HRIMessage.from_langchain(langchain_message) assert hri_message.text == "Hello" - assert hri_message.images is None - assert hri_message.audios is None + assert hri_message.images == [] + assert hri_message.audios == [] assert hri_message.message_author == "human" From b6f3516d4a2a206a35d86c6fe3d2db69ce45694d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Mon, 3 Feb 2025 17:02:40 +0100 Subject: [PATCH 2/2] feat: change encoding format to enforced PNG --- src/rai/rai/communication/hri_connector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/rai/rai/communication/hri_connector.py b/src/rai/rai/communication/hri_connector.py index 3cbcf1dab..07380f638 100644 --- a/src/rai/rai/communication/hri_connector.py +++ b/src/rai/rai/communication/hri_connector.py @@ -67,8 +67,7 @@ def __repr__(self): def _image_to_base64(self, image: ImageType) -> str: buffered = BytesIO() - img_format = image.format if image.format else "PNG" - image.save(buffered, img_format) + image.save(buffered, "PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def _audio_to_base64(self, audio: AudioSegment) -> str: