1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
// Copyright (c) Zefchain Labs, Inc.
// SPDX-License-Identifier: Apache-2.0
//! An extension trait to allow determining at compile time how tasks are spawned on the Tokio
//! runtime.
//!
//! In most cases the [`Future`] task to be spawned should implement [`Send`], but that's
//! not possible when compiling for the Web. In that case, the task is spawned on the
//! browser event loop.
use futures::channel::oneshot;
#[cfg(web)]
mod implementation {
pub use futures::future::AbortHandle;
use futures::{future, stream, StreamExt as _};
use super::*;
#[derive(Default)]
pub struct JoinSet(Vec<oneshot::Receiver<()>>);
/// An extension trait for the [`JoinSet`] type.
pub trait JoinSetExt: Sized {
/// Spawns a `future` task on this [`JoinSet`] using [`JoinSet::spawn_local`].
///
/// Returns a [`oneshot::Receiver`] to receive the `future`'s output, and an
/// [`AbortHandle`] to cancel execution of the task.
fn spawn_task<F: Future + 'static>(&mut self, future: F) -> TaskHandle<F::Output>;
/// Awaits all tasks spawned in this [`JoinSet`].
fn await_all_tasks(&mut self) -> impl Future<Output = ()>;
/// Reaps tasks that have finished.
fn reap_finished_tasks(&mut self);
}
impl JoinSetExt for JoinSet {
fn spawn_task<F: Future + 'static>(&mut self, future: F) -> TaskHandle<F::Output> {
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let (send_done, recv_done) = oneshot::channel();
let (send_output, recv_output) = oneshot::channel();
let future = async move {
let _ = send_output.send(future.await);
let _ = send_done.send(());
};
self.0.push(recv_done);
wasm_bindgen_futures::spawn_local(
future::Abortable::new(future, abort_registration).map(drop),
);
TaskHandle {
output_receiver: recv_output,
abort_handle,
}
}
async fn await_all_tasks(&mut self) {
stream::iter(&mut self.0)
.then(|x| x)
.map(drop)
.collect()
.await
}
fn reap_finished_tasks(&mut self) {
for mut done in self.0.drain(..) {
while done.try_recv() == Ok(None) {}
}
}
}
}
#[cfg(not(web))]
mod implementation {
pub use tokio::task::AbortHandle;
use super::*;
pub type JoinSet = tokio::task::JoinSet<()>;
/// An extension trait for the [`JoinSet`] type.
#[trait_variant::make(Send)]
pub trait JoinSetExt: Sized {
/// Spawns a `future` task on this [`JoinSet`] using [`JoinSet::spawn`].
///
/// Returns a [`oneshot::Receiver`] to receive the `future`'s output, and an
/// [`AbortHandle`] to cancel execution of the task.
fn spawn_task<F: Future<Output: Send> + Send + 'static>(
&mut self,
future: F,
) -> TaskHandle<F::Output>;
/// Awaits all tasks spawned in this [`JoinSet`].
async fn await_all_tasks(&mut self);
/// Reaps tasks that have finished.
fn reap_finished_tasks(&mut self);
}
impl JoinSetExt for JoinSet {
fn spawn_task<F>(&mut self, future: F) -> TaskHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send,
{
let (output_sender, output_receiver) = oneshot::channel();
let abort_handle = self.spawn(async move {
let _ = output_sender.send(future.await);
});
TaskHandle {
output_receiver,
abort_handle,
}
}
async fn await_all_tasks(&mut self) {
while self.join_next().await.is_some() {}
}
fn reap_finished_tasks(&mut self) {
while self.try_join_next().is_some() {}
}
}
}
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use futures::FutureExt as _;
pub use implementation::*;
/// A handle to a task spawned with [`JoinSetExt`].
///
/// Dropping a handle detaches its respective task.
pub struct TaskHandle<Output> {
output_receiver: oneshot::Receiver<Output>,
abort_handle: AbortHandle,
}
impl<Output> Future for TaskHandle<Output> {
type Output = Result<Output, oneshot::Canceled>;
fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
self.as_mut().output_receiver.poll_unpin(context)
}
}
impl<Output> TaskHandle<Output> {
/// Aborts the task.
pub fn abort(&self) {
self.abort_handle.abort();
}
/// Returns [`true`] if the task is still running.
pub fn is_running(&mut self) -> bool {
self.output_receiver.try_recv().is_err()
}
}