linera_core/
notifier.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5
6use linera_base::identifiers::ChainId;
7use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
8use tracing::trace;
9
10use crate::worker;
11
12// TODO(#2171): replace this with a Tokio broadcast channel
13
14/// A `Notifier` holds references to clients waiting to receive notifications
15/// from the validator.
16/// Clients will be evicted if their connections are terminated.
17pub struct ChannelNotifier<N> {
18    inner: papaya::HashMap<ChainId, Vec<UnboundedSender<N>>>,
19}
20
21impl<N> Default for ChannelNotifier<N> {
22    fn default() -> Self {
23        Self {
24            inner: papaya::HashMap::default(),
25        }
26    }
27}
28
29impl<N> ChannelNotifier<N> {
30    fn add_sender(&self, chain_ids: Vec<ChainId>, sender: &UnboundedSender<N>) {
31        let pinned = self.inner.pin();
32        for id in chain_ids {
33            pinned.update_or_insert_with(
34                id,
35                |senders| senders.iter().cloned().chain([sender.clone()]).collect(),
36                || vec![sender.clone()],
37            );
38        }
39    }
40
41    /// Creates a subscription given a collection of chain IDs and a sender to the client.
42    pub fn subscribe(&self, chain_ids: Vec<ChainId>) -> UnboundedReceiver<N> {
43        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
44        self.add_sender(chain_ids, &tx);
45        rx
46    }
47
48    /// Creates a subscription given a collection of chain IDs and a sender to the client.
49    /// Immediately posts a first notification as an ACK.
50    pub fn subscribe_with_ack(&self, chain_ids: Vec<ChainId>, ack: N) -> UnboundedReceiver<N> {
51        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
52        self.add_sender(chain_ids, &tx);
53        tx.send(ack)
54            .expect("pushing to a new channel should succeed");
55        rx
56    }
57}
58
59impl<N> ChannelNotifier<N>
60where
61    N: Clone,
62{
63    /// Notifies all the clients waiting for a notification from a given chain.
64    pub fn notify_chain(&self, chain_id: &ChainId, notification: &N) {
65        self.inner.pin().compute(*chain_id, |senders| {
66            let Some((_key, senders)) = senders else {
67                trace!("Chain {chain_id:?} has no subscribers.");
68                return papaya::Operation::Abort(());
69            };
70            let live_senders = senders
71                .iter()
72                .filter(|sender| sender.send(notification.clone()).is_ok())
73                .cloned()
74                .collect::<Vec<_>>();
75            if live_senders.is_empty() {
76                trace!("No more subscribers for chain {chain_id:?}. Removing entry.");
77                return papaya::Operation::Remove;
78            }
79            papaya::Operation::Insert(live_senders)
80        });
81    }
82}
83
84pub trait Notifier: Clone + Send + 'static {
85    fn notify(&self, notifications: &[worker::Notification]);
86}
87
88impl Notifier for Arc<ChannelNotifier<worker::Notification>> {
89    fn notify(&self, notifications: &[worker::Notification]) {
90        for notification in notifications {
91            self.notify_chain(&notification.chain_id, notification);
92        }
93    }
94}
95
96impl Notifier for () {
97    fn notify(&self, _notifications: &[worker::Notification]) {}
98}
99
100#[cfg(with_testing)]
101impl Notifier for Arc<std::sync::Mutex<Vec<worker::Notification>>> {
102    fn notify(&self, notifications: &[worker::Notification]) {
103        let mut guard = self.lock().unwrap();
104        guard.extend(notifications.iter().cloned())
105    }
106}
107
108#[cfg(test)]
109pub mod tests {
110    use std::{
111        sync::{atomic::Ordering, Arc},
112        time::Duration,
113    };
114
115    use linera_execution::test_utils::dummy_chain_description;
116
117    use super::*;
118
119    #[test]
120    fn test_concurrent() {
121        let notifier = ChannelNotifier::default();
122
123        let chain_a = dummy_chain_description(0).id();
124        let chain_b = dummy_chain_description(1).id();
125
126        let a_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
127        let b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
128        let a_b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
129
130        let mut rx_a = notifier.subscribe(vec![chain_a]);
131        let mut rx_b = notifier.subscribe(vec![chain_b]);
132        let mut rx_a_b = notifier.subscribe(vec![chain_a, chain_b]);
133
134        let a_rec_clone = a_rec.clone();
135        let b_rec_clone = b_rec.clone();
136        let a_b_rec_clone = a_b_rec.clone();
137
138        let notifier = Arc::new(notifier);
139
140        std::thread::spawn(move || {
141            while rx_a.blocking_recv().is_some() {
142                a_rec_clone.fetch_add(1, Ordering::Relaxed);
143            }
144        });
145
146        std::thread::spawn(move || {
147            while rx_b.blocking_recv().is_some() {
148                b_rec_clone.fetch_add(1, Ordering::Relaxed);
149            }
150        });
151
152        std::thread::spawn(move || {
153            while rx_a_b.blocking_recv().is_some() {
154                a_b_rec_clone.fetch_add(1, Ordering::Relaxed);
155            }
156        });
157
158        const NOTIFICATIONS_A: usize = 500;
159        const NOTIFICATIONS_B: usize = 700;
160
161        let a_notifier = notifier.clone();
162        let handle_a = std::thread::spawn(move || {
163            for _ in 0..NOTIFICATIONS_A {
164                a_notifier.notify_chain(&chain_a, &());
165            }
166        });
167
168        let handle_b = std::thread::spawn(move || {
169            for _ in 0..NOTIFICATIONS_B {
170                notifier.notify_chain(&chain_b, &());
171            }
172        });
173
174        // finish sending all the messages
175        handle_a.join().unwrap();
176        handle_b.join().unwrap();
177
178        // give some time for the messages to be received.
179        std::thread::sleep(Duration::from_millis(100));
180
181        assert_eq!(a_rec.load(Ordering::Relaxed), NOTIFICATIONS_A);
182        assert_eq!(b_rec.load(Ordering::Relaxed), NOTIFICATIONS_B);
183        assert_eq!(
184            a_b_rec.load(Ordering::Relaxed),
185            NOTIFICATIONS_A + NOTIFICATIONS_B
186        );
187    }
188
189    #[test]
190    fn test_eviction() {
191        let notifier = ChannelNotifier::default();
192
193        let chain_a = dummy_chain_description(0).id();
194        let chain_b = dummy_chain_description(1).id();
195        let chain_c = dummy_chain_description(2).id();
196        let chain_d = dummy_chain_description(3).id();
197
198        // Chain A -> Notify A, Notify B
199        // Chain B -> Notify A, Notify B
200        // Chain C -> Notify C
201        // Chain D -> Notify A, Notify B, Notify C, Notify D
202
203        let mut rx_a = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
204        let mut rx_b = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
205        let mut rx_c = notifier.subscribe(vec![chain_c, chain_d]);
206        let mut rx_d = notifier.subscribe(vec![chain_d]);
207
208        assert_eq!(notifier.inner.len(), 4);
209
210        rx_c.close();
211        notifier.notify_chain(&chain_c, &());
212        assert_eq!(notifier.inner.len(), 3);
213
214        rx_a.close();
215        notifier.notify_chain(&chain_a, &());
216        assert_eq!(notifier.inner.len(), 3);
217
218        rx_b.close();
219        notifier.notify_chain(&chain_b, &());
220        assert_eq!(notifier.inner.len(), 2);
221
222        notifier.notify_chain(&chain_a, &());
223        assert_eq!(notifier.inner.len(), 1);
224
225        rx_d.close();
226        notifier.notify_chain(&chain_d, &());
227        assert_eq!(notifier.inner.len(), 0);
228    }
229}