Skip to content

Commit

Permalink
test: add sound device connector tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rachwalk committed Jan 30, 2025
1 parent 6bbc11d commit 4018cec
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 15 deletions.
14 changes: 12 additions & 2 deletions src/rai/rai/communication/sound_device/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
# limitations under the License.


import base64
import io
from typing import Callable, Literal, Optional, Sequence, Tuple

from scipy.io import wavfile

try:
import sounddevice as sd
except ImportError as e:
Expand Down Expand Up @@ -90,7 +94,10 @@ def send_message(self, message: SoundDeviceMessage, target: str, **kwargs) -> No
)
else:
if message.audios is not None:
self.devices[target].write(message.audios[0])
wav_bytes = base64.b64decode(message.audios[0])
wav_buffer = io.BytesIO(wav_bytes)
_, audio_data = wavfile.read(wav_buffer)
self.devices[target].write(audio_data)
else:
raise SoundDeviceError("Failed to provice audios in message to play")

Expand Down Expand Up @@ -118,7 +125,10 @@ def service_call(
)
ret = SoundDeviceMessage(payload)
else:
self.devices[target].write(message.payload.audio, blocking=True)
if message.audios is not None:
self.devices[target].write(message.audios[0], blocking=True)
else:
raise SoundDeviceError("Failed to provice audios in message to play")
ret = SoundDeviceMessage()
return ret

Expand Down
45 changes: 32 additions & 13 deletions tests/communication/sounds_device/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
from unittest.mock import MagicMock, patch

import numpy as np
import pytest
import sounddevice

Expand All @@ -29,7 +29,7 @@
def mock_sound_device_api():
with patch("rai.communication.sound_device.api.SoundDeviceAPI") as mock:
mock_instance = MagicMock()
mock_instance.play = MagicMock()
mock_instance.write = MagicMock()
mock_instance.rec = MagicMock()
mock_instance.stop = MagicMock()
mock.return_value = mock_instance
Expand Down Expand Up @@ -58,21 +58,40 @@ def connector(mock_sound_device_api):
)
targets = [("speaker", config)]
sources = [("microphone", config)]
return SoundDeviceConnector(targets, sources)
ret = SoundDeviceConnector(targets, sources)
ret.devices["speaker"] = mock_sound_device_api
ret.devices["microphone"] = mock_sound_device_api

return ret

def test_send_message_play_audio(connector, mock_sound_device_api):

@pytest.fixture
def base64_audio():
# load audio file
audio_file = "tests/resources/sine_wave.wav"
with open(audio_file, "rb") as wav_file:
wav_bytes = wav_file.read()
base64_string = base64.b64encode(wav_bytes).decode(
"utf-8"
) # Encode and convert to string
return base64_string


def test_send_message_play_audio(connector, mock_sound_device_api, base64_audio):
message = SoundDeviceMessage(
payload=HRIPayload(text="", audios=[np.array([1, 2, 3])])
payload=HRIPayload(
text="",
audios=[base64_audio],
)
)
connector.send_message(message, "speaker")
connector.devices["speaker"].assert_called_once_with(b"test_audio")
connector.devices["speaker"].write.assert_called_once()


def test_send_message_stop_audio(connector, mock_sound_device_api):
message = SoundDeviceMessage(stop=True)
connector.send_message(message, "speaker")
connector.devices["speaker"].assert_called_once()
connector.devices["speaker"].stop.assert_called_once()


def test_send_message_read_error(connector):
Expand All @@ -84,19 +103,19 @@ def test_send_message_read_error(connector):
connector.send_message(message, "speaker")


def test_service_call_play_audio(connector, mock_sound_device_api):
message = SoundDeviceMessage(payload=HRIPayload(text="", audios=["test_audio"]))
def test_service_call_play_audio(connector, mock_sound_device_api, base64_audio):
message = SoundDeviceMessage(payload=HRIPayload(text="", audios=[base64_audio]))
result = connector.service_call(message, "speaker")
mock_sound_device_api.play.assert_called_once_with(b"test_audio", blocking=True)
mock_sound_device_api.write.assert_called_once()
assert isinstance(result, SoundDeviceMessage)


def test_service_call_read_audio(connector, mock_sound_device_api):
mock_sound_device_api.record.return_value = b"recorded_audio"
def test_service_call_read_audio(connector, mock_sound_device_api, base64_audio):
mock_sound_device_api.record.return_value = base64_audio
message = SoundDeviceMessage(read=True, duration=2.0)
result = connector.service_call(message, "microphone")
mock_sound_device_api.record.assert_called_once_with(2.0, blocking=True)
assert result.payload.audios == [b"recorded_audio"]
assert result.audios == [base64_audio]


def test_service_call_stop_error(connector):
Expand Down
Binary file added tests/resources/sine_wave.wav
Binary file not shown.

0 comments on commit 4018cec

Please sign in to comment.