linera_core/
notifier.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5
6use linera_base::identifiers::ChainId;
7use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
8use tracing::trace;
9
10use crate::worker;
11
12// TODO(#2171): replace this with a Tokio broadcast channel
13
14/// A `Notifier` holds references to clients waiting to receive notifications
15/// from the validator.
16/// Clients will be evicted if their connections are terminated.
17pub struct ChannelNotifier<N> {
18    inner: papaya::HashMap<ChainId, Vec<UnboundedSender<N>>>,
19}
20
21impl<N> Default for ChannelNotifier<N> {
22    fn default() -> Self {
23        Self {
24            inner: papaya::HashMap::default(),
25        }
26    }
27}
28
29impl<N> ChannelNotifier<N> {
30    /// Registers a sender for notifications on the given chain IDs.
31    pub fn add_sender(&self, chain_ids: Vec<ChainId>, sender: &UnboundedSender<N>) {
32        let pinned = self.inner.pin();
33        for id in chain_ids {
34            pinned.update_or_insert_with(
35                id,
36                |senders| senders.iter().cloned().chain([sender.clone()]).collect(),
37                || vec![sender.clone()],
38            );
39        }
40    }
41
42    /// Creates a subscription given a collection of chain IDs and a sender to the client.
43    pub fn subscribe(&self, chain_ids: Vec<ChainId>) -> UnboundedReceiver<N> {
44        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
45        self.add_sender(chain_ids, &tx);
46        rx
47    }
48
49    /// Creates a subscription given a collection of chain IDs and a sender to the client.
50    /// Immediately posts a first notification as an ACK.
51    pub fn subscribe_with_ack(&self, chain_ids: Vec<ChainId>, ack: N) -> UnboundedReceiver<N> {
52        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
53        self.add_sender(chain_ids, &tx);
54        tx.send(ack)
55            .expect("pushing to a new channel should succeed");
56        rx
57    }
58}
59
60impl<N> ChannelNotifier<N>
61where
62    N: Clone,
63{
64    /// Notifies all the clients waiting for a notification from a given chain.
65    pub fn notify_chain(&self, chain_id: &ChainId, notification: &N) {
66        let pinned = self.inner.pin();
67
68        // Read senders outside of `compute` to avoid side effects in a
69        // retriable closure. papaya's `compute` may call its closure
70        // multiple times on CAS contention, so `send()` must not happen
71        // inside it.
72        let Some(senders) = pinned.get(chain_id).cloned() else {
73            trace!("Chain {chain_id} has no subscribers.");
74            return;
75        };
76
77        // Send notifications (side effect — must happen exactly once).
78        let mut has_dead = false;
79        for sender in &senders {
80            if sender.send(notification.clone()).is_err() {
81                has_dead = true;
82            }
83        }
84
85        // Clean up dead senders. The closure is pure: `is_closed()` is
86        // idempotent and has no side effects, so retries are safe.
87        if has_dead {
88            pinned.compute(*chain_id, |entry| {
89                let Some((_key, current_senders)) = entry else {
90                    return papaya::Operation::Abort(());
91                };
92                let live: Vec<_> = current_senders
93                    .iter()
94                    .filter(|s| !s.is_closed())
95                    .cloned()
96                    .collect();
97                if live.is_empty() {
98                    trace!("No more subscribers for chain {chain_id}. Removing entry.");
99                    papaya::Operation::Remove
100                } else {
101                    papaya::Operation::Insert(live)
102                }
103            });
104        }
105    }
106}
107
108pub trait Notifier: Clone + Send + 'static {
109    fn notify(&self, notifications: &[worker::Notification]);
110}
111
112impl Notifier for Arc<ChannelNotifier<worker::Notification>> {
113    fn notify(&self, notifications: &[worker::Notification]) {
114        for notification in notifications {
115            self.notify_chain(&notification.chain_id, notification);
116        }
117    }
118}
119
120impl Notifier for () {
121    fn notify(&self, _notifications: &[worker::Notification]) {}
122}
123
124#[cfg(with_testing)]
125impl Notifier for Arc<std::sync::Mutex<Vec<worker::Notification>>> {
126    fn notify(&self, notifications: &[worker::Notification]) {
127        let mut guard = self.lock().unwrap();
128        guard.extend(notifications.iter().cloned())
129    }
130}
131
132#[cfg(test)]
133pub mod tests {
134    use std::{
135        sync::{atomic::Ordering, Arc},
136        time::Duration,
137    };
138
139    use linera_execution::test_utils::dummy_chain_description;
140
141    use super::*;
142
143    #[test]
144    fn test_concurrent() {
145        let notifier = ChannelNotifier::default();
146
147        let chain_a = dummy_chain_description(0).id();
148        let chain_b = dummy_chain_description(1).id();
149
150        let a_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
151        let b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
152        let a_b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
153
154        let mut rx_a = notifier.subscribe(vec![chain_a]);
155        let mut rx_b = notifier.subscribe(vec![chain_b]);
156        let mut rx_a_b = notifier.subscribe(vec![chain_a, chain_b]);
157
158        let a_rec_clone = a_rec.clone();
159        let b_rec_clone = b_rec.clone();
160        let a_b_rec_clone = a_b_rec.clone();
161
162        let notifier = Arc::new(notifier);
163
164        std::thread::spawn(move || {
165            while rx_a.blocking_recv().is_some() {
166                a_rec_clone.fetch_add(1, Ordering::Relaxed);
167            }
168        });
169
170        std::thread::spawn(move || {
171            while rx_b.blocking_recv().is_some() {
172                b_rec_clone.fetch_add(1, Ordering::Relaxed);
173            }
174        });
175
176        std::thread::spawn(move || {
177            while rx_a_b.blocking_recv().is_some() {
178                a_b_rec_clone.fetch_add(1, Ordering::Relaxed);
179            }
180        });
181
182        const NOTIFICATIONS_A: usize = 500;
183        const NOTIFICATIONS_B: usize = 700;
184
185        let a_notifier = notifier.clone();
186        let handle_a = std::thread::spawn(move || {
187            for _ in 0..NOTIFICATIONS_A {
188                a_notifier.notify_chain(&chain_a, &());
189            }
190        });
191
192        let handle_b = std::thread::spawn(move || {
193            for _ in 0..NOTIFICATIONS_B {
194                notifier.notify_chain(&chain_b, &());
195            }
196        });
197
198        // finish sending all the messages
199        handle_a.join().unwrap();
200        handle_b.join().unwrap();
201
202        // give some time for the messages to be received.
203        std::thread::sleep(Duration::from_millis(100));
204
205        assert_eq!(a_rec.load(Ordering::Relaxed), NOTIFICATIONS_A);
206        assert_eq!(b_rec.load(Ordering::Relaxed), NOTIFICATIONS_B);
207        assert_eq!(
208            a_b_rec.load(Ordering::Relaxed),
209            NOTIFICATIONS_A + NOTIFICATIONS_B
210        );
211    }
212
213    #[test]
214    fn test_eviction() {
215        let notifier = ChannelNotifier::default();
216
217        let chain_a = dummy_chain_description(0).id();
218        let chain_b = dummy_chain_description(1).id();
219        let chain_c = dummy_chain_description(2).id();
220        let chain_d = dummy_chain_description(3).id();
221
222        // Chain A -> Notify A, Notify B
223        // Chain B -> Notify A, Notify B
224        // Chain C -> Notify C
225        // Chain D -> Notify A, Notify B, Notify C, Notify D
226
227        let mut rx_a = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
228        let mut rx_b = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
229        let mut rx_c = notifier.subscribe(vec![chain_c, chain_d]);
230        let mut rx_d = notifier.subscribe(vec![chain_d]);
231
232        assert_eq!(notifier.inner.len(), 4);
233
234        rx_c.close();
235        notifier.notify_chain(&chain_c, &());
236        assert_eq!(notifier.inner.len(), 3);
237
238        rx_a.close();
239        notifier.notify_chain(&chain_a, &());
240        assert_eq!(notifier.inner.len(), 3);
241
242        rx_b.close();
243        notifier.notify_chain(&chain_b, &());
244        assert_eq!(notifier.inner.len(), 2);
245
246        notifier.notify_chain(&chain_a, &());
247        assert_eq!(notifier.inner.len(), 1);
248
249        rx_d.close();
250        notifier.notify_chain(&chain_d, &());
251        assert_eq!(notifier.inner.len(), 0);
252    }
253}