linera_core/
notifier.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5
6use dashmap::DashMap;
7use linera_base::identifiers::ChainId;
8use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
9use tracing::trace;
10
11use crate::worker;
12
13// TODO(#2171): replace this with a Tokio broadcast channel
14
15/// A `Notifier` holds references to clients waiting to receive notifications
16/// from the validator.
17/// Clients will be evicted if their connections are terminated.
18pub struct ChannelNotifier<N> {
19    inner: DashMap<ChainId, Vec<UnboundedSender<N>>>,
20}
21
22impl<N> Default for ChannelNotifier<N> {
23    fn default() -> Self {
24        Self {
25            inner: DashMap::default(),
26        }
27    }
28}
29
30impl<N> ChannelNotifier<N> {
31    fn add_sender(&self, chain_ids: Vec<ChainId>, sender: &UnboundedSender<N>) {
32        for id in chain_ids {
33            let mut senders = self.inner.entry(id).or_default();
34            senders.push(sender.clone());
35        }
36    }
37
38    /// Creates a subscription given a collection of chain IDs and a sender to the client.
39    pub fn subscribe(&self, chain_ids: Vec<ChainId>) -> UnboundedReceiver<N> {
40        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
41        self.add_sender(chain_ids, &tx);
42        rx
43    }
44
45    /// Creates a subscription given a collection of chain IDs and a sender to the client.
46    /// Immediately posts a first notification as an ACK.
47    pub fn subscribe_with_ack(&self, chain_ids: Vec<ChainId>, ack: N) -> UnboundedReceiver<N> {
48        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
49        self.add_sender(chain_ids, &tx);
50        tx.send(ack)
51            .expect("pushing to a new channel should succeed");
52        rx
53    }
54}
55
56impl<N> ChannelNotifier<N>
57where
58    N: Clone,
59{
60    /// Notifies all the clients waiting for a notification from a given chain.
61    pub fn notify_chain(&self, chain_id: &ChainId, notification: &N) {
62        let senders_is_empty = {
63            let Some(mut senders) = self.inner.get_mut(chain_id) else {
64                trace!("Chain {chain_id:?} has no subscribers.");
65                return;
66            };
67            let mut dead_senders = vec![];
68            let senders = senders.value_mut();
69
70            for (index, sender) in senders.iter_mut().enumerate() {
71                if sender.send(notification.clone()).is_err() {
72                    dead_senders.push(index);
73                }
74            }
75
76            for index in dead_senders.into_iter().rev() {
77                trace!("Removed dead subscriber for chain {chain_id:?}.");
78                senders.remove(index);
79            }
80
81            senders.is_empty()
82        };
83
84        if senders_is_empty {
85            trace!("No more subscribers for chain {chain_id:?}. Removing entry.");
86            self.inner.remove(chain_id);
87        }
88    }
89}
90
91pub trait Notifier: Clone + Send + 'static {
92    fn notify(&self, notifications: &[worker::Notification]);
93}
94
95impl Notifier for Arc<ChannelNotifier<worker::Notification>> {
96    fn notify(&self, notifications: &[worker::Notification]) {
97        for notification in notifications {
98            self.notify_chain(&notification.chain_id, notification);
99        }
100    }
101}
102
103impl Notifier for () {
104    fn notify(&self, _notifications: &[worker::Notification]) {}
105}
106
107#[cfg(with_testing)]
108impl Notifier for Arc<std::sync::Mutex<Vec<worker::Notification>>> {
109    fn notify(&self, notifications: &[worker::Notification]) {
110        let mut guard = self.lock().unwrap();
111        guard.extend(notifications.iter().cloned())
112    }
113}
114
115#[cfg(test)]
116pub mod tests {
117    use std::{
118        sync::{atomic::Ordering, Arc},
119        time::Duration,
120    };
121
122    use linera_execution::test_utils::dummy_chain_description;
123
124    use super::*;
125
126    #[test]
127    fn test_concurrent() {
128        let notifier = ChannelNotifier::default();
129
130        let chain_a = dummy_chain_description(0).id();
131        let chain_b = dummy_chain_description(1).id();
132
133        let a_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
134        let b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
135        let a_b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
136
137        let mut rx_a = notifier.subscribe(vec![chain_a]);
138        let mut rx_b = notifier.subscribe(vec![chain_b]);
139        let mut rx_a_b = notifier.subscribe(vec![chain_a, chain_b]);
140
141        let a_rec_clone = a_rec.clone();
142        let b_rec_clone = b_rec.clone();
143        let a_b_rec_clone = a_b_rec.clone();
144
145        let notifier = Arc::new(notifier);
146
147        std::thread::spawn(move || {
148            while rx_a.blocking_recv().is_some() {
149                a_rec_clone.fetch_add(1, Ordering::Relaxed);
150            }
151        });
152
153        std::thread::spawn(move || {
154            while rx_b.blocking_recv().is_some() {
155                b_rec_clone.fetch_add(1, Ordering::Relaxed);
156            }
157        });
158
159        std::thread::spawn(move || {
160            while rx_a_b.blocking_recv().is_some() {
161                a_b_rec_clone.fetch_add(1, Ordering::Relaxed);
162            }
163        });
164
165        const NOTIFICATIONS_A: usize = 500;
166        const NOTIFICATIONS_B: usize = 700;
167
168        let a_notifier = notifier.clone();
169        let handle_a = std::thread::spawn(move || {
170            for _ in 0..NOTIFICATIONS_A {
171                a_notifier.notify_chain(&chain_a, &());
172            }
173        });
174
175        let handle_b = std::thread::spawn(move || {
176            for _ in 0..NOTIFICATIONS_B {
177                notifier.notify_chain(&chain_b, &());
178            }
179        });
180
181        // finish sending all the messages
182        handle_a.join().unwrap();
183        handle_b.join().unwrap();
184
185        // give some time for the messages to be received.
186        std::thread::sleep(Duration::from_millis(100));
187
188        assert_eq!(a_rec.load(Ordering::Relaxed), NOTIFICATIONS_A);
189        assert_eq!(b_rec.load(Ordering::Relaxed), NOTIFICATIONS_B);
190        assert_eq!(
191            a_b_rec.load(Ordering::Relaxed),
192            NOTIFICATIONS_A + NOTIFICATIONS_B
193        );
194    }
195
196    #[test]
197    fn test_eviction() {
198        let notifier = ChannelNotifier::default();
199
200        let chain_a = dummy_chain_description(0).id();
201        let chain_b = dummy_chain_description(1).id();
202        let chain_c = dummy_chain_description(2).id();
203        let chain_d = dummy_chain_description(3).id();
204
205        // Chain A -> Notify A, Notify B
206        // Chain B -> Notify A, Notify B
207        // Chain C -> Notify C
208        // Chain D -> Notify A, Notify B, Notify C, Notify D
209
210        let mut rx_a = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
211        let mut rx_b = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
212        let mut rx_c = notifier.subscribe(vec![chain_c, chain_d]);
213        let mut rx_d = notifier.subscribe(vec![chain_d]);
214
215        assert_eq!(notifier.inner.len(), 4);
216
217        rx_c.close();
218        notifier.notify_chain(&chain_c, &());
219        assert_eq!(notifier.inner.len(), 3);
220
221        rx_a.close();
222        notifier.notify_chain(&chain_a, &());
223        assert_eq!(notifier.inner.len(), 3);
224
225        rx_b.close();
226        notifier.notify_chain(&chain_b, &());
227        assert_eq!(notifier.inner.len(), 2);
228
229        notifier.notify_chain(&chain_a, &());
230        assert_eq!(notifier.inner.len(), 1);
231
232        rx_d.close();
233        notifier.notify_chain(&chain_d, &());
234        assert_eq!(notifier.inner.len(), 0);
235    }
236}