Skip to content

Commit

Permalink
use rustpotter 0.7.0; add averaged_threshold option
Browse files Browse the repository at this point in the history
  • Loading branch information
GiviMAD committed May 8, 2022
1 parent 7679a49 commit 18f34a0
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 36 deletions.
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[package]
name = "rustpotter-cli"
version = "0.6.0"
version = "0.7.0"
edition = "2021"
license = "Apache-2.0"
description = "CLI for rustpotter, a personal keywords spotter written in Rust"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
rustpotter = "0.6.0"
rustpotter = "0.7.0"
log = "0.4.6"
pv_recorder = "1.0.2"
ctrlc = "3.2.2"
Expand Down
10 changes: 8 additions & 2 deletions src/cli/build_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ pub struct BuildModelCommand {
#[clap(min_values = 1, required = true)]
/// List of sample record paths
sample_path: Vec<String>,
#[clap(short = 't', long)]
/// Threshold to configure in the generated model, overwrites the detector threshold
threshold: Option<f32>,
#[clap(short = 'a', long)]
/// Averaged threshold to configure in the generated model, overwrites the detector averaged threshold
averaged_threshold: Option<f32>,
}
pub fn build(command: BuildModelCommand) -> Result<(), String> {
println!("Start building {}!", command.model_path);
Expand All @@ -33,8 +39,8 @@ pub fn build(command: BuildModelCommand) -> Result<(), String> {
word_detector.add_keyword(
command.model_name.clone(),
false,
true,
None,
command.averaged_threshold,
command.threshold,
command.sample_path,
);
match word_detector.generate_wakeword_model_file(command.model_name.clone(), command.model_path)
Expand Down
18 changes: 9 additions & 9 deletions src/cli/spot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::atomic::{AtomicBool, Ordering};

#[cfg(not(debug_assertions))]
use crate::pv_recorder_utils::_get_pv_recorder_lib;
use crate::utils::enable_rustpotter_log;
use clap::Args;
use pv_recorder::RecorderBuilder;
use rustpotter::{VadMode, WakewordDetectorBuilder};
Expand All @@ -15,12 +16,12 @@ pub struct SpotCommand {
#[clap(short, long, default_value_t = 0.5)]
/// Default detection threshold, only applies to models without threshold
threshold: f32,
#[clap(short, long)]
/// Default detection averaged threshold, only applies to models without averaged threshold, defaults to threshold/2.
averaged_threshold: Option<f32>,
#[clap(short, long, default_value_t = 0)]
/// Input device index used for record
device_index: usize,
#[clap(short = 'a', long)]
/// Enables template averaging
average_templates: bool,
#[clap(short = 'e', long)]
/// Enables eager mode
eager_mode: bool,
Expand All @@ -44,13 +45,13 @@ pub struct SpotCommand {
pub fn spot(command: SpotCommand) -> Result<(), String> {
println!("Spotting using models: {:?}!", command.model_path);
if command.debug {
simple_logger::SimpleLogger::new()
.with_level(log::LevelFilter::Warn)
.with_module_level("rustpotter", log::LevelFilter::Debug)
.init().unwrap();
enable_rustpotter_log();
}
let mut detector_builder = WakewordDetectorBuilder::new();
detector_builder.set_threshold(command.threshold);
if command.averaged_threshold.is_some() {
detector_builder.set_averaged_threshold(command.averaged_threshold.unwrap());
}
detector_builder.set_sample_rate(16000);
detector_builder.set_eager_mode(command.eager_mode);
detector_builder.set_single_thread(command.single_thread);
Expand All @@ -61,8 +62,7 @@ pub fn spot(command: SpotCommand) -> Result<(), String> {
}
let mut word_detector = detector_builder.build();
for path in command.model_path {
let result =
word_detector.add_keyword_from_model_file(path, command.average_templates, true);
let result = word_detector.add_keyword_from_model_file(path, true);
if result.is_err() {
clap::Error::raw(clap::ErrorKind::InvalidValue, result.unwrap_err() + "\n").exit();
}
Expand Down
44 changes: 25 additions & 19 deletions src/cli/test_model.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use clap::Args;
use hound::{SampleFormat, WavReader};
use rustpotter::WakewordDetectorBuilder;
use std::{
fs::{File},
io::BufReader,
};
use std::{fs::File, io::BufReader};

use crate::utils::enable_rustpotter_log;

#[derive(Args, Debug)]
/// Test model file against a wav sample, detector is automatically configured according to the sample spec
/// Test model file against a wav sample, detector is automatically configured according to the sample spec
#[clap()]
pub struct TestModelCommand {
#[clap()]
Expand All @@ -16,35 +15,38 @@ pub struct TestModelCommand {
#[clap()]
/// Sample record path
sample_path: String,
#[clap(short = 'a', long)]
/// Enables template averaging
average_templates: bool,
#[clap(short = 't', long, default_value_t = 0.)]
/// Customize detection threshold
#[clap(short, long, default_value_t = 0.5)]
/// Default detection threshold, only applies to models without threshold
threshold: f32,
#[clap(short, long, default_value_t = 0.2)]
/// Default detection averaged threshold, only applies to models without averaged threshold
averaged_threshold: f32,
#[clap(long)]
/// Enables rustpotter debug log
debug: bool,
}
pub fn test(command: TestModelCommand) -> Result<(), String> {
println!(
"Testing file {} against model {}!",
command.sample_path, command.model_path,
);
if command.debug {
enable_rustpotter_log();
}
let mut detector_builder = WakewordDetectorBuilder::new();
let reader =
BufReader::new(File::open(command.sample_path).or_else(|err| Err(err.to_string()))?);
let mut wav_reader = WavReader::new(reader).or_else(|err| Err(err.to_string()))?;
let wav_specs = wav_reader.spec();
detector_builder.set_averaged_threshold(command.averaged_threshold);
detector_builder.set_threshold(command.threshold);
detector_builder.set_sample_rate(wav_specs.sample_rate as usize);
detector_builder.set_bits_per_sample(wav_specs.bits_per_sample);
detector_builder.set_sample_format(wav_specs.sample_format);
// multi-channel still not supported
assert!(wav_specs.channels == 1);
let mut word_detector = detector_builder.build();
let add_wakeword_result = word_detector.add_keyword_from_model_file(
command.model_path,
command.average_templates,
true,
);
let add_wakeword_result = word_detector.add_keyword_from_model_file(command.model_path, true);
if add_wakeword_result.is_err() {
clap::Error::raw(
clap::ErrorKind::InvalidValue,
Expand All @@ -54,10 +56,12 @@ pub fn test(command: TestModelCommand) -> Result<(), String> {
}
match wav_specs.sample_format {
SampleFormat::Int => {
wav_reader
let mut buffer = wav_reader
.samples::<i32>()
.map(Result::unwrap)
.collect::<Vec<_>>()
.collect::<Vec<_>>();
buffer.append(&mut vec![0; word_detector.get_samples_per_frame()]);
buffer
.chunks_exact(word_detector.get_samples_per_frame())
.filter_map(|chunk| word_detector.process(chunk))
.for_each(|detection| {
Expand All @@ -70,10 +74,12 @@ pub fn test(command: TestModelCommand) -> Result<(), String> {
Ok(())
}
SampleFormat::Float => {
wav_reader
let mut buffer = wav_reader
.samples::<f32>()
.map(Result::unwrap)
.collect::<Vec<_>>()
.collect::<Vec<_>>();
buffer.append(&mut vec![0.; word_detector.get_samples_per_frame()]);
buffer
.chunks_exact(word_detector.get_samples_per_frame())
.filter_map(|chunk| word_detector.process_f32(chunk))
.for_each(|detection| {
Expand Down
3 changes: 2 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod cli;
mod pv_recorder_utils;
mod utils;
use cli::run_cli;
fn main() {
run_cli();
}
}
7 changes: 7 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pub fn enable_rustpotter_log() {
simple_logger::SimpleLogger::new()
.with_level(log::LevelFilter::Warn)
.with_module_level("rustpotter", log::LevelFilter::Debug)
.init()
.unwrap();
}

0 comments on commit 18f34a0

Please sign in to comment.