Skip to content

Commit

Permalink
feat: implement base class for asr models, add local whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejmajek committed Sep 6, 2024
1 parent ddfd016 commit a157a01
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 31 deletions.
73 changes: 73 additions & 0 deletions src/rai_asr/launch/local.launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node


def generate_launch_description():
return LaunchDescription(
[
DeclareLaunchArgument(
"recording_device",
default_value="0",
description="Microphone device number. See available by running python -c 'import sounddevice as sd; print(sd.query_devices())'",
),
DeclareLaunchArgument(
"language",
default_value="en",
description="Language code for the ASR model",
),
DeclareLaunchArgument(
"model_name",
default_value="base",
description="Model name for the ASR model",
),
DeclareLaunchArgument(
"model_vendor",
default_value="whisper",
description="Model vendor of the ASR",
),
DeclareLaunchArgument(
"silence_grace_period",
default_value="2.0",
description="Grace period in seconds after silence to stop recording",
),
DeclareLaunchArgument(
"sample_rate",
default_value="0",
description="Sample rate for audio capture (0 for auto-detect)",
),
Node(
package="rai_asr",
executable="asr_node",
name="rai_asr",
output="screen",
emulate_tty=True,
parameters=[
{
"language": LaunchConfiguration("language"),
"model": LaunchConfiguration("model"),
"silence_grace_period": LaunchConfiguration(
"silence_grace_period"
),
"sample_rate": LaunchConfiguration("sample_rate"),
}
],
),
]
)
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,14 @@ def generate_launch_description():
description="Language code for the ASR model",
),
DeclareLaunchArgument(
"model",
"model_name",
default_value="whisper-1",
description="Model type for the ASR model",
description="Model name for the ASR model",
),
DeclareLaunchArgument(
"model_vendor",
default_value="openai",
description="Model vendor of the ASR",
),
DeclareLaunchArgument(
"silence_grace_period",
Expand Down
59 changes: 59 additions & 0 deletions src/rai_asr/rai_asr/asr_clients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import io
import os
from abc import abstractmethod
from functools import partial

import numpy as np
import whisper
from numpy.typing import NDArray
from openai import OpenAI
from scipy.io import wavfile
from whisper.transcribe import transcribe


class ASRModel:
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
self.model_name = model_name
self.sample_rate = sample_rate
self.language = language

@abstractmethod
def transcribe(self, data: NDArray[np.int16]) -> str:
pass

def __call__(self, data: NDArray[np.int16]) -> str:
return self.transcribe(data)


class OpenAIWhisper(ASRModel):
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
super().__init__(model_name, sample_rate, language)
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
raise ValueError("OPENAI_API_KEY environment variable is not set.")
self.api_key = api_key
self.openai_client = OpenAI()
self.model = partial(
self.openai_client.audio.transcriptions.create,
model=self.model_name,
)

def transcribe(self, data: NDArray[np.int16]) -> str:
with io.BytesIO() as temp_wav_buffer:
wavfile.write(temp_wav_buffer, self.sample_rate, data)
temp_wav_buffer.seek(0)
temp_wav_buffer.name = "temp.wav"
response = self.model(file=temp_wav_buffer, language=self.language)
transcription = response.text
return transcription


class LocalWhisper(ASRModel):
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
super().__init__(model_name, sample_rate, language)
self.whisper = whisper.load_model(self.model_name)

def transcribe(self, data: NDArray[np.int16]) -> str:
result = transcribe(self.whisper, data.astype(np.float32) / 32768.0)
transcription = result["text"]
return transcription
67 changes: 38 additions & 29 deletions src/rai_asr/rai_asr/asr_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,23 @@
# limitations under the License.
#

import io
import os
import threading
import time
from datetime import datetime, timedelta
from functools import partial
from typing import Literal, Optional, cast

import numpy as np
import rclpy
import sounddevice as sd
import torch
from numpy.typing import NDArray
from openai import OpenAI
from openwakeword.model import Model as OWWModel
from openwakeword.utils import download_models
from rcl_interfaces.msg import ParameterDescriptor, ParameterType
from rclpy.callback_groups import ReentrantCallbackGroup
from rclpy.executors import MultiThreadedExecutor
from rclpy.node import Node
from scipy.io import wavfile
from std_msgs.msg import String

SAMPLING_RATE = 16000
Expand All @@ -47,7 +43,7 @@ def __init__(self):
self._setup_node_components()
self._setup_publishers_and_subscribers()

self.asr_model = self._load_whisper_model()
self.asr_model = self._initialize_asr_model()
self.vad_model = self._initialize_vad_model()
self.oww_model = self._initialize_open_wake_word()

Expand Down Expand Up @@ -106,6 +102,14 @@ def _declare_parameters(self):
),
),
)
self.declare_parameter(
"model_vendor",
"whisper", # openai, whisper
ParameterDescriptor(
type=ParameterType.PARAMETER_STRING,
description="Vendor of the ASR model",
),
)
self.declare_parameter(
"language",
"en",
Expand All @@ -115,8 +119,8 @@ def _declare_parameters(self):
),
)
self.declare_parameter(
"model",
"whisper-1",
"model_name",
"base",
ParameterDescriptor(
type=ParameterType.PARAMETER_STRING,
description="Model type for the ASR model",
Expand Down Expand Up @@ -162,18 +166,19 @@ def _initialize_parameters(self):
.get_parameter_value()
.double_value,
)
self.whisper_model = cast(
str,
self.get_parameter("model").get_parameter_value().string_value,
)
self.language = cast(
str,
self.get_parameter("language").get_parameter_value().string_value,
)
self.vad_threshold = cast(
float,
self.get_parameter("vad_threshold").get_parameter_value().double_value,
)
) # type: ignore
self.model_name = (
self.get_parameter("model_name").get_parameter_value().string_value
) # type: ignore
self.model_vendor = (
self.get_parameter("model_vendor").get_parameter_value().string_value
) # type: ignore
self.language = (
self.get_parameter("language").get_parameter_value().string_value
) # type: ignore

self.use_wake_word = cast(
bool,
Expand Down Expand Up @@ -218,12 +223,17 @@ def _setup_publishers_and_subscribers(self):
callback_group=self.callback_group,
)

def _load_whisper_model(self):
self.openai_client = OpenAI()
model = partial(
self.openai_client.audio.transcriptions.create, model=self.whisper_model
)
return model
def _initialize_asr_model(self):
if self.model_vendor == "openai":
from rai_asr.asr_clients import OpenAIWhisper

self.model = OpenAIWhisper(self.model_name, self.sample_rate, self.language)
elif self.model_vendor == "whisper":
from rai_asr.asr_clients import LocalWhisper

self.model = LocalWhisper(self.model_name, self.sample_rate, self.language)
else:
raise ValueError(f"Unknown model vendor: {self.model_vendor}")

def tts_status_callback(self, msg: String):
if msg.data == "processing":
Expand Down Expand Up @@ -325,14 +335,13 @@ def transcribe_audio(self):
combined_audio = np.concatenate(self.audio_buffer)
self.reset_buffer() # consume the buffer, so we don't transcribe the same audio twice

with io.BytesIO() as temp_wav_buffer:
wavfile.write(temp_wav_buffer, self.sample_rate, combined_audio)
temp_wav_buffer.seek(0)
temp_wav_buffer.name = "temp.wav"
transcription = self.model(data=combined_audio)

response = self.asr_model(file=temp_wav_buffer, language=self.language)
transcription = response.text
self.get_logger().debug(f"Transcription: {transcription}") # type: ignore
if transcription.lower() in ["you", ""]:
self.get_logger().info(f"Dropping transcription: '{transcription}'")
self.publish_status("dropping")
else:
self.get_logger().info(f"Transcription: {transcription}")
self.publish_transcription(transcription)

self.last_transcription_time = time.time()
Expand Down

0 comments on commit a157a01

Please sign in to comment.