diff --git a/tokenizers/src/utils/parallelism.rs b/tokenizers/src/utils/parallelism.rs index 2ac2eaa80..b955731d1 100644 --- a/tokenizers/src/utils/parallelism.rs +++ b/tokenizers/src/utils/parallelism.rs @@ -5,14 +5,15 @@ use rayon::iter::IterBridge; use rayon::prelude::*; use rayon_cond::CondIterator; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; // Re-export rayon current_num_threads pub use rayon::current_num_threads; pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM"; -// Reading/Writing this variable should always happen on the main thread -static mut USED_PARALLELISM: bool = false; +static USED_PARALLELISM: AtomicBool = AtomicBool::new(false); /// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set pub fn is_parallelism_configured() -> bool { @@ -21,7 +22,7 @@ pub fn is_parallelism_configured() -> bool { /// Check if at some point we used a parallel iterator pub fn has_parallelism_been_used() -> bool { - unsafe { USED_PARALLELISM } + USED_PARALLELISM.load(Ordering::SeqCst) } /// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable @@ -70,7 +71,7 @@ where fn into_maybe_par_iter(self) -> CondIterator { let parallelism = get_parallelism(); if parallelism { - unsafe { USED_PARALLELISM = true }; + USED_PARALLELISM.store(true, Ordering::SeqCst); } CondIterator::new(self, parallelism) } @@ -159,7 +160,7 @@ where let iter = CondIterator::from_serial(self); if get_parallelism() { - unsafe { USED_PARALLELISM = true }; + USED_PARALLELISM.store(true, Ordering::SeqCst); CondIterator::from_parallel(iter.into_parallel().right().unwrap()) } else { iter