Skip to content

Commit

Permalink
feat: allow input device with any sampling rate
Browse files Browse the repository at this point in the history
refactor: use stream callback, use SingleThreadedExecutor
feat: new parameters in launch files
  • Loading branch information
maciejmajek committed Sep 7, 2024
1 parent fe29ce8 commit da13b88
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 65 deletions.
34 changes: 28 additions & 6 deletions src/rai_asr/launch/local.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,28 @@ def generate_launch_description():
),
DeclareLaunchArgument(
"silence_grace_period",
default_value="2.0",
default_value="1.0",
description="Grace period in seconds after silence to stop recording",
),
DeclareLaunchArgument(
"sample_rate",
default_value="0",
description="Sample rate for audio capture (0 for auto-detect)",
"use_wake_word",
default_value="False",
description="Whether to use wake word detection",
),
DeclareLaunchArgument(
"wake_word_model",
default_value="",
description="Wake word model to use",
),
DeclareLaunchArgument(
"wake_word_threshold",
default_value="0.5",
description="Threshold for wake word detection",
),
DeclareLaunchArgument(
"vad_threshold",
default_value="0.5",
description="Threshold for voice activity detection",
),
Node(
package="rai_asr",
Expand All @@ -60,12 +75,19 @@ def generate_launch_description():
emulate_tty=True,
parameters=[
{
"recording_device": LaunchConfiguration("recording_device"),
"language": LaunchConfiguration("language"),
"model": LaunchConfiguration("model"),
"model_name": LaunchConfiguration("model_name"),
"model_vendor": LaunchConfiguration("model_vendor"),
"silence_grace_period": LaunchConfiguration(
"silence_grace_period"
),
"sample_rate": LaunchConfiguration("sample_rate"),
"use_wake_word": LaunchConfiguration("use_wake_word"),
"wake_word_model": LaunchConfiguration("wake_word_model"),
"wake_word_threshold": LaunchConfiguration(
"wake_word_threshold"
),
"vad_threshold": LaunchConfiguration("vad_threshold"),
}
],
),
Expand Down
34 changes: 28 additions & 6 deletions src/rai_asr/launch/openai.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,28 @@ def generate_launch_description():
),
DeclareLaunchArgument(
"silence_grace_period",
default_value="2.0",
default_value="1.0",
description="Grace period in seconds after silence to stop recording",
),
DeclareLaunchArgument(
"sample_rate",
default_value="0",
description="Sample rate for audio capture (0 for auto-detect)",
"use_wake_word",
default_value="False",
description="Whether to use wake word detection",
),
DeclareLaunchArgument(
"wake_word_model",
default_value="",
description="Wake word model to use",
),
DeclareLaunchArgument(
"wake_word_threshold",
default_value="0.5",
description="Threshold for wake word detection",
),
DeclareLaunchArgument(
"vad_threshold",
default_value="0.5",
description="Threshold for voice activity detection",
),
Node(
package="rai_asr",
Expand All @@ -60,12 +75,19 @@ def generate_launch_description():
emulate_tty=True,
parameters=[
{
"recording_device": LaunchConfiguration("recording_device"),
"language": LaunchConfiguration("language"),
"model": LaunchConfiguration("model"),
"model_name": LaunchConfiguration("model_name"),
"model_vendor": LaunchConfiguration("model_vendor"),
"silence_grace_period": LaunchConfiguration(
"silence_grace_period"
),
"sample_rate": LaunchConfiguration("sample_rate"),
"use_wake_word": LaunchConfiguration("use_wake_word"),
"wake_word_model": LaunchConfiguration("wake_word_model"),
"wake_word_threshold": LaunchConfiguration(
"wake_word_threshold"
),
"vad_threshold": LaunchConfiguration("vad_threshold"),
}
],
),
Expand Down
101 changes: 50 additions & 51 deletions src/rai_asr/rai_asr/asr_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
#

import os
import threading
import time
from datetime import datetime, timedelta
from typing import Literal, Optional, cast

import numpy as np
Expand All @@ -28,12 +26,13 @@
from openwakeword.utils import download_models
from rcl_interfaces.msg import ParameterDescriptor, ParameterType
from rclpy.callback_groups import ReentrantCallbackGroup
from rclpy.executors import MultiThreadedExecutor
from rclpy.executors import SingleThreadedExecutor
from rclpy.node import Node
from scipy.signal import resample
from std_msgs.msg import String

DEFAULT_SAMPLING_RATE = 16000
DEFAULT_BLOCKSIZE = 1280


class ASRNode(Node):
Expand All @@ -48,14 +47,17 @@ def __init__(self):
self.vad_model = self._initialize_vad_model()
self.oww_model = self._initialize_open_wake_word()

self.initialize_sounddevice_stream()

self.is_recording = False
self.audio_buffer = []
self.silence_start_time: Optional[datetime] = None
self.silence_start_time: Optional[float] = None
self.last_transcription_time = 0
self.hmi_lock = False
self.tts_lock = False

self.grace_period = timedelta(seconds=self.silence_grace_period)
self.current_chunk: Optional[NDArray[np.int16]] = None

self.transcription_recording_timeout = 1
self.get_logger().info("ASR Node has been initialized") # type: ignore

Expand Down Expand Up @@ -206,6 +208,7 @@ def _initialize_parameters(self):
self.get_logger().info("Parameters have been initialized") # type: ignore

def _setup_publishers_and_subscribers(self):

self.transcription_publisher = self.create_publisher(String, "/from_human", 10)
self.status_publisher = self.create_publisher(String, "/asr_status", 10)
self.tts_status_subscriber = self.create_subscription(
Expand Down Expand Up @@ -280,57 +283,57 @@ def int2float(sound: NDArray[np.int16]):

return False

def capture_sound(self):
device_sample_rate = sd.query_devices(
def sd_callback(self, indata, frames, _, status):
if status:
self.get_logger().warning(f"Stream status: {status}") # type: ignore
indata = indata.flatten()
sample_time_length = len(indata) / self.device_sample_rate
if self.device_sample_rate != DEFAULT_SAMPLING_RATE:
indata = resample(indata, int(sample_time_length * DEFAULT_SAMPLING_RATE))

asr_lock = (
time.time()
< self.last_transcription_time + self.transcription_recording_timeout
)
if asr_lock or self.hmi_lock or self.tts_lock:
return

if self.should_listen(indata):
self.silence_start_time = time.time()
if not self.is_recording:
self.start_recording()
self.audio_buffer.append(indata)
elif self.is_recording:
self.audio_buffer.append(indata)
if not isinstance(self.silence_start_time, float):
raise ValueError(
"Silence start time is not set, this should not happen"
)
if time.time() - self.silence_start_time > self.silence_grace_period:
self.stop_recording_and_transcribe()

def initialize_sounddevice_stream(self):
sd.default.latency = ("low", "low")
self.device_sample_rate = sd.query_devices(
device=self.recording_device_number, kind="input"
)[
"default_samplerate"
] # type: ignore
window_size_samples = int(1280 * device_sample_rate / DEFAULT_SAMPLING_RATE)
stream = sd.InputStream(
samplerate=device_sample_rate,
self.window_size_samples = int(
DEFAULT_BLOCKSIZE * self.device_sample_rate / DEFAULT_SAMPLING_RATE
)
self.stream = sd.InputStream(
samplerate=self.device_sample_rate,
channels=1,
device=self.recording_device_number,
dtype="int16",
blocksize=self.window_size_samples,
callback=self.sd_callback,
)
stream.start()
self.get_logger().info("Stream started. Waiting for speech...") # type: ignore
while True:
audio_data, _ = stream.read(window_size_samples)
audio_data = audio_data.flatten()

asr_lock = (
time.time()
< self.last_transcription_time + self.transcription_recording_timeout
)
if asr_lock or self.hmi_lock or self.tts_lock:
continue

sample_time_length = len(audio_data) / device_sample_rate
if device_sample_rate != DEFAULT_SAMPLING_RATE:
audio_data = resample(
audio_data, int(sample_time_length * DEFAULT_SAMPLING_RATE)
)

if self.should_listen(audio_data):
self.silence_start_time = datetime.now()
if not self.is_recording:
self.start_recording()
self.audio_buffer.append(audio_data)
elif self.is_recording:
self.audio_buffer.append(audio_data)
if not isinstance(self.silence_start_time, datetime):
raise ValueError(
"Silence start time is not set, this should not happen"
)

if datetime.now() - self.silence_start_time > timedelta(
seconds=self.silence_grace_period
):
self.stop_recording_and_transcribe()
self.stream.start()

def reset_buffer(self):
self.audio_buffer = []
self.audio_buffer.clear()

def start_recording(self):
self.get_logger().info("Recording...") # type: ignore
Expand Down Expand Up @@ -376,13 +379,9 @@ def publish_status(
def main(args=None):
rclpy.init(args=args)
node = ASRNode()

executor = MultiThreadedExecutor()
executor = SingleThreadedExecutor()
executor.add_node(node)

thread = threading.Thread(target=node.capture_sound)
thread.start()

try:
executor.spin()
except KeyboardInterrupt:
Expand Down
2 changes: 1 addition & 1 deletion src/rai_hmi/rai_hmi/hmi_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self):
)

self.create_timer(
0.01, self.status_callback, callback_group=self.callback_group
0.25, self.status_callback, callback_group=self.callback_group
)

self.status_publisher = self.create_publisher(String, "hmi_status", 10) # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion src/rai_tts/rai_tts/tts_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self):
self.job_id: int = 0
self.queued_job_id = 0
self.tts_client = self._initialize_client()
self.create_timer(0.01, self.status_callback)
self.create_timer(0.25, self.status_callback)
threading.Thread(target=self._process_queue).start()
self.get_logger().info("TTS Node has been started") # type: ignore
self.threads_number = 0
Expand Down

0 comments on commit da13b88

Please sign in to comment.