From ce27797806ce498975c14dabeae4da47040a1a5d Mon Sep 17 00:00:00 2001 From: Swarnim Arun Date: Thu, 16 Nov 2023 13:02:19 +0530 Subject: [PATCH] feat: use process groups for swarms handle swarm processes as a child process group to ensure the exit happens consistently using group kill command --- src-tauri/Cargo.lock | 36 ++++++++++++++++++++++++++++++++++++ src-tauri/Cargo.toml | 1 + src-tauri/src/main.rs | 2 ++ src-tauri/src/swarm.rs | 35 ++++++++++++++++++++++++++++++----- 4 files changed, 69 insertions(+), 5 deletions(-) diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index df7f4fda..b9f962cc 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -62,6 +62,17 @@ version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" +[[package]] +name = "async-trait" +version = "0.1.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + [[package]] name = "atk" version = "0.15.1" @@ -350,6 +361,18 @@ dependencies = [ "memchr", ] +[[package]] +name = "command-group" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e389ace313e22a2ac5a0e54f5805616a783077943b001edabbdb8640d52c490e" +dependencies = [ + "async-trait", + "nix", + "tokio", + "winapi", +] + [[package]] name = "concurrent-queue" version = "2.3.0" @@ -2499,6 +2522,7 @@ name = "prem-app" version = "0.2.1" dependencies = [ "chrono", + "command-group", "ctrlc", "futures", "log", @@ -3948,9 +3972,21 @@ dependencies = [ "pin-project-lite", "signal-hook-registry", "socket2 0.5.5", + "tokio-macros", "windows-sys 0.48.0", ] +[[package]] +name = "tokio-macros" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + [[package]] name = "tokio-native-tls" version = "0.3.1" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 0a69c97b..ad6b0841 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -7,6 +7,7 @@ [dependencies] chrono = "0.4.31" + command-group = { version = "4.1.0", features = ["with-tokio"] } ctrlc = "3.4.1" log = "0.4.20" pretty_env_logger = "0.5.0" diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 99250cbd..e5af497c 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -11,6 +11,7 @@ use crate::controller_binaries::stop_all_services; use std::{collections::HashMap, env, ops::Deref, str, sync::Arc}; +use command_group::AsyncGroupChild; use sentry_tauri::sentry; use serde::{Deserialize, Serialize}; use tauri::{ @@ -222,6 +223,7 @@ fn main() { .plugin(sentry_tauri::plugin()) .plugin(tauri_plugin_store::Builder::default().build()) .manage(state.clone()) + .manage(Option::>::None) .invoke_handler(tauri::generate_handler![ controller_binaries::start_service, controller_binaries::stop_service, diff --git a/src-tauri/src/swarm.rs b/src-tauri/src/swarm.rs index 9cb1b184..3f0e3980 100644 --- a/src-tauri/src/swarm.rs +++ b/src-tauri/src/swarm.rs @@ -1,3 +1,4 @@ +use command_group::AsyncGroupChild; use reqwest::get; use serde::Deserialize; use std::{ @@ -9,6 +10,7 @@ use std::{ time::Duration, }; use tauri::api::process::Command; +use tokio::sync::Mutex; #[derive(Deserialize)] struct PetalsModelInfo { @@ -138,15 +140,25 @@ pub fn delete_environment(handle: tauri::AppHandle) { } #[tauri::command(async)] -pub fn run_swarm(handle: tauri::AppHandle, num_blocks: i32, model: String, public_name: String) { +pub async fn run_swarm( + handle: tauri::AppHandle, + num_blocks: i32, + model: String, + public_name: String, + state: tauri::State<'_, Mutex>>, +) -> crate::errors::Result<()> { let petals_path = get_petals_path(handle.clone()); let config = Config::new(); let mut env = HashMap::new(); env.insert("PREM_PYTHON".to_string(), config.python); + use command_group::AsyncCommandGroup; + use tokio::process::Command; + println!("🚀 Starting the Swarm..."); - let _ = Command::new("sh") + // start the command as a (sub/child) process group + let group = Command::new("sh") .args([ format!("{petals_path}/run_swarm.sh").as_str(), &num_blocks.to_string(), @@ -154,8 +166,12 @@ pub fn run_swarm(handle: tauri::AppHandle, num_blocks: i32, model: String, publi &model, ]) .envs(env) - .spawn() + .group_spawn() .expect("🙈 Failed to run swarm"); + + _ = state.lock().await.insert(group); + + Ok(()) } fn get_petals_path(handle: tauri::AppHandle) -> String { @@ -218,8 +234,10 @@ pub fn get_swarm_processes() -> Vec { processes } -#[tauri::command] -pub fn stop_swarm_mode() { +#[tauri::command(async)] +pub async fn stop_swarm_mode( + state: tauri::State<'_, Mutex>>, +) -> crate::errors::Result<()> { println!("🛑 Stopping the Swarm..."); let processes = get_swarm_processes(); println!("🛑 Stopping Processes: {:?}", processes); @@ -256,5 +274,12 @@ pub fn stop_swarm_mode() { println!("🛑 Stopping Process with SIGTERM: {}", process); } } + + // attempt to kill the process group, but don't wait for it unnecessarily + if let Some(group) = &mut *state.lock().await { + _ = group.start_kill(); + } + println!("🛑 Stopped all the Swarm Processes."); + Ok(()) }