linera_rpc/simple/
transport.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2// Copyright (c) Zefchain Labs, Inc.
3// SPDX-License-Identifier: Apache-2.0
4
5use std::{collections::HashMap, io, mem, net::SocketAddr, pin::pin, sync::Arc};
6
7use async_trait::async_trait;
8use futures::{
9    future,
10    stream::{self, FuturesUnordered, SplitSink, SplitStream},
11    Sink, SinkExt, Stream, StreamExt, TryStreamExt,
12};
13use linera_core::{JoinSetExt as _, TaskHandle};
14use serde::{Deserialize, Serialize};
15use tokio::{
16    io::AsyncWriteExt,
17    net::{lookup_host, TcpListener, TcpStream, ToSocketAddrs, UdpSocket},
18    sync::Mutex,
19    task::JoinSet,
20};
21use tokio_util::{codec::Framed, sync::CancellationToken, udp::UdpFramed};
22use tracing::{error, warn};
23
24use crate::{
25    simple::{codec, codec::Codec},
26    RpcMessage,
27};
28
29/// Suggested buffer size
30pub const DEFAULT_MAX_DATAGRAM_SIZE: &str = "65507";
31
32/// Number of tasks to spawn before attempting to reap some finished tasks to prevent memory leaks.
33const REAP_TASKS_THRESHOLD: usize = 100;
34
35// Supported transport protocols.
36#[derive(clap::ValueEnum, Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
37pub enum TransportProtocol {
38    Udp,
39    Tcp,
40}
41
42impl std::str::FromStr for TransportProtocol {
43    type Err = String;
44
45    fn from_str(s: &str) -> Result<Self, Self::Err> {
46        clap::ValueEnum::from_str(s, true)
47    }
48}
49
50impl std::fmt::Display for TransportProtocol {
51    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
52        write!(f, "{:?}", self)
53    }
54}
55
56impl TransportProtocol {
57    pub fn scheme(&self) -> &'static str {
58        match self {
59            TransportProtocol::Udp => "udp",
60            TransportProtocol::Tcp => "tcp",
61        }
62    }
63}
64
65/// A pool of (outgoing) data streams.
66pub trait ConnectionPool: Send {
67    fn send_message_to<'a>(
68        &'a mut self,
69        message: RpcMessage,
70        address: &'a str,
71    ) -> future::BoxFuture<'a, Result<(), codec::Error>>;
72}
73
74/// The handler required to create a service.
75///
76/// The implementation needs to implement [`Clone`] because a seed instance is used to generate
77/// cloned instances, where each cloned instance handles a single request. Multiple cloned instances
78/// may exist at the same time and handle separate requests concurrently.
79#[async_trait]
80pub trait MessageHandler: Clone {
81    async fn handle_message(&mut self, message: RpcMessage) -> Option<RpcMessage>;
82}
83
84/// The result of spawning a server is oneshot channel to track completion, and the set of
85/// executing tasks.
86pub struct ServerHandle {
87    pub handle: TaskHandle<Result<(), std::io::Error>>,
88}
89
90impl ServerHandle {
91    pub async fn join(self) -> Result<(), std::io::Error> {
92        self.handle.await.map_err(|_| {
93            std::io::Error::new(
94                std::io::ErrorKind::Interrupted,
95                "Server task did not finish successfully",
96            )
97        })?
98    }
99}
100
101/// A trait alias for a protocol transport.
102///
103/// A transport is an active connection that can be used to send and receive
104/// [`RpcMessage`]s.
105pub trait Transport:
106    Stream<Item = Result<RpcMessage, codec::Error>> + Sink<RpcMessage, Error = codec::Error>
107{
108}
109
110impl<T> Transport for T where
111    T: Stream<Item = Result<RpcMessage, codec::Error>> + Sink<RpcMessage, Error = codec::Error>
112{
113}
114
115impl TransportProtocol {
116    /// Creates a transport for this protocol.
117    pub async fn connect(
118        self,
119        address: impl ToSocketAddrs,
120    ) -> Result<impl Transport, std::io::Error> {
121        let mut addresses = lookup_host(address)
122            .await
123            .expect("Invalid address to connect to");
124        let address = addresses
125            .next()
126            .expect("Couldn't resolve address to connect to");
127
128        let stream: futures::future::Either<_, _> = match self {
129            TransportProtocol::Udp => {
130                let socket = UdpSocket::bind(&"0.0.0.0:0").await?;
131
132                UdpFramed::new(socket, Codec)
133                    .with(move |message| future::ready(Ok((message, address))))
134                    .map_ok(|(message, _address)| message)
135                    .left_stream()
136            }
137            TransportProtocol::Tcp => {
138                let stream = TcpStream::connect(address).await?;
139
140                Framed::new(stream, Codec).right_stream()
141            }
142        };
143
144        Ok(stream)
145    }
146
147    /// Creates a [`ConnectionPool`] for this protocol.
148    pub async fn make_outgoing_connection_pool(
149        self,
150    ) -> Result<Box<dyn ConnectionPool>, std::io::Error> {
151        let pool: Box<dyn ConnectionPool> = match self {
152            Self::Udp => Box::new(UdpConnectionPool::new().await?),
153            Self::Tcp => Box::new(TcpConnectionPool::new().await?),
154        };
155        Ok(pool)
156    }
157
158    /// Runs a server for this protocol and the given message handler.
159    pub fn spawn_server<S>(
160        self,
161        address: impl ToSocketAddrs + Send + 'static,
162        state: S,
163        shutdown_signal: CancellationToken,
164        join_set: &mut JoinSet<()>,
165    ) -> ServerHandle
166    where
167        S: MessageHandler + Send + 'static,
168    {
169        let handle = match self {
170            Self::Udp => join_set.spawn_task(UdpServer::run(address, state, shutdown_signal)),
171            Self::Tcp => join_set.spawn_task(TcpServer::run(address, state, shutdown_signal)),
172        };
173        ServerHandle { handle }
174    }
175}
176
177/// An implementation of [`ConnectionPool`] based on UDP.
178struct UdpConnectionPool {
179    transport: UdpFramed<Codec>,
180}
181
182impl UdpConnectionPool {
183    async fn new() -> Result<Self, std::io::Error> {
184        let socket = UdpSocket::bind(&"0.0.0.0:0").await?;
185        let transport = UdpFramed::new(socket, Codec);
186        Ok(Self { transport })
187    }
188}
189
190impl ConnectionPool for UdpConnectionPool {
191    fn send_message_to<'a>(
192        &'a mut self,
193        message: RpcMessage,
194        address: &'a str,
195    ) -> future::BoxFuture<'a, Result<(), codec::Error>> {
196        Box::pin(async move {
197            let address = address.parse().map_err(std::io::Error::other)?;
198            self.transport.send((message, address)).await
199        })
200    }
201}
202
203/// Server implementation for UDP.
204pub struct UdpServer<State> {
205    handler: State,
206    udp_sink: SharedUdpSink,
207    udp_stream: SplitStream<UdpFramed<Codec>>,
208    active_handlers: HashMap<SocketAddr, TaskHandle<()>>,
209    join_set: JoinSet<()>,
210}
211
212/// Type alias for the outgoing endpoint of UDP messages.
213type SharedUdpSink = Arc<Mutex<SplitSink<UdpFramed<Codec>, (RpcMessage, SocketAddr)>>>;
214
215impl<State> UdpServer<State>
216where
217    State: MessageHandler + Send + 'static,
218{
219    /// Runs the UDP server implementation.
220    pub async fn run(
221        address: impl ToSocketAddrs,
222        state: State,
223        shutdown_signal: CancellationToken,
224    ) -> Result<(), std::io::Error> {
225        let mut server = Self::bind(address, state).await?;
226
227        loop {
228            tokio::select! { biased;
229                _ = shutdown_signal.cancelled() => {
230                    server.shutdown().await;
231                    return Ok(());
232                }
233                result = server.udp_stream.next() => match result {
234                    Some(Ok((message, peer))) => server.handle_message(message, peer),
235                    Some(Err(error)) => server.handle_error(error).await?,
236                    None => unreachable!("`UdpFramed` should never return `None`"),
237                },
238            }
239        }
240    }
241
242    /// Creates a [`UdpServer`] bound to the provided `address`, handling messages using the
243    /// provided `handler`.
244    async fn bind(address: impl ToSocketAddrs, handler: State) -> Result<Self, std::io::Error> {
245        let socket = UdpSocket::bind(address).await?;
246        let (udp_sink, udp_stream) = UdpFramed::new(socket, Codec).split();
247
248        Ok(UdpServer {
249            handler,
250            udp_sink: Arc::new(Mutex::new(udp_sink)),
251            udp_stream,
252            active_handlers: HashMap::new(),
253            join_set: JoinSet::new(),
254        })
255    }
256
257    /// Spawns a task to handle a single incoming message.
258    fn handle_message(&mut self, message: RpcMessage, peer: SocketAddr) {
259        let previous_task = self.active_handlers.remove(&peer);
260        let mut state = self.handler.clone();
261        let udp_sink = self.udp_sink.clone();
262
263        let new_task = self.join_set.spawn_task(async move {
264            if let Some(reply) = state.handle_message(message).await {
265                if let Some(task) = previous_task {
266                    if let Err(error) = task.await {
267                        warn!("Message handler task panicked: {}", error);
268                    }
269                }
270                let status = udp_sink.lock().await.send((reply, peer)).await;
271                if let Err(error) = status {
272                    error!("Failed to send query response: {}", error);
273                }
274            }
275        });
276
277        self.active_handlers.insert(peer, new_task);
278
279        if self.active_handlers.len() >= REAP_TASKS_THRESHOLD {
280            // Collect finished tasks to avoid leaking memory.
281            self.active_handlers.retain(|_, task| task.is_running());
282            self.join_set.reap_finished_tasks();
283        }
284    }
285
286    /// Handles an error while receiving a message.
287    async fn handle_error(&mut self, error: codec::Error) -> Result<(), std::io::Error> {
288        match error {
289            codec::Error::IoError(io_error) => {
290                error!("I/O error in UDP server: {io_error}");
291                self.shutdown().await;
292                Err(io_error)
293            }
294            other_error => {
295                warn!("Received an invalid message: {other_error}");
296                Ok(())
297            }
298        }
299    }
300
301    /// Gracefully shuts down the server, waiting for existing tasks to finish.
302    async fn shutdown(&mut self) {
303        let handlers = mem::take(&mut self.active_handlers);
304        let mut handler_results = FuturesUnordered::from_iter(handlers.into_values());
305
306        while let Some(result) = handler_results.next().await {
307            if let Err(error) = result {
308                warn!("Message handler panicked: {}", error);
309            }
310        }
311
312        self.join_set.await_all_tasks().await;
313    }
314}
315
316/// An implementation of [`ConnectionPool`] based on TCP.
317struct TcpConnectionPool {
318    streams: HashMap<String, Framed<TcpStream, Codec>>,
319}
320
321impl TcpConnectionPool {
322    async fn new() -> Result<Self, std::io::Error> {
323        let streams = HashMap::new();
324        Ok(Self { streams })
325    }
326
327    async fn get_stream(
328        &mut self,
329        address: &str,
330    ) -> Result<&mut Framed<TcpStream, Codec>, io::Error> {
331        if !self.streams.contains_key(address) {
332            match TcpStream::connect(address).await {
333                Ok(s) => {
334                    self.streams
335                        .insert(address.to_string(), Framed::new(s, Codec));
336                }
337                Err(error) => {
338                    error!("Failed to open connection to {}: {}", address, error);
339                    return Err(error);
340                }
341            };
342        };
343        Ok(self.streams.get_mut(address).unwrap())
344    }
345}
346
347impl ConnectionPool for TcpConnectionPool {
348    fn send_message_to<'a>(
349        &'a mut self,
350        message: RpcMessage,
351        address: &'a str,
352    ) -> future::BoxFuture<'a, Result<(), codec::Error>> {
353        Box::pin(async move {
354            let stream = self.get_stream(address).await?;
355            let result = stream.send(message).await;
356            if result.is_err() {
357                self.streams.remove(address);
358            }
359            result
360        })
361    }
362}
363
364/// Server implementation for TCP.
365pub struct TcpServer<State> {
366    connection: Framed<TcpStream, Codec>,
367    handler: State,
368    shutdown_signal: CancellationToken,
369}
370
371impl<State> TcpServer<State>
372where
373    State: MessageHandler + Send + 'static,
374{
375    /// Runs the TCP server implementation.
376    ///
377    /// Listens for connections and spawns a task with a new [`TcpServer`] instance to serve that
378    /// client.
379    pub async fn run(
380        address: impl ToSocketAddrs,
381        handler: State,
382        shutdown_signal: CancellationToken,
383    ) -> Result<(), std::io::Error> {
384        let listener = TcpListener::bind(address).await?;
385
386        let mut accept_stream = stream::try_unfold(listener, |listener| async move {
387            let (socket, _) = listener.accept().await?;
388            Ok::<_, io::Error>(Some((socket, listener)))
389        });
390        let mut accept_stream = pin!(accept_stream);
391
392        let connection_shutdown_signal = shutdown_signal.child_token();
393        let mut join_set = JoinSet::new();
394        let mut reap_countdown = REAP_TASKS_THRESHOLD;
395
396        loop {
397            tokio::select! { biased;
398                _ = shutdown_signal.cancelled() => {
399                    join_set.await_all_tasks().await;
400                    return Ok(());
401                }
402                maybe_socket = accept_stream.next() => match maybe_socket {
403                    Some(Ok(socket)) => {
404                        let server = TcpServer::new_connection(
405                            socket,
406                            handler.clone(),
407                            connection_shutdown_signal.clone(),
408                        );
409                        join_set.spawn_task(server.serve());
410                        reap_countdown -= 1;
411                    }
412                    Some(Err(error)) => {
413                        join_set.await_all_tasks().await;
414                        return Err(error);
415                    }
416                    None => unreachable!(
417                        "The `accept_stream` should never finish unless there's an error",
418                    ),
419                },
420            }
421
422            if reap_countdown == 0 {
423                join_set.reap_finished_tasks();
424                reap_countdown = REAP_TASKS_THRESHOLD;
425            }
426        }
427    }
428
429    /// Creates a new [`TcpServer`] to serve a single connection established on the provided
430    /// [`TcpStream`].
431    fn new_connection(
432        tcp_stream: TcpStream,
433        handler: State,
434        shutdown_signal: CancellationToken,
435    ) -> Self {
436        TcpServer {
437            connection: Framed::new(tcp_stream, Codec),
438            handler,
439            shutdown_signal,
440        }
441    }
442
443    /// Serves a client through a single connection.
444    async fn serve(mut self) {
445        loop {
446            tokio::select! { biased;
447                _ = self.shutdown_signal.cancelled() => {
448                    let mut tcp_stream = self.connection.into_inner();
449                    if let Err(error) = tcp_stream.shutdown().await {
450                        let peer = tcp_stream
451                            .peer_addr()
452                            .map(|address| address.to_string())
453                            .unwrap_or_else(|_| "an unknown peer".to_owned());
454                        warn!("Failed to close connection to {peer}: {error:?}");
455                    }
456                    return;
457                }
458                result = self.connection.next() => match result {
459                    Some(Ok(message)) => self.handle_message(message).await,
460                    Some(Err(error)) => {
461                        self.handle_error(error);
462                        return;
463                    }
464                    None => break,
465                },
466            }
467        }
468    }
469
470    /// Handles a single request message from a client.
471    async fn handle_message(&mut self, message: RpcMessage) {
472        if let Some(reply) = self.handler.handle_message(message).await {
473            if let Err(error) = self.connection.send(reply).await {
474                error!("Failed to send query response: {error}");
475            }
476        }
477    }
478
479    /// Handles an error received while attempting to receive from the connection.
480    ///
481    /// Ignores a successful connection termination, while logging an unexpected connection
482    /// termination or any other error.
483    fn handle_error(&self, error: codec::Error) {
484        if !matches!(
485            &error,
486            codec::Error::IoError(error)
487                if error.kind() == io::ErrorKind::UnexpectedEof
488                || error.kind() == io::ErrorKind::ConnectionReset
489        ) {
490            error!("Error while reading TCP stream: {error}");
491        }
492    }
493}