tonic/transport/server/
incoming.rs

1use std::{
2    net::{SocketAddr, TcpListener as StdTcpListener},
3    pin::Pin,
4    task::{Context, Poll},
5    time::Duration,
6};
7
8use socket2::TcpKeepalive;
9use tokio::net::{TcpListener, TcpStream};
10use tokio_stream::{wrappers::TcpListenerStream, Stream};
11use tracing::warn;
12
13/// Binds a socket address for a [Router](super::Router)
14///
15/// An incoming stream, usable with [Router::serve_with_incoming](super::Router::serve_with_incoming),
16/// of `AsyncRead + AsyncWrite` that communicate with clients that connect to a socket address.
17#[derive(Debug)]
18pub struct TcpIncoming {
19    inner: TcpListenerStream,
20    nodelay: Option<bool>,
21    keepalive: Option<TcpKeepalive>,
22    keepalive_time: Option<Duration>,
23    keepalive_interval: Option<Duration>,
24    keepalive_retries: Option<u32>,
25}
26
27impl TcpIncoming {
28    /// Creates an instance by binding (opening) the specified socket address.
29    ///
30    /// Returns a TcpIncoming if the socket address was successfully bound.
31    ///
32    /// # Examples
33    /// ```no_run
34    /// # use tower_service::Service;
35    /// # use http::{request::Request, response::Response};
36    /// # use tonic::{body::Body, server::NamedService, transport::{Server, server::TcpIncoming}};
37    /// # use core::convert::Infallible;
38    /// # use std::error::Error;
39    /// # fn main() { }  // Cannot have type parameters, hence instead define:
40    /// # fn run<S>(some_service: S) -> Result<(), Box<dyn Error + Send + Sync>>
41    /// # where
42    /// #   S: Service<Request<Body>, Response = Response<Body>, Error = Infallible> + NamedService + Clone + Send + Sync + 'static,
43    /// #   S::Future: Send + 'static,
44    /// # {
45    /// // Find a free port
46    /// let mut port = 1322;
47    /// let tinc = loop {
48    ///    let addr = format!("127.0.0.1:{}", port).parse().unwrap();
49    ///    match TcpIncoming::bind(addr) {
50    ///       Ok(t) => break t,
51    ///       Err(_) => port += 1
52    ///    }
53    /// };
54    /// Server::builder()
55    ///    .add_service(some_service)
56    ///    .serve_with_incoming(tinc);
57    /// # Ok(())
58    /// # }
59    pub fn bind(addr: SocketAddr) -> std::io::Result<Self> {
60        let std_listener = StdTcpListener::bind(addr)?;
61        std_listener.set_nonblocking(true)?;
62
63        Ok(TcpListener::from_std(std_listener)?.into())
64    }
65
66    /// Sets the `TCP_NODELAY` option on the accepted connection.
67    pub fn with_nodelay(self, nodelay: Option<bool>) -> Self {
68        Self { nodelay, ..self }
69    }
70
71    /// Sets the `TCP_KEEPALIVE` option on the accepted connection.
72    pub fn with_keepalive(self, keepalive_time: Option<Duration>) -> Self {
73        Self {
74            keepalive_time,
75            keepalive: make_keepalive(
76                keepalive_time,
77                self.keepalive_interval,
78                self.keepalive_retries,
79            ),
80            ..self
81        }
82    }
83
84    /// Sets the `TCP_KEEPINTVL` option on the accepted connection.
85    pub fn with_keepalive_interval(self, keepalive_interval: Option<Duration>) -> Self {
86        Self {
87            keepalive_interval,
88            keepalive: make_keepalive(
89                self.keepalive_time,
90                keepalive_interval,
91                self.keepalive_retries,
92            ),
93            ..self
94        }
95    }
96
97    /// Sets the `TCP_KEEPCNT` option on the accepted connection.
98    pub fn with_keepalive_retries(self, keepalive_retries: Option<u32>) -> Self {
99        Self {
100            keepalive_retries,
101            keepalive: make_keepalive(
102                self.keepalive_time,
103                self.keepalive_interval,
104                keepalive_retries,
105            ),
106            ..self
107        }
108    }
109
110    /// Returns the local address that this tcp incoming is bound to.
111    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
112        self.inner.as_ref().local_addr()
113    }
114}
115
116impl From<TcpListener> for TcpIncoming {
117    fn from(listener: TcpListener) -> Self {
118        Self {
119            inner: TcpListenerStream::new(listener),
120            nodelay: None,
121            keepalive: None,
122            keepalive_time: None,
123            keepalive_interval: None,
124            keepalive_retries: None,
125        }
126    }
127}
128
129impl Stream for TcpIncoming {
130    type Item = std::io::Result<TcpStream>;
131
132    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
133        let polled = Pin::new(&mut self.inner).poll_next(cx);
134
135        if let Poll::Ready(Some(Ok(stream))) = &polled {
136            set_accepted_socket_options(stream, self.nodelay, &self.keepalive);
137        }
138
139        polled
140    }
141}
142
143// Consistent with hyper-0.14, this function does not return an error.
144fn set_accepted_socket_options(
145    stream: &TcpStream,
146    nodelay: Option<bool>,
147    keepalive: &Option<TcpKeepalive>,
148) {
149    if let Some(nodelay) = nodelay {
150        if let Err(e) = stream.set_nodelay(nodelay) {
151            warn!("error trying to set TCP_NODELAY: {e}");
152        }
153    }
154
155    if let Some(keepalive) = keepalive {
156        let sock_ref = socket2::SockRef::from(&stream);
157        if let Err(e) = sock_ref.set_tcp_keepalive(keepalive) {
158            warn!("error trying to set TCP_KEEPALIVE: {e}");
159        }
160    }
161}
162
163fn make_keepalive(
164    keepalive_time: Option<Duration>,
165    keepalive_interval: Option<Duration>,
166    keepalive_retries: Option<u32>,
167) -> Option<TcpKeepalive> {
168    let mut dirty = false;
169    let mut keepalive = TcpKeepalive::new();
170    if let Some(t) = keepalive_time {
171        keepalive = keepalive.with_time(t);
172        dirty = true;
173    }
174
175    #[cfg(
176        // See https://docs.rs/socket2/0.5.8/src/socket2/lib.rs.html#511-525
177        any(
178            target_os = "android",
179            target_os = "dragonfly",
180            target_os = "freebsd",
181            target_os = "fuchsia",
182            target_os = "illumos",
183            target_os = "ios",
184            target_os = "visionos",
185            target_os = "linux",
186            target_os = "macos",
187            target_os = "netbsd",
188            target_os = "tvos",
189            target_os = "watchos",
190            target_os = "windows",
191        )
192    )]
193    if let Some(t) = keepalive_interval {
194        keepalive = keepalive.with_interval(t);
195        dirty = true;
196    }
197
198    #[cfg(
199        // See https://docs.rs/socket2/0.5.8/src/socket2/lib.rs.html#557-570
200        any(
201            target_os = "android",
202            target_os = "dragonfly",
203            target_os = "freebsd",
204            target_os = "fuchsia",
205            target_os = "illumos",
206            target_os = "ios",
207            target_os = "visionos",
208            target_os = "linux",
209            target_os = "macos",
210            target_os = "netbsd",
211            target_os = "tvos",
212            target_os = "watchos",
213        )
214    )]
215    if let Some(r) = keepalive_retries {
216        keepalive = keepalive.with_retries(r);
217        dirty = true;
218    }
219
220    // avoid clippy errors for targets that do not use these fields.
221    let _ = keepalive_retries;
222    let _ = keepalive_interval;
223
224    dirty.then_some(keepalive)
225}
226
227#[cfg(test)]
228mod tests {
229    use crate::transport::server::TcpIncoming;
230    #[tokio::test]
231    async fn one_tcpincoming_at_a_time() {
232        let addr = "127.0.0.1:1322".parse().unwrap();
233        {
234            let _t1 = TcpIncoming::bind(addr).unwrap();
235            let _t2 = TcpIncoming::bind(addr).unwrap_err();
236        }
237        let _t3 = TcpIncoming::bind(addr).unwrap();
238    }
239}