linera_core/
join_set_ext.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4//! An extension trait to allow determining at compile time how tasks are spawned on the Tokio
5//! runtime.
6//!
7//! In most cases the [`Future`] task to be spawned should implement [`Send`], but that's
8//! not possible when compiling for the Web. In that case, the task is spawned on the
9//! browser event loop.
10
11use futures::channel::oneshot;
12
13#[cfg(web)]
14mod implementation {
15    pub use futures::future::AbortHandle;
16    use futures::{future, stream, StreamExt as _};
17
18    use super::*;
19
20    #[derive(Default)]
21    pub struct JoinSet(Vec<oneshot::Receiver<()>>);
22
23    /// An extension trait for the [`JoinSet`] type.
24    pub trait JoinSetExt: Sized {
25        /// Spawns a `future` task on this [`JoinSet`] using [`JoinSet::spawn_local`].
26        ///
27        /// Returns a [`oneshot::Receiver`] to receive the `future`'s output, and an
28        /// [`AbortHandle`] to cancel execution of the task.
29        fn spawn_task<F: Future + 'static>(&mut self, future: F) -> TaskHandle<F::Output>;
30
31        /// Awaits all tasks spawned in this [`JoinSet`].
32        fn await_all_tasks(&mut self) -> impl Future<Output = ()>;
33
34        /// Reaps tasks that have finished.
35        fn reap_finished_tasks(&mut self);
36    }
37
38    impl JoinSetExt for JoinSet {
39        fn spawn_task<F: Future + 'static>(&mut self, future: F) -> TaskHandle<F::Output> {
40            let (abort_handle, abort_registration) = AbortHandle::new_pair();
41            let (send_done, recv_done) = oneshot::channel();
42            let (send_output, recv_output) = oneshot::channel();
43            let future = async move {
44                let _ = send_output.send(future.await);
45                let _ = send_done.send(());
46            };
47            self.0.push(recv_done);
48            wasm_bindgen_futures::spawn_local(
49                future::Abortable::new(future, abort_registration).map(drop),
50            );
51
52            TaskHandle {
53                output_receiver: recv_output,
54                abort_handle,
55            }
56        }
57
58        async fn await_all_tasks(&mut self) {
59            stream::iter(&mut self.0)
60                .then(|x| x)
61                .map(drop)
62                .collect()
63                .await
64        }
65
66        fn reap_finished_tasks(&mut self) {
67            self.0.retain_mut(|task| task.try_recv() == Ok(None));
68        }
69    }
70}
71
72#[cfg(not(web))]
73mod implementation {
74    pub use tokio::task::AbortHandle;
75
76    use super::*;
77
78    pub type JoinSet = tokio::task::JoinSet<()>;
79
80    /// An extension trait for the [`JoinSet`] type.
81    #[trait_variant::make(Send)]
82    pub trait JoinSetExt: Sized {
83        /// Spawns a `future` task on this [`JoinSet`] using [`JoinSet::spawn`].
84        ///
85        /// Returns a [`oneshot::Receiver`] to receive the `future`'s output, and an
86        /// [`AbortHandle`] to cancel execution of the task.
87        fn spawn_task<F: Future<Output: Send> + Send + 'static>(
88            &mut self,
89            future: F,
90        ) -> TaskHandle<F::Output>;
91
92        /// Awaits all tasks spawned in this [`JoinSet`].
93        async fn await_all_tasks(&mut self);
94
95        /// Reaps tasks that have finished.
96        fn reap_finished_tasks(&mut self);
97    }
98
99    impl JoinSetExt for JoinSet {
100        fn spawn_task<F>(&mut self, future: F) -> TaskHandle<F::Output>
101        where
102            F: Future + Send + 'static,
103            F::Output: Send,
104        {
105            let (output_sender, output_receiver) = oneshot::channel();
106
107            let abort_handle = self.spawn(async move {
108                let _ = output_sender.send(future.await);
109            });
110
111            TaskHandle {
112                output_receiver,
113                abort_handle,
114            }
115        }
116
117        async fn await_all_tasks(&mut self) {
118            while self.join_next().await.is_some() {}
119        }
120
121        fn reap_finished_tasks(&mut self) {
122            while self.try_join_next().is_some() {}
123        }
124    }
125}
126
127use std::{
128    future::Future,
129    pin::Pin,
130    task::{Context, Poll},
131};
132
133use futures::FutureExt as _;
134pub use implementation::*;
135
136/// A handle to a task spawned with [`JoinSetExt`].
137///
138/// Dropping a handle detaches its respective task.
139pub struct TaskHandle<Output> {
140    output_receiver: oneshot::Receiver<Output>,
141    abort_handle: AbortHandle,
142}
143
144impl<Output> Future for TaskHandle<Output> {
145    type Output = Result<Output, oneshot::Canceled>;
146
147    fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
148        self.as_mut().output_receiver.poll_unpin(context)
149    }
150}
151
152impl<Output> TaskHandle<Output> {
153    /// Aborts the task.
154    pub fn abort(&self) {
155        self.abort_handle.abort();
156    }
157
158    /// Returns [`true`] if the task is still running.
159    pub fn is_running(&mut self) -> bool {
160        self.output_receiver.try_recv().is_err()
161    }
162}