1use std::{
6 collections::HashMap,
7 io, mem,
8 net::SocketAddr,
9 pin::{pin, Pin},
10 sync::Arc,
11};
12
13use async_trait::async_trait;
14use futures::{
15 future,
16 stream::{self, FuturesUnordered, SplitSink, SplitStream},
17 Sink, SinkExt, Stream, StreamExt, TryStreamExt,
18};
19use linera_base::identifiers::{BlobId, ChainId};
20use linera_core::{JoinSetExt as _, TaskHandle};
21use serde::{Deserialize, Serialize};
22use tokio::{
23 io::AsyncWriteExt,
24 net::{lookup_host, TcpListener, TcpStream, ToSocketAddrs, UdpSocket},
25 sync::Mutex,
26 task::JoinSet,
27};
28use tokio_util::{codec::Framed, sync::CancellationToken, udp::UdpFramed};
29use tracing::{error, warn};
30
31use crate::{
32 simple::{codec, codec::Codec},
33 RpcMessage,
34};
35
36pub const DEFAULT_MAX_DATAGRAM_SIZE: &str = "65507";
38
39const REAP_TASKS_THRESHOLD: usize = 100;
41
42#[derive(clap::ValueEnum, Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
44pub enum TransportProtocol {
45 Udp,
46 Tcp,
47}
48
49impl std::str::FromStr for TransportProtocol {
50 type Err = String;
51
52 fn from_str(s: &str) -> Result<Self, Self::Err> {
53 clap::ValueEnum::from_str(s, true)
54 }
55}
56
57impl std::fmt::Display for TransportProtocol {
58 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
59 write!(f, "{self:?}")
60 }
61}
62
63impl TransportProtocol {
64 pub fn scheme(&self) -> &'static str {
65 match self {
66 TransportProtocol::Udp => "udp",
67 TransportProtocol::Tcp => "tcp",
68 }
69 }
70}
71
72pub trait ConnectionPool: Send {
74 fn send_message_to<'a>(
75 &'a mut self,
76 message: RpcMessage,
77 address: &'a str,
78 ) -> future::BoxFuture<'a, Result<(), codec::Error>>;
79}
80
81#[async_trait]
87pub trait MessageHandler: Clone {
88 async fn handle_message(&mut self, message: RpcMessage) -> Option<RpcMessage>;
89
90 async fn handle_subscribe(
93 &mut self,
94 _chains: Vec<ChainId>,
95 ) -> Option<Pin<Box<dyn Stream<Item = RpcMessage> + Send>>> {
96 None
97 }
98
99 async fn handle_download_blobs(
103 &mut self,
104 _blob_ids: Vec<BlobId>,
105 ) -> Option<Pin<Box<dyn Stream<Item = RpcMessage> + Send>>> {
106 None
107 }
108}
109
110pub struct ServerHandle {
113 pub handle: TaskHandle<Result<(), std::io::Error>>,
114}
115
116impl ServerHandle {
117 pub async fn join(self) -> Result<(), std::io::Error> {
118 self.handle.await.map_err(|_| {
119 std::io::Error::new(
120 std::io::ErrorKind::Interrupted,
121 "Server task did not finish successfully",
122 )
123 })?
124 }
125}
126
127pub trait Transport:
132 Stream<Item = Result<RpcMessage, codec::Error>> + Sink<RpcMessage, Error = codec::Error>
133{
134}
135
136impl<T> Transport for T where
137 T: Stream<Item = Result<RpcMessage, codec::Error>> + Sink<RpcMessage, Error = codec::Error>
138{
139}
140
141impl TransportProtocol {
142 pub async fn connect(
144 self,
145 address: impl ToSocketAddrs,
146 ) -> Result<impl Transport, std::io::Error> {
147 let mut addresses = lookup_host(address)
148 .await
149 .expect("Invalid address to connect to");
150 let address = addresses
151 .next()
152 .expect("Couldn't resolve address to connect to");
153
154 let stream: futures::future::Either<_, _> = match self {
155 TransportProtocol::Udp => {
156 let socket = UdpSocket::bind(&"0.0.0.0:0").await?;
157
158 UdpFramed::new(socket, Codec)
159 .with(move |message| future::ready(Ok((message, address))))
160 .map_ok(|(message, _address)| message)
161 .left_stream()
162 }
163 TransportProtocol::Tcp => {
164 let stream = TcpStream::connect(address).await?;
165
166 Framed::new(stream, Codec).right_stream()
167 }
168 };
169
170 Ok(stream)
171 }
172
173 pub async fn make_outgoing_connection_pool(
175 self,
176 ) -> Result<Box<dyn ConnectionPool>, std::io::Error> {
177 let pool: Box<dyn ConnectionPool> = match self {
178 Self::Udp => Box::new(UdpConnectionPool::new().await?),
179 Self::Tcp => Box::new(TcpConnectionPool::new()),
180 };
181 Ok(pool)
182 }
183
184 pub fn spawn_server<S>(
186 self,
187 address: impl ToSocketAddrs + Send + 'static,
188 state: S,
189 shutdown_signal: CancellationToken,
190 join_set: &mut JoinSet<()>,
191 ) -> ServerHandle
192 where
193 S: MessageHandler + Send + 'static,
194 {
195 let handle = match self {
196 Self::Udp => join_set.spawn_task(UdpServer::run(address, state, shutdown_signal)),
197 Self::Tcp => join_set.spawn_task(TcpServer::run(address, state, shutdown_signal)),
198 };
199 ServerHandle { handle }
200 }
201}
202
203struct UdpConnectionPool {
205 transport: UdpFramed<Codec>,
206}
207
208impl UdpConnectionPool {
209 async fn new() -> Result<Self, std::io::Error> {
210 let socket = UdpSocket::bind(&"0.0.0.0:0").await?;
211 let transport = UdpFramed::new(socket, Codec);
212 Ok(Self { transport })
213 }
214}
215
216impl ConnectionPool for UdpConnectionPool {
217 fn send_message_to<'a>(
218 &'a mut self,
219 message: RpcMessage,
220 address: &'a str,
221 ) -> future::BoxFuture<'a, Result<(), codec::Error>> {
222 Box::pin(async move {
223 let address = address.parse().map_err(std::io::Error::other)?;
224 self.transport.send((message, address)).await
225 })
226 }
227}
228
229pub struct UdpServer<State> {
231 handler: State,
232 udp_sink: SharedUdpSink,
233 udp_stream: SplitStream<UdpFramed<Codec>>,
234 active_handlers: HashMap<SocketAddr, TaskHandle<()>>,
235 join_set: JoinSet<()>,
236}
237
238type SharedUdpSink = Arc<Mutex<SplitSink<UdpFramed<Codec>, (RpcMessage, SocketAddr)>>>;
240
241impl<State> UdpServer<State>
242where
243 State: MessageHandler + Send + 'static,
244{
245 pub async fn run(
247 address: impl ToSocketAddrs,
248 state: State,
249 shutdown_signal: CancellationToken,
250 ) -> Result<(), std::io::Error> {
251 let mut server = Self::bind(address, state).await?;
252
253 loop {
254 tokio::select! { biased;
255 _ = shutdown_signal.cancelled() => {
256 server.shutdown().await;
257 return Ok(());
258 }
259 result = server.udp_stream.next() => match result {
260 Some(Ok((message, peer))) => server.handle_message(message, peer),
261 Some(Err(error)) => server.handle_error(error).await?,
262 None => unreachable!("`UdpFramed` should never return `None`"),
263 },
264 }
265 }
266 }
267
268 async fn bind(address: impl ToSocketAddrs, handler: State) -> Result<Self, std::io::Error> {
271 let socket = UdpSocket::bind(address).await?;
272 let (udp_sink, udp_stream) = UdpFramed::new(socket, Codec).split();
273
274 Ok(UdpServer {
275 handler,
276 udp_sink: Arc::new(Mutex::new(udp_sink)),
277 udp_stream,
278 active_handlers: HashMap::new(),
279 join_set: JoinSet::new(),
280 })
281 }
282
283 fn handle_message(&mut self, message: RpcMessage, peer: SocketAddr) {
285 let previous_task = self.active_handlers.remove(&peer);
286 let mut state = self.handler.clone();
287 let udp_sink = self.udp_sink.clone();
288
289 let new_task = self.join_set.spawn_task(async move {
290 if let Some(reply) = state.handle_message(message).await {
291 if let Some(task) = previous_task {
292 if let Err(error) = task.await {
293 warn!("Message handler task panicked: {}", error);
294 }
295 }
296 let status = udp_sink.lock().await.send((reply, peer)).await;
297 if let Err(error) = status {
298 error!("Failed to send query response: {}", error);
299 }
300 }
301 });
302
303 self.active_handlers.insert(peer, new_task);
304
305 if self.active_handlers.len() >= REAP_TASKS_THRESHOLD {
306 self.active_handlers.retain(|_, task| task.is_running());
308 self.join_set.reap_finished_tasks();
309 }
310 }
311
312 async fn handle_error(&mut self, error: codec::Error) -> Result<(), std::io::Error> {
314 match error {
315 codec::Error::IoError(io_error) => {
316 error!("I/O error in UDP server: {io_error}");
317 self.shutdown().await;
318 Err(io_error)
319 }
320 other_error => {
321 warn!("Received an invalid message: {other_error}");
322 Ok(())
323 }
324 }
325 }
326
327 async fn shutdown(&mut self) {
329 let handlers = mem::take(&mut self.active_handlers);
330 let mut handler_results = handlers.into_values().collect::<FuturesUnordered<_>>();
331
332 while let Some(result) = handler_results.next().await {
333 if let Err(error) = result {
334 warn!("Message handler panicked: {}", error);
335 }
336 }
337
338 self.join_set.await_all_tasks().await;
339 }
340}
341
342struct TcpConnectionPool {
344 streams: HashMap<String, Framed<TcpStream, Codec>>,
345}
346
347impl TcpConnectionPool {
348 fn new() -> Self {
349 let streams = HashMap::new();
350 Self { streams }
351 }
352
353 async fn get_stream(
354 &mut self,
355 address: &str,
356 ) -> Result<&mut Framed<TcpStream, Codec>, io::Error> {
357 if !self.streams.contains_key(address) {
358 match TcpStream::connect(address).await {
359 Ok(s) => {
360 self.streams
361 .insert(address.to_string(), Framed::new(s, Codec));
362 }
363 Err(error) => {
364 error!("Failed to open connection to {}: {}", address, error);
365 return Err(error);
366 }
367 };
368 };
369 Ok(self.streams.get_mut(address).unwrap())
370 }
371}
372
373impl ConnectionPool for TcpConnectionPool {
374 fn send_message_to<'a>(
375 &'a mut self,
376 message: RpcMessage,
377 address: &'a str,
378 ) -> future::BoxFuture<'a, Result<(), codec::Error>> {
379 Box::pin(async move {
380 let stream = self.get_stream(address).await?;
381 let result = stream.send(message).await;
382 if result.is_err() {
383 self.streams.remove(address);
384 }
385 result
386 })
387 }
388}
389
390pub struct TcpServer<State> {
392 connection: Framed<TcpStream, Codec>,
393 handler: State,
394 shutdown_signal: CancellationToken,
395}
396
397impl<State> TcpServer<State>
398where
399 State: MessageHandler + Send + 'static,
400{
401 pub async fn run(
406 address: impl ToSocketAddrs,
407 handler: State,
408 shutdown_signal: CancellationToken,
409 ) -> Result<(), std::io::Error> {
410 let listener = TcpListener::bind(address).await?;
411
412 let accept_stream = stream::try_unfold(listener, |listener| async move {
413 let (socket, _) = listener.accept().await?;
414 Ok::<_, io::Error>(Some((socket, listener)))
415 });
416 let mut accept_stream = pin!(accept_stream);
417
418 let connection_shutdown_signal = shutdown_signal.child_token();
419 let mut join_set = JoinSet::new();
420 let mut reap_countdown = REAP_TASKS_THRESHOLD;
421
422 loop {
423 tokio::select! { biased;
424 _ = shutdown_signal.cancelled() => {
425 join_set.await_all_tasks().await;
426 return Ok(());
427 }
428 maybe_socket = accept_stream.next() => match maybe_socket {
429 Some(Ok(socket)) => {
430 let server = TcpServer::new_connection(
431 socket,
432 handler.clone(),
433 connection_shutdown_signal.clone(),
434 );
435 join_set.spawn_task(server.serve());
436 reap_countdown -= 1;
437 }
438 Some(Err(error)) => {
439 join_set.await_all_tasks().await;
440 return Err(error);
441 }
442 None => unreachable!(
443 "The `accept_stream` should never finish unless there's an error",
444 ),
445 },
446 }
447
448 if reap_countdown == 0 {
449 join_set.reap_finished_tasks();
450 reap_countdown = REAP_TASKS_THRESHOLD;
451 }
452 }
453 }
454
455 fn new_connection(
458 tcp_stream: TcpStream,
459 handler: State,
460 shutdown_signal: CancellationToken,
461 ) -> Self {
462 TcpServer {
463 connection: Framed::new(tcp_stream, Codec),
464 handler,
465 shutdown_signal,
466 }
467 }
468
469 async fn serve(mut self) {
471 loop {
472 tokio::select! { biased;
473 _ = self.shutdown_signal.cancelled() => {
474 let mut tcp_stream = self.connection.into_inner();
475 if let Err(error) = tcp_stream.shutdown().await {
476 let peer = tcp_stream
477 .peer_addr()
478 .map_or_else(|_| "an unknown peer".to_owned(), |address| address.to_string());
479 warn!("Failed to close connection to {peer}: {error:?}");
480 }
481 return;
482 }
483 result = self.connection.next() => match result {
484 Some(Ok(RpcMessage::SubscribeNotifications(chains))) => {
485 self.handle_subscription(chains).await;
486 return;
487 }
488 Some(Ok(RpcMessage::DownloadBlobs(blob_ids))) => {
489 self.handle_download_blobs(blob_ids).await;
490 return;
491 }
492 Some(Ok(message)) => self.handle_message(message).await,
493 Some(Err(error)) => {
494 Self::handle_error(&error);
495 return;
496 }
497 None => break,
498 },
499 }
500 }
501 }
502
503 async fn handle_message(&mut self, message: RpcMessage) {
505 if let Some(reply) = self.handler.handle_message(message).await {
506 if let Err(error) = self.connection.send(reply).await {
507 error!("Failed to send query response: {error}");
508 }
509 }
510 }
511
512 async fn handle_subscription(&mut self, chains: Vec<ChainId>) {
514 let Some(mut stream) = self.handler.handle_subscribe(chains).await else {
515 return;
516 };
517 loop {
518 tokio::select! { biased;
519 _ = self.shutdown_signal.cancelled() => break,
520 msg = stream.next() => match msg {
521 Some(notification) => {
522 if let Err(error) = self.connection.send(notification).await {
523 error!("Failed to send notification: {error}");
524 break;
525 }
526 }
527 None => break,
528 }
529 }
530 }
531 }
532
533 async fn handle_download_blobs(&mut self, blob_ids: Vec<BlobId>) {
535 let Some(mut stream) = self.handler.handle_download_blobs(blob_ids).await else {
536 return;
537 };
538 loop {
539 tokio::select! { biased;
540 _ = self.shutdown_signal.cancelled() => break,
541 msg = stream.next() => match msg {
542 Some(message) => {
543 if let Err(error) = self.connection.send(message).await {
544 error!("Failed to send blob response: {error}");
545 break;
546 }
547 }
548 None => break,
549 }
550 }
551 }
552 }
553
554 fn handle_error(error: &codec::Error) {
559 if !matches!(
560 error,
561 codec::Error::IoError(error)
562 if error.kind() == io::ErrorKind::UnexpectedEof
563 || error.kind() == io::ErrorKind::ConnectionReset
564 ) {
565 error!("Error while reading TCP stream: {error}");
566 }
567 }
568}