diff --git a/library/std/src/lib.rs b/library/std/src/lib.rs index ba76ee31b42dd..5ce38f530e3c7 100644 --- a/library/std/src/lib.rs +++ b/library/std/src/lib.rs @@ -287,6 +287,7 @@ #![feature(prelude_2024)] #![feature(ptr_as_uninit)] #![feature(raw_os_nonzero)] +#![feature(cfg_sanitize)] #![feature(slice_internals)] #![feature(slice_ptr_get)] #![feature(std_internals)] diff --git a/library/std/src/thread/mod.rs b/library/std/src/thread/mod.rs index c70ac8c9806d6..89445788399be 100644 --- a/library/std/src/thread/mod.rs +++ b/library/std/src/thread/mod.rs @@ -166,7 +166,7 @@ use crate::num::NonZeroUsize; use crate::panic; use crate::panicking; use crate::pin::Pin; -use crate::ptr::addr_of_mut; +use crate::ptr::{addr_of_mut, NonNull}; use crate::str; use crate::sync::Arc; use crate::sys::thread as imp; @@ -463,7 +463,7 @@ impl Builder { unsafe fn spawn_unchecked_<'a, 'scope, F, T>( self, f: F, - scope_data: Option>, + scope_data: Option>, ) -> io::Result> where F: FnOnce() -> T, @@ -481,7 +481,7 @@ impl Builder { let their_thread = my_thread.clone(); let my_packet: Arc> = Arc::new(Packet { - scope: scope_data, + scope: scope_data.map(|data| NonNull::from(data.get_ref())), result: UnsafeCell::new(None), _marker: PhantomData, }); @@ -511,8 +511,8 @@ impl Builder { unsafe { *their_packet.result.get() = Some(try_result) }; }; - if let Some(scope_data) = &my_packet.scope { - scope_data.increment_num_running_threads(); + if let Some(scope_data) = my_packet.scope { + unsafe { scope_data.as_ref().increment_num_running_threads() }; } Ok(JoinInner { @@ -1302,7 +1302,7 @@ pub type Result = crate::result::Result>; // An Arc to the packet is stored into a `JoinInner` which in turns is placed // in `JoinHandle`. struct Packet<'scope, T> { - scope: Option>, + scope: Option>, result: UnsafeCell>>, _marker: PhantomData>, } @@ -1335,12 +1335,23 @@ impl<'scope, T> Drop for Packet<'scope, T> { rtabort!("thread result panicked on drop"); } // Book-keeping so the scope knows when it's done. - if let Some(scope) = &self.scope { + if let Some(scope_data) = self.scope { // Now that there will be no more user code running on this thread // that can use 'scope, mark the thread as 'finished'. // It's important we only do this after the `result` has been dropped, // since dropping it might still use things it borrowed from 'scope. - scope.decrement_num_running_threads(unhandled_panic); + // + // A static method to decrement is used to keep `ScopeData` as a raw pointer. + // Using a reference risks the decrement function waking the `scope()` thread, + // invalidating our `ScopeData`, and leaving us with a dangling dereferenceable &ScopeData. + // This avoids issue #55005. + // + // SAFETY: + // Given the thread has been spawned, + // there was a matching call to `ScopeData::increment_num_running_threads()`. + unsafe { + scoped::ScopeData::decrement_num_running_threads(scope_data, unhandled_panic); + } } } } diff --git a/library/std/src/thread/scoped.rs b/library/std/src/thread/scoped.rs index e6dbf35bd0286..a39650757e1e7 100644 --- a/library/std/src/thread/scoped.rs +++ b/library/std/src/thread/scoped.rs @@ -1,17 +1,21 @@ use super::{current, park, Builder, JoinInner, Result, Thread}; +use crate::cell::UnsafeCell; use crate::fmt; use crate::io; -use crate::marker::PhantomData; +use crate::marker::{PhantomData, PhantomPinned}; use crate::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; -use crate::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use crate::pin::Pin; +use crate::ptr::NonNull; +use crate::sync::atomic::{fence, AtomicBool, AtomicUsize, Ordering}; use crate::sync::Arc; +use core::intrinsics::{atomic_store_rel, atomic_xsub_rel}; /// A scope to spawn scoped threads in. /// /// See [`scope`] for details. #[stable(feature = "scoped_threads", since = "1.63.0")] pub struct Scope<'scope, 'env: 'scope> { - data: Arc, + data: Pin<&'scope ScopeData>, /// Invariance over 'scope, to make sure 'scope cannot shrink, /// which is necessary for soundness. /// @@ -35,28 +39,133 @@ pub struct Scope<'scope, 'env: 'scope> { #[stable(feature = "scoped_threads", since = "1.63.0")] pub struct ScopedJoinHandle<'scope, T>(JoinInner<'scope, T>); +const WAITING_BIT: usize = 1; +const ONE_RUNNING: usize = 2; + +/// Artificial limit on the maximum number of concurrently running threads in scope. +/// This is used to preemptively avoid hitting an overflow condition in the running thread count. +const MAX_RUNNING: usize = usize::MAX / 2; + +#[derive(Default)] pub(super) struct ScopeData { - num_running_threads: AtomicUsize, - a_thread_panicked: AtomicBool, - main_thread: Thread, + sync_state: AtomicUsize, + thread_panicked: AtomicBool, + scope_thread: UnsafeCell>, + _pinned: PhantomPinned, } +unsafe impl Send for ScopeData {} // SAFETY: ScopeData needs to be sent to the spawned threads in the scope. +unsafe impl Sync for ScopeData {} // SAFETY: ScopeData is shared between the spawned threads and the scope thread. + impl ScopeData { + /// Issues an Acquire fence which synchronizes with the `sync_state` Release sequence. + fn fence_acquire_sync_state(&self) { + // ThreadSanitizier doesn't properly support fences + // so use an atomic load instead to avoid false positive data-race reports. + if cfg!(sanitize = "thread") { + self.sync_state.load(Ordering::Acquire); + } else { + fence(Ordering::Acquire); + } + } + pub(super) fn increment_num_running_threads(&self) { - // We check for 'overflow' with usize::MAX / 2, to make sure there's no - // chance it overflows to 0, which would result in unsoundness. - if self.num_running_threads.fetch_add(1, Ordering::Relaxed) > usize::MAX / 2 { - // This can only reasonably happen by mem::forget()'ing many many ScopedJoinHandles. - self.decrement_num_running_threads(false); - panic!("too many running threads in thread scope"); + // No need for any memory barriers as this is just incrementing the running count + // with the assumption that the ScopeData remains valid before and after this call. + let state = self.sync_state.fetch_add(ONE_RUNNING, Ordering::Relaxed); + + // Make sure we're not spawning too many threads on the scope. + // The `MAX_RUNNING` is intentionally lower than `usize::MAX` to detect overflow + // conditions on the running count earlier, even in the presence of multiple threads. + let running_threads = state / ONE_RUNNING; + assert!(running_threads <= MAX_RUNNING, "too many running threads in thread scope"); + } + + /// Decrement the number of running threads with the assumption that one was running before. + /// Once the number of running threads becomes zero, it wakes up the scope thread if it's waiting. + /// The running thread count hitting zero "happens before" the scope thread returns from waiting. + /// + /// SAFETY: + /// Caller must ensure that there was a matching call to increment_num_running_threadS() prior. + pub(super) unsafe fn decrement_num_running_threads(data: NonNull, panicked: bool) { + unsafe { + if panicked { + data.as_ref().thread_panicked.store(true, Ordering::Relaxed); + } + + // Decrement the running count with a Release barrier. + // This ensures that all data accesses and side effects before the decrement + // "happen before" the scope thread observes the running count to be zero. + let state_ptr = data.as_ref().sync_state.as_mut_ptr(); + let state = atomic_xsub_rel(state_ptr, ONE_RUNNING); + + let running_threads = state / ONE_RUNNING; + assert_ne!( + running_threads, 0, + "decrement_num_running_threads called when not incremented" + ); + + // Wake up the scope thread if it's waiting and if we're the last running thread. + if state == (ONE_RUNNING | WAITING_BIT) { + // Acquire barrier ensures that both the scope_thread store and WAITING_BIT set, + // along with the data accesses and decrements from previous threads, + // "happen before" we start to wake up the scope thread. + data.as_ref().fence_acquire_sync_state(); + + let scope_thread = { + let thread_ref = &mut *data.as_ref().scope_thread.get(); + thread_ref.take().expect("ScopeData has no thread even when WAITING_BIT is set") + }; + + // Wake up the scope thread by removing the WAITING_BIT and unparking the thread. + // Release barrier ensures the consume of `scope_thread` "happens before" the + // waiting scope thread observes 0 and returns to invalidate our data pointer. + atomic_store_rel(state_ptr, 0); + scope_thread.unpark(); + } } } - pub(super) fn decrement_num_running_threads(&self, panic: bool) { - if panic { - self.a_thread_panicked.store(true, Ordering::Relaxed); + + /// Blocks the callers thread until all running threads have called decrement_num_running_threads(). + /// + /// SAFETY: + /// Caller must ensure that they're the sole scope_thread calling this function. + /// There should also be no future calls to `increment_num_running_threads()` at this point. + unsafe fn wait_for_running_threads(&self) { + // Fast check to see if no threads are running. + // Acquire barrier ensures the running thread count updates + // and previous side effects on those threads "happen before" we observe 0 and return. + if self.sync_state.load(Ordering::Acquire) == 0 { + return; } - if self.num_running_threads.fetch_sub(1, Ordering::Release) == 1 { - self.main_thread.unpark(); + + // Register our Thread object to be unparked. + unsafe { + let thread_ref = &mut *self.scope_thread.get(); + let old_scope_thread = thread_ref.replace(current()); + assert!(old_scope_thread.is_none(), "multiple threads waiting on same ScopeData"); + } + + // Set the WAITING_BIT on the state to indicate there's a waiter. + // Uses `fetch_add` over `fetch_or` as the former compiles to accelerated instructions on modern CPUs. + // Release barrier ensures Thread registration above "happens before" WAITING_BIT is observed by last running thread. + let state = self.sync_state.fetch_add(WAITING_BIT, Ordering::Release); + assert_eq!(state & WAITING_BIT, 0, "multiple threads waiting on same ScopeData"); + + // Don't wait if all running threads completed while we were trying to set the WAITING_BIT. + // Acquire barrier ensures all running thread count updates and related side effects "happen before" we return. + if state / ONE_RUNNING == 0 { + self.fence_acquire_sync_state(); + return; + } + + // Block the thread until the last running thread sees the WAITING_BIT and resets the state to zero. + // Acquire barrier ensures all running thread count updates and related side effects "happen before" we return. + loop { + park(); + if self.sync_state.load(Ordering::Acquire) == 0 { + return; + } } } } @@ -130,30 +239,26 @@ pub fn scope<'env, F, T>(f: F) -> T where F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T, { - // We put the `ScopeData` into an `Arc` so that other threads can finish their - // `decrement_num_running_threads` even after this function returns. - let scope = Scope { - data: Arc::new(ScopeData { - num_running_threads: AtomicUsize::new(0), - main_thread: current(), - a_thread_panicked: AtomicBool::new(false), - }), - env: PhantomData, - scope: PhantomData, - }; + // We can store the ScopeData on the stack as we're careful about accessing it intrusively. + let data = ScopeData::default(); + + // Make sure the store the ScopeData as Pinned to document in the type system + // that it must remain valid until it is dropped at the end of this function. + // SAFETY: the ScopeData is stored on the stack. + let scope = + Scope { data: unsafe { Pin::new_unchecked(&data) }, env: PhantomData, scope: PhantomData }; // Run `f`, but catch panics so we can make sure to wait for all the threads to join. let result = catch_unwind(AssertUnwindSafe(|| f(&scope))); // Wait until all the threads are finished. - while scope.data.num_running_threads.load(Ordering::Acquire) != 0 { - park(); - } + // SAFETY: this is the only thread that calls ScopeData::wait_for_running_threads(). + unsafe { scope.data.wait_for_running_threads() }; // Throw any panic from `f`, or the return value of `f` if no thread panicked. match result { Err(e) => resume_unwind(e), - Ok(_) if scope.data.a_thread_panicked.load(Ordering::Relaxed) => { + Ok(_) if scope.data.thread_panicked.load(Ordering::Relaxed) => { panic!("a scoped thread panicked") } Ok(result) => result, @@ -252,7 +357,7 @@ impl Builder { F: FnOnce() -> T + Send + 'scope, T: Send + 'scope, { - Ok(ScopedJoinHandle(unsafe { self.spawn_unchecked_(f, Some(scope.data.clone())) }?)) + Ok(ScopedJoinHandle(unsafe { self.spawn_unchecked_(f, Some(scope.data)) }?)) } } @@ -327,10 +432,14 @@ impl<'scope, T> ScopedJoinHandle<'scope, T> { #[stable(feature = "scoped_threads", since = "1.63.0")] impl fmt::Debug for Scope<'_, '_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let state = self.data.sync_state.load(Ordering::Relaxed); + let num_running_threads = state / ONE_RUNNING; + let main_thread_waiting = state & WAITING_BIT != 0; + f.debug_struct("Scope") - .field("num_running_threads", &self.data.num_running_threads.load(Ordering::Relaxed)) - .field("a_thread_panicked", &self.data.a_thread_panicked.load(Ordering::Relaxed)) - .field("main_thread", &self.data.main_thread) + .field("num_running_threads", &num_running_threads) + .field("thread_panicked", &self.data.thread_panicked.load(Ordering::Relaxed)) + .field("main_thread_waiting", &main_thread_waiting) .finish_non_exhaustive() } }