linera_rpc/simple/
transport.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2// Copyright (c) Zefchain Labs, Inc.
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::HashMap,
7    io, mem,
8    net::SocketAddr,
9    pin::{pin, Pin},
10    sync::Arc,
11};
12
13use async_trait::async_trait;
14use futures::{
15    future,
16    stream::{self, FuturesUnordered, SplitSink, SplitStream},
17    Sink, SinkExt, Stream, StreamExt, TryStreamExt,
18};
19use linera_base::identifiers::{BlobId, ChainId};
20use linera_core::{JoinSetExt as _, TaskHandle};
21use serde::{Deserialize, Serialize};
22use tokio::{
23    io::AsyncWriteExt,
24    net::{lookup_host, TcpListener, TcpStream, ToSocketAddrs, UdpSocket},
25    sync::Mutex,
26    task::JoinSet,
27};
28use tokio_util::{codec::Framed, sync::CancellationToken, udp::UdpFramed};
29use tracing::{error, warn};
30
31use crate::{
32    simple::{codec, codec::Codec},
33    RpcMessage,
34};
35
36/// Suggested buffer size
37pub const DEFAULT_MAX_DATAGRAM_SIZE: &str = "65507";
38
39/// Number of tasks to spawn before attempting to reap some finished tasks to prevent memory leaks.
40const REAP_TASKS_THRESHOLD: usize = 100;
41
42// Supported transport protocols.
43#[derive(clap::ValueEnum, Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
44pub enum TransportProtocol {
45    Udp,
46    Tcp,
47}
48
49impl std::str::FromStr for TransportProtocol {
50    type Err = String;
51
52    fn from_str(s: &str) -> Result<Self, Self::Err> {
53        clap::ValueEnum::from_str(s, true)
54    }
55}
56
57impl std::fmt::Display for TransportProtocol {
58    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
59        write!(f, "{self:?}")
60    }
61}
62
63impl TransportProtocol {
64    pub fn scheme(&self) -> &'static str {
65        match self {
66            TransportProtocol::Udp => "udp",
67            TransportProtocol::Tcp => "tcp",
68        }
69    }
70}
71
72/// A pool of (outgoing) data streams.
73pub trait ConnectionPool: Send {
74    fn send_message_to<'a>(
75        &'a mut self,
76        message: RpcMessage,
77        address: &'a str,
78    ) -> future::BoxFuture<'a, Result<(), codec::Error>>;
79}
80
81/// The handler required to create a service.
82///
83/// The implementation needs to implement [`Clone`] because a seed instance is used to generate
84/// cloned instances, where each cloned instance handles a single request. Multiple cloned instances
85/// may exist at the same time and handle separate requests concurrently.
86#[async_trait]
87pub trait MessageHandler: Clone {
88    async fn handle_message(&mut self, message: RpcMessage) -> Option<RpcMessage>;
89
90    /// Handle a notification subscription request. Returns a stream of notification
91    /// messages if supported, or `None` if subscriptions are not supported.
92    async fn handle_subscribe(
93        &mut self,
94        _chains: Vec<ChainId>,
95    ) -> Option<Pin<Box<dyn Stream<Item = RpcMessage> + Send>>> {
96        None
97    }
98
99    /// Handle a batch blob download request by streaming one
100    /// `RpcMessage::DownloadBlobResponse` per requested blob ID.
101    /// Returns `None` if not supported.
102    async fn handle_download_blobs(
103        &mut self,
104        _blob_ids: Vec<BlobId>,
105    ) -> Option<Pin<Box<dyn Stream<Item = RpcMessage> + Send>>> {
106        None
107    }
108}
109
110/// The result of spawning a server is oneshot channel to track completion, and the set of
111/// executing tasks.
112pub struct ServerHandle {
113    pub handle: TaskHandle<Result<(), std::io::Error>>,
114}
115
116impl ServerHandle {
117    pub async fn join(self) -> Result<(), std::io::Error> {
118        self.handle.await.map_err(|_| {
119            std::io::Error::new(
120                std::io::ErrorKind::Interrupted,
121                "Server task did not finish successfully",
122            )
123        })?
124    }
125}
126
127/// A trait alias for a protocol transport.
128///
129/// A transport is an active connection that can be used to send and receive
130/// [`RpcMessage`]s.
131pub trait Transport:
132    Stream<Item = Result<RpcMessage, codec::Error>> + Sink<RpcMessage, Error = codec::Error>
133{
134}
135
136impl<T> Transport for T where
137    T: Stream<Item = Result<RpcMessage, codec::Error>> + Sink<RpcMessage, Error = codec::Error>
138{
139}
140
141impl TransportProtocol {
142    /// Creates a transport for this protocol.
143    pub async fn connect(
144        self,
145        address: impl ToSocketAddrs,
146    ) -> Result<impl Transport, std::io::Error> {
147        let mut addresses = lookup_host(address)
148            .await
149            .expect("Invalid address to connect to");
150        let address = addresses
151            .next()
152            .expect("Couldn't resolve address to connect to");
153
154        let stream: futures::future::Either<_, _> = match self {
155            TransportProtocol::Udp => {
156                let socket = UdpSocket::bind(&"0.0.0.0:0").await?;
157
158                UdpFramed::new(socket, Codec)
159                    .with(move |message| future::ready(Ok((message, address))))
160                    .map_ok(|(message, _address)| message)
161                    .left_stream()
162            }
163            TransportProtocol::Tcp => {
164                let stream = TcpStream::connect(address).await?;
165
166                Framed::new(stream, Codec).right_stream()
167            }
168        };
169
170        Ok(stream)
171    }
172
173    /// Creates a [`ConnectionPool`] for this protocol.
174    pub async fn make_outgoing_connection_pool(
175        self,
176    ) -> Result<Box<dyn ConnectionPool>, std::io::Error> {
177        let pool: Box<dyn ConnectionPool> = match self {
178            Self::Udp => Box::new(UdpConnectionPool::new().await?),
179            Self::Tcp => Box::new(TcpConnectionPool::new()),
180        };
181        Ok(pool)
182    }
183
184    /// Runs a server for this protocol and the given message handler.
185    pub fn spawn_server<S>(
186        self,
187        address: impl ToSocketAddrs + Send + 'static,
188        state: S,
189        shutdown_signal: CancellationToken,
190        join_set: &mut JoinSet<()>,
191    ) -> ServerHandle
192    where
193        S: MessageHandler + Send + 'static,
194    {
195        let handle = match self {
196            Self::Udp => join_set.spawn_task(UdpServer::run(address, state, shutdown_signal)),
197            Self::Tcp => join_set.spawn_task(TcpServer::run(address, state, shutdown_signal)),
198        };
199        ServerHandle { handle }
200    }
201}
202
203/// An implementation of [`ConnectionPool`] based on UDP.
204struct UdpConnectionPool {
205    transport: UdpFramed<Codec>,
206}
207
208impl UdpConnectionPool {
209    async fn new() -> Result<Self, std::io::Error> {
210        let socket = UdpSocket::bind(&"0.0.0.0:0").await?;
211        let transport = UdpFramed::new(socket, Codec);
212        Ok(Self { transport })
213    }
214}
215
216impl ConnectionPool for UdpConnectionPool {
217    fn send_message_to<'a>(
218        &'a mut self,
219        message: RpcMessage,
220        address: &'a str,
221    ) -> future::BoxFuture<'a, Result<(), codec::Error>> {
222        Box::pin(async move {
223            let address = address.parse().map_err(std::io::Error::other)?;
224            self.transport.send((message, address)).await
225        })
226    }
227}
228
229/// Server implementation for UDP.
230pub struct UdpServer<State> {
231    handler: State,
232    udp_sink: SharedUdpSink,
233    udp_stream: SplitStream<UdpFramed<Codec>>,
234    active_handlers: HashMap<SocketAddr, TaskHandle<()>>,
235    join_set: JoinSet<()>,
236}
237
238/// Type alias for the outgoing endpoint of UDP messages.
239type SharedUdpSink = Arc<Mutex<SplitSink<UdpFramed<Codec>, (RpcMessage, SocketAddr)>>>;
240
241impl<State> UdpServer<State>
242where
243    State: MessageHandler + Send + 'static,
244{
245    /// Runs the UDP server implementation.
246    pub async fn run(
247        address: impl ToSocketAddrs,
248        state: State,
249        shutdown_signal: CancellationToken,
250    ) -> Result<(), std::io::Error> {
251        let mut server = Self::bind(address, state).await?;
252
253        loop {
254            tokio::select! { biased;
255                _ = shutdown_signal.cancelled() => {
256                    server.shutdown().await;
257                    return Ok(());
258                }
259                result = server.udp_stream.next() => match result {
260                    Some(Ok((message, peer))) => server.handle_message(message, peer),
261                    Some(Err(error)) => server.handle_error(error).await?,
262                    None => unreachable!("`UdpFramed` should never return `None`"),
263                },
264            }
265        }
266    }
267
268    /// Creates a [`UdpServer`] bound to the provided `address`, handling messages using the
269    /// provided `handler`.
270    async fn bind(address: impl ToSocketAddrs, handler: State) -> Result<Self, std::io::Error> {
271        let socket = UdpSocket::bind(address).await?;
272        let (udp_sink, udp_stream) = UdpFramed::new(socket, Codec).split();
273
274        Ok(UdpServer {
275            handler,
276            udp_sink: Arc::new(Mutex::new(udp_sink)),
277            udp_stream,
278            active_handlers: HashMap::new(),
279            join_set: JoinSet::new(),
280        })
281    }
282
283    /// Spawns a task to handle a single incoming message.
284    fn handle_message(&mut self, message: RpcMessage, peer: SocketAddr) {
285        let previous_task = self.active_handlers.remove(&peer);
286        let mut state = self.handler.clone();
287        let udp_sink = self.udp_sink.clone();
288
289        let new_task = self.join_set.spawn_task(async move {
290            if let Some(reply) = state.handle_message(message).await {
291                if let Some(task) = previous_task {
292                    if let Err(error) = task.await {
293                        warn!("Message handler task panicked: {}", error);
294                    }
295                }
296                let status = udp_sink.lock().await.send((reply, peer)).await;
297                if let Err(error) = status {
298                    error!("Failed to send query response: {}", error);
299                }
300            }
301        });
302
303        self.active_handlers.insert(peer, new_task);
304
305        if self.active_handlers.len() >= REAP_TASKS_THRESHOLD {
306            // Collect finished tasks to avoid leaking memory.
307            self.active_handlers.retain(|_, task| task.is_running());
308            self.join_set.reap_finished_tasks();
309        }
310    }
311
312    /// Handles an error while receiving a message.
313    async fn handle_error(&mut self, error: codec::Error) -> Result<(), std::io::Error> {
314        match error {
315            codec::Error::IoError(io_error) => {
316                error!("I/O error in UDP server: {io_error}");
317                self.shutdown().await;
318                Err(io_error)
319            }
320            other_error => {
321                warn!("Received an invalid message: {other_error}");
322                Ok(())
323            }
324        }
325    }
326
327    /// Gracefully shuts down the server, waiting for existing tasks to finish.
328    async fn shutdown(&mut self) {
329        let handlers = mem::take(&mut self.active_handlers);
330        let mut handler_results = handlers.into_values().collect::<FuturesUnordered<_>>();
331
332        while let Some(result) = handler_results.next().await {
333            if let Err(error) = result {
334                warn!("Message handler panicked: {}", error);
335            }
336        }
337
338        self.join_set.await_all_tasks().await;
339    }
340}
341
342/// An implementation of [`ConnectionPool`] based on TCP.
343struct TcpConnectionPool {
344    streams: HashMap<String, Framed<TcpStream, Codec>>,
345}
346
347impl TcpConnectionPool {
348    fn new() -> Self {
349        let streams = HashMap::new();
350        Self { streams }
351    }
352
353    async fn get_stream(
354        &mut self,
355        address: &str,
356    ) -> Result<&mut Framed<TcpStream, Codec>, io::Error> {
357        if !self.streams.contains_key(address) {
358            match TcpStream::connect(address).await {
359                Ok(s) => {
360                    self.streams
361                        .insert(address.to_string(), Framed::new(s, Codec));
362                }
363                Err(error) => {
364                    error!("Failed to open connection to {}: {}", address, error);
365                    return Err(error);
366                }
367            };
368        };
369        Ok(self.streams.get_mut(address).unwrap())
370    }
371}
372
373impl ConnectionPool for TcpConnectionPool {
374    fn send_message_to<'a>(
375        &'a mut self,
376        message: RpcMessage,
377        address: &'a str,
378    ) -> future::BoxFuture<'a, Result<(), codec::Error>> {
379        Box::pin(async move {
380            let stream = self.get_stream(address).await?;
381            let result = stream.send(message).await;
382            if result.is_err() {
383                self.streams.remove(address);
384            }
385            result
386        })
387    }
388}
389
390/// Server implementation for TCP.
391pub struct TcpServer<State> {
392    connection: Framed<TcpStream, Codec>,
393    handler: State,
394    shutdown_signal: CancellationToken,
395}
396
397impl<State> TcpServer<State>
398where
399    State: MessageHandler + Send + 'static,
400{
401    /// Runs the TCP server implementation.
402    ///
403    /// Listens for connections and spawns a task with a new [`TcpServer`] instance to serve that
404    /// client.
405    pub async fn run(
406        address: impl ToSocketAddrs,
407        handler: State,
408        shutdown_signal: CancellationToken,
409    ) -> Result<(), std::io::Error> {
410        let listener = TcpListener::bind(address).await?;
411
412        let accept_stream = stream::try_unfold(listener, |listener| async move {
413            let (socket, _) = listener.accept().await?;
414            Ok::<_, io::Error>(Some((socket, listener)))
415        });
416        let mut accept_stream = pin!(accept_stream);
417
418        let connection_shutdown_signal = shutdown_signal.child_token();
419        let mut join_set = JoinSet::new();
420        let mut reap_countdown = REAP_TASKS_THRESHOLD;
421
422        loop {
423            tokio::select! { biased;
424                _ = shutdown_signal.cancelled() => {
425                    join_set.await_all_tasks().await;
426                    return Ok(());
427                }
428                maybe_socket = accept_stream.next() => match maybe_socket {
429                    Some(Ok(socket)) => {
430                        let server = TcpServer::new_connection(
431                            socket,
432                            handler.clone(),
433                            connection_shutdown_signal.clone(),
434                        );
435                        join_set.spawn_task(server.serve());
436                        reap_countdown -= 1;
437                    }
438                    Some(Err(error)) => {
439                        join_set.await_all_tasks().await;
440                        return Err(error);
441                    }
442                    None => unreachable!(
443                        "The `accept_stream` should never finish unless there's an error",
444                    ),
445                },
446            }
447
448            if reap_countdown == 0 {
449                join_set.reap_finished_tasks();
450                reap_countdown = REAP_TASKS_THRESHOLD;
451            }
452        }
453    }
454
455    /// Creates a new [`TcpServer`] to serve a single connection established on the provided
456    /// [`TcpStream`].
457    fn new_connection(
458        tcp_stream: TcpStream,
459        handler: State,
460        shutdown_signal: CancellationToken,
461    ) -> Self {
462        TcpServer {
463            connection: Framed::new(tcp_stream, Codec),
464            handler,
465            shutdown_signal,
466        }
467    }
468
469    /// Serves a client through a single connection.
470    async fn serve(mut self) {
471        loop {
472            tokio::select! { biased;
473                _ = self.shutdown_signal.cancelled() => {
474                    let mut tcp_stream = self.connection.into_inner();
475                    if let Err(error) = tcp_stream.shutdown().await {
476                        let peer = tcp_stream
477                            .peer_addr()
478                            .map_or_else(|_| "an unknown peer".to_owned(), |address| address.to_string());
479                        warn!("Failed to close connection to {peer}: {error:?}");
480                    }
481                    return;
482                }
483                result = self.connection.next() => match result {
484                    Some(Ok(RpcMessage::SubscribeNotifications(chains))) => {
485                        self.handle_subscription(chains).await;
486                        return;
487                    }
488                    Some(Ok(RpcMessage::DownloadBlobs(blob_ids))) => {
489                        self.handle_download_blobs(blob_ids).await;
490                        return;
491                    }
492                    Some(Ok(message)) => self.handle_message(message).await,
493                    Some(Err(error)) => {
494                        Self::handle_error(&error);
495                        return;
496                    }
497                    None => break,
498                },
499            }
500        }
501    }
502
503    /// Handles a single request message from a client.
504    async fn handle_message(&mut self, message: RpcMessage) {
505        if let Some(reply) = self.handler.handle_message(message).await {
506            if let Err(error) = self.connection.send(reply).await {
507                error!("Failed to send query response: {error}");
508            }
509        }
510    }
511
512    /// Handles a notification subscription request by switching to streaming mode.
513    async fn handle_subscription(&mut self, chains: Vec<ChainId>) {
514        let Some(mut stream) = self.handler.handle_subscribe(chains).await else {
515            return;
516        };
517        loop {
518            tokio::select! { biased;
519                _ = self.shutdown_signal.cancelled() => break,
520                msg = stream.next() => match msg {
521                    Some(notification) => {
522                        if let Err(error) = self.connection.send(notification).await {
523                            error!("Failed to send notification: {error}");
524                            break;
525                        }
526                    }
527                    None => break,
528                }
529            }
530        }
531    }
532
533    /// Handles a batch blob download request by streaming one response per blob.
534    async fn handle_download_blobs(&mut self, blob_ids: Vec<BlobId>) {
535        let Some(mut stream) = self.handler.handle_download_blobs(blob_ids).await else {
536            return;
537        };
538        loop {
539            tokio::select! { biased;
540                _ = self.shutdown_signal.cancelled() => break,
541                msg = stream.next() => match msg {
542                    Some(message) => {
543                        if let Err(error) = self.connection.send(message).await {
544                            error!("Failed to send blob response: {error}");
545                            break;
546                        }
547                    }
548                    None => break,
549                }
550            }
551        }
552    }
553
554    /// Handles an error received while attempting to receive from the connection.
555    ///
556    /// Ignores a successful connection termination, while logging an unexpected connection
557    /// termination or any other error.
558    fn handle_error(error: &codec::Error) {
559        if !matches!(
560            error,
561            codec::Error::IoError(error)
562                if error.kind() == io::ErrorKind::UnexpectedEof
563                || error.kind() == io::ErrorKind::ConnectionReset
564        ) {
565            error!("Error while reading TCP stream: {error}");
566        }
567    }
568}