1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
// Copyright (c) Zefchain Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

//! An extension trait to allow determining at compile time how tasks are spawned on the Tokio
//! runtime.
//!
//! In most cases the [`Future`] task to be spawned should implement [`Send`], but that's
//! not possible when compiling for the Web. In that case, the task is spawned on the
//! browser event loop.

use futures::channel::oneshot;

#[cfg(web)]
mod implementation {
    pub use futures::future::AbortHandle;
    use futures::{future, stream, StreamExt as _};

    use super::*;

    #[derive(Default)]
    pub struct JoinSet(Vec<oneshot::Receiver<()>>);

    /// An extension trait for the [`JoinSet`] type.
    pub trait JoinSetExt: Sized {
        /// Spawns a `future` task on this [`JoinSet`] using [`JoinSet::spawn_local`].
        ///
        /// Returns a [`oneshot::Receiver`] to receive the `future`'s output, and an
        /// [`AbortHandle`] to cancel execution of the task.
        fn spawn_task<F: Future + 'static>(&mut self, future: F) -> TaskHandle<F::Output>;

        /// Awaits all tasks spawned in this [`JoinSet`].
        fn await_all_tasks(&mut self) -> impl Future<Output = ()>;

        /// Reaps tasks that have finished.
        fn reap_finished_tasks(&mut self);
    }

    impl JoinSetExt for JoinSet {
        fn spawn_task<F: Future + 'static>(&mut self, future: F) -> TaskHandle<F::Output> {
            let (abort_handle, abort_registration) = AbortHandle::new_pair();
            let (send_done, recv_done) = oneshot::channel();
            let (send_output, recv_output) = oneshot::channel();
            let future = async move {
                let _ = send_output.send(future.await);
                let _ = send_done.send(());
            };
            self.0.push(recv_done);
            wasm_bindgen_futures::spawn_local(
                future::Abortable::new(future, abort_registration).map(drop),
            );

            TaskHandle {
                output_receiver: recv_output,
                abort_handle,
            }
        }

        async fn await_all_tasks(&mut self) {
            stream::iter(&mut self.0)
                .then(|x| x)
                .map(drop)
                .collect()
                .await
        }

        fn reap_finished_tasks(&mut self) {
            for mut done in self.0.drain(..) {
                while done.try_recv() == Ok(None) {}
            }
        }
    }
}

#[cfg(not(web))]
mod implementation {
    pub use tokio::task::AbortHandle;

    use super::*;

    pub type JoinSet = tokio::task::JoinSet<()>;

    /// An extension trait for the [`JoinSet`] type.
    #[trait_variant::make(Send)]
    pub trait JoinSetExt: Sized {
        /// Spawns a `future` task on this [`JoinSet`] using [`JoinSet::spawn`].
        ///
        /// Returns a [`oneshot::Receiver`] to receive the `future`'s output, and an
        /// [`AbortHandle`] to cancel execution of the task.
        fn spawn_task<F: Future<Output: Send> + Send + 'static>(
            &mut self,
            future: F,
        ) -> TaskHandle<F::Output>;

        /// Awaits all tasks spawned in this [`JoinSet`].
        async fn await_all_tasks(&mut self);

        /// Reaps tasks that have finished.
        fn reap_finished_tasks(&mut self);
    }

    impl JoinSetExt for JoinSet {
        fn spawn_task<F>(&mut self, future: F) -> TaskHandle<F::Output>
        where
            F: Future + Send + 'static,
            F::Output: Send,
        {
            let (output_sender, output_receiver) = oneshot::channel();

            let abort_handle = self.spawn(async move {
                let _ = output_sender.send(future.await);
            });

            TaskHandle {
                output_receiver,
                abort_handle,
            }
        }

        async fn await_all_tasks(&mut self) {
            while self.join_next().await.is_some() {}
        }

        fn reap_finished_tasks(&mut self) {
            while self.try_join_next().is_some() {}
        }
    }
}

use std::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};

use futures::FutureExt as _;
pub use implementation::*;

/// A handle to a task spawned with [`JoinSetExt`].
///
/// Dropping a handle detaches its respective task.
pub struct TaskHandle<Output> {
    output_receiver: oneshot::Receiver<Output>,
    abort_handle: AbortHandle,
}

impl<Output> Future for TaskHandle<Output> {
    type Output = Result<Output, oneshot::Canceled>;

    fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
        self.as_mut().output_receiver.poll_unpin(context)
    }
}

impl<Output> TaskHandle<Output> {
    /// Aborts the task.
    pub fn abort(&self) {
        self.abort_handle.abort();
    }

    /// Returns [`true`] if the task is still running.
    pub fn is_running(&mut self) -> bool {
        self.output_receiver.try_recv().is_err()
    }
}