From a32bcc2bdce2e63246d4673705f6adf46c971ec0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81lvarez?= Date: Tue, 29 Aug 2023 16:37:24 +0200 Subject: [PATCH] use v3-beta-5, add configurable mfcc size and vad --- Cargo.lock | 173 +++++++++++++++++++------------------------ Cargo.toml | 4 +- src/cli/build_ref.rs | 18 ++--- src/cli/spot.rs | 31 +++++++- src/cli/test.rs | 6 +- src/cli/train.rs | 4 + 6 files changed, 124 insertions(+), 112 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 601fba3..5195ccf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -35,24 +35,23 @@ dependencies = [ [[package]] name = "anstream" -version = "0.3.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ca84f3628370c59db74ee214b3263d58f9aadd9b4fe7e711fd87dc452b7f163" +checksum = "b1f58811cfac344940f1a400b6e6231ce35171f614f26439e80f8c1465c5cc0c" dependencies = [ "anstyle", "anstyle-parse", "anstyle-query", "anstyle-wincon", "colorchoice", - "is-terminal", "utf8parse", ] [[package]] name = "anstyle" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a30da5c5f2d5e72842e00bcb57657162cdabef0931f40e2deb9b4140440cecd" +checksum = "15c4c2c83f81532e5845a733998b6971faca23490340a418e9b72a3ec9de12ea" [[package]] name = "anstyle-parse" @@ -74,9 +73,9 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "1.0.2" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c677ab05e09154296dd37acecd46420c17b9713e8366facafa8fc0885167cf4c" +checksum = "58f54d10c6dfa51283a066ceab3ec1ab78d13fae00aa49243a45e4571fb79dfd" dependencies = [ "anstyle", "windows-sys", @@ -146,9 +145,9 @@ checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" [[package]] name = "candle-core" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e56d08f7794036648d7ba5448c82ab7c3f38c25e90cdc4032afd246d1292a42c" +checksum = "1a14e585d4632a3c278c03d69db4e3595801cceeb72f329da099bb687025b05c" dependencies = [ "byteorder", "candle-gemm", @@ -291,9 +290,9 @@ dependencies = [ [[package]] name = "candle-nn" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fba63698e86c8d82c9851181d7f2fa8cede3956a0c9172d74e5ea6ee36ff6b6a" +checksum = "fdf20b6d0b76ee97b9438763972b17e0f1204bffb31c9f1fd4f7a30446a4b4b1" dependencies = [ "candle-core", "safetensors", @@ -302,9 +301,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.82" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "305fe645edc1442a0fa8b6726ba61d422798d37a52e12eaecf4b022ebbb88f01" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" dependencies = [ "jobserver", "libc", @@ -371,9 +370,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.3.22" +version = "4.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b417ae4361bca3f5de378294fc7472d3c4ed86a5ef9f49e93ae722f432aae8d2" +checksum = "7c8d502cbaec4595d2e7d5f61e318f05417bd2b66fdc3809498f0d3fdf0bea27" dependencies = [ "clap_builder", "clap_derive", @@ -382,9 +381,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.3.22" +version = "4.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c90dc0f0e42c64bff177ca9d7be6fcc9ddb0f26a6e062174a61c84dd6c644d4" +checksum = "5891c7bc0edb3e1c2204fc5e94009affabeb1821c9e5fdc3959536c5c0bb984d" dependencies = [ "anstream", "anstyle", @@ -394,9 +393,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.3.12" +version = "4.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54a9bb5758fc5dfe728d1019941681eccaf0cf8a4189b692a0ee2f2ecf90a050" +checksum = "c9fd1a5729c4548118d7d70ff234a44868d00489a4b6597b0b020918a0e91a1a" dependencies = [ "heck", "proc-macro2", @@ -406,9 +405,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" +checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961" [[package]] name = "colorchoice" @@ -547,7 +546,7 @@ version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a011bbe2c35ce9c1f143b7af6f94f29a167beb4cd1d29e6740ce836f723120e" dependencies = [ - "nix 0.26.2", + "nix 0.26.4", "windows-sys", ] @@ -559,9 +558,9 @@ checksum = "0c87e182de0887fd5361989c677c4e8f5000cd9491d6d563161a8f3a5519fc7f" [[package]] name = "deranged" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7684a49fb1af197853ef7b2ee694bc1f5b4179556f1e5710e1760c5db6f5e929" +checksum = "f2696e8a945f658fd14dc3b87242e6b80cd0f36ff04ea560fa39082368847946" [[package]] name = "dyn-stack" @@ -587,9 +586,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b30f669a7961ef1631673d2766cc92f52d64f7ef354d4fe0ddfd30ed52f0f4f" +checksum = "136526188508e25c6fef639d7927dfb3e0e3084488bf202267829cf7fc23dbdd" dependencies = [ "errno-dragonfly", "libc", @@ -703,17 +702,6 @@ dependencies = [ "hashbrown", ] -[[package]] -name = "is-terminal" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" -dependencies = [ - "hermit-abi", - "rustix", - "windows-sys", -] - [[package]] name = "itoa" version = "1.0.9" @@ -839,9 +827,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "76fc44e2588d5b436dbc3c6cf62aef290f90dab6235744a93dfe1cc18f451e2c" [[package]] name = "memmap2" @@ -909,14 +897,13 @@ dependencies = [ [[package]] name = "nix" -version = "0.26.2" +version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" +checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b" dependencies = [ "bitflags 1.3.2", "cfg-if", "libc", - "static_assertions", ] [[package]] @@ -1049,7 +1036,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets 0.48.3", + "windows-targets 0.48.5", ] [[package]] @@ -1216,9 +1203,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.3" +version = "1.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81bc1d4caf89fac26a70747fe603c130093b53c773888797a6329091246d651a" +checksum = "12de2eff854e5fa4b1295edd650e227e9d8fb0c9e90b12e7f36d6a6811791a29" dependencies = [ "aho-corasick", "memchr", @@ -1228,9 +1215,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fed1ceff11a1dddaee50c9dc8e4938bd106e9d89ae372f192311e7da498e3b69" +checksum = "49530408a136e16e5b486e883fbb6ba058e8e4e8ae6621a77b048b314336e629" dependencies = [ "aho-corasick", "memchr", @@ -1239,9 +1226,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" [[package]] name = "rubato" @@ -1278,9 +1265,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.8" +version = "0.38.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ed4fa021d81c8392ce04db050a3da9a60299050b7ae1cf482d862b54a7218f" +checksum = "ed6248e1caa625eb708e266e06159f135e8c26f2bb7ceb72dc4b2766d0340964" dependencies = [ "bitflags 2.4.0", "errno", @@ -1291,9 +1278,9 @@ dependencies = [ [[package]] name = "rustpotter" -version = "3.0.0-beta.4" +version = "3.0.0-beta.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62e73cbed04283cc5af1dd62c019c9bc25c2637eb63a9cbca8839e2a3c740489" +checksum = "f222bac1ebe5c020312fce48ee2d8c489e5247327d4246356e5a07738705c488" dependencies = [ "candle-core", "candle-nn", @@ -1306,7 +1293,7 @@ dependencies = [ [[package]] name = "rustpotter-cli" -version = "3.0.0-beta.4" +version = "3.0.0-beta.5" dependencies = [ "clap", "cpal", @@ -1325,9 +1312,9 @@ checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" [[package]] name = "safetensors" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad8cbd90c388a0b028565d8ad22e090101599d951c6b5f105b4f7772721a9d5f" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" dependencies = [ "serde", "serde_json", @@ -1356,18 +1343,18 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.183" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32ac8da02677876d532745a130fc9d8e6edfa81a269b107c5b00829b91d8eb3c" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.183" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aafe972d60b0b9bee71a91b92fee2d4fb3c9d7e8f6b179aa99f27203d99a4816" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", @@ -1397,12 +1384,6 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" -[[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" - [[package]] name = "strength_reduce" version = "0.2.4" @@ -1439,9 +1420,9 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.7.1" +version = "3.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc02fddf48964c42031a0b3fe0428320ecf3a73c401040fc0096f97794310651" +checksum = "cb94d2f3cc536af71caac6b6fcebf65860b347e7ce0cc9ebe8f70d3e521054ef" dependencies = [ "cfg-if", "fastrand", @@ -1472,9 +1453,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.25" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fdd63d58b18d663fbdf70e049f00a22c8e42be082203be7f26589213cd75ea" +checksum = "17f6bb557fd245c28e6411aa56b6403c689ad95061f50e4be16c274e70a17e48" dependencies = [ "deranged", "serde", @@ -1670,7 +1651,7 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets 0.48.3", + "windows-targets 0.48.5", ] [[package]] @@ -1690,17 +1671,17 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.48.3" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27f51fb4c64f8b770a823c043c7fad036323e1c48f55287b7bbb7987b2fcdf3b" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm 0.48.3", - "windows_aarch64_msvc 0.48.3", - "windows_i686_gnu 0.48.3", - "windows_i686_msvc 0.48.3", - "windows_x86_64_gnu 0.48.3", - "windows_x86_64_gnullvm 0.48.3", - "windows_x86_64_msvc 0.48.3", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", ] [[package]] @@ -1711,9 +1692,9 @@ checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.3" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fde1bb55ae4ce76a597a8566d82c57432bc69c039449d61572a7a353da28f68c" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_msvc" @@ -1723,9 +1704,9 @@ checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" [[package]] name = "windows_aarch64_msvc" -version = "0.48.3" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1513e8d48365a78adad7322fd6b5e4c4e99d92a69db8df2d435b25b1f1f286d4" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_i686_gnu" @@ -1735,9 +1716,9 @@ checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" [[package]] name = "windows_i686_gnu" -version = "0.48.3" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60587c0265d2b842298f5858e1a5d79d146f9ee0c37be5782e92a6eb5e1d7a83" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_msvc" @@ -1747,9 +1728,9 @@ checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" [[package]] name = "windows_i686_msvc" -version = "0.48.3" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "224fe0e0ffff5d2ea6a29f82026c8f43870038a0ffc247aa95a52b47df381ac4" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_x86_64_gnu" @@ -1759,9 +1740,9 @@ checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" [[package]] name = "windows_x86_64_gnu" -version = "0.48.3" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62fc52a0f50a088de499712cbc012df7ebd94e2d6eb948435449d76a6287e7ad" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnullvm" @@ -1771,9 +1752,9 @@ checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.3" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2093925509d91ea3d69bcd20238f4c2ecdb1a29d3c281d026a09705d0dd35f3d" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_msvc" @@ -1783,15 +1764,15 @@ checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" [[package]] name = "windows_x86_64_msvc" -version = "0.48.3" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6ade45bc8bf02ae2aa34a9d54ba660a1a58204da34ba793c00d83ca3730b5f1" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "winnow" -version = "0.5.12" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83817bbecf72c73bad717ee86820ebf286203d2e04c3951f3cd538869c897364" +checksum = "7c2e3184b9c4e92ad5167ca73039d0c42476302ab603e2fec4487511f38ccefc" dependencies = [ "memchr", ] diff --git a/Cargo.toml b/Cargo.toml index aabf090..f8d1b59 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rustpotter-cli" -version = "3.0.0-beta.4" +version = "3.0.0-beta.5" edition = "2021" license = "Apache-2.0" description = "CLI for Rustpotter, an open source wakeword spotter forged in rust." @@ -8,7 +8,7 @@ authors = ["Miguel Álvarez Díez "] repository = "https://github.com/GiviMAD/rustpotter" exclude = ["tools/**",".github",".gitignore"] [dependencies] -rustpotter = { version = "3.0.0-beta.4", features = ["debug", "audio", "record"] } +rustpotter = { version = "3.0.0-beta.5", features = ["debug", "audio", "record"] } ctrlc = "3.2.2" clap = { version = "4.1.6", features = ["derive"] } hound = "3.4.0" diff --git a/src/cli/build_ref.rs b/src/cli/build_ref.rs index b002d5e..ce072d6 100644 --- a/src/cli/build_ref.rs +++ b/src/cli/build_ref.rs @@ -23,6 +23,9 @@ pub struct BuildCommand { #[clap(short = 'a', long)] /// Averaged threshold to configure in the generated model, overwrites the detector averaged threshold averaged_threshold: Option, + #[clap(short = 'c', long, default_value_t = 16)] + /// Number of extracted mel-frequency cepstral coefficients + mfcc_size: u16, } pub fn build_ref(command: BuildCommand) -> Result<(), String> { println!("Start building {}!", command.model_path); @@ -39,18 +42,9 @@ pub fn build_ref(command: BuildCommand) -> Result<(), String> { command.threshold, command.averaged_threshold, command.sample_path, + command.mfcc_size, )?; - match wakeword.save_to_file(&command.model_path) { - Ok(_) => { - println!("{} created!", command.model_name); - } - Err(error) => { - clap::Error::raw( - clap::error::ErrorKind::InvalidValue, - error.to_string() + "\n", - ) - .exit(); - } - }; + wakeword.save_to_file(&command.model_path)?; + println!("{} created!", command.model_name); Ok(()) } diff --git a/src/cli/spot.rs b/src/cli/spot.rs index f16e7e5..9df2faf 100644 --- a/src/cli/spot.rs +++ b/src/cli/spot.rs @@ -8,7 +8,7 @@ use cpal::{ }; use gag::Gag; use rustpotter::{ - Rustpotter, RustpotterConfig, RustpotterDetection, Sample, SampleFormat, ScoreMode, + Rustpotter, RustpotterConfig, RustpotterDetection, Sample, SampleFormat, ScoreMode, VADMode, }; use time::OffsetDateTime; @@ -49,6 +49,9 @@ pub struct SpotCommand { #[clap(short = 's', long, default_value_t = ClapScoreMode::Max)] /// How to calculate a unified score score_mode: ClapScoreMode, + #[clap(short = 'v', long)] + /// Enabled vad detection. + vad_mode: Option, #[clap(short = 'g', long)] /// Enables a gain-normalizer audio filter. gain_normalizer: bool, @@ -129,6 +132,7 @@ pub fn spot(command: SpotCommand) -> Result<(), String> { config.detector.min_scores = command.min_scores; config.detector.score_mode = command.score_mode.into(); config.detector.score_ref = command.score_ref; + config.detector.vad_mode = command.vad_mode.map(|v| v.into()); config.detector.record_path = command.record_path; config.filters.gain_normalizer.enabled = command.gain_normalizer; config.filters.gain_normalizer.gain_ref = command.gain_ref; @@ -328,6 +332,31 @@ pub(crate) fn print_detection( }; } +#[derive(clap::ValueEnum, Clone, Debug)] +pub(crate) enum ClapVADMode { + Easy, + Medium, + Hard, +} +impl std::fmt::Display for ClapVADMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ClapVADMode::Easy => write!(f, "easy"), + ClapVADMode::Medium => write!(f, "medium"), + ClapVADMode::Hard => write!(f, "hard"), + } + } +} + +impl From for VADMode { + fn from(value: ClapVADMode) -> Self { + match value { + ClapVADMode::Easy => VADMode::Easy, + ClapVADMode::Medium => VADMode::Medium, + ClapVADMode::Hard => VADMode::Hard, + } + } +} #[derive(clap::ValueEnum, Clone, Debug)] pub(crate) enum ClapScoreMode { Max, diff --git a/src/cli/test.rs b/src/cli/test.rs index c54ae8f..06fcc38 100644 --- a/src/cli/test.rs +++ b/src/cli/test.rs @@ -3,7 +3,7 @@ use hound::{SampleFormat, WavReader}; use rustpotter::{Rustpotter, RustpotterConfig, Sample}; use std::{fs::File, io::BufReader}; -use super::spot::{print_detection, ClapScoreMode}; +use super::spot::{print_detection, ClapScoreMode, ClapVADMode}; #[derive(Args, Debug)] /// Test wakeword file against a wav sample, detector is automatically configured according to the sample spec @@ -27,6 +27,9 @@ pub struct TestCommand { #[clap(short = 's', long, default_value_t = ClapScoreMode::Max)] /// How to calculate a unified score, no applies to wakeword models. score_mode: ClapScoreMode, + #[clap(short = 'v', long)] + /// Enabled vad detection. + vad_mode: Option, #[clap(short = 'g', long)] /// Enables a gain-normalizer audio filter. gain_normalizer: bool, @@ -80,6 +83,7 @@ pub fn test(command: TestCommand) -> Result<(), String> { config.detector.min_scores = command.min_scores; config.detector.score_mode = command.score_mode.into(); config.detector.score_ref = command.score_ref; + config.detector.vad_mode = command.vad_mode.map(|v| v.into()); config.detector.record_path = command.record_path; config.filters.gain_normalizer.enabled = command.gain_normalizer; config.filters.gain_normalizer.gain_ref = command.gain_ref; diff --git a/src/cli/train.rs b/src/cli/train.rs index dbd6a8d..23398c4 100644 --- a/src/cli/train.rs +++ b/src/cli/train.rs @@ -23,6 +23,9 @@ pub struct TrainCommand { #[clap(short = 'e', long, default_value_t = 1000)] /// Number of backward and forward cycles to run epochs: usize, + #[clap(short = 'c', long, default_value_t = 16)] + /// Number of extracted mel-frequency cepstral coefficients + mfcc_size: u16, #[clap(short = 'm')] /// Model to continue training from wakeword_model: Option, @@ -40,6 +43,7 @@ pub fn train(command: TrainCommand) -> Result<(), String> { command.test_dir, command.learning_rate, command.epochs, + command.mfcc_size, model, ) .map_err(|err| err.to_string())?;