From 96c34ea7ada113e40470e94121cbc51e7b811462 Mon Sep 17 00:00:00 2001 From: Albert Garde Date: Thu, 11 Apr 2024 17:09:55 +0200 Subject: [PATCH] :bug: Fix unsoundness The static variable `USED_PARALLELISM` was previously modified and accessed in safe functions with no guarantee of soundness. This is fixed by putting it into a Mutex. --- tokenizers/src/utils/parallelism.rs | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tokenizers/src/utils/parallelism.rs b/tokenizers/src/utils/parallelism.rs index 2ac2eaa80..cef517bf5 100644 --- a/tokenizers/src/utils/parallelism.rs +++ b/tokenizers/src/utils/parallelism.rs @@ -2,6 +2,8 @@ //! This module defines helpers to allow optional Rayon usage. //! +use std::sync::Mutex; + use rayon::iter::IterBridge; use rayon::prelude::*; use rayon_cond::CondIterator; @@ -12,7 +14,7 @@ pub use rayon::current_num_threads; pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM"; // Reading/Writing this variable should always happen on the main thread -static mut USED_PARALLELISM: bool = false; +static USED_PARALLELISM: Mutex = Mutex::new(false); /// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set pub fn is_parallelism_configured() -> bool { @@ -21,7 +23,9 @@ pub fn is_parallelism_configured() -> bool { /// Check if at some point we used a parallel iterator pub fn has_parallelism_been_used() -> bool { - unsafe { USED_PARALLELISM } + *USED_PARALLELISM + .lock() + .expect("`USED_PARALLELISM` should only be accessed on the main thread.") } /// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable @@ -70,7 +74,9 @@ where fn into_maybe_par_iter(self) -> CondIterator { let parallelism = get_parallelism(); if parallelism { - unsafe { USED_PARALLELISM = true }; + *USED_PARALLELISM + .lock() + .expect("`USED_PARALLELISM` should only be accessed on the main thread.") = true; } CondIterator::new(self, parallelism) } @@ -159,7 +165,9 @@ where let iter = CondIterator::from_serial(self); if get_parallelism() { - unsafe { USED_PARALLELISM = true }; + *USED_PARALLELISM + .lock() + .expect("`USED_PARALLELISM` should only be accessed on the main thread.") = true; CondIterator::from_parallel(iter.into_parallel().right().unwrap()) } else { iter