diff --git a/Cargo.toml b/Cargo.toml index a282260..62c7dee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,9 +13,13 @@ categories = ["asynchronous", "concurrency"] readme = "README.md" [dependencies] -crossbeam-utils = "0.7.0" +libc = "0.2.66" + +[target.'cfg(windows)'.dependencies] +winapi = { version = "0.3.8", features = ["processthreadsapi"] } [dev-dependencies] crossbeam = "0.7.3" +crossbeam-utils = "0.7.0" futures = "0.3.1" lazy_static = "1.4.0" diff --git a/src/header.rs b/src/header.rs index f9aff4f..5882f85 100644 --- a/src/header.rs +++ b/src/header.rs @@ -1,10 +1,8 @@ -use std::alloc::Layout; -use std::cell::Cell; -use std::fmt; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::task::Waker; - -use crossbeam_utils::Backoff; +use core::alloc::Layout; +use core::cell::UnsafeCell; +use core::fmt; +use core::sync::atomic::{AtomicUsize, Ordering}; +use core::task::Waker; use crate::raw::TaskVTable; use crate::state::*; @@ -22,7 +20,7 @@ pub(crate) struct Header { /// The task that is blocked on the `JoinHandle`. /// /// This waker needs to be woken up once the task completes or is closed. - pub(crate) awaiter: Cell>, + pub(crate) awaiter: UnsafeCell>, /// The virtual table. /// @@ -55,7 +53,7 @@ impl Header { Ok(_) => { // Notify the awaiter that the task has been closed. if state & AWAITER != 0 { - self.notify(); + self.notify(None); } break; @@ -67,68 +65,105 @@ impl Header { /// Notifies the awaiter blocked on this task. /// - /// If there is a registered waker, it will be removed from the header and woken up. + /// If the awaiter is the same as the current waker, it will not be notified. #[inline] - pub(crate) fn notify(&self) { - if let Some(waker) = self.swap_awaiter(None) { - // We need a safeguard against panics because waking can panic. - abort_on_panic(|| { - waker.wake(); - }); - } - } + pub(crate) fn notify(&self, current: Option<&Waker>) { + // Mark the awaiter as being notified. + let state = self.state.fetch_or(NOTIFYING, Ordering::AcqRel); - /// Notifies the awaiter blocked on this task, unless its waker matches `current`. - /// - /// If there is a registered waker, it will be removed from the header in any case. - #[inline] - pub(crate) fn notify_unless(&self, current: &Waker) { - if let Some(waker) = self.swap_awaiter(None) { - if !waker.will_wake(current) { + // If the awaiter was not being notified nor registered... + if state & (NOTIFYING | REGISTERING) == 0 { + // Take the waker out. + let waker = unsafe { (*self.awaiter.get()).take() }; + + // Mark the state as not being notified anymore nor containing an awaiter. + self.state + .fetch_and(!NOTIFYING & !AWAITER, Ordering::Release); + + if let Some(w) = waker { // We need a safeguard against panics because waking can panic. - abort_on_panic(|| { - waker.wake(); + abort_on_panic(|| match current { + None => w.wake(), + Some(c) if !w.will_wake(c) => w.wake(), + Some(_) => {} }); } } } - /// Swaps the awaiter for a new waker and returns the previous value. + /// Registers a new awaiter blocked on this task. + /// + /// This method is called when `JoinHandle` is polled and the task has not completed. #[inline] - pub(crate) fn swap_awaiter(&self, new: Option) -> Option { - let new_is_none = new.is_none(); + pub(crate) fn register(&self, waker: &Waker) { + // Load the state and synchronize with it. + let mut state = self.state.fetch_or(0, Ordering::Acquire); - // We're about to try acquiring the lock in a loop. If it's already being held by another - // thread, we'll have to spin for a while so it's best to employ a backoff strategy. - let backoff = Backoff::new(); loop { - // Acquire the lock. If we're storing an awaiter, then also set the awaiter flag. - let state = if new_is_none { - self.state.fetch_or(LOCKED, Ordering::Acquire) - } else { - self.state.fetch_or(LOCKED | AWAITER, Ordering::Acquire) - }; + // There can't be two concurrent registrations because `JoinHandle` can only be polled + // by a unique pinned reference. + debug_assert!(state & REGISTERING == 0); + + // If the state is being notified at this moment, just wake and return without + // registering. + if state & NOTIFYING != 0 { + waker.wake_by_ref(); + return; + } - // If the lock was acquired, break from the loop. - if state & LOCKED == 0 { - break; + // Mark the state to let other threads know we're registering a new awaiter. + match self.state.compare_exchange_weak( + state, + state | REGISTERING, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + state |= REGISTERING; + break; + } + Err(s) => state = s, } + } - // Snooze for a little while because the lock is held by another thread. - backoff.snooze(); + // Put the waker into the awaiter field. + unsafe { + abort_on_panic(|| (*self.awaiter.get()) = Some(waker.clone())); } - // Replace the awaiter. - let old = self.awaiter.replace(new); + // This variable will contain the newly registered waker if a notification comes in before + // we complete registration. + let mut waker = None; + + loop { + // If there was a notification, take the waker out of the awaiter field. + if state & NOTIFYING != 0 { + if let Some(w) = unsafe { (*self.awaiter.get()).take() } { + waker = Some(w); + } + } + + // The new state is not being notified nor registered, but there might or might not be + // an awaiter depending on whether there was a concurrent notification. + let new = if waker.is_none() { + (state & !NOTIFYING & !REGISTERING) | AWAITER + } else { + state & !NOTIFYING & !REGISTERING & !AWAITER + }; - // Release the lock. If we've cleared the awaiter, then also unset the awaiter flag. - if new_is_none { - self.state.fetch_and(!LOCKED & !AWAITER, Ordering::Release); - } else { - self.state.fetch_and(!LOCKED, Ordering::Release); + match self + .state + .compare_exchange_weak(state, new, Ordering::AcqRel, Ordering::Acquire) + { + Ok(_) => break, + Err(s) => state = s, + } } - old + // If there was a notification during registration, wake the awaiter now. + if let Some(w) = waker { + abort_on_panic(|| w.wake()); + } } /// Returns the offset at which the tag of type `T` is stored. diff --git a/src/join_handle.rs b/src/join_handle.rs index 9357d32..49d529b 100644 --- a/src/join_handle.rs +++ b/src/join_handle.rs @@ -1,10 +1,10 @@ -use std::fmt; -use std::future::Future; -use std::marker::{PhantomData, Unpin}; -use std::pin::Pin; -use std::ptr::NonNull; -use std::sync::atomic::Ordering; -use std::task::{Context, Poll}; +use core::fmt; +use core::future::Future; +use core::marker::{PhantomData, Unpin}; +use core::pin::Pin; +use core::ptr::NonNull; +use core::sync::atomic::Ordering; +use core::task::{Context, Poll}; use crate::header::Header; use crate::state::*; @@ -71,7 +71,7 @@ impl JoinHandle { // Notify the awaiter that the task has been closed. if state & AWAITER != 0 { - (*header).notify(); + (*header).notify(None); } break; @@ -190,7 +190,7 @@ impl Future for JoinHandle { if state & CLOSED != 0 { // Even though the awaiter is most likely the current task, it could also be // another task. - (*header).notify_unless(cx.waker()); + (*header).notify(Some(cx.waker())); return Poll::Ready(None); } @@ -199,7 +199,7 @@ impl Future for JoinHandle { // Replace the waker with one associated with the current task. We need a // safeguard against panics because dropping the previous waker can panic. abort_on_panic(|| { - (*header).swap_awaiter(Some(cx.waker().clone())); + (*header).register(cx.waker()); }); // Reload the state after registering. It is possible that the task became @@ -230,7 +230,7 @@ impl Future for JoinHandle { // Notify the awaiter. Even though the awaiter is most likely the current // task, it could also be another task. if state & AWAITER != 0 { - (*header).notify_unless(cx.waker()); + (*header).notify(Some(cx.waker())); } // Take the output from the task. diff --git a/src/lib.rs b/src/lib.rs index a265679..5fe858a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -97,10 +97,13 @@ //! [`Task`]: struct.Task.html //! [`JoinHandle`]: struct.JoinHandle.html +#![no_std] #![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)] #![doc(test(attr(deny(rust_2018_idioms, warnings))))] #![doc(test(attr(allow(unused_extern_crates, unused_variables))))] +extern crate alloc; + mod header; mod join_handle; mod raw; diff --git a/src/raw.rs b/src/raw.rs index c250c02..6af184f 100644 --- a/src/raw.rs +++ b/src/raw.rs @@ -1,12 +1,12 @@ -use std::alloc::{self, Layout}; -use std::cell::Cell; -use std::future::Future; -use std::marker::PhantomData; -use std::mem::{self, ManuallyDrop}; -use std::pin::Pin; -use std::ptr::NonNull; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; +use alloc::alloc::Layout; +use core::cell::UnsafeCell; +use core::future::Future; +use core::marker::PhantomData; +use core::mem::{self, ManuallyDrop}; +use core::pin::Pin; +use core::ptr::NonNull; +use core::sync::atomic::{AtomicUsize, Ordering}; +use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; use crate::header::Header; use crate::state::*; @@ -107,8 +107,8 @@ where unsafe { // Allocate enough space for the entire task. - let raw_task = match NonNull::new(alloc::alloc(task_layout.layout) as *mut ()) { - None => std::process::abort(), + let raw_task = match NonNull::new(alloc::alloc::alloc(task_layout.layout) as *mut ()) { + None => libc::abort(), Some(p) => p, }; @@ -117,7 +117,7 @@ where // Write the header as the first field of the task. (raw.header as *mut Header).write(Header { state: AtomicUsize::new(SCHEDULED | HANDLE | REFERENCE), - awaiter: Cell::new(None), + awaiter: UnsafeCell::new(None), vtable: &TaskVTable { raw_waker_vtable: RawWakerVTable::new( Self::clone_waker, @@ -307,7 +307,7 @@ where if state & RUNNING == 0 { // If the reference count overflowed, abort. if state > isize::max_value() as usize { - std::process::abort(); + libc::abort(); } // Schedule the task. There is no need to call `Self::schedule(ptr)` @@ -339,7 +339,7 @@ where // If the reference count overflowed, abort. if state > isize::max_value() as usize { - std::process::abort(); + libc::abort(); } RawWaker::new(ptr, raw_waker_vtable) @@ -449,7 +449,7 @@ where }); // Finally, deallocate the memory reserved by the task. - alloc::dealloc(ptr as *mut u8, task_layout.layout); + alloc::alloc::dealloc(ptr as *mut u8, task_layout.layout); } /// Runs a task. @@ -474,7 +474,7 @@ where if state & CLOSED != 0 { // Notify the awaiter that the task has been closed. if state & AWAITER != 0 { - (*raw.header).notify(); + (*raw.header).notify(None); } // Drop the future. @@ -542,7 +542,7 @@ where // Notify the awaiter that the task has been completed. if state & AWAITER != 0 { - (*raw.header).notify(); + (*raw.header).notify(None); } // Drop the task reference. @@ -649,7 +649,7 @@ where // Notify the awaiter that the task has been closed. if state & AWAITER != 0 { - (*raw.header).notify(); + (*raw.header).notify(None); } // Drop the task reference. diff --git a/src/state.rs b/src/state.rs index c03fea3..167a371 100644 --- a/src/state.rs +++ b/src/state.rs @@ -48,10 +48,16 @@ pub(crate) const HANDLE: usize = 1 << 4; /// check that tells us if we need to wake anyone without acquiring the lock inside the task. pub(crate) const AWAITER: usize = 1 << 5; -/// Set if the awaiter is locked. +/// Set if an awaiter is being registered. /// -/// This lock is acquired before a new awaiter is registered or the existing one is woken up. -pub(crate) const LOCKED: usize = 1 << 6; +/// This flag is set when `JoinHandle` is polled and we are registering a new awaiter. +pub(crate) const REGISTERING: usize = 1 << 6; + +/// Set if the awaiter is being notified. +/// +/// This flag is set when notifying the awaiter. If an awaiter is concurrently registered and +/// notified, whichever side came first will take over the reposibility of resolving the race. +pub(crate) const NOTIFYING: usize = 1 << 7; /// A single reference. /// @@ -61,4 +67,4 @@ pub(crate) const LOCKED: usize = 1 << 6; /// /// Note that the reference counter only tracks the `Task` and `Waker`s. The `JoinHandle` is /// tracked separately by the `HANDLE` flag. -pub(crate) const REFERENCE: usize = 1 << 7; +pub(crate) const REFERENCE: usize = 1 << 8; diff --git a/src/task.rs b/src/task.rs index 83cdf79..80953f4 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,11 +1,10 @@ -use std::fmt; -use std::future::Future; -use std::marker::PhantomData; -use std::mem::{self, ManuallyDrop}; -use std::pin::Pin; -use std::ptr::NonNull; -use std::task::{Context, Poll}; -use std::thread::{self, ThreadId}; +use core::fmt; +use core::future::Future; +use core::marker::PhantomData; +use core::mem::{self, ManuallyDrop}; +use core::pin::Pin; +use core::ptr::NonNull; +use core::task::{Context, Poll}; use crate::header::Header; use crate::raw::RawTask; @@ -109,20 +108,29 @@ where S: Fn(Task) + Send + Sync + 'static, T: Send + Sync + 'static, { - thread_local! { - static ID: ThreadId = thread::current().id(); + #[cfg(unix)] + #[inline] + fn thread_id() -> usize { + unsafe { libc::pthread_self() as usize } + } + + #[cfg(windows)] + #[inline] + fn thread_id() -> usize { + unsafe { winapi::um::processthreadsapi::GetCurrentThreadId() as usize } } struct Checked { - id: ThreadId, + id: usize, inner: ManuallyDrop, } impl Drop for Checked { fn drop(&mut self) { - if ID.with(|id| *id) != self.id { - panic!("local task dropped by a thread that didn't spawn it"); - } + assert!( + self.id == thread_id(), + "local task dropped by a thread that didn't spawn it" + ); unsafe { ManuallyDrop::drop(&mut self.inner); } @@ -133,15 +141,16 @@ where type Output = F::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if ID.with(|id| *id) != self.id { - panic!("local task polled by a thread that didn't spawn it"); - } + assert!( + self.id == thread_id(), + "local task polled by a thread that didn't spawn it" + ); unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) } } } let future = Checked { - id: ID.with(|id| *id), + id: thread_id(), inner: ManuallyDrop::new(future), }; diff --git a/src/utils.rs b/src/utils.rs index 441ead1..7c71deb 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,5 @@ -use std::alloc::Layout; -use std::mem; +use core::alloc::Layout; +use core::mem; /// Calls a function and aborts if it panics. /// @@ -10,7 +10,7 @@ pub(crate) fn abort_on_panic(f: impl FnOnce() -> T) -> T { impl Drop for Bomb { fn drop(&mut self) { - std::process::abort(); + unsafe { libc::abort() } } }