1use std::{collections::HashMap, io, mem, net::SocketAddr, pin::pin, sync::Arc};
6
7use async_trait::async_trait;
8use futures::{
9 future,
10 stream::{self, FuturesUnordered, SplitSink, SplitStream},
11 Sink, SinkExt, Stream, StreamExt, TryStreamExt,
12};
13use linera_core::{JoinSetExt as _, TaskHandle};
14use serde::{Deserialize, Serialize};
15use tokio::{
16 io::AsyncWriteExt,
17 net::{lookup_host, TcpListener, TcpStream, ToSocketAddrs, UdpSocket},
18 sync::Mutex,
19 task::JoinSet,
20};
21use tokio_util::{codec::Framed, sync::CancellationToken, udp::UdpFramed};
22use tracing::{error, warn};
23
24use crate::{
25 simple::{codec, codec::Codec},
26 RpcMessage,
27};
28
29pub const DEFAULT_MAX_DATAGRAM_SIZE: &str = "65507";
31
32const REAP_TASKS_THRESHOLD: usize = 100;
34
35#[derive(clap::ValueEnum, Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
37pub enum TransportProtocol {
38 Udp,
39 Tcp,
40}
41
42impl std::str::FromStr for TransportProtocol {
43 type Err = String;
44
45 fn from_str(s: &str) -> Result<Self, Self::Err> {
46 clap::ValueEnum::from_str(s, true)
47 }
48}
49
50impl std::fmt::Display for TransportProtocol {
51 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
52 write!(f, "{:?}", self)
53 }
54}
55
56impl TransportProtocol {
57 pub fn scheme(&self) -> &'static str {
58 match self {
59 TransportProtocol::Udp => "udp",
60 TransportProtocol::Tcp => "tcp",
61 }
62 }
63}
64
65pub trait ConnectionPool: Send {
67 fn send_message_to<'a>(
68 &'a mut self,
69 message: RpcMessage,
70 address: &'a str,
71 ) -> future::BoxFuture<'a, Result<(), codec::Error>>;
72}
73
74#[async_trait]
80pub trait MessageHandler: Clone {
81 async fn handle_message(&mut self, message: RpcMessage) -> Option<RpcMessage>;
82}
83
84pub struct ServerHandle {
87 pub handle: TaskHandle<Result<(), std::io::Error>>,
88}
89
90impl ServerHandle {
91 pub async fn join(self) -> Result<(), std::io::Error> {
92 self.handle.await.map_err(|_| {
93 std::io::Error::new(
94 std::io::ErrorKind::Interrupted,
95 "Server task did not finish successfully",
96 )
97 })?
98 }
99}
100
101pub trait Transport:
106 Stream<Item = Result<RpcMessage, codec::Error>> + Sink<RpcMessage, Error = codec::Error>
107{
108}
109
110impl<T> Transport for T where
111 T: Stream<Item = Result<RpcMessage, codec::Error>> + Sink<RpcMessage, Error = codec::Error>
112{
113}
114
115impl TransportProtocol {
116 pub async fn connect(
118 self,
119 address: impl ToSocketAddrs,
120 ) -> Result<impl Transport, std::io::Error> {
121 let mut addresses = lookup_host(address)
122 .await
123 .expect("Invalid address to connect to");
124 let address = addresses
125 .next()
126 .expect("Couldn't resolve address to connect to");
127
128 let stream: futures::future::Either<_, _> = match self {
129 TransportProtocol::Udp => {
130 let socket = UdpSocket::bind(&"0.0.0.0:0").await?;
131
132 UdpFramed::new(socket, Codec)
133 .with(move |message| future::ready(Ok((message, address))))
134 .map_ok(|(message, _address)| message)
135 .left_stream()
136 }
137 TransportProtocol::Tcp => {
138 let stream = TcpStream::connect(address).await?;
139
140 Framed::new(stream, Codec).right_stream()
141 }
142 };
143
144 Ok(stream)
145 }
146
147 pub async fn make_outgoing_connection_pool(
149 self,
150 ) -> Result<Box<dyn ConnectionPool>, std::io::Error> {
151 let pool: Box<dyn ConnectionPool> = match self {
152 Self::Udp => Box::new(UdpConnectionPool::new().await?),
153 Self::Tcp => Box::new(TcpConnectionPool::new().await?),
154 };
155 Ok(pool)
156 }
157
158 pub fn spawn_server<S>(
160 self,
161 address: impl ToSocketAddrs + Send + 'static,
162 state: S,
163 shutdown_signal: CancellationToken,
164 join_set: &mut JoinSet<()>,
165 ) -> ServerHandle
166 where
167 S: MessageHandler + Send + 'static,
168 {
169 let handle = match self {
170 Self::Udp => join_set.spawn_task(UdpServer::run(address, state, shutdown_signal)),
171 Self::Tcp => join_set.spawn_task(TcpServer::run(address, state, shutdown_signal)),
172 };
173 ServerHandle { handle }
174 }
175}
176
177struct UdpConnectionPool {
179 transport: UdpFramed<Codec>,
180}
181
182impl UdpConnectionPool {
183 async fn new() -> Result<Self, std::io::Error> {
184 let socket = UdpSocket::bind(&"0.0.0.0:0").await?;
185 let transport = UdpFramed::new(socket, Codec);
186 Ok(Self { transport })
187 }
188}
189
190impl ConnectionPool for UdpConnectionPool {
191 fn send_message_to<'a>(
192 &'a mut self,
193 message: RpcMessage,
194 address: &'a str,
195 ) -> future::BoxFuture<'a, Result<(), codec::Error>> {
196 Box::pin(async move {
197 let address = address.parse().map_err(std::io::Error::other)?;
198 self.transport.send((message, address)).await
199 })
200 }
201}
202
203pub struct UdpServer<State> {
205 handler: State,
206 udp_sink: SharedUdpSink,
207 udp_stream: SplitStream<UdpFramed<Codec>>,
208 active_handlers: HashMap<SocketAddr, TaskHandle<()>>,
209 join_set: JoinSet<()>,
210}
211
212type SharedUdpSink = Arc<Mutex<SplitSink<UdpFramed<Codec>, (RpcMessage, SocketAddr)>>>;
214
215impl<State> UdpServer<State>
216where
217 State: MessageHandler + Send + 'static,
218{
219 pub async fn run(
221 address: impl ToSocketAddrs,
222 state: State,
223 shutdown_signal: CancellationToken,
224 ) -> Result<(), std::io::Error> {
225 let mut server = Self::bind(address, state).await?;
226
227 loop {
228 tokio::select! { biased;
229 _ = shutdown_signal.cancelled() => {
230 server.shutdown().await;
231 return Ok(());
232 }
233 result = server.udp_stream.next() => match result {
234 Some(Ok((message, peer))) => server.handle_message(message, peer),
235 Some(Err(error)) => server.handle_error(error).await?,
236 None => unreachable!("`UdpFramed` should never return `None`"),
237 },
238 }
239 }
240 }
241
242 async fn bind(address: impl ToSocketAddrs, handler: State) -> Result<Self, std::io::Error> {
245 let socket = UdpSocket::bind(address).await?;
246 let (udp_sink, udp_stream) = UdpFramed::new(socket, Codec).split();
247
248 Ok(UdpServer {
249 handler,
250 udp_sink: Arc::new(Mutex::new(udp_sink)),
251 udp_stream,
252 active_handlers: HashMap::new(),
253 join_set: JoinSet::new(),
254 })
255 }
256
257 fn handle_message(&mut self, message: RpcMessage, peer: SocketAddr) {
259 let previous_task = self.active_handlers.remove(&peer);
260 let mut state = self.handler.clone();
261 let udp_sink = self.udp_sink.clone();
262
263 let new_task = self.join_set.spawn_task(async move {
264 if let Some(reply) = state.handle_message(message).await {
265 if let Some(task) = previous_task {
266 if let Err(error) = task.await {
267 warn!("Message handler task panicked: {}", error);
268 }
269 }
270 let status = udp_sink.lock().await.send((reply, peer)).await;
271 if let Err(error) = status {
272 error!("Failed to send query response: {}", error);
273 }
274 }
275 });
276
277 self.active_handlers.insert(peer, new_task);
278
279 if self.active_handlers.len() >= REAP_TASKS_THRESHOLD {
280 self.active_handlers.retain(|_, task| task.is_running());
282 self.join_set.reap_finished_tasks();
283 }
284 }
285
286 async fn handle_error(&mut self, error: codec::Error) -> Result<(), std::io::Error> {
288 match error {
289 codec::Error::IoError(io_error) => {
290 error!("I/O error in UDP server: {io_error}");
291 self.shutdown().await;
292 Err(io_error)
293 }
294 other_error => {
295 warn!("Received an invalid message: {other_error}");
296 Ok(())
297 }
298 }
299 }
300
301 async fn shutdown(&mut self) {
303 let handlers = mem::take(&mut self.active_handlers);
304 let mut handler_results = FuturesUnordered::from_iter(handlers.into_values());
305
306 while let Some(result) = handler_results.next().await {
307 if let Err(error) = result {
308 warn!("Message handler panicked: {}", error);
309 }
310 }
311
312 self.join_set.await_all_tasks().await;
313 }
314}
315
316struct TcpConnectionPool {
318 streams: HashMap<String, Framed<TcpStream, Codec>>,
319}
320
321impl TcpConnectionPool {
322 async fn new() -> Result<Self, std::io::Error> {
323 let streams = HashMap::new();
324 Ok(Self { streams })
325 }
326
327 async fn get_stream(
328 &mut self,
329 address: &str,
330 ) -> Result<&mut Framed<TcpStream, Codec>, io::Error> {
331 if !self.streams.contains_key(address) {
332 match TcpStream::connect(address).await {
333 Ok(s) => {
334 self.streams
335 .insert(address.to_string(), Framed::new(s, Codec));
336 }
337 Err(error) => {
338 error!("Failed to open connection to {}: {}", address, error);
339 return Err(error);
340 }
341 };
342 };
343 Ok(self.streams.get_mut(address).unwrap())
344 }
345}
346
347impl ConnectionPool for TcpConnectionPool {
348 fn send_message_to<'a>(
349 &'a mut self,
350 message: RpcMessage,
351 address: &'a str,
352 ) -> future::BoxFuture<'a, Result<(), codec::Error>> {
353 Box::pin(async move {
354 let stream = self.get_stream(address).await?;
355 let result = stream.send(message).await;
356 if result.is_err() {
357 self.streams.remove(address);
358 }
359 result
360 })
361 }
362}
363
364pub struct TcpServer<State> {
366 connection: Framed<TcpStream, Codec>,
367 handler: State,
368 shutdown_signal: CancellationToken,
369}
370
371impl<State> TcpServer<State>
372where
373 State: MessageHandler + Send + 'static,
374{
375 pub async fn run(
380 address: impl ToSocketAddrs,
381 handler: State,
382 shutdown_signal: CancellationToken,
383 ) -> Result<(), std::io::Error> {
384 let listener = TcpListener::bind(address).await?;
385
386 let mut accept_stream = stream::try_unfold(listener, |listener| async move {
387 let (socket, _) = listener.accept().await?;
388 Ok::<_, io::Error>(Some((socket, listener)))
389 });
390 let mut accept_stream = pin!(accept_stream);
391
392 let connection_shutdown_signal = shutdown_signal.child_token();
393 let mut join_set = JoinSet::new();
394 let mut reap_countdown = REAP_TASKS_THRESHOLD;
395
396 loop {
397 tokio::select! { biased;
398 _ = shutdown_signal.cancelled() => {
399 join_set.await_all_tasks().await;
400 return Ok(());
401 }
402 maybe_socket = accept_stream.next() => match maybe_socket {
403 Some(Ok(socket)) => {
404 let server = TcpServer::new_connection(
405 socket,
406 handler.clone(),
407 connection_shutdown_signal.clone(),
408 );
409 join_set.spawn_task(server.serve());
410 reap_countdown -= 1;
411 }
412 Some(Err(error)) => {
413 join_set.await_all_tasks().await;
414 return Err(error);
415 }
416 None => unreachable!(
417 "The `accept_stream` should never finish unless there's an error",
418 ),
419 },
420 }
421
422 if reap_countdown == 0 {
423 join_set.reap_finished_tasks();
424 reap_countdown = REAP_TASKS_THRESHOLD;
425 }
426 }
427 }
428
429 fn new_connection(
432 tcp_stream: TcpStream,
433 handler: State,
434 shutdown_signal: CancellationToken,
435 ) -> Self {
436 TcpServer {
437 connection: Framed::new(tcp_stream, Codec),
438 handler,
439 shutdown_signal,
440 }
441 }
442
443 async fn serve(mut self) {
445 loop {
446 tokio::select! { biased;
447 _ = self.shutdown_signal.cancelled() => {
448 let mut tcp_stream = self.connection.into_inner();
449 if let Err(error) = tcp_stream.shutdown().await {
450 let peer = tcp_stream
451 .peer_addr()
452 .map(|address| address.to_string())
453 .unwrap_or_else(|_| "an unknown peer".to_owned());
454 warn!("Failed to close connection to {peer}: {error:?}");
455 }
456 return;
457 }
458 result = self.connection.next() => match result {
459 Some(Ok(message)) => self.handle_message(message).await,
460 Some(Err(error)) => {
461 self.handle_error(error);
462 return;
463 }
464 None => break,
465 },
466 }
467 }
468 }
469
470 async fn handle_message(&mut self, message: RpcMessage) {
472 if let Some(reply) = self.handler.handle_message(message).await {
473 if let Err(error) = self.connection.send(reply).await {
474 error!("Failed to send query response: {error}");
475 }
476 }
477 }
478
479 fn handle_error(&self, error: codec::Error) {
484 if !matches!(
485 &error,
486 codec::Error::IoError(error)
487 if error.kind() == io::ErrorKind::UnexpectedEof
488 || error.kind() == io::ErrorKind::ConnectionReset
489 ) {
490 error!("Error while reading TCP stream: {error}");
491 }
492 }
493}