linera_rpc/simple/
transport.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2// Copyright (c) Zefchain Labs, Inc.
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{
6    collections::HashMap,
7    io, mem,
8    net::SocketAddr,
9    pin::{pin, Pin},
10    sync::Arc,
11};
12
13use async_trait::async_trait;
14use futures::{
15    future,
16    stream::{self, FuturesUnordered, SplitSink, SplitStream},
17    Sink, SinkExt, Stream, StreamExt, TryStreamExt,
18};
19use linera_base::identifiers::ChainId;
20use linera_core::{JoinSetExt as _, TaskHandle};
21use serde::{Deserialize, Serialize};
22use tokio::{
23    io::AsyncWriteExt,
24    net::{lookup_host, TcpListener, TcpStream, ToSocketAddrs, UdpSocket},
25    sync::Mutex,
26    task::JoinSet,
27};
28use tokio_util::{codec::Framed, sync::CancellationToken, udp::UdpFramed};
29use tracing::{error, warn};
30
31use crate::{
32    simple::{codec, codec::Codec},
33    RpcMessage,
34};
35
36/// Suggested buffer size
37pub const DEFAULT_MAX_DATAGRAM_SIZE: &str = "65507";
38
39/// Number of tasks to spawn before attempting to reap some finished tasks to prevent memory leaks.
40const REAP_TASKS_THRESHOLD: usize = 100;
41
42// Supported transport protocols.
43#[derive(clap::ValueEnum, Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
44pub enum TransportProtocol {
45    Udp,
46    Tcp,
47}
48
49impl std::str::FromStr for TransportProtocol {
50    type Err = String;
51
52    fn from_str(s: &str) -> Result<Self, Self::Err> {
53        clap::ValueEnum::from_str(s, true)
54    }
55}
56
57impl std::fmt::Display for TransportProtocol {
58    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
59        write!(f, "{:?}", self)
60    }
61}
62
63impl TransportProtocol {
64    pub fn scheme(&self) -> &'static str {
65        match self {
66            TransportProtocol::Udp => "udp",
67            TransportProtocol::Tcp => "tcp",
68        }
69    }
70}
71
72/// A pool of (outgoing) data streams.
73pub trait ConnectionPool: Send {
74    fn send_message_to<'a>(
75        &'a mut self,
76        message: RpcMessage,
77        address: &'a str,
78    ) -> future::BoxFuture<'a, Result<(), codec::Error>>;
79}
80
81/// The handler required to create a service.
82///
83/// The implementation needs to implement [`Clone`] because a seed instance is used to generate
84/// cloned instances, where each cloned instance handles a single request. Multiple cloned instances
85/// may exist at the same time and handle separate requests concurrently.
86#[async_trait]
87pub trait MessageHandler: Clone {
88    async fn handle_message(&mut self, message: RpcMessage) -> Option<RpcMessage>;
89
90    /// Handle a notification subscription request. Returns a stream of notification
91    /// messages if supported, or `None` if subscriptions are not supported.
92    async fn handle_subscribe(
93        &mut self,
94        _chains: Vec<ChainId>,
95    ) -> Option<Pin<Box<dyn Stream<Item = RpcMessage> + Send>>> {
96        None
97    }
98}
99
100/// The result of spawning a server is oneshot channel to track completion, and the set of
101/// executing tasks.
102pub struct ServerHandle {
103    pub handle: TaskHandle<Result<(), std::io::Error>>,
104}
105
106impl ServerHandle {
107    pub async fn join(self) -> Result<(), std::io::Error> {
108        self.handle.await.map_err(|_| {
109            std::io::Error::new(
110                std::io::ErrorKind::Interrupted,
111                "Server task did not finish successfully",
112            )
113        })?
114    }
115}
116
117/// A trait alias for a protocol transport.
118///
119/// A transport is an active connection that can be used to send and receive
120/// [`RpcMessage`]s.
121pub trait Transport:
122    Stream<Item = Result<RpcMessage, codec::Error>> + Sink<RpcMessage, Error = codec::Error>
123{
124}
125
126impl<T> Transport for T where
127    T: Stream<Item = Result<RpcMessage, codec::Error>> + Sink<RpcMessage, Error = codec::Error>
128{
129}
130
131impl TransportProtocol {
132    /// Creates a transport for this protocol.
133    pub async fn connect(
134        self,
135        address: impl ToSocketAddrs,
136    ) -> Result<impl Transport, std::io::Error> {
137        let mut addresses = lookup_host(address)
138            .await
139            .expect("Invalid address to connect to");
140        let address = addresses
141            .next()
142            .expect("Couldn't resolve address to connect to");
143
144        let stream: futures::future::Either<_, _> = match self {
145            TransportProtocol::Udp => {
146                let socket = UdpSocket::bind(&"0.0.0.0:0").await?;
147
148                UdpFramed::new(socket, Codec)
149                    .with(move |message| future::ready(Ok((message, address))))
150                    .map_ok(|(message, _address)| message)
151                    .left_stream()
152            }
153            TransportProtocol::Tcp => {
154                let stream = TcpStream::connect(address).await?;
155
156                Framed::new(stream, Codec).right_stream()
157            }
158        };
159
160        Ok(stream)
161    }
162
163    /// Creates a [`ConnectionPool`] for this protocol.
164    pub async fn make_outgoing_connection_pool(
165        self,
166    ) -> Result<Box<dyn ConnectionPool>, std::io::Error> {
167        let pool: Box<dyn ConnectionPool> = match self {
168            Self::Udp => Box::new(UdpConnectionPool::new().await?),
169            Self::Tcp => Box::new(TcpConnectionPool::new()),
170        };
171        Ok(pool)
172    }
173
174    /// Runs a server for this protocol and the given message handler.
175    pub fn spawn_server<S>(
176        self,
177        address: impl ToSocketAddrs + Send + 'static,
178        state: S,
179        shutdown_signal: CancellationToken,
180        join_set: &mut JoinSet<()>,
181    ) -> ServerHandle
182    where
183        S: MessageHandler + Send + 'static,
184    {
185        let handle = match self {
186            Self::Udp => join_set.spawn_task(UdpServer::run(address, state, shutdown_signal)),
187            Self::Tcp => join_set.spawn_task(TcpServer::run(address, state, shutdown_signal)),
188        };
189        ServerHandle { handle }
190    }
191}
192
193/// An implementation of [`ConnectionPool`] based on UDP.
194struct UdpConnectionPool {
195    transport: UdpFramed<Codec>,
196}
197
198impl UdpConnectionPool {
199    async fn new() -> Result<Self, std::io::Error> {
200        let socket = UdpSocket::bind(&"0.0.0.0:0").await?;
201        let transport = UdpFramed::new(socket, Codec);
202        Ok(Self { transport })
203    }
204}
205
206impl ConnectionPool for UdpConnectionPool {
207    fn send_message_to<'a>(
208        &'a mut self,
209        message: RpcMessage,
210        address: &'a str,
211    ) -> future::BoxFuture<'a, Result<(), codec::Error>> {
212        Box::pin(async move {
213            let address = address.parse().map_err(std::io::Error::other)?;
214            self.transport.send((message, address)).await
215        })
216    }
217}
218
219/// Server implementation for UDP.
220pub struct UdpServer<State> {
221    handler: State,
222    udp_sink: SharedUdpSink,
223    udp_stream: SplitStream<UdpFramed<Codec>>,
224    active_handlers: HashMap<SocketAddr, TaskHandle<()>>,
225    join_set: JoinSet<()>,
226}
227
228/// Type alias for the outgoing endpoint of UDP messages.
229type SharedUdpSink = Arc<Mutex<SplitSink<UdpFramed<Codec>, (RpcMessage, SocketAddr)>>>;
230
231impl<State> UdpServer<State>
232where
233    State: MessageHandler + Send + 'static,
234{
235    /// Runs the UDP server implementation.
236    pub async fn run(
237        address: impl ToSocketAddrs,
238        state: State,
239        shutdown_signal: CancellationToken,
240    ) -> Result<(), std::io::Error> {
241        let mut server = Self::bind(address, state).await?;
242
243        loop {
244            tokio::select! { biased;
245                _ = shutdown_signal.cancelled() => {
246                    server.shutdown().await;
247                    return Ok(());
248                }
249                result = server.udp_stream.next() => match result {
250                    Some(Ok((message, peer))) => server.handle_message(message, peer),
251                    Some(Err(error)) => server.handle_error(error).await?,
252                    None => unreachable!("`UdpFramed` should never return `None`"),
253                },
254            }
255        }
256    }
257
258    /// Creates a [`UdpServer`] bound to the provided `address`, handling messages using the
259    /// provided `handler`.
260    async fn bind(address: impl ToSocketAddrs, handler: State) -> Result<Self, std::io::Error> {
261        let socket = UdpSocket::bind(address).await?;
262        let (udp_sink, udp_stream) = UdpFramed::new(socket, Codec).split();
263
264        Ok(UdpServer {
265            handler,
266            udp_sink: Arc::new(Mutex::new(udp_sink)),
267            udp_stream,
268            active_handlers: HashMap::new(),
269            join_set: JoinSet::new(),
270        })
271    }
272
273    /// Spawns a task to handle a single incoming message.
274    fn handle_message(&mut self, message: RpcMessage, peer: SocketAddr) {
275        let previous_task = self.active_handlers.remove(&peer);
276        let mut state = self.handler.clone();
277        let udp_sink = self.udp_sink.clone();
278
279        let new_task = self.join_set.spawn_task(async move {
280            if let Some(reply) = state.handle_message(message).await {
281                if let Some(task) = previous_task {
282                    if let Err(error) = task.await {
283                        warn!("Message handler task panicked: {}", error);
284                    }
285                }
286                let status = udp_sink.lock().await.send((reply, peer)).await;
287                if let Err(error) = status {
288                    error!("Failed to send query response: {}", error);
289                }
290            }
291        });
292
293        self.active_handlers.insert(peer, new_task);
294
295        if self.active_handlers.len() >= REAP_TASKS_THRESHOLD {
296            // Collect finished tasks to avoid leaking memory.
297            self.active_handlers.retain(|_, task| task.is_running());
298            self.join_set.reap_finished_tasks();
299        }
300    }
301
302    /// Handles an error while receiving a message.
303    async fn handle_error(&mut self, error: codec::Error) -> Result<(), std::io::Error> {
304        match error {
305            codec::Error::IoError(io_error) => {
306                error!("I/O error in UDP server: {io_error}");
307                self.shutdown().await;
308                Err(io_error)
309            }
310            other_error => {
311                warn!("Received an invalid message: {other_error}");
312                Ok(())
313            }
314        }
315    }
316
317    /// Gracefully shuts down the server, waiting for existing tasks to finish.
318    async fn shutdown(&mut self) {
319        let handlers = mem::take(&mut self.active_handlers);
320        let mut handler_results = handlers.into_values().collect::<FuturesUnordered<_>>();
321
322        while let Some(result) = handler_results.next().await {
323            if let Err(error) = result {
324                warn!("Message handler panicked: {}", error);
325            }
326        }
327
328        self.join_set.await_all_tasks().await;
329    }
330}
331
332/// An implementation of [`ConnectionPool`] based on TCP.
333struct TcpConnectionPool {
334    streams: HashMap<String, Framed<TcpStream, Codec>>,
335}
336
337impl TcpConnectionPool {
338    fn new() -> Self {
339        let streams = HashMap::new();
340        Self { streams }
341    }
342
343    async fn get_stream(
344        &mut self,
345        address: &str,
346    ) -> Result<&mut Framed<TcpStream, Codec>, io::Error> {
347        if !self.streams.contains_key(address) {
348            match TcpStream::connect(address).await {
349                Ok(s) => {
350                    self.streams
351                        .insert(address.to_string(), Framed::new(s, Codec));
352                }
353                Err(error) => {
354                    error!("Failed to open connection to {}: {}", address, error);
355                    return Err(error);
356                }
357            };
358        };
359        Ok(self.streams.get_mut(address).unwrap())
360    }
361}
362
363impl ConnectionPool for TcpConnectionPool {
364    fn send_message_to<'a>(
365        &'a mut self,
366        message: RpcMessage,
367        address: &'a str,
368    ) -> future::BoxFuture<'a, Result<(), codec::Error>> {
369        Box::pin(async move {
370            let stream = self.get_stream(address).await?;
371            let result = stream.send(message).await;
372            if result.is_err() {
373                self.streams.remove(address);
374            }
375            result
376        })
377    }
378}
379
380/// Server implementation for TCP.
381pub struct TcpServer<State> {
382    connection: Framed<TcpStream, Codec>,
383    handler: State,
384    shutdown_signal: CancellationToken,
385}
386
387impl<State> TcpServer<State>
388where
389    State: MessageHandler + Send + 'static,
390{
391    /// Runs the TCP server implementation.
392    ///
393    /// Listens for connections and spawns a task with a new [`TcpServer`] instance to serve that
394    /// client.
395    pub async fn run(
396        address: impl ToSocketAddrs,
397        handler: State,
398        shutdown_signal: CancellationToken,
399    ) -> Result<(), std::io::Error> {
400        let listener = TcpListener::bind(address).await?;
401
402        let accept_stream = stream::try_unfold(listener, |listener| async move {
403            let (socket, _) = listener.accept().await?;
404            Ok::<_, io::Error>(Some((socket, listener)))
405        });
406        let mut accept_stream = pin!(accept_stream);
407
408        let connection_shutdown_signal = shutdown_signal.child_token();
409        let mut join_set = JoinSet::new();
410        let mut reap_countdown = REAP_TASKS_THRESHOLD;
411
412        loop {
413            tokio::select! { biased;
414                _ = shutdown_signal.cancelled() => {
415                    join_set.await_all_tasks().await;
416                    return Ok(());
417                }
418                maybe_socket = accept_stream.next() => match maybe_socket {
419                    Some(Ok(socket)) => {
420                        let server = TcpServer::new_connection(
421                            socket,
422                            handler.clone(),
423                            connection_shutdown_signal.clone(),
424                        );
425                        join_set.spawn_task(server.serve());
426                        reap_countdown -= 1;
427                    }
428                    Some(Err(error)) => {
429                        join_set.await_all_tasks().await;
430                        return Err(error);
431                    }
432                    None => unreachable!(
433                        "The `accept_stream` should never finish unless there's an error",
434                    ),
435                },
436            }
437
438            if reap_countdown == 0 {
439                join_set.reap_finished_tasks();
440                reap_countdown = REAP_TASKS_THRESHOLD;
441            }
442        }
443    }
444
445    /// Creates a new [`TcpServer`] to serve a single connection established on the provided
446    /// [`TcpStream`].
447    fn new_connection(
448        tcp_stream: TcpStream,
449        handler: State,
450        shutdown_signal: CancellationToken,
451    ) -> Self {
452        TcpServer {
453            connection: Framed::new(tcp_stream, Codec),
454            handler,
455            shutdown_signal,
456        }
457    }
458
459    /// Serves a client through a single connection.
460    async fn serve(mut self) {
461        loop {
462            tokio::select! { biased;
463                _ = self.shutdown_signal.cancelled() => {
464                    let mut tcp_stream = self.connection.into_inner();
465                    if let Err(error) = tcp_stream.shutdown().await {
466                        let peer = tcp_stream
467                            .peer_addr()
468                            .map_or_else(|_| "an unknown peer".to_owned(), |address| address.to_string());
469                        warn!("Failed to close connection to {peer}: {error:?}");
470                    }
471                    return;
472                }
473                result = self.connection.next() => match result {
474                    Some(Ok(RpcMessage::SubscribeNotifications(chains))) => {
475                        self.handle_subscription(chains).await;
476                        return;
477                    }
478                    Some(Ok(message)) => self.handle_message(message).await,
479                    Some(Err(error)) => {
480                        Self::handle_error(&error);
481                        return;
482                    }
483                    None => break,
484                },
485            }
486        }
487    }
488
489    /// Handles a single request message from a client.
490    async fn handle_message(&mut self, message: RpcMessage) {
491        if let Some(reply) = self.handler.handle_message(message).await {
492            if let Err(error) = self.connection.send(reply).await {
493                error!("Failed to send query response: {error}");
494            }
495        }
496    }
497
498    /// Handles a notification subscription request by switching to streaming mode.
499    async fn handle_subscription(&mut self, chains: Vec<ChainId>) {
500        let Some(mut stream) = self.handler.handle_subscribe(chains).await else {
501            return;
502        };
503        loop {
504            tokio::select! { biased;
505                _ = self.shutdown_signal.cancelled() => break,
506                msg = stream.next() => match msg {
507                    Some(notification) => {
508                        if let Err(error) = self.connection.send(notification).await {
509                            error!("Failed to send notification: {error}");
510                            break;
511                        }
512                    }
513                    None => break,
514                }
515            }
516        }
517    }
518
519    /// Handles an error received while attempting to receive from the connection.
520    ///
521    /// Ignores a successful connection termination, while logging an unexpected connection
522    /// termination or any other error.
523    fn handle_error(error: &codec::Error) {
524        if !matches!(
525            error,
526            codec::Error::IoError(error)
527                if error.kind() == io::ErrorKind::UnexpectedEof
528                || error.kind() == io::ErrorKind::ConnectionReset
529        ) {
530            error!("Error while reading TCP stream: {error}");
531        }
532    }
533}