From 3161f3bff188f28dfc1a54cddb42b6a003f9173b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81lvarez?= Date: Tue, 5 Sep 2023 00:37:23 +0200 Subject: [PATCH] version 3 --- .github/workflows/main.yml | 16 +- Cargo.lock | 668 ++++++++++++++++++++++----- Cargo.toml | 4 +- README.md | 122 +++-- src/cli/{build_model.rs => build.rs} | 36 +- src/cli/devices.rs | 10 +- src/cli/filter.rs | 87 ++-- src/cli/mod.rs | 73 ++- src/cli/record.rs | 149 ++++-- src/cli/spot.rs | 274 ++++++----- src/cli/test.rs | 187 ++++++++ src/cli/test_model.rs | 155 ------- src/cli/train.rs | 59 +++ tools/Dockerfile | 2 + tools/create_tag.sh | 2 +- tools/install.sh | 4 + 16 files changed, 1237 insertions(+), 611 deletions(-) rename src/cli/{build_model.rs => build.rs} (61%) create mode 100644 src/cli/test.rs delete mode 100644 src/cli/test_model.rs create mode 100644 src/cli/train.rs create mode 100755 tools/install.sh diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d1954d6..bc70ddf 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -31,15 +31,15 @@ jobs: steps: - uses: actions/checkout@v3 - name: Set up QEMU - uses: docker/setup-qemu-action@v1 - - name: Docker Setup Buildx - uses: docker/setup-buildx-action@v1.6.0 + uses: docker/setup-qemu-action@v2 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 with: install: true - name: Build binaries run: | mkdir -p output - docker build -f tools/Dockerfile . -t rustpotter-cli_binary:arm --platform arm --load + docker build -f tools/Dockerfile --build-arg RUSTFLAGS="-C target-feature=+fp16" . -t rustpotter-cli_binary:arm --platform arm --load DOCKER_BUILDKIT=1 docker run --platform=arm -v $(pwd)/output:/out rustpotter-cli_binary:arm bash -c "cp /code/output/* /out/" - name: artifact debian arm uses: actions/upload-artifact@v3 @@ -52,15 +52,15 @@ jobs: steps: - uses: actions/checkout@v3 - name: Set up QEMU - uses: docker/setup-qemu-action@v1 - - name: Docker Setup Buildx - uses: docker/setup-buildx-action@v1.6.0 + uses: docker/setup-qemu-action@v2 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v2 with: install: true - name: Build binaries run: | mkdir -p output - docker build -f tools/Dockerfile . -t rustpotter-cli_binary:arm64 --platform arm64 --load + docker build -f tools/Dockerfile --build-arg RUSTFLAGS="-C target-feature=+fp16" . -t rustpotter-cli_binary:arm64 --platform arm64 --load DOCKER_BUILDKIT=1 docker run --platform=arm64 -v $(pwd)/output:/out rustpotter-cli_binary:arm64 bash -c "cp /code/output/* /out/" - name: artifact debian arm64 uses: actions/upload-artifact@v3 diff --git a/Cargo.lock b/Cargo.lock index 4426083..8acc456 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,18 +4,18 @@ version = 3 [[package]] name = "aho-corasick" -version = "1.0.2" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43f6cb1bf222025340178f382c426f13757b2960e89779dfcb319c32542a5a41" +checksum = "0c378d78423fdad8089616f827526ee33c19f2fddbd5de1629152c9593ba4783" dependencies = [ "memchr", ] [[package]] name = "alsa" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8512c9117059663fb5606788fbca3619e2a91dac0e3fe516242eab1fa6be5e44" +checksum = "e2562ad8dcf0f789f65c6fdaad8a8a9708ed6b488e649da28c01656ad66b8b47" dependencies = [ "alsa-sys", "bitflags 1.3.2", @@ -35,24 +35,23 @@ dependencies = [ [[package]] name = "anstream" -version = "0.3.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ca84f3628370c59db74ee214b3263d58f9aadd9b4fe7e711fd87dc452b7f163" +checksum = "b1f58811cfac344940f1a400b6e6231ce35171f614f26439e80f8c1465c5cc0c" dependencies = [ "anstyle", "anstyle-parse", "anstyle-query", "anstyle-wincon", "colorchoice", - "is-terminal", "utf8parse", ] [[package]] name = "anstyle" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a30da5c5f2d5e72842e00bcb57657162cdabef0931f40e2deb9b4140440cecd" +checksum = "15c4c2c83f81532e5845a733998b6971faca23490340a418e9b72a3ec9de12ea" [[package]] name = "anstyle-parse" @@ -74,9 +73,9 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "1.0.1" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "180abfa45703aebe0093f79badacc01b8fd4ea2e35118747e5811127f926e188" +checksum = "58f54d10c6dfa51283a066ceab3ec1ab78d13fae00aa49243a45e4571fb79dfd" dependencies = [ "anstyle", "windows-sys", @@ -116,9 +115,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.3.3" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" +checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" [[package]] name = "bumpalo" @@ -126,19 +125,188 @@ version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" +[[package]] +name = "bytemuck" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" + +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + [[package]] name = "bytes" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +[[package]] +name = "candle-core" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27bf6ecb02090cbc47aaecd6df7ec25b4a00b2b01076b77563dc7092fba64bc1" +dependencies = [ + "byteorder", + "candle-gemm", + "half 2.3.1", + "memmap2", + "num-traits", + "num_cpus", + "rand", + "rand_distr", + "rayon", + "safetensors", + "thiserror", + "zip", +] + +[[package]] +name = "candle-gemm" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b726a1f6cdd7ff080e95e3d91694701b1e04a58acd198e4a78c39428b2274e" +dependencies = [ + "candle-gemm-c32", + "candle-gemm-c64", + "candle-gemm-common", + "candle-gemm-f16", + "candle-gemm-f32", + "candle-gemm-f64", + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-c32" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "661470663389f0c99fd8449e620bfae630a662739f830a323eda4dcf80888843" +dependencies = [ + "candle-gemm-common", + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-c64" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a111ddf61db562854a6d2ff4dfe1e8a84066431b7bc68d3afae4bf60874fda0" +dependencies = [ + "candle-gemm-common", + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-common" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6dd93783ead7eeef14361667ea32014dc6f716a2fc956b075fe78729e10dd5" +dependencies = [ + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-f16" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b76499bf4b858cacc526c5c8f948bc7152774247dce8568f174b743ab1363fa4" +dependencies = [ + "candle-gemm-common", + "candle-gemm-f32", + "dyn-stack", + "half 2.3.1", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-f32" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bec152e7d36339d3785e0d746d75ee94a4e92968fbb12ddcc91b536b938d016" +dependencies = [ + "candle-gemm-common", + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-f64" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00f59ac68a5521e2ff71431bb7f1b22126ff0b60c5e66599b1f4676433da6e69" +dependencies = [ + "candle-gemm-common", + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-nn" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ac23e21d440c67867bb8ee4542ab76c83428daad73866632714b99072fcc59e" +dependencies = [ + "candle-core", + "safetensors", + "thiserror", +] + [[package]] name = "cc" -version = "1.0.79" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" dependencies = [ "jobserver", + "libc", ] [[package]] @@ -186,7 +354,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" dependencies = [ "ciborium-io", - "half", + "half 1.8.2", ] [[package]] @@ -202,20 +370,19 @@ dependencies = [ [[package]] name = "clap" -version = "4.3.19" +version = "4.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fd304a20bff958a57f04c4e96a2e7594cc4490a0e809cbd48bb6437edaa452d" +checksum = "6a13b88d2c62ff462f88e4a121f17a82c1af05693a2f192b5c38d14de73c19f6" dependencies = [ "clap_builder", "clap_derive", - "once_cell", ] [[package]] name = "clap_builder" -version = "4.3.19" +version = "4.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01c6a3f08f1fe5662a35cfe393aec09c4df95f60ee93b7556505260f75eee9e1" +checksum = "2bb9faaa7c2ef94b2743a21f5a29e6f0010dff4caa69ac8e9d6cf8b6fa74da08" dependencies = [ "anstream", "anstyle", @@ -225,21 +392,21 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.3.12" +version = "4.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54a9bb5758fc5dfe728d1019941681eccaf0cf8a4189b692a0ee2f2ecf90a050" +checksum = "0862016ff20d69b84ef8247369fabf5c008a7417002411897d40ee1f4532b873" dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.27", + "syn 2.0.31", ] [[package]] name = "clap_lex" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" +checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961" [[package]] name = "colorchoice" @@ -314,13 +481,71 @@ dependencies = [ "windows", ] +[[package]] +name = "crc32fast" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + [[package]] name = "ctrlc" -version = "3.4.0" +version = "3.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a011bbe2c35ce9c1f143b7af6f94f29a167beb4cd1d29e6740ce836f723120e" +checksum = "82e95fbd621905b854affdc67943b043a0fbb6ed7385fd5a25650d19a8a6cfdf" dependencies = [ - "nix 0.26.2", + "nix 0.27.1", "windows-sys", ] @@ -330,6 +555,28 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c87e182de0887fd5361989c677c4e8f5000cd9491d6d563161a8f3a5519fc7f" +[[package]] +name = "deranged" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2696e8a945f658fd14dc3b87242e6b80cd0f36ff04ea560fa39082368847946" + +[[package]] +name = "dyn-stack" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24269739c7c175bc12130622ef1a60b9ab2d5b30c0b9ce5110cd406d7fd497bc" +dependencies = [ + "bytemuck", + "reborrow", +] + +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + [[package]] name = "equivalent" version = "1.0.1" @@ -338,9 +585,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.1" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" +checksum = "136526188508e25c6fef639d7927dfb3e0e3084488bf202267829cf7fc23dbdd" dependencies = [ "errno-dragonfly", "libc", @@ -384,6 +631,17 @@ dependencies = [ "tempfile", ] +[[package]] +name = "getrandom" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "glob" version = "0.3.1" @@ -396,6 +654,19 @@ version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" +[[package]] +name = "half" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +dependencies = [ + "cfg-if", + "crunchy", + "num-traits", + "rand", + "rand_distr", +] + [[package]] name = "hashbrown" version = "0.14.0" @@ -431,15 +702,10 @@ dependencies = [ ] [[package]] -name = "is-terminal" -version = "0.4.9" +name = "itoa" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" -dependencies = [ - "hermit-abi", - "rustix", - "windows-sys", -] +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" [[package]] name = "jni" @@ -521,11 +787,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "libm" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" + [[package]] name = "linux-raw-sys" -version = "0.4.3" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09fc20d2ca12cb9f044c93e3bd6d32d523e6e2ec3db4f7b2939cd99026ecd3f0" +checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503" [[package]] name = "lock_api" @@ -539,9 +811,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.19" +version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "mach2" @@ -554,9 +826,27 @@ dependencies = [ [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" + +[[package]] +name = "memmap2" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" +dependencies = [ + "libc", +] + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] [[package]] name = "minimal-lexical" @@ -606,14 +896,13 @@ dependencies = [ [[package]] name = "nix" -version = "0.26.2" +version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a" +checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.0", "cfg-if", "libc", - "static_assertions", ] [[package]] @@ -628,9 +917,9 @@ dependencies = [ [[package]] name = "num-complex" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" dependencies = [ "num-traits", ] @@ -663,6 +952,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" dependencies = [ "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", ] [[package]] @@ -735,9 +1035,15 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets 0.48.1", + "windows-targets 0.48.5", ] +[[package]] +name = "paste" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" + [[package]] name = "peeking_take_while" version = "0.1.2" @@ -750,6 +1056,12 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "primal-check" version = "0.3.3" @@ -780,19 +1092,90 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.32" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50f3b39ccfb720540debaa0164757101c08ecb8d326b15358ce76a62c7e85965" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "raw-cpuid" +version = "10.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "raw-window-handle" version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9" +[[package]] +name = "rayon" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "num_cpus", +] + [[package]] name = "realfft" version = "3.3.0" @@ -802,6 +1185,12 @@ dependencies = [ "rustfft", ] +[[package]] +name = "reborrow" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2962bf2e1f971c53ef59b2d7ca51d6a5e5c4a9d2be47eb1f661a321a4da85888" + [[package]] name = "redox_syscall" version = "0.3.5" @@ -813,9 +1202,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.1" +version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2eae68fc220f7cf2532e4494aded17545fce192d59cd996e0fe7887f4ceb575" +checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" dependencies = [ "aho-corasick", "memchr", @@ -825,9 +1214,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.3" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39354c10dd07468c2e73926b23bb9c2caca74c5501e38a35da70406f1d923310" +checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" dependencies = [ "aho-corasick", "memchr", @@ -836,9 +1225,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" [[package]] name = "rubato" @@ -875,11 +1264,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.4" +version = "0.38.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a962918ea88d644592894bc6dc55acc6c0956488adcebbfb6e273506b7fd6e5" +checksum = "c0c3dde1fc030af041adc40e79c0e7fbcf431dd24870053d187d7c66e4b87453" dependencies = [ - "bitflags 2.3.3", + "bitflags 2.4.0", "errno", "libc", "linux-raw-sys", @@ -888,10 +1277,12 @@ dependencies = [ [[package]] name = "rustpotter" -version = "2.0.1" +version = "3.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "588856a71149ec53e769d2cd9e54e765be6a4664824d7306dc33e295aa8465b9" +checksum = "4b362aaf5f873dae48230ae6e0ccf41f270d3728132cf154070a07143589cb33" dependencies = [ + "candle-core", + "candle-nn", "ciborium", "hound", "rubato", @@ -901,7 +1292,7 @@ dependencies = [ [[package]] name = "rustpotter-cli" -version = "2.0.6" +version = "3.0.0" dependencies = [ "clap", "cpal", @@ -912,6 +1303,22 @@ dependencies = [ "time", ] +[[package]] +name = "ryu" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" + +[[package]] +name = "safetensors" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "same-file" version = "1.0.6" @@ -927,24 +1334,41 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" + [[package]] name = "serde" -version = "1.0.175" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d25439cd7397d044e2748a6fe2432b5e85db703d6d097bd014b3c0ad1ebff0b" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.175" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b23f7ade6f110613c0d63858ddb8b94c1041f550eab58a16b371bdf2c9c80ab4" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.27", + "syn 2.0.31", +] + +[[package]] +name = "serde_json" +version = "1.0.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +dependencies = [ + "itoa", + "ryu", + "serde", ] [[package]] @@ -959,12 +1383,6 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" -[[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" - [[package]] name = "strength_reduce" version = "0.2.4" @@ -990,9 +1408,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.27" +version = "2.0.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b60f673f44a8255b9c8c657daf66a596d435f2da81a555b06dc644d080ba45e0" +checksum = "718fa2415bcb8d8bd775917a1bf12a7931b6dfa890753378538118181e0cb398" dependencies = [ "proc-macro2", "quote", @@ -1001,9 +1419,9 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.7.0" +version = "3.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5486094ee78b2e5038a6382ed7645bc084dc2ec433426ca4c3cb61e2007b8998" +checksum = "cb94d2f3cc536af71caac6b6fcebf65860b347e7ce0cc9ebe8f70d3e521054ef" dependencies = [ "cfg-if", "fastrand", @@ -1014,30 +1432,31 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.44" +version = "1.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "611040a08a0439f8248d1990b111c95baa9c704c805fa1f62104b39655fd7f90" +checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.44" +version = "1.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "090198534930841fab3a5d1bb637cde49e339654e606195f8d9c76eeb081dc96" +checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" dependencies = [ "proc-macro2", "quote", - "syn 2.0.27", + "syn 2.0.31", ] [[package]] name = "time" -version = "0.3.23" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59e399c068f43a5d116fedaf73b203fa4f9c519f17e2b34f63221d3792f81446" +checksum = "17f6bb557fd245c28e6411aa56b6403c689ad95061f50e4be16c274e70a17e48" dependencies = [ + "deranged", "serde", "time-core", ] @@ -1103,6 +1522,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + [[package]] name = "wasm-bindgen" version = "0.2.87" @@ -1124,7 +1549,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.27", + "syn 2.0.31", "wasm-bindgen-shared", ] @@ -1158,7 +1583,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.27", + "syn 2.0.31", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -1225,7 +1650,7 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets 0.48.1", + "windows-targets 0.48.5", ] [[package]] @@ -1245,17 +1670,17 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.48.1" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05d4b17490f70499f20b9e791dcf6a299785ce8af4d709018206dc5b4953e95f" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm 0.48.0", - "windows_aarch64_msvc 0.48.0", - "windows_i686_gnu 0.48.0", - "windows_i686_msvc 0.48.0", - "windows_x86_64_gnu 0.48.0", - "windows_x86_64_gnullvm 0.48.0", - "windows_x86_64_msvc 0.48.0", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", ] [[package]] @@ -1266,9 +1691,9 @@ checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_msvc" @@ -1278,9 +1703,9 @@ checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" [[package]] name = "windows_aarch64_msvc" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_i686_gnu" @@ -1290,9 +1715,9 @@ checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" [[package]] name = "windows_i686_gnu" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_msvc" @@ -1302,9 +1727,9 @@ checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" [[package]] name = "windows_i686_msvc" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_x86_64_gnu" @@ -1314,9 +1739,9 @@ checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" [[package]] name = "windows_x86_64_gnu" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnullvm" @@ -1326,9 +1751,9 @@ checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" [[package]] name = "windows_x86_64_gnullvm" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_msvc" @@ -1338,15 +1763,26 @@ checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" [[package]] name = "windows_x86_64_msvc" -version = "0.48.0" +version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "winnow" -version = "0.5.1" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25b5872fa2e10bd067ae946f927e726d7d603eaeb6e02fa6a350e0722d2b8c11" +checksum = "7c2e3184b9c4e92ad5167ca73039d0c42476302ab603e2fec4487511f38ccefc" dependencies = [ "memchr", ] + +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "byteorder", + "crc32fast", + "crossbeam-utils", +] diff --git a/Cargo.toml b/Cargo.toml index 7adc84a..5d1a2f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rustpotter-cli" -version = "2.0.6" +version = "3.0.0" edition = "2021" license = "Apache-2.0" description = "CLI for Rustpotter, an open source wakeword spotter forged in rust." @@ -8,7 +8,7 @@ authors = ["Miguel Álvarez Díez "] repository = "https://github.com/GiviMAD/rustpotter" exclude = ["tools/**",".github",".gitignore"] [dependencies] -rustpotter = { version = "2.0.1", features = ["debug", "internals"] } +rustpotter = { version = "3.0.0", features = ["debug", "display", "audio", "record"] } ctrlc = "3.2.2" clap = { version = "4.1.6", features = ["derive"] } hound = "3.4.0" diff --git a/README.md b/README.md index 9bb514f..a672b4b 100644 --- a/README.md +++ b/README.md @@ -10,36 +10,19 @@ CLI for Rustpotter, an open source wakeword spotter forged in rust This is a client for using the [rustpotter](https://github.com/GiviMAD/rustpotter) library on the command line. -You can use it to record wav samples, build rustpotter models and tests them. +You can use it to record wav samples, create rustpotter wakeword files and test them. # Installation -You can download rustpotter-cli binaries for the supported platforms from the 'Assets' tab of the [releases](https://github.com/GiviMAD/rustpotter-cli/releases), it can be installed like any other executable. +Some pre-build executables for the supported platforms can be found on the 'Assets' tab of the [releases](https://github.com/GiviMAD/rustpotter-cli/releases). -For instance on debian it can be installed like: - -```bash -# This command print your arch, just in case you don't remember it. -$ uname -m -# Here I used the armv7l binary -$ curl -OL https://github.com/GiviMAD/rustpotter-cli/releases/download/v2.0.5/rustpotter-cli_debian_armv7l -# Make executable -$ chmod +x rustpotter-cli_debian_armv7l -# Check simple execution -$ ./rustpotter-cli_debian_armv7l --version -# Make available as ruspotter-cli -$ sudo mv ./rustpotter-cli_debian_armv7l /usr/local/bin/rustpotter-cli -``` - -# How to used it. +# Basic usage. ## Listing available audio input devices and formats. Your can list the available audio sources with the `devices` command, the `--configs` option can be added to display the default and available record formats for each source. -Host warnings are hidden by default, you can enable them by providing the `--host-warnings` option. - Every device and config has a numerical id to the left which is the one you can use on the other commands (`record` and `spot`) to change its audio source and format. @@ -67,7 +50,7 @@ Available Devices: ## Recording audio samples The `record` command allows to record audio samples. -You pass the id returned by the `devices` commands using the `--device-index` option to change the input device. +To use a different input device provide the `--device-index` argument with the id returned by the `devices` commands. You pass the configuration id returned by the `devices` commands using the `--config-index` option to change the audio format. Once executed you need to press the `Ctrl + c` key combination to finish the record. @@ -91,30 +74,76 @@ WAKEWORD_FILENAME="${WAKEWORD// /_}" for i in {0..9}; do (rustpotter-cli record $WAKEWORD_FILENAME$i.wav && sleep 1); done ``` -## Filter a file. +## Creating a Wakeword Model + +The `train` command allows to create wakeword models. + +It's required to setup a training and testing folders containing wav records which need to be tagged (contains [label] in its file name, where 'label' is the tag the network should predict for that audio segment) or untagged (equivalent to contain [none] on the filename). + +The size and cpu usage of a wakeword model is based on the model type you choose, and the audio duration it was trained on (which is defined by the max audio duration found on the training set). + +Example run: + +```sh +$ rustpotter-cli train -t small --train-dir train.wav/train --test-dir train.wav/test --test-epochs 10 --epochs 2500 -l 0.017 trained-small.rpw +Start training trained-small.rpw! +Model type: small. +Labels: ["none", "ok_casa"]. +Training with 2042 records. +Testing with 119 records. +Training on 1950ms of audio. + 10 train loss: 0.12944 test acc: 90.76% + 20 train loss: 0.06484 test acc: 93.28% + 30 train loss: 0.04454 test acc: 94.12% + 40 train loss: 0.03361 test acc: 94.12% + 50 train loss: 0.02687 test acc: 94.12% + 60 train loss: 0.02227 test acc: 94.12% + 70 train loss: 0.01916 test acc: 94.12% + 80 train loss: 0.01681 test acc: 94.12% + 90 train loss: 0.01499 test acc: 94.12% + 100 train loss: 0.01354 test acc: 94.12% + 110 train loss: 0.01232 test acc: 94.96% +... + 160 train loss: 0.00822 test acc: 94.96% + 170 train loss: 0.00766 test acc: 94.96% + 180 train loss: 0.00717 test acc: 95.80% + 190 train loss: 0.00673 test acc: 95.80% +... + 470 train loss: 0.00234 test acc: 95.80% + 480 train loss: 0.00229 test acc: 95.80% + 490 train loss: 0.00224 test acc: 96.64% + 500 train loss: 0.00219 test acc: 96.64% +... +1180 train loss: 0.00083 test acc: 96.64% +1190 train loss: 0.00082 test acc: 96.64% +1200 train loss: 0.00081 test acc: 97.48% +1210 train loss: 0.00081 test acc: 97.48% +... +2340 train loss: 0.00034 test acc: 97.48% +2350 train loss: 0.00034 test acc: 97.48% +2360 train loss: 0.00034 test acc: 98.32% +2370 train loss: 0.00033 test acc: 98.32% +... +2480 train loss: 0.00031 test acc: 98.32% +2490 train loss: 0.00031 test acc: 98.32% +2500 train loss: 0.00031 test acc: 98.32% +trained-small.rpw created! +``` -The available audio filters in rustpotter can be applied to a wav file using the `filter` command. -To enable the gain normalizer filter you can use the `--gain-normalizer` option. -To enable the band-pass filter you can use the `--band-pass` option. -To display the full command options you can run `rustpotter-cli filter -h`. +Note that you can obtain different results on different executions with the same training set as the initialization of the weights is not constant. -This is an example run on macOS: +To get a correct idea about the accuracy of the model, do not share records between the train and test folders. -```bash -$ rustpotter-cli filter test_noise.wav -g --gain-ref 0.005 -b --low-cutoff 1000 --high-cutoff 2000 -Creating new file test_noise-gain0.005-bandpass1000_2000.wav -``` +## Creating a Wakeword Reference -## Creating a wakeword model +The `build` command allows to create a wakeword reference file from some records. -The `build-model` command allows to create a wakeword file (also referred as model in this document). -You just need to provide the `--model-name` option (which defines the detection name), -the `--model-path` with the desired output path for the file, and a list of wav audio files. +This wakeword type requires a low number of records to be created but offers more inconsistent results than the wakeword models. -For example: +As an example example: ``` -rustpotter-cli build-model --model-name "ok home" --model-path ok_home.rpw ok_home1.wav ok_home2.wav +rustpotter-cli build --model-name "ok home" --model-path ok_home.rpw ok_home1.wav ok_home2.wav ``` This is an example run on macOS: @@ -122,7 +151,7 @@ This is an example run on macOS: ```bash $ WAKEWORD="ok home" $ WAKEWORD_FILENAME="${WAKEWORD// /_}" -$ rustpotter-cli build-model --model-name "$WAKEWORD" --model-path $WAKEWORD_FILENAME.rpw $WAKEWORD_FILENAME*.wav +$ rustpotter-cli build --model-name "$WAKEWORD" --model-path $WAKEWORD_FILENAME.rpw $WAKEWORD_FILENAME*.wav ok_home1.wav: WavSpec { channels: 2, sample_rate: 44100, bits_per_sample: 32, sample_format: Float } ok home created! ``` @@ -133,23 +162,20 @@ You can use the commands `spot` to test a model in real time using the available or `test_model` to do it against an audio file. Both expose similar options to make change from one to the other simpler. -So it's recommended to record an example file using the record command and try to tune the options there to then test those for real. +This way you can record an example record and tune the options there to then test those on real time. This is an example run on macOS: ```bash -$ rustpotter-cli test-model -g -b --low-cutoff 500 --high-cutoff 1500 ok_home_test.rpw test_noise.wav -Testing file test_noise.wav against model ok_home_test.rpw! +$ rustpotter-cli test -g --gain-ref 0.004 ok_home_test.rpw test_audio.wav +Testing file test_audio.wav against model ok_home_test.rpw! Wakeword detection: [11:06:11] RustpotterDetection { name: "ok_home_test", avg_score: 0.0, score: 0.5261932, scores: {"ok_home1-bandpass1000_2000.wav": 0.5261932}, counter: 12, gain: 0.9 } ``` -The more relevant options for the `spot` and `test-model` commands are: +The more relevant options for the `spot` and `test` commands are: * `-d` parameter enables the called 'debug mode' so you can see the partial detections. * `-t` sets the threshold value (defaults to 0.5). -* `-a` configures the `averaged threshold`, recommended as reduces the cpu usage. (set to half of the threshold or similar) * `-m 6` require at least 6 frames of positive scoring (compared against the detection `counter` field). -* `-s` the comparison strategy used, defines the way the detection score is calculated from the different scores. -* `-g` enables gain normalization. To debug the gain normalization you can use `--debug-gain`. -* `--gain-ref` changes the gain normalization reference (the default value is printed at the beginning when `--debug-gain` is provided, changes with the model) -* `-b --low-cutoff 500 --high-cutoff 1500` the band-pass filter configuration, helps to attenuate background noises. - +* `-e` enables the eager mode so detection is emitted as soon as possible (on min positive scores). +* `-g` enables gain normalization. To debug the gain normalization you can use `--debug-gain`, or look at the gain reflected on the detection. +* `--gain-ref` changes the gain normalization reference. (the default value is printed at the beginning when `--debug-gain` is provided, depends on the wakeword) \ No newline at end of file diff --git a/src/cli/build_model.rs b/src/cli/build.rs similarity index 61% rename from src/cli/build_model.rs rename to src/cli/build.rs index 03d8914..b725d0b 100644 --- a/src/cli/build_model.rs +++ b/src/cli/build.rs @@ -2,18 +2,18 @@ use std::{fs::File, io::BufReader}; use clap::Args; use hound::WavReader; -use rustpotter::Wakeword; +use rustpotter::{WakewordRef, WakewordRefBuildFromFiles, WakewordSave}; #[derive(Args, Debug)] -/// Creates a wakeword using RIFF wav audio files. +/// Creates a wakeword reference using wav audio files. #[clap()] -pub struct BuildModelCommand { +pub struct BuildCommand { #[clap(short = 'n', long)] /// The term emitted on the spot event - model_name: String, + name: String, #[clap(short = 'p', long)] /// Generated model path - model_path: String, + path: String, #[clap(num_args = 1.., required = true)] /// List of sample record paths sample_path: Vec, @@ -23,9 +23,12 @@ pub struct BuildModelCommand { #[clap(short = 'a', long)] /// Averaged threshold to configure in the generated model, overwrites the detector averaged threshold averaged_threshold: Option, + #[clap(short = 'c', long, default_value_t = 16)] + /// Number of extracted mel-frequency cepstral coefficients + mfcc_size: u16, } -pub fn build(command: BuildModelCommand) -> Result<(), String> { - println!("Start building {}!", command.model_path); +pub fn build_ref(command: BuildCommand) -> Result<(), String> { + println!("Start building {}!", command.path); println!("From samples:"); for path in &command.sample_path { let reader = BufReader::new(File::open(path).map_err(|err| err.to_string())?); @@ -34,23 +37,14 @@ pub fn build(command: BuildModelCommand) -> Result<(), String> { .spec(); println!("{}: {:?}", path, wav_spec); } - let wakeword = Wakeword::new_from_sample_files( - command.model_name.clone(), + let wakeword = WakewordRef::new_from_sample_files( + command.name.clone(), command.threshold, command.averaged_threshold, command.sample_path, + command.mfcc_size, )?; - match wakeword.save_to_file(&command.model_path) { - Ok(_) => { - println!("{} created!", command.model_name); - } - Err(error) => { - clap::Error::raw( - clap::error::ErrorKind::InvalidValue, - error.to_string() + "\n", - ) - .exit(); - } - }; + wakeword.save_to_file(&command.path)?; + println!("{} created!", command.name); Ok(()) } diff --git a/src/cli/devices.rs b/src/cli/devices.rs index 78c4910..29109fb 100644 --- a/src/cli/devices.rs +++ b/src/cli/devices.rs @@ -14,7 +14,7 @@ pub struct DevicesCommand { #[clap(long, short)] /// Filter device configs by max channel number max_channels: Option, - #[clap(short='w', long)] + #[clap(short = 'w', long)] /// Display host warnings host_warnings: bool, } @@ -25,13 +25,17 @@ pub fn devices(command: DevicesCommand) -> Result<(), String> { drop(stderr_gag); } println!("Audio hosts:\n - {:?}", default_host.id()); - let default_in = default_host.default_input_device().map(|e| e.name().unwrap()); + let default_in = default_host + .default_input_device() + .map(|e| e.name().unwrap()); if let Some(def_in) = default_in { println!("Default input device:\n - {}", def_in); } else { println!("No default input device"); } - let devices = default_host.input_devices().map_err(|err| err.to_string())?; + let devices = default_host + .input_devices() + .map_err(|err| err.to_string())?; println!("Available Devices: "); for (device_index, device) in devices.enumerate() { println!( diff --git a/src/cli/filter.rs b/src/cli/filter.rs index a1a030d..963dc6f 100644 --- a/src/cli/filter.rs +++ b/src/cli/filter.rs @@ -1,8 +1,8 @@ use clap::Args; -use hound::{SampleFormat, WavReader}; +use hound::WavReader; use rustpotter::{ - BandPassFilter, Endianness, GainNormalizerFilter, WAVEncoder, WavFmt, - DETECTOR_INTERNAL_SAMPLE_RATE, FEATURE_EXTRACTOR_FRAME_LENGTH_MS, + BandPassFilter, GainNormalizerFilter, Sample, SampleFormat, AudioEncoder, AudioFmt, + DETECTOR_INTERNAL_SAMPLE_RATE, MFCCS_EXTRACTOR_FRAME_LENGTH_MS, }; use std::{fs::File, io::BufReader, path::Path}; @@ -60,17 +60,11 @@ pub fn filter(command: FilterCommand) -> Result<(), String> { // Read wav file let file_reader = BufReader::new(File::open(command.sample_path).map_err(|err| err.to_string())?); - let wav_reader = WavReader::new(file_reader).map_err(|err| err.to_string())?; - let wav_spec = WavFmt { - sample_rate: wav_reader.spec().sample_rate as usize, - sample_format: wav_reader.spec().sample_format, - bits_per_sample: wav_reader.spec().bits_per_sample, - channels: wav_reader.spec().channels, - endianness: Endianness::Little, - }; - let mut encoder = WAVEncoder::new( + let mut wav_reader = WavReader::new(file_reader).map_err(|err| err.to_string())?; + let wav_spec: AudioFmt = wav_reader.spec().try_into()?; + let mut encoder = AudioEncoder::new( &wav_spec, - FEATURE_EXTRACTOR_FRAME_LENGTH_MS, + MFCCS_EXTRACTOR_FRAME_LENGTH_MS, DETECTOR_INTERNAL_SAMPLE_RATE, ) .unwrap(); @@ -87,38 +81,41 @@ pub fn filter(command: FilterCommand) -> Result<(), String> { command.low_cutoff, command.high_cutoff, ); - if wav_reader.spec().sample_format == SampleFormat::Float { - wav_reader - .into_samples::() - .map(|chunk| *chunk.as_ref().unwrap()) - .collect::>() - .chunks_exact(encoder.get_input_frame_length()) - .map(|chuck| encoder.reencode_float(chuck)) - .collect::>>() - } else { - wav_reader - .into_samples::() - .map(|chunk| *chunk.as_ref().unwrap()) - .collect::>() - .chunks_exact(encoder.get_input_frame_length()) - .map(|chuck| encoder.reencode_int(chuck)) - .collect::>>() - }.into_iter() - .map(|mut chunk| { - if command.gain_normalizer { - let rms_level = GainNormalizerFilter::get_rms_level(&chunk); - gain_filter.filter(&mut chunk, rms_level); - } - if command.band_pass { - bandpass_filter.filter(&mut chunk); - } - chunk - }) - .for_each(|encoded_chunk| { - for sample in encoded_chunk { - writer.write_sample(sample).ok(); - } - }); + match wav_spec.sample_format { + SampleFormat::I8 => get_encoded_chucks::(&mut wav_reader, &mut encoder), + SampleFormat::I16 => get_encoded_chucks::(&mut wav_reader, &mut encoder), + SampleFormat::I32 => get_encoded_chucks::(&mut wav_reader, &mut encoder), + SampleFormat::F32 => get_encoded_chucks::(&mut wav_reader, &mut encoder), + } + .into_iter() + .map(|mut chunk| { + if command.gain_normalizer { + let rms_level = GainNormalizerFilter::get_rms_level(&chunk); + gain_filter.filter(&mut chunk, rms_level); + } + if command.band_pass { + bandpass_filter.filter(&mut chunk); + } + chunk + }) + .for_each(|encoded_chunk| { + for sample in encoded_chunk { + writer.write_sample(sample).ok(); + } + }); writer.finalize().expect("Unable to save file"); Ok(()) } + +fn get_encoded_chucks( + wav_reader: &mut WavReader>, + encoder: &mut AudioEncoder, +) -> Vec> { + wav_reader + .samples::() + .map(|chunk| *chunk.as_ref().unwrap()) + .collect::>() + .chunks_exact(encoder.get_input_frame_length()) + .map(|chuck| encoder.rencode_and_resample(chuck.to_vec())) + .collect::>>() +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs index eba5fcc..9d16f8f 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -1,52 +1,83 @@ use clap::{Parser, Subcommand}; -mod build_model; +mod build; mod devices; +mod filter; mod record; mod spot; -mod test_model; -mod filter; +mod test; +mod train; use self::{ - build_model::{build, BuildModelCommand}, + build::{build_ref, BuildCommand}, devices::{devices, DevicesCommand}, + filter::{filter, FilterCommand}, record::{record, RecordCommand}, spot::{spot, SpotCommand}, - test_model::{test, TestModelCommand}, - filter::{filter, FilterCommand}, + test::{test, TestCommand}, + train::{train, TrainCommand}, }; #[derive(Parser, Debug)] /// CLI for RustPotter: an open source wakeword spotter forged in rust #[clap(author, version, about, long_about = None, arg_required_else_help = true)] -struct CLI { +struct Cli { #[clap(subcommand)] - command: Option, + command: Option, } #[derive(Subcommand, Debug)] -enum Commands { - /// Build wakeword model from wav audio files - BuildModel(BuildModelCommand), - /// List audio devices and configurations +enum Command { + /// Build wakeword reference from wav audio files. + /// + /// These wakewords offers worst quality detection than the wakeword models but requires low number of records (recommended 3 to 8). + /// + /// The file size and the cpu consumption depends on the number of sample files for built it. + /// + Build(BuildCommand), + /// Train wakeword model from wav audio files. + /// + /// This wakeword type requires a more sample files to be created, but produces high fidelity results. + /// + /// The file size and the cpu consumption depends on the model type and the duration on the longer audio sample on the training folder. + /// + /// It's required to setup a train and test folders containing wav files labeled as something (for example "[ok_casa]20:44:32.wav") and + /// others without any tag ("20:46:32.wav.wav" and "[none]20:46:32.wav" is equivalent). + /// + /// The will train a basic classification neural network for the available labels, that the tool can use emit detections when a label other + /// than "none" is predicted. + /// + /// The weight initialization is not fixed and can produce different results per execution but the + /// + /// Tested with a training set of 155 affirmative samples and 1355 noise/ambient samples over a test set of 108 samples. + /// I obtain a round 96% of accuracy using the different model types, and all work nice in real live, + /// the small and medium models can require setting a higher threshold or min partial detections to avoid false detections, + /// but other than that all seems to be reliable. + /// + Train(TrainCommand), + /// List available audio devices and configurations + /// + /// Useful in order to know how to configure the input and format + /// for the "record" and "spot" commands. Devices(DevicesCommand), - /// Apply available filters to a wav audio file. + /// Apply available filters to a wav audio file Filter(FilterCommand), /// Record wav audio file Record(RecordCommand), /// Spot wakewords in real time Spot(SpotCommand), /// Spot wakewords against a wav file - TestModel(TestModelCommand), + Test(TestCommand), } pub(crate) fn run_cli() { - let cli = CLI::parse(); + let cli = Cli::parse(); match cli.command.unwrap() { - Commands::Record(command) => record(command), - Commands::Filter(command) => filter(command), - Commands::BuildModel(command) => build(command), - Commands::TestModel(command) => test(command), - Commands::Spot(command) => spot(command), - Commands::Devices(command) => devices(command), + Command::Build(command) => build_ref(command), + Command::Devices(command) => devices(command), + Command::Filter(command) => filter(command), + Command::Record(command) => record(command), + Command::Spot(command) => spot(command), + Command::Test(command) => test(command), + Command::Train(command) => train(command), } .expect("Command failed"); } diff --git a/src/cli/record.rs b/src/cli/record.rs index 367cb8a..5b9ea16 100644 --- a/src/cli/record.rs +++ b/src/cli/record.rs @@ -1,10 +1,11 @@ use std::fs::File; use std::io::BufWriter; +use std::sync::mpsc::Sender; use std::sync::{mpsc, Arc, Mutex}; use clap::Args; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; -use cpal::{FromSample, Sample, SampleRate}; +use cpal::{FromSample, Sample, SampleRate, SizedSample}; use gag::Gag; #[derive(Args, Debug)] /// Record wav audio @@ -19,7 +20,7 @@ pub struct RecordCommand { #[clap(short, long)] /// Input device configuration index used for record. config_index: Option, - #[clap(short='w', long)] + #[clap(short = 'w', long)] /// Display host warnings host_warnings: bool, #[clap(long, default_value_t = 16000)] @@ -28,6 +29,9 @@ pub struct RecordCommand { #[clap(short, long, default_value_t = 1.)] /// Adjust the recording volume. value > 1.0 amplifies, value < 1.0 attenuates gain: f32, + #[clap(long = "ms")] + /// Max record duration in milliseconds + duration_ms: Option, } pub fn record(command: RecordCommand) -> Result<(), String> { let mut stderr_gag = None; @@ -47,51 +51,66 @@ pub fn record(command: RecordCommand) -> Result<(), String> { device.name().map_err(|err| err.to_string())? ); let device_config = get_config(command.config_index, &device, command.sample_rate); - println!("Input device config: Sample Rate: {}, Channels: {}, Format: {}", device_config.sample_rate().0, device_config.channels(), device_config.sample_format()); + println!( + "Input device config: Sample Rate: {}, Channels: {}, Format: {}", + device_config.sample_rate().0, + device_config.channels(), + device_config.sample_format() + ); // disable gag after device config - if stderr_gag.is_some() { - drop(stderr_gag.unwrap()); + if let Some(stderr_gag) = stderr_gag { + drop(stderr_gag); } // Create wav spec let spec = wav_spec_from_config(&device_config); - let writer = hound::WavWriter::create(command.output_path.to_string(), spec).unwrap(); + let writer = hound::WavWriter::create(&command.output_path, spec).unwrap(); let writer = Arc::new(Mutex::new(Some(writer))); println!("Begin recording..."); // Run the input stream on a separate thread. let writer_2 = writer.clone(); - let err_fn = move |err| { - eprintln!("an error occurred on stream: {}", err); - }; - let err_cb = move |err: cpal::BuildStreamError| err.to_string(); + let (tx, rx) = mpsc::channel(); + let remaining_samples = command + .duration_ms + .map(|ms| ((spec.sample_rate as f32 / 1000.) * (ms as f32) * spec.channels as f32) as u64); let stream = match device_config.sample_format() { - cpal::SampleFormat::I16 => device - .build_input_stream( - &device_config.into(), - move |data, _: &_| write_input_data::(data, &writer_2, command.gain), - err_fn, - None, - ) - .map_err(err_cb)?, - cpal::SampleFormat::I32 => device - .build_input_stream( - &device_config.into(), - move |data, _: &_| write_input_data::(data, &writer_2, command.gain), - err_fn, - None, - ) - .map_err(err_cb)?, - cpal::SampleFormat::F32 => device - .build_input_stream( - &device_config.into(), - move |data, _: &_| write_input_data::(data, &writer_2, command.gain), - err_fn, - None, - ) - .map_err(err_cb)?, + cpal::SampleFormat::I8 => new_record_stream::( + &device, + device_config, + writer_2, + &tx, + command.gain, + remaining_samples, + )?, + cpal::SampleFormat::I16 => new_record_stream::( + &device, + device_config, + writer_2, + &tx, + command.gain, + remaining_samples, + )?, + cpal::SampleFormat::I32 => new_record_stream::( + &device, + device_config, + writer_2, + &tx, + command.gain, + remaining_samples, + )?, + cpal::SampleFormat::F32 => new_record_stream::( + &device, + device_config, + writer_2, + &tx, + command.gain, + remaining_samples, + )?, _ => return Err("Only support sample formats: i16, i32, f32".to_string())?, }; stream.play().expect("Unable to record"); - let (tx, rx) = mpsc::channel(); + if let Some(duration_ms) = command.duration_ms { + println!("Stopping in {}ms.", duration_ms); + } ctrlc::set_handler(move || tx.send(()).expect("Could not send signal on channel.")) .expect("Unable to listen keyboard"); println!("Press 'Ctrl + c' to stop."); @@ -108,20 +127,61 @@ pub fn record(command: RecordCommand) -> Result<(), String> { Ok(()) } +fn new_record_stream( + device: &cpal::Device, + device_config: cpal::SupportedStreamConfig, + writer_2: Arc>>>>, + tx: &Sender<()>, + gain: f32, + mut remaining_samples: Option, +) -> Result +where + T: Sample + SizedSample, + U: Sample + hound::Sample + FromSample, +{ + let err_fn = move |err| { + eprintln!("an error occurred on stream: {}", err); + }; + let err_cb = move |err: cpal::BuildStreamError| err.to_string(); + let tx_clone = tx.clone(); + device + .build_input_stream( + &device_config.into(), + move |data, _: &_| { + write_input_data::(data, &writer_2, gain, &tx_clone, &mut remaining_samples) + }, + err_fn, + None, + ) + .map_err(err_cb) +} + fn write_input_data( - input: &[T], + data: &[T], writer: &Arc>>>>, gain: f32, + tx: &Sender<()>, + remaining_samples: &mut Option, ) where T: Sample, U: Sample + hound::Sample + FromSample, { + if remaining_samples.is_some() && remaining_samples.as_ref().unwrap().eq(&0) { + return; + } if let Ok(mut guard) = writer.try_lock() { if let Some(writer) = guard.as_mut() { let gain_sample = Sample::from_sample(gain); - for &sample in input.iter() { + for &sample in data.iter() { let sample: U = U::from_sample(sample.mul_amp(gain_sample)); writer.write_sample(sample).ok(); + if let Some(remaining_samples) = remaining_samples.as_mut() { + *remaining_samples -= 1; + if *remaining_samples == 0 { + tx.send(()).ok(); + break; + } + } } } } @@ -143,12 +203,10 @@ fn wav_spec_from_config(config: &cpal::SupportedStreamConfig) -> hound::WavSpec // cpal utils for selecting input device and stream config pub(crate) fn is_compatible_format(format: &cpal::SampleFormat) -> bool { - match format { - cpal::SampleFormat::I16 => true, - cpal::SampleFormat::I32 => true, - cpal::SampleFormat::F32 => true, - _ => false, - } + matches!( + format, + cpal::SampleFormat::I16 | cpal::SampleFormat::I32 | cpal::SampleFormat::F32 + ) } pub(crate) fn is_compatible_buffer_size( supported_buffer_size: &cpal::SupportedBufferSize, @@ -214,7 +272,7 @@ fn try_get_config_with_sample_rate( } pub(crate) fn get_device(device_index: Option, host: cpal::Host) -> cpal::Device { - let device = device_index + device_index .map_or_else( || host.default_input_device(), |device_index| { @@ -224,6 +282,5 @@ pub(crate) fn get_device(device_index: Option, host: cpal::Host) -> cpal: .find_map(|(i, d)| if i == device_index { Some(d) } else { None }) }, ) - .expect("Failed to find input device"); - device + .expect("Failed to find input device") } diff --git a/src/cli/spot.rs b/src/cli/spot.rs index 5151e25..916f3bd 100644 --- a/src/cli/spot.rs +++ b/src/cli/spot.rs @@ -2,9 +2,14 @@ use std::{sync::mpsc, time::SystemTime}; use crate::cli::record::{self, is_compatible_buffer_size}; use clap::Args; -use cpal::traits::{DeviceTrait, StreamTrait}; +use cpal::{ + traits::{DeviceTrait, StreamTrait}, + SizedSample, +}; use gag::Gag; -use rustpotter::{Rustpotter, RustpotterConfig, RustpotterDetection, ScoreMode}; +use rustpotter::{ + Rustpotter, RustpotterConfig, RustpotterDetection, Sample, SampleFormat, ScoreMode, VADMode, +}; use time::OffsetDateTime; #[derive(Args, Debug)] @@ -41,9 +46,15 @@ pub struct SpotCommand { #[clap(short, long, default_value_t = 10)] /// Minimum number of partial detections min_scores: usize, - #[clap(short = 's', long, default_value_t = ClapScoreMode::Max)] + #[clap(short, long)] + /// Emit detection on min scores. + eager: bool, + #[clap(short = 's', long, default_value_t = ScoreMode::Max)] /// How to calculate a unified score - score_mode: ClapScoreMode, + score_mode: ScoreMode, + #[clap(short = 'v', long)] + /// Enabled vad detection. + vad_mode: Option, #[clap(short = 'g', long)] /// Enables a gain-normalizer audio filter. gain_normalizer: bool, @@ -66,18 +77,18 @@ pub struct SpotCommand { #[clap(long, default_value_t = 400.)] /// Band-pass audio filter high cutoff. high_cutoff: f32, - #[clap(long, default_value_t = 5)] - /// Band size of the comparison. (Advanced) - comparator_band_size: u16, #[clap(long, default_value_t = 0.22)] /// Used to express the result as a probability. (Advanced) - comparator_ref: f32, + score_ref: f32, #[clap(short, long)] /// Log partial detections. debug: bool, #[clap(long)] /// Log rms level ref, gain applied per frame and frame rms level. debug_gain: bool, + #[clap(short, long)] + /// Path to create records, one on the first partial detection and another each one that scores better. + record_path: Option, } pub fn spot(command: SpotCommand) -> Result<(), String> { @@ -105,25 +116,28 @@ pub fn spot(command: SpotCommand) -> Result<(), String> { device_config.sample_format() ); // disable gag after device config - if stderr_gag.is_some() { - drop(stderr_gag.unwrap()); + if let Some(stderr_gag) = stderr_gag { + drop(stderr_gag); } + let bits_per_sample = (device_config.sample_format().sample_size() * 8) as u16; // configure rustpotter let mut config = RustpotterConfig::default(); config.fmt.sample_rate = device_config.sample_rate().0 as _; - config.fmt.bits_per_sample = (device_config.sample_format().sample_size() * 8) as _; config.fmt.channels = device_config.channels(); config.fmt.sample_format = if device_config.sample_format().is_float() { - hound::SampleFormat::Float + SampleFormat::float_of_size(bits_per_sample) } else { - hound::SampleFormat::Int - }; + SampleFormat::int_of_size(bits_per_sample) + } + .expect("Unsupported wav format"); config.detector.avg_threshold = command.averaged_threshold; config.detector.threshold = command.threshold; config.detector.min_scores = command.min_scores; - config.detector.score_mode = command.score_mode.into(); - config.detector.comparator_band_size = command.comparator_band_size; - config.detector.comparator_ref = command.comparator_ref; + config.detector.eager = command.eager; + config.detector.score_mode = command.score_mode; + config.detector.score_ref = command.score_ref; + config.detector.vad_mode = command.vad_mode; + config.detector.record_path = command.record_path; config.filters.gain_normalizer.enabled = command.gain_normalizer; config.filters.gain_normalizer.gain_ref = command.gain_ref; config.filters.gain_normalizer.min_gain = command.min_gain; @@ -144,9 +158,9 @@ pub fn spot(command: SpotCommand) -> Result<(), String> { .unwrap_or(rustpotter.get_samples_per_frame() as u32); if host_name == "ALSA" && required_buffer_size % 2 != 0 { // force even buffer size to workaround issue mentioned here https://github.com/RustAudio/cpal/pull/582#pullrequestreview-1095655011 - required_buffer_size = required_buffer_size + 1; + required_buffer_size += 1; } - if !is_compatible_buffer_size(&device_config.buffer_size(), required_buffer_size) { + if !is_compatible_buffer_size(device_config.buffer_size(), required_buffer_size) { clap::Error::raw( clap::error::ErrorKind::Io, "Required buffer size does not matches device configuration, try selecting other.\n", @@ -158,10 +172,8 @@ pub fn spot(command: SpotCommand) -> Result<(), String> { None }; for path in command.model_path { - let result = rustpotter.add_wakeword_from_file(&path); - if let Err(error) = result { - clap::Error::raw(clap::error::ErrorKind::InvalidValue, error + "\n").exit(); - } + println!("Loading wakeword file: {}", path); + rustpotter.add_wakeword_from_file("w", &path)?; } if command.debug_gain { println!( @@ -170,94 +182,52 @@ pub fn spot(command: SpotCommand) -> Result<(), String> { ); } println!("Begin recording..."); - let err_fn = move |err| { - eprintln!("an error occurred on stream: {}", err); - }; - let err_cb = move |err: cpal::BuildStreamError| err.to_string(); let stream_config = cpal::StreamConfig { channels: device_config.channels(), sample_rate: device_config.sample_rate(), buffer_size: required_buffer_size - .map_or(cpal::BufferSize::Default, |v| cpal::BufferSize::Fixed(v)), + .map_or(cpal::BufferSize::Default, cpal::BufferSize::Fixed), }; if command.debug { println!("Audio stream config: {:?}", stream_config); } - let mut partial_detection_counter = 0; - let mut buffer_i16: Vec = Vec::new(); - let mut buffer_i32: Vec = Vec::new(); - let mut buffer_f32: Vec = Vec::new(); - let rustpotter_samples_per_frame = rustpotter.get_samples_per_frame(); + let buffer_i8: Vec = Vec::new(); + let buffer_i16: Vec = Vec::new(); + let buffer_i32: Vec = Vec::new(); + let buffer_f32: Vec = Vec::new(); let stream = match device_config.sample_format() { - cpal::SampleFormat::I16 => device - .build_input_stream( - &stream_config, - move |data: &[i16], _: &_| { - buffer_i16.extend_from_slice(data); - while buffer_i16.len() >= rustpotter_samples_per_frame { - let detection = rustpotter.process_i16( - buffer_i16.drain(0..rustpotter_samples_per_frame).as_slice(), - ); - print_detection( - &rustpotter, - detection, - &mut partial_detection_counter, - command.debug, - command.debug_gain, - get_time_string, - ); - } - }, - err_fn, - None, - ) - .map_err(err_cb)?, - cpal::SampleFormat::I32 => device - .build_input_stream( - &stream_config, - move |data: &[i32], _: &_| { - buffer_i32.extend_from_slice(data); - while buffer_i32.len() >= rustpotter_samples_per_frame { - let detection = rustpotter.process_i32( - &buffer_i32.drain(0..rustpotter_samples_per_frame).as_slice(), - ); - print_detection( - &rustpotter, - detection, - &mut partial_detection_counter, - command.debug, - command.debug_gain, - get_time_string, - ); - } - }, - err_fn, - None, - ) - .map_err(err_cb)?, - cpal::SampleFormat::F32 => device - .build_input_stream( - &stream_config, - move |data: &[f32], _: &_| { - buffer_f32.extend_from_slice(data); - while buffer_f32.len() >= rustpotter_samples_per_frame { - let detection = rustpotter.process_f32( - buffer_f32.drain(0..rustpotter_samples_per_frame).as_slice(), - ); - print_detection( - &rustpotter, - detection, - &mut partial_detection_counter, - command.debug, - command.debug_gain, - get_time_string, - ); - } - }, - err_fn, - None, - ) - .map_err(err_cb)?, + cpal::SampleFormat::I8 => init_spot_stream( + &device, + &stream_config, + rustpotter, + buffer_i8, + command.debug, + command.debug_gain, + )?, + cpal::SampleFormat::I16 => init_spot_stream( + &device, + &stream_config, + rustpotter, + buffer_i16, + command.debug, + command.debug_gain, + )?, + cpal::SampleFormat::I32 => init_spot_stream( + &device, + &stream_config, + rustpotter, + buffer_i32, + command.debug, + command.debug_gain, + )?, + cpal::SampleFormat::F32 => init_spot_stream( + &device, + &stream_config, + rustpotter, + buffer_f32, + command.debug, + command.debug_gain, + )?, _ => return Err("Only support sample formats: i16, i32, f32".to_string())?, }; stream.play().expect("Unable to record"); @@ -271,6 +241,63 @@ pub fn spot(command: SpotCommand) -> Result<(), String> { Ok(()) } +fn init_spot_stream( + device: &cpal::Device, + stream_config: &cpal::StreamConfig, + mut rustpotter: Rustpotter, + mut buffer: Vec, + debug: bool, + debug_gain: bool, +) -> Result { + let error_callback = move |err| { + eprintln!("an error occurred on stream: {}", err); + }; + let mut partial_detection_counter = 0; + let rustpotter_samples_per_frame = rustpotter.get_samples_per_frame(); + let data_callback = move |data: &[S], _: &_| { + run_detection( + &mut rustpotter, + data, + &mut buffer, + rustpotter_samples_per_frame, + &mut partial_detection_counter, + debug, + debug_gain, + ) + }; + device + .build_input_stream(stream_config, data_callback, error_callback, None) + .map_err(|err: cpal::BuildStreamError| err.to_string()) +} + +fn run_detection( + rustpotter: &mut Rustpotter, + data: &[T], + buffer: &mut Vec, + rustpotter_samples_per_frame: usize, + partial_detection_counter: &mut usize, + debug: bool, + debug_gain: bool, +) { + buffer.extend_from_slice(data); + while buffer.len() >= rustpotter_samples_per_frame { + let detection = rustpotter.process_samples( + buffer + .drain(0..rustpotter_samples_per_frame) + .as_slice() + .into(), + ); + print_detection( + &*rustpotter, + detection, + partial_detection_counter, + debug, + debug_gain, + get_time_string, + ); + } +} + pub(crate) fn print_detection( rustpotter: &Rustpotter, detection: Option, @@ -308,49 +335,6 @@ pub(crate) fn print_detection( ), }; } - -#[derive(clap::ValueEnum, Clone, Debug)] -pub(crate) enum ClapScoreMode { - Max, - Avg, - Median, - P25, - P50, - P75, - P80, - P90, - P95, -} -impl std::fmt::Display for ClapScoreMode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ClapScoreMode::Avg => write!(f, "avg"), - ClapScoreMode::Max => write!(f, "max"), - ClapScoreMode::Median => write!(f, "median"), - ClapScoreMode::P25 => write!(f, "p25"), - ClapScoreMode::P50 => write!(f, "p50"), - ClapScoreMode::P75 => write!(f, "p75"), - ClapScoreMode::P80 => write!(f, "p80"), - ClapScoreMode::P90 => write!(f, "p90"), - ClapScoreMode::P95 => write!(f, "p95"), - } - } -} -impl From for ScoreMode { - fn from(value: ClapScoreMode) -> Self { - match value { - ClapScoreMode::Avg => ScoreMode::Average, - ClapScoreMode::Max => ScoreMode::Max, - ClapScoreMode::Median => ScoreMode::Median, - ClapScoreMode::P25 => ScoreMode::P25, - ClapScoreMode::P50 => ScoreMode::P50, - ClapScoreMode::P75 => ScoreMode::P75, - ClapScoreMode::P80 => ScoreMode::P80, - ClapScoreMode::P90 => ScoreMode::P90, - ClapScoreMode::P95 => ScoreMode::P95, - } - } -} fn get_time_string() -> String { let dt: OffsetDateTime = SystemTime::now().into(); format!("{:02}:{:02}:{:02}", dt.hour(), dt.minute(), dt.second()) diff --git a/src/cli/test.rs b/src/cli/test.rs new file mode 100644 index 0000000..da00fdc --- /dev/null +++ b/src/cli/test.rs @@ -0,0 +1,187 @@ +use clap::Args; +use hound::{SampleFormat, WavReader}; +use rustpotter::{Rustpotter, RustpotterConfig, Sample, ScoreMode, VADMode}; +use std::{fs::File, io::BufReader}; + +use super::spot::print_detection; + +#[derive(Args, Debug)] +/// Test wakeword file against a wav sample, detector is automatically configured according to the sample spec +#[clap()] +pub struct TestCommand { + #[clap()] + /// Model to test. + model_path: String, + #[clap()] + /// Wav record to test. + sample_path: String, + #[clap(short, long, default_value_t = 0.5)] + /// Default detection threshold, only applies to models without threshold. + threshold: f32, + #[clap(short, long, default_value_t = 0.2)] + /// Default detection averaged threshold, only applies to models without averaged threshold. + averaged_threshold: f32, + #[clap(short, long, default_value_t = 10)] + /// Minimum number of partial detections + min_scores: usize, + #[clap(short, long)] + /// Emit detection on min scores. + eager: bool, + #[clap(short = 's', long, default_value_t = ScoreMode::Max)] + /// How to calculate a unified score, no applies to wakeword models. + score_mode: ScoreMode, + #[clap(short = 'v', long)] + /// Enabled vad detection. + vad_mode: Option, + #[clap(short = 'g', long)] + /// Enables a gain-normalizer audio filter. + gain_normalizer: bool, + #[clap(long, default_value_t = 0.1)] + /// Min gain applied by the gain-normalizer filter. + min_gain: f32, + #[clap(long, default_value_t = 1.)] + /// Max gain applied by the gain-normalizer filter. + max_gain: f32, + #[clap(long)] + /// Set the rms level reference used by the gain normalizer filter. + /// If unset the max wakeword rms level is used. + gain_ref: Option, + #[clap(short, long)] + /// Enables a band-pass audio filter. + band_pass: bool, + #[clap(long, default_value_t = 80.)] + /// Band-pass audio filter low cutoff. + low_cutoff: f32, + #[clap(long, default_value_t = 400.)] + /// Band-pass audio filter high cutoff. + high_cutoff: f32, + #[clap(long, default_value_t = 0.22)] + /// Used to express the score as value in range 0 - 1. + score_ref: f32, + #[clap(short, long)] + /// Log partial detections. + debug: bool, + #[clap(long)] + /// Log rms level ref, gain applied per frame and frame rms level. + debug_gain: bool, + #[clap(short, long)] + /// Path to create records, one on the first partial detection and another each one that scores better. + record_path: Option, +} +pub fn test(command: TestCommand) -> Result<(), String> { + println!( + "Testing file {} against model {}!", + command.sample_path, command.model_path, + ); + // Read wav file + let file_reader = + BufReader::new(File::open(command.sample_path).map_err(|err| err.to_string())?); + let mut wav_reader = WavReader::new(file_reader).map_err(|err| err.to_string())?; + let wav_specs = wav_reader.spec(); + let mut config = RustpotterConfig::default(); + let sample_rate = wav_specs.sample_rate as usize; + config.fmt = wav_specs.try_into()?; + config.detector.avg_threshold = command.averaged_threshold; + config.detector.threshold = command.threshold; + config.detector.min_scores = command.min_scores; + config.detector.eager = command.eager; + config.detector.score_mode = command.score_mode; + config.detector.score_ref = command.score_ref; + config.detector.vad_mode = command.vad_mode; + config.detector.record_path = command.record_path; + config.filters.gain_normalizer.enabled = command.gain_normalizer; + config.filters.gain_normalizer.gain_ref = command.gain_ref; + config.filters.gain_normalizer.min_gain = command.min_gain; + config.filters.gain_normalizer.max_gain = command.max_gain; + config.filters.band_pass.enabled = command.band_pass; + config.filters.band_pass.low_cutoff = command.low_cutoff; + config.filters.band_pass.high_cutoff = command.high_cutoff; + if command.debug { + println!("Rustpotter config:\n{:?}", config); + } + let mut rustpotter = Rustpotter::new(&config)?; + println!("Loading wakeword file: {}", command.model_path); + rustpotter.add_wakeword_from_file("_", &command.model_path)?; + let mut partial_detection_counter = 0; + let mut chunk_counter = 0; + match wav_specs.sample_format { + SampleFormat::Int => match wav_specs.bits_per_sample { + 8 => run_detection::( + &mut wav_reader, + &mut rustpotter, + &mut chunk_counter, + &mut partial_detection_counter, + sample_rate, + command.debug, + command.debug_gain, + ), + 16 => run_detection::( + &mut wav_reader, + &mut rustpotter, + &mut chunk_counter, + &mut partial_detection_counter, + sample_rate, + command.debug, + command.debug_gain, + ), + 32 => run_detection::( + &mut wav_reader, + &mut rustpotter, + &mut chunk_counter, + &mut partial_detection_counter, + sample_rate, + command.debug, + command.debug_gain, + ), + _ => panic!("Unsupported wav format"), + }, + SampleFormat::Float => match wav_specs.bits_per_sample { + 32 => run_detection::( + &mut wav_reader, + &mut rustpotter, + &mut chunk_counter, + &mut partial_detection_counter, + sample_rate, + command.debug, + command.debug_gain, + ), + _ => panic!("Unsupported wav format"), + }, + }; + Ok(()) +} + +fn run_detection( + wav_reader: &mut WavReader>, + rustpotter: &mut Rustpotter, + chunk_counter: &mut usize, + partial_detection_counter: &mut usize, + sample_rate: usize, + debug: bool, + debug_gain: bool, +) { + let chunk_size = rustpotter.get_samples_per_frame(); + let mut buffer = wav_reader + .samples::() + .map(Result::unwrap) + .collect::>(); + buffer.append(&mut vec![T::get_zero(); chunk_size * 100]); + buffer.chunks_exact(chunk_size).for_each(|chunk| { + *chunk_counter += 1; + let detection = rustpotter.process_samples(chunk.into()); + print_detection( + rustpotter, + detection, + partial_detection_counter, + debug, + debug_gain, + || get_time_string(*chunk_counter, chunk_size, sample_rate), + ); + }); +} +fn get_time_string(chunk_number: usize, chunk_size: usize, sample_rate: usize) -> String { + let total_seconds = (chunk_number * chunk_size) as f32 / sample_rate as f32; + let minutes = (total_seconds / 60.).floor() as i32; + let seconds = (total_seconds % 60.).floor() as i32; + format!("00:{:02}:{:02}", minutes, seconds) +} diff --git a/src/cli/test_model.rs b/src/cli/test_model.rs deleted file mode 100644 index 05046b8..0000000 --- a/src/cli/test_model.rs +++ /dev/null @@ -1,155 +0,0 @@ -use clap::Args; -use hound::{SampleFormat, WavReader}; -use rustpotter::{Rustpotter, RustpotterConfig}; -use std::{fs::File, io::BufReader}; - -use super::spot::{print_detection, ClapScoreMode}; - -#[derive(Args, Debug)] -/// Test model file against a wav sample, detector is automatically configured according to the sample spec -#[clap()] -pub struct TestModelCommand { - #[clap()] - /// Model to test. - model_path: String, - #[clap()] - /// Wav record to test. - sample_path: String, - #[clap(short, long, default_value_t = 0.5)] - /// Default detection threshold, only applies to models without threshold. - threshold: f32, - #[clap(short, long, default_value_t = 0.2)] - /// Default detection averaged threshold, only applies to models without averaged threshold. - averaged_threshold: f32, - #[clap(short, long, default_value_t = 10)] - /// Minimum number of partial detections - min_scores: usize, - #[clap(short = 's', long, default_value_t = ClapScoreMode::Max)] - /// How to calculate a unified score - score_mode: ClapScoreMode, - #[clap(short = 'g', long)] - /// Enables a gain-normalizer audio filter. - gain_normalizer: bool, - #[clap(long, default_value_t = 0.1)] - /// Min gain applied by the gain-normalizer filter. - min_gain: f32, - #[clap(long, default_value_t = 1.)] - /// Max gain applied by the gain-normalizer filter. - max_gain: f32, - #[clap(long)] - /// Set the rms level reference used by the gain normalizer filter. - /// If unset the max wakeword rms level is used. - gain_ref: Option, - #[clap(short, long)] - /// Enables a band-pass audio filter. - band_pass: bool, - #[clap(long, default_value_t = 80.)] - /// Band-pass audio filter low cutoff. - low_cutoff: f32, - #[clap(long, default_value_t = 400.)] - /// Band-pass audio filter high cutoff. - high_cutoff: f32, - #[clap(long, default_value_t = 5)] - /// Band size of the comparison. (Advanced) - comparator_band_size: u16, - #[clap(long, default_value_t = 0.22)] - /// Used to express the result as a probability. (Advanced) - comparator_ref: f32, - #[clap(short, long)] - /// Log partial detections. - debug: bool, - #[clap(long)] - /// Log rms level ref, gain applied per frame and frame rms level. - debug_gain: bool, -} -pub fn test(command: TestModelCommand) -> Result<(), String> { - println!( - "Testing file {} against model {}!", - command.sample_path, command.model_path, - ); - // Read wav file - let file_reader = - BufReader::new(File::open(command.sample_path).map_err(|err| err.to_string())?); - let mut wav_reader = WavReader::new(file_reader).map_err(|err| err.to_string())?; - let wav_specs = wav_reader.spec(); - let mut config = RustpotterConfig::default(); - let sample_rate = wav_specs.sample_rate as usize; - config.fmt.sample_rate = sample_rate; - config.fmt.bits_per_sample = wav_specs.bits_per_sample; - config.fmt.channels = wav_specs.channels; - config.detector.avg_threshold = command.averaged_threshold; - config.detector.threshold = command.threshold; - config.detector.min_scores = command.min_scores; - config.detector.score_mode = command.score_mode.into(); - config.detector.comparator_band_size = command.comparator_band_size; - config.detector.comparator_ref = command.comparator_ref; - config.filters.gain_normalizer.enabled = command.gain_normalizer; - config.filters.gain_normalizer.gain_ref = command.gain_ref; - config.filters.gain_normalizer.min_gain = command.min_gain; - config.filters.gain_normalizer.max_gain = command.max_gain; - config.filters.band_pass.enabled = command.band_pass; - config.filters.band_pass.low_cutoff = command.low_cutoff; - config.filters.band_pass.high_cutoff = command.high_cutoff; - let mut rustpotter = Rustpotter::new(&config)?; - if let Err(error) = rustpotter.add_wakeword_from_file(&command.model_path) { - clap::Error::raw( - clap::error::ErrorKind::InvalidValue, - error.to_string() + "\n", - ) - .exit(); - } - let mut partial_detection_counter = 0; - let mut chunk_counter = 0; - let chunk_size = rustpotter.get_samples_per_frame(); - match wav_specs.sample_format { - SampleFormat::Int => { - let mut buffer = wav_reader - .samples::() - .map(Result::unwrap) - .collect::>(); - buffer.append(&mut vec![0; chunk_size * 100]); - buffer - .chunks_exact(chunk_size) - .for_each(|chunk| { - chunk_counter+=1; - let detection = rustpotter.process_i32(chunk); - print_detection( - &rustpotter, - detection, - &mut partial_detection_counter, - command.debug, - command.debug_gain, - || { get_time_string(chunk_counter, chunk_size, sample_rate) }, - ); - }); - } - SampleFormat::Float => { - let mut buffer = wav_reader - .samples::() - .map(Result::unwrap) - .collect::>(); - buffer.append(&mut vec![0.; chunk_size * 100]); - buffer - .chunks_exact(chunk_size) - .for_each(|chunk| { - chunk_counter+=1; - let detection = rustpotter.process_f32(chunk); - print_detection( - &rustpotter, - detection, - &mut partial_detection_counter, - command.debug, - command.debug_gain, - || { get_time_string(chunk_counter, chunk_size, sample_rate) }, - ); - }); - } - }; - Ok(()) -} -fn get_time_string(chunk_number: usize, chunk_size: usize, sample_rate: usize) -> String { - let total_seconds = (chunk_number * chunk_size) as f32 / sample_rate as f32; - let minutes = (total_seconds / 60.).floor() as i32; - let seconds = (total_seconds % 60.).floor() as i32; - format!("00:{:02}:{:02}", minutes, seconds) -} \ No newline at end of file diff --git a/src/cli/train.rs b/src/cli/train.rs new file mode 100644 index 0000000..33bc73d --- /dev/null +++ b/src/cli/train.rs @@ -0,0 +1,59 @@ +use clap::Args; +use rustpotter::{ + ModelType, WakewordLoad, WakewordModel, WakewordModelTrain, WakewordModelTrainOptions, + WakewordSave, +}; + +#[derive(Args, Debug)] +/// Train wakeword model, using wav audio files +#[clap()] +pub struct TrainCommand { + #[clap()] + /// Generated model path + model_path: String, + #[clap(short = 't', long, default_value_t = ModelType::Medium)] + /// Generated model type + model_type: ModelType, + #[clap(long, required = true)] + /// Train data directory path + train_dir: String, + #[clap(long, required = true)] + /// Test data directory path + test_dir: String, + #[clap(short = 'l', long, default_value_t = 0.03)] + /// Training learning rate + learning_rate: f64, + #[clap(short = 'e', long, default_value_t = 1000)] + /// Number of backward and forward cycles to run + epochs: usize, + #[clap(long, default_value_t = 1)] + /// Number of epochs for testing the model and print the progress. + test_epochs: usize, + #[clap(short = 'c', long, default_value_t = 16)] + /// Number of extracted mel-frequency cepstral coefficients + mfcc_size: u16, + #[clap(short = 'm')] + /// Model to continue training from + wakeword_model: Option, +} +pub fn train(command: TrainCommand) -> Result<(), String> { + println!("Start training {}!", command.model_path); + let model: Option = if let Some(wakeword_model_path) = command.wakeword_model { + Some(WakewordModel::load_from_file(&wakeword_model_path)?) + } else { + None + }; + let options = WakewordModelTrainOptions::new( + command.model_type, + command.learning_rate, + command.epochs, + command.test_epochs, + command.mfcc_size, + ); + let wakeword = + WakewordModel::train_from_dirs(options, command.train_dir, command.test_dir, model) + .map_err(|err| err.to_string())?; + wakeword.save_to_file(&command.model_path)?; + println!("{} created!", command.model_path); + Ok(()) +} diff --git a/tools/Dockerfile b/tools/Dockerfile index 3bc3f63..d8ec255 100644 --- a/tools/Dockerfile +++ b/tools/Dockerfile @@ -8,6 +8,8 @@ RUN mkdir -p /code/.cargo \ && cargo vendor > /code/.cargo/config # build FROM rust:buster +ARG RUSTFLAGS "" +ENV RUSTFLAGS $RUSTFLAGS COPY src /code/src COPY Cargo.* /code/ COPY --from=rust_vendor /code/.cargo /code/.cargo diff --git a/tools/create_tag.sh b/tools/create_tag.sh index d7c02ce..5c05eba 100755 --- a/tools/create_tag.sh +++ b/tools/create_tag.sh @@ -1,7 +1,7 @@ #!/bin/sh set -e cargo publish -VERSION=$(cat Cargo.toml | grep ^version | egrep -i -o '\d*\.\d*(\.\d*)?') +VERSION=$(cat Cargo.toml | grep ^version | cut -d'"' -f 2) TAG_NAME="v$VERSION" echo "creating $TAG_NAME" git tag -a $TAG_NAME -m "version $VERSION" diff --git a/tools/install.sh b/tools/install.sh new file mode 100755 index 0000000..c456914 --- /dev/null +++ b/tools/install.sh @@ -0,0 +1,4 @@ +#!/bin/sh +set -e +cargo build --release +mv target/release/rustpotter-cli /usr/local/bin/ \ No newline at end of file