Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: make HRIMessage use PIL.Image and pydub.AudioSegment #401

Merged
merged 2 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ openwakeword = { git = "https://github.com/maciejmajek/openWakeWord.git", branch
pytest-timeout = "^2.3.1"
tomli-w = "^1.1.0"
faster-whisper = "^1.1.1"
pydub = "^0.25.1"
[tool.poetry.group.dev.dependencies]
ipykernel = "^6.29.4"

Expand Down
72 changes: 52 additions & 20 deletions src/rai/rai/communication/hri_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import (
Annotated,
Generic,
List,
Literal,
Optional,
Sequence,
TypeVar,
get_args,
)
import base64
from dataclasses import dataclass, field
from io import BytesIO
from typing import Generic, Literal, Sequence, TypeVar, get_args

from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage as LangchainBaseMessage
from langchain_core.messages import HumanMessage
from pydantic import BaseModel
from PIL import Image
from PIL.Image import Image as ImageType
from pydub import AudioSegment

from rai.messages import AiMultimodalMessage, HumanMultimodalMessage
from rai.messages.multimodal import MultimodalMessage as RAIMultimodalMessage
Expand All @@ -39,10 +35,19 @@ def __init__(self, msg):
super().__init__(msg)


class HRIPayload(BaseModel):
@dataclass
class HRIPayload:
text: str
images: Optional[Annotated[List[str], "base64 encoded png images"]] = None
audios: Optional[Annotated[List[str], "base64 encoded wav audio"]] = None
images: list[ImageType] = field(default_factory=list)
audios: list[AudioSegment] = field(default_factory=list)

def __post_init__(self):
if not isinstance(self.text, str):
raise TypeError(f"Text should be of type str, got {type(self.text)}")
if not isinstance(self.images, list):
raise TypeError(f"Images should be of type list, got {type(self.images)}")
if not isinstance(self.audios, list):
raise TypeError(f"Audios should be of type list, got {type(self.audios)}")


class HRIMessage(BaseMessage):
Expand All @@ -60,19 +65,42 @@ def __init__(
def __repr__(self):
return f"HRIMessage(type={self.message_author}, text={self.text}, images={self.images}, audios={self.audios})"

def _image_to_base64(self, image: ImageType) -> str:
buffered = BytesIO()
image.save(buffered, "PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")

def _audio_to_base64(self, audio: AudioSegment) -> str:
buffered = BytesIO()
audio.export(buffered, format="wav")
return base64.b64encode(buffered.getvalue()).decode("utf-8")

@classmethod
def _base64_to_image(cls, base64_str: str) -> ImageType:
img_data = base64.b64decode(base64_str)
return Image.open(BytesIO(img_data))

@classmethod
def _base64_to_audio(cls, base64_str: str) -> AudioSegment:
audio_data = base64.b64decode(base64_str)
return AudioSegment.from_file(BytesIO(audio_data), format="wav")

def to_langchain(self) -> LangchainBaseMessage:
base64_images = [self._image_to_base64(image) for image in self.images]
base64_audios = [self._audio_to_base64(audio) for audio in self.audios]
match self.message_author:
case "human":
if self.images is None and self.audios is None:
if self.images == [] and self.audios == []:
return HumanMessage(content=self.text)

return HumanMultimodalMessage(
content=self.text, images=self.images, audios=self.audios
content=self.text, images=base64_images, audios=base64_audios
)
case "ai":
if self.images is None and self.audios is None:
if self.images == [] and self.audios == []:
return AIMessage(content=self.text)
return AiMultimodalMessage(
content=self.text, images=self.images, audios=self.audios
content=self.text, images=base64_images, audios=base64_images
)
case _:
raise ValueError(
Expand All @@ -97,8 +125,12 @@ def from_langchain(
return cls(
payload=HRIPayload(
text=text,
images=images,
audios=audios,
images=(
[cls._base64_to_image(image) for image in images] if images else []
),
audios=(
[cls._base64_to_audio(audio) for audio in audios] if audios else []
),
),
message_author=message.type, # type: ignore
)
Expand Down
36 changes: 25 additions & 11 deletions tests/communication/test_hri_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,53 @@
import pytest
from langchain_core.messages import BaseMessage as LangchainBaseMessage
from langchain_core.messages import HumanMessage
from PIL import Image
from pydub import AudioSegment

from rai.communication import HRIMessage, HRIPayload
from rai.messages.multimodal import MultimodalMessage as RAIMultimodalMessage


def test_initialization():
payload = HRIPayload(text="Hello", images=["image1"], audios=["audio1"])
@pytest.fixture
def image():
img = Image.new("RGB", (100, 100), color="red")
return img


@pytest.fixture
def audio():
audio = AudioSegment.silent(duration=1000)
return audio


def test_initialization(image, audio):
payload = HRIPayload(text="Hello", images=[image], audios=[audio])
message = HRIMessage(payload=payload, message_author="human")

assert message.text == "Hello"
assert message.images == ["image1"]
assert message.audios == ["audio1"]
assert message.images == [image]
assert message.audios == [audio]
assert message.message_author == "human"


def test_repr():
payload = HRIPayload(text="Hello", images=None, audios=None)
payload = HRIPayload(text="Hello")
message = HRIMessage(payload=payload, message_author="ai")

assert repr(message) == "HRIMessage(type=ai, text=Hello, images=None, audios=None)"
assert repr(message) == "HRIMessage(type=ai, text=Hello, images=[], audios=[])"


def test_to_langchain_human():
payload = HRIPayload(text="Hi there", images=None, audios=None)
payload = HRIPayload(text="Hi there", images=[], audios=[])
message = HRIMessage(payload=payload, message_author="human")
langchain_message = message.to_langchain()

assert isinstance(langchain_message, HumanMessage)
assert langchain_message.content == "Hi there"


def test_to_langchain_ai_multimodal():
payload = HRIPayload(text="Response", images=["img"], audios=["audio"])
def test_to_langchain_ai_multimodal(image, audio):
payload = HRIPayload(text="Response", images=[image], audios=[audio])
message = HRIMessage(payload=payload, message_author="ai")

with pytest.raises(
Expand All @@ -67,8 +81,8 @@ def test_from_langchain_human():
hri_message = HRIMessage.from_langchain(langchain_message)

assert hri_message.text == "Hello"
assert hri_message.images is None
assert hri_message.audios is None
assert hri_message.images == []
assert hri_message.audios == []
assert hri_message.message_author == "human"


Expand Down