Skip to content

Add spawn_local and clarify what the schedule function can do #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/panic-propagation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use futures::executor;
use futures::future::FutureExt;
use lazy_static::lazy_static;

type Task = async_task::Task<()>;

/// Spawns a future on the executor.
fn spawn<F, R>(future: F) -> JoinHandle<R>
where
Expand All @@ -19,8 +21,8 @@ where
{
lazy_static! {
// A channel that holds scheduled tasks.
static ref QUEUE: Sender<async_task::Task<()>> = {
let (sender, receiver) = unbounded::<async_task::Task<()>>();
static ref QUEUE: Sender<Task> = {
let (sender, receiver) = unbounded::<Task>();

// Start the executor thread.
thread::spawn(|| {
Expand Down
9 changes: 6 additions & 3 deletions examples/panic-result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@ use futures::executor;
use futures::future::FutureExt;
use lazy_static::lazy_static;

type Task = async_task::Task<()>;
type JoinHandle<T> = async_task::JoinHandle<T, ()>;

/// Spawns a future on the executor.
fn spawn<F, R>(future: F) -> async_task::JoinHandle<thread::Result<R>, ()>
fn spawn<F, R>(future: F) -> JoinHandle<thread::Result<R>>
where
F: Future<Output = R> + Send + 'static,
R: Send + 'static,
{
lazy_static! {
// A channel that holds scheduled tasks.
static ref QUEUE: Sender<async_task::Task<()>> = {
let (sender, receiver) = unbounded::<async_task::Task<()>>();
static ref QUEUE: Sender<Task> = {
let (sender, receiver) = unbounded::<Task>();

// Start the executor thread.
thread::spawn(|| {
Expand Down
76 changes: 76 additions & 0 deletions examples/spawn-local.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
//! A simple single-threaded executor that can spawn non-`Send` futures.

use std::cell::Cell;
use std::future::Future;
use std::rc::Rc;

use crossbeam::channel::{unbounded, Receiver, Sender};

type Task = async_task::Task<()>;
type JoinHandle<T> = async_task::JoinHandle<T, ()>;

thread_local! {
// A channel that holds scheduled tasks.
static QUEUE: (Sender<Task>, Receiver<Task>) = unbounded();
}

/// Spawns a future on the executor.
fn spawn<F, R>(future: F) -> JoinHandle<R>
where
F: Future<Output = R> + 'static,
R: 'static,
{
// Create a task that is scheduled by sending itself into the channel.
let schedule = |t| QUEUE.with(|(s, _)| s.send(t).unwrap());
let (task, handle) = async_task::spawn_local(future, schedule, ());

// Schedule the task by sending it into the queue.
task.schedule();

handle
}

/// Runs a future to completion.
fn run<F, R>(future: F) -> R
where
F: Future<Output = R> + 'static,
R: 'static,
{
// Spawn a task that sends its result through a channel.
let (s, r) = unbounded();
spawn(async move { s.send(future.await).unwrap() });

loop {
// If the original task has completed, return its result.
if let Ok(val) = r.try_recv() {
return val;
}

// Otherwise, take a task from the queue and run it.
QUEUE.with(|(_, r)| r.recv().unwrap().run());
}
}

fn main() {
let val = Rc::new(Cell::new(0));

// Run a future that increments a non-`Send` value.
run({
let val = val.clone();
async move {
// Spawn a future that increments the value.
let handle = spawn({
let val = val.clone();
async move {
val.set(dbg!(val.get()) + 1);
}
});

val.set(dbg!(val.get()) + 1);
handle.await;
}
});

// The value should be 2 at the end of the program.
dbg!(val.get());
}
4 changes: 3 additions & 1 deletion examples/spawn-on-thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ use std::thread;
use crossbeam::channel;
use futures::executor;

type JoinHandle<T> = async_task::JoinHandle<T, ()>;

/// Spawns a future on a new dedicated thread.
///
/// The returned handle can be used to await the output of the future.
fn spawn_on_thread<F, R>(future: F) -> async_task::JoinHandle<R, ()>
fn spawn_on_thread<F, R>(future: F) -> JoinHandle<R>
where
F: Future<Output = R> + Send + 'static,
R: Send + 'static,
Expand Down
9 changes: 6 additions & 3 deletions examples/spawn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@ use crossbeam::channel::{unbounded, Sender};
use futures::executor;
use lazy_static::lazy_static;

type Task = async_task::Task<()>;
type JoinHandle<T> = async_task::JoinHandle<T, ()>;

/// Spawns a future on the executor.
fn spawn<F, R>(future: F) -> async_task::JoinHandle<R, ()>
fn spawn<F, R>(future: F) -> JoinHandle<R>
where
F: Future<Output = R> + Send + 'static,
R: Send + 'static,
{
lazy_static! {
// A channel that holds scheduled tasks.
static ref QUEUE: Sender<async_task::Task<()>> = {
let (sender, receiver) = unbounded::<async_task::Task<()>>();
static ref QUEUE: Sender<Task> = {
let (sender, receiver) = unbounded::<Task>();

// Start the executor thread.
thread::spawn(|| {
Expand Down
9 changes: 6 additions & 3 deletions examples/task-id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ use lazy_static::lazy_static;
#[derive(Clone, Copy, Debug)]
struct TaskId(usize);

type Task = async_task::Task<TaskId>;
type JoinHandle<T> = async_task::JoinHandle<T, TaskId>;

thread_local! {
/// The ID of the current task.
static TASK_ID: Cell<Option<TaskId>> = Cell::new(None);
Expand All @@ -26,15 +29,15 @@ fn task_id() -> Option<TaskId> {
}

/// Spawns a future on the executor.
fn spawn<F, R>(future: F) -> async_task::JoinHandle<R, TaskId>
fn spawn<F, R>(future: F) -> JoinHandle<R>
where
F: Future<Output = R> + Send + 'static,
R: Send + 'static,
{
lazy_static! {
// A channel that holds scheduled tasks.
static ref QUEUE: Sender<async_task::Task<TaskId>> = {
let (sender, receiver) = unbounded::<async_task::Task<TaskId>>();
static ref QUEUE: Sender<Task> = {
let (sender, receiver) = unbounded::<Task>();

// Start the executor thread.
thread::spawn(|| {
Expand Down
2 changes: 1 addition & 1 deletion src/join_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub struct JoinHandle<R, T> {
pub(crate) _marker: PhantomData<(R, T)>,
}

unsafe impl<R, T> Send for JoinHandle<R, T> {}
unsafe impl<R: Send, T> Send for JoinHandle<R, T> {}
unsafe impl<R, T> Sync for JoinHandle<R, T> {}

impl<R, T> Unpin for JoinHandle<R, T> {}
Expand Down
5 changes: 3 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
//! # let (task, handle) = async_task::spawn(future, schedule, ());
//! ```
//!
//! A task is constructed using the [`spawn`] function:
//! A task is constructed using either [`spawn`] or [`spawn_local`]:
//!
//! ```
//! # let (sender, receiver) = crossbeam::channel::unbounded();
Expand Down Expand Up @@ -93,6 +93,7 @@
//! union of the future and its output.
//!
//! [`spawn`]: fn.spawn.html
//! [`spawn_local`]: fn.spawn_local.html
//! [`Task`]: struct.Task.html
//! [`JoinHandle`]: struct.JoinHandle.html

Expand All @@ -108,4 +109,4 @@ mod task;
mod utils;

pub use crate::join_handle::JoinHandle;
pub use crate::task::{spawn, Task};
pub use crate::task::{spawn, spawn_local, Task};
16 changes: 5 additions & 11 deletions src/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,13 @@ impl<F, R, S, T> Clone for RawTask<F, R, S, T> {

impl<F, R, S, T> RawTask<F, R, S, T>
where
F: Future<Output = R> + Send + 'static,
R: Send + 'static,
F: Future<Output = R> + 'static,
S: Fn(Task<T>) + Send + Sync + 'static,
T: Send + 'static,
{
/// Allocates a task with the given `future` and `schedule` function.
///
/// It is assumed that initially only the `Task` reference and the `JoinHandle` exist.
pub(crate) fn allocate(tag: T, future: F, schedule: S) -> NonNull<()> {
pub(crate) fn allocate(future: F, schedule: S, tag: T) -> NonNull<()> {
// Compute the layout of the task for allocation. Abort if the computation fails.
let task_layout = abort_on_panic(|| Self::task_layout());

Expand Down Expand Up @@ -592,17 +590,13 @@ where
/// A guard that closes the task if polling its future panics.
struct Guard<F, R, S, T>(RawTask<F, R, S, T>)
where
F: Future<Output = R> + Send + 'static,
R: Send + 'static,
S: Fn(Task<T>) + Send + Sync + 'static,
T: Send + 'static;
F: Future<Output = R> + 'static,
S: Fn(Task<T>) + Send + Sync + 'static;

impl<F, R, S, T> Drop for Guard<F, R, S, T>
where
F: Future<Output = R> + Send + 'static,
R: Send + 'static,
F: Future<Output = R> + 'static,
S: Fn(Task<T>) + Send + Sync + 'static,
T: Send + 'static,
{
fn drop(&mut self) {
let raw = self.0;
Expand Down
106 changes: 104 additions & 2 deletions src/task.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::mem;
use std::mem::{self, ManuallyDrop};
use std::pin::Pin;
use std::ptr::NonNull;
use std::task::{Context, Poll};
use std::thread::{self, ThreadId};

use crate::header::Header;
use crate::raw::RawTask;
Expand All @@ -16,8 +19,16 @@ use crate::JoinHandle;
/// When run, the task polls `future`. When woken up, it gets scheduled for running by the
/// `schedule` function. Argument `tag` is an arbitrary piece of data stored inside the task.
///
/// The schedule function should not attempt to run the task nor to drop it. Instead, it should
/// push the task into some kind of queue so that it can be processed later.
///
/// If you need to spawn a future that does not implement [`Send`], consider using the
/// [`spawn_local`] function instead.
///
/// [`Task`]: struct.Task.html
/// [`JoinHandle`]: struct.JoinHandle.html
/// [`Send`]: https://doc.rust-lang.org/std/marker/trait.Send.html
/// [`spawn_local`]: fn.spawn_local.html
///
/// # Examples
///
Expand All @@ -43,7 +54,98 @@ where
S: Fn(Task<T>) + Send + Sync + 'static,
T: Send + Sync + 'static,
{
let raw_task = RawTask::<F, R, S, T>::allocate(tag, future, schedule);
let raw_task = RawTask::<F, R, S, T>::allocate(future, schedule, tag);
let task = Task {
raw_task,
_marker: PhantomData,
};
let handle = JoinHandle {
raw_task,
_marker: PhantomData,
};
(task, handle)
}

/// Creates a new local task.
///
/// This constructor returns a [`Task`] reference that runs the future and a [`JoinHandle`] that
/// awaits its result.
///
/// When run, the task polls `future`. When woken up, it gets scheduled for running by the
/// `schedule` function. Argument `tag` is an arbitrary piece of data stored inside the task.
///
/// The schedule function should not attempt to run the task nor to drop it. Instead, it should
/// push the task into some kind of queue so that it can be processed later.
///
/// Unlike [`spawn`], this function does not require the future to implement [`Send`]. If the
/// [`Task`] reference is run or dropped on a thread it was not created on, a panic will occur.
///
/// [`Task`]: struct.Task.html
/// [`JoinHandle`]: struct.JoinHandle.html
/// [`spawn`]: fn.spawn.html
/// [`Send`]: https://doc.rust-lang.org/std/marker/trait.Send.html
///
/// # Examples
///
/// ```
/// use crossbeam::channel;
///
/// // The future inside the task.
/// let future = async {
/// println!("Hello, world!");
/// };
///
/// // If the task gets woken up, it will be sent into this channel.
/// let (s, r) = channel::unbounded();
/// let schedule = move |task| s.send(task).unwrap();
///
/// // Create a task with the future and the schedule function.
/// let (task, handle) = async_task::spawn_local(future, schedule, ());
/// ```
pub fn spawn_local<F, R, S, T>(future: F, schedule: S, tag: T) -> (Task<T>, JoinHandle<R, T>)
where
F: Future<Output = R> + 'static,
R: 'static,
S: Fn(Task<T>) + Send + Sync + 'static,
T: Send + Sync + 'static,
{
thread_local! {
static ID: ThreadId = thread::current().id();
}

struct Checked<F> {
id: ThreadId,
inner: ManuallyDrop<F>,
}

impl<F> Drop for Checked<F> {
fn drop(&mut self) {
if ID.with(|id| *id) != self.id {
panic!("local task dropped by a thread that didn't spawn it");
}
unsafe {
ManuallyDrop::drop(&mut self.inner);
}
}
}

impl<F: Future> Future for Checked<F> {
type Output = F::Output;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if ID.with(|id| *id) != self.id {
panic!("local task polled by a thread that didn't spawn it");
}
unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
}
}

let future = Checked {
id: ID.with(|id| *id),
inner: ManuallyDrop::new(future),
};

let raw_task = RawTask::<_, R, S, T>::allocate(future, schedule, tag);
let task = Task {
raw_task,
_marker: PhantomData,
Expand Down