diff --git a/Cargo.lock b/Cargo.lock index b18c570..a2a8e16 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -462,9 +462,9 @@ dependencies = [ [[package]] name = "rustpotter" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09231bcd1551093331af3e5aa38ab0b236b15b8c3159140412087542f12f3260" +checksum = "1d00f112181e60cc8b869792c3780c821e9f514d947a628dd9498d7eceb929f3" dependencies = [ "hound", "log", @@ -479,7 +479,7 @@ dependencies = [ [[package]] name = "rustpotter-cli" -version = "0.6.0" +version = "0.7.0" dependencies = [ "clap", "ctrlc", diff --git a/Cargo.toml b/Cargo.toml index 7306bbf..e8e6f69 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rustpotter-cli" -version = "0.6.0" +version = "0.7.0" edition = "2021" license = "Apache-2.0" description = "CLI for rustpotter, a personal keywords spotter written in Rust" @@ -8,7 +8,7 @@ description = "CLI for rustpotter, a personal keywords spotter written in Rust" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -rustpotter = "0.6.0" +rustpotter = "0.7.0" log = "0.4.6" pv_recorder = "1.0.2" ctrlc = "3.2.2" diff --git a/src/cli/build_model.rs b/src/cli/build_model.rs index e667fbd..dddfb0f 100644 --- a/src/cli/build_model.rs +++ b/src/cli/build_model.rs @@ -17,6 +17,12 @@ pub struct BuildModelCommand { #[clap(min_values = 1, required = true)] /// List of sample record paths sample_path: Vec, + #[clap(short = 't', long)] + /// Threshold to configure in the generated model, overwrites the detector threshold + threshold: Option, + #[clap(short = 'a', long)] + /// Averaged threshold to configure in the generated model, overwrites the detector averaged threshold + averaged_threshold: Option, } pub fn build(command: BuildModelCommand) -> Result<(), String> { println!("Start building {}!", command.model_path); @@ -33,8 +39,8 @@ pub fn build(command: BuildModelCommand) -> Result<(), String> { word_detector.add_keyword( command.model_name.clone(), false, - true, - None, + command.averaged_threshold, + command.threshold, command.sample_path, ); match word_detector.generate_wakeword_model_file(command.model_name.clone(), command.model_path) diff --git a/src/cli/spot.rs b/src/cli/spot.rs index 295f682..c45730d 100644 --- a/src/cli/spot.rs +++ b/src/cli/spot.rs @@ -2,6 +2,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; #[cfg(not(debug_assertions))] use crate::pv_recorder_utils::_get_pv_recorder_lib; +use crate::utils::enable_rustpotter_log; use clap::Args; use pv_recorder::RecorderBuilder; use rustpotter::{VadMode, WakewordDetectorBuilder}; @@ -15,12 +16,12 @@ pub struct SpotCommand { #[clap(short, long, default_value_t = 0.5)] /// Default detection threshold, only applies to models without threshold threshold: f32, + #[clap(short, long)] + /// Default detection averaged threshold, only applies to models without averaged threshold, defaults to threshold/2. + averaged_threshold: Option, #[clap(short, long, default_value_t = 0)] /// Input device index used for record device_index: usize, - #[clap(short = 'a', long)] - /// Enables template averaging - average_templates: bool, #[clap(short = 'e', long)] /// Enables eager mode eager_mode: bool, @@ -44,13 +45,13 @@ pub struct SpotCommand { pub fn spot(command: SpotCommand) -> Result<(), String> { println!("Spotting using models: {:?}!", command.model_path); if command.debug { - simple_logger::SimpleLogger::new() - .with_level(log::LevelFilter::Warn) - .with_module_level("rustpotter", log::LevelFilter::Debug) - .init().unwrap(); + enable_rustpotter_log(); } let mut detector_builder = WakewordDetectorBuilder::new(); detector_builder.set_threshold(command.threshold); + if command.averaged_threshold.is_some() { + detector_builder.set_averaged_threshold(command.averaged_threshold.unwrap()); + } detector_builder.set_sample_rate(16000); detector_builder.set_eager_mode(command.eager_mode); detector_builder.set_single_thread(command.single_thread); @@ -61,8 +62,7 @@ pub fn spot(command: SpotCommand) -> Result<(), String> { } let mut word_detector = detector_builder.build(); for path in command.model_path { - let result = - word_detector.add_keyword_from_model_file(path, command.average_templates, true); + let result = word_detector.add_keyword_from_model_file(path, true); if result.is_err() { clap::Error::raw(clap::ErrorKind::InvalidValue, result.unwrap_err() + "\n").exit(); } diff --git a/src/cli/test_model.rs b/src/cli/test_model.rs index 6745607..70c4708 100644 --- a/src/cli/test_model.rs +++ b/src/cli/test_model.rs @@ -1,13 +1,12 @@ use clap::Args; use hound::{SampleFormat, WavReader}; use rustpotter::WakewordDetectorBuilder; -use std::{ - fs::{File}, - io::BufReader, -}; +use std::{fs::File, io::BufReader}; + +use crate::utils::enable_rustpotter_log; #[derive(Args, Debug)] -/// Test model file against a wav sample, detector is automatically configured according to the sample spec +/// Test model file against a wav sample, detector is automatically configured according to the sample spec #[clap()] pub struct TestModelCommand { #[clap()] @@ -16,23 +15,30 @@ pub struct TestModelCommand { #[clap()] /// Sample record path sample_path: String, - #[clap(short = 'a', long)] - /// Enables template averaging - average_templates: bool, - #[clap(short = 't', long, default_value_t = 0.)] - /// Customize detection threshold + #[clap(short, long, default_value_t = 0.5)] + /// Default detection threshold, only applies to models without threshold threshold: f32, + #[clap(short, long, default_value_t = 0.2)] + /// Default detection averaged threshold, only applies to models without averaged threshold + averaged_threshold: f32, + #[clap(long)] + /// Enables rustpotter debug log + debug: bool, } pub fn test(command: TestModelCommand) -> Result<(), String> { println!( "Testing file {} against model {}!", command.sample_path, command.model_path, ); + if command.debug { + enable_rustpotter_log(); + } let mut detector_builder = WakewordDetectorBuilder::new(); let reader = BufReader::new(File::open(command.sample_path).or_else(|err| Err(err.to_string()))?); let mut wav_reader = WavReader::new(reader).or_else(|err| Err(err.to_string()))?; let wav_specs = wav_reader.spec(); + detector_builder.set_averaged_threshold(command.averaged_threshold); detector_builder.set_threshold(command.threshold); detector_builder.set_sample_rate(wav_specs.sample_rate as usize); detector_builder.set_bits_per_sample(wav_specs.bits_per_sample); @@ -40,11 +46,7 @@ pub fn test(command: TestModelCommand) -> Result<(), String> { // multi-channel still not supported assert!(wav_specs.channels == 1); let mut word_detector = detector_builder.build(); - let add_wakeword_result = word_detector.add_keyword_from_model_file( - command.model_path, - command.average_templates, - true, - ); + let add_wakeword_result = word_detector.add_keyword_from_model_file(command.model_path, true); if add_wakeword_result.is_err() { clap::Error::raw( clap::ErrorKind::InvalidValue, @@ -54,10 +56,12 @@ pub fn test(command: TestModelCommand) -> Result<(), String> { } match wav_specs.sample_format { SampleFormat::Int => { - wav_reader + let mut buffer = wav_reader .samples::() .map(Result::unwrap) - .collect::>() + .collect::>(); + buffer.append(&mut vec![0; word_detector.get_samples_per_frame()]); + buffer .chunks_exact(word_detector.get_samples_per_frame()) .filter_map(|chunk| word_detector.process(chunk)) .for_each(|detection| { @@ -70,10 +74,12 @@ pub fn test(command: TestModelCommand) -> Result<(), String> { Ok(()) } SampleFormat::Float => { - wav_reader + let mut buffer = wav_reader .samples::() .map(Result::unwrap) - .collect::>() + .collect::>(); + buffer.append(&mut vec![0.; word_detector.get_samples_per_frame()]); + buffer .chunks_exact(word_detector.get_samples_per_frame()) .filter_map(|chunk| word_detector.process_f32(chunk)) .for_each(|detection| { diff --git a/src/main.rs b/src/main.rs index b83a9c0..9b827ff 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ mod cli; mod pv_recorder_utils; +mod utils; use cli::run_cli; fn main() { run_cli(); -} \ No newline at end of file +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..b37eb5b --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,7 @@ +pub fn enable_rustpotter_log() { + simple_logger::SimpleLogger::new() + .with_level(log::LevelFilter::Warn) + .with_module_level("rustpotter", log::LevelFilter::Debug) + .init() + .unwrap(); +}