tonic/transport/server/
incoming.rs1use std::{
2 net::{SocketAddr, TcpListener as StdTcpListener},
3 pin::Pin,
4 task::{Context, Poll},
5 time::Duration,
6};
7
8use socket2::TcpKeepalive;
9use tokio::net::{TcpListener, TcpStream};
10use tokio_stream::{wrappers::TcpListenerStream, Stream};
11use tracing::warn;
12
13#[derive(Debug)]
18pub struct TcpIncoming {
19 inner: TcpListenerStream,
20 nodelay: Option<bool>,
21 keepalive: Option<TcpKeepalive>,
22 keepalive_time: Option<Duration>,
23 keepalive_interval: Option<Duration>,
24 keepalive_retries: Option<u32>,
25}
26
27impl TcpIncoming {
28 pub fn bind(addr: SocketAddr) -> std::io::Result<Self> {
60 let std_listener = StdTcpListener::bind(addr)?;
61 std_listener.set_nonblocking(true)?;
62
63 Ok(TcpListener::from_std(std_listener)?.into())
64 }
65
66 pub fn with_nodelay(self, nodelay: Option<bool>) -> Self {
68 Self { nodelay, ..self }
69 }
70
71 pub fn with_keepalive(self, keepalive_time: Option<Duration>) -> Self {
73 Self {
74 keepalive_time,
75 keepalive: make_keepalive(
76 keepalive_time,
77 self.keepalive_interval,
78 self.keepalive_retries,
79 ),
80 ..self
81 }
82 }
83
84 pub fn with_keepalive_interval(self, keepalive_interval: Option<Duration>) -> Self {
86 Self {
87 keepalive_interval,
88 keepalive: make_keepalive(
89 self.keepalive_time,
90 keepalive_interval,
91 self.keepalive_retries,
92 ),
93 ..self
94 }
95 }
96
97 pub fn with_keepalive_retries(self, keepalive_retries: Option<u32>) -> Self {
99 Self {
100 keepalive_retries,
101 keepalive: make_keepalive(
102 self.keepalive_time,
103 self.keepalive_interval,
104 keepalive_retries,
105 ),
106 ..self
107 }
108 }
109
110 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
112 self.inner.as_ref().local_addr()
113 }
114}
115
116impl From<TcpListener> for TcpIncoming {
117 fn from(listener: TcpListener) -> Self {
118 Self {
119 inner: TcpListenerStream::new(listener),
120 nodelay: None,
121 keepalive: None,
122 keepalive_time: None,
123 keepalive_interval: None,
124 keepalive_retries: None,
125 }
126 }
127}
128
129impl Stream for TcpIncoming {
130 type Item = std::io::Result<TcpStream>;
131
132 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
133 let polled = Pin::new(&mut self.inner).poll_next(cx);
134
135 if let Poll::Ready(Some(Ok(stream))) = &polled {
136 set_accepted_socket_options(stream, self.nodelay, &self.keepalive);
137 }
138
139 polled
140 }
141}
142
143fn set_accepted_socket_options(
145 stream: &TcpStream,
146 nodelay: Option<bool>,
147 keepalive: &Option<TcpKeepalive>,
148) {
149 if let Some(nodelay) = nodelay {
150 if let Err(e) = stream.set_nodelay(nodelay) {
151 warn!("error trying to set TCP_NODELAY: {e}");
152 }
153 }
154
155 if let Some(keepalive) = keepalive {
156 let sock_ref = socket2::SockRef::from(&stream);
157 if let Err(e) = sock_ref.set_tcp_keepalive(keepalive) {
158 warn!("error trying to set TCP_KEEPALIVE: {e}");
159 }
160 }
161}
162
163fn make_keepalive(
164 keepalive_time: Option<Duration>,
165 keepalive_interval: Option<Duration>,
166 keepalive_retries: Option<u32>,
167) -> Option<TcpKeepalive> {
168 let mut dirty = false;
169 let mut keepalive = TcpKeepalive::new();
170 if let Some(t) = keepalive_time {
171 keepalive = keepalive.with_time(t);
172 dirty = true;
173 }
174
175 #[cfg(
176 any(
178 target_os = "android",
179 target_os = "dragonfly",
180 target_os = "freebsd",
181 target_os = "fuchsia",
182 target_os = "illumos",
183 target_os = "ios",
184 target_os = "visionos",
185 target_os = "linux",
186 target_os = "macos",
187 target_os = "netbsd",
188 target_os = "tvos",
189 target_os = "watchos",
190 target_os = "windows",
191 )
192 )]
193 if let Some(t) = keepalive_interval {
194 keepalive = keepalive.with_interval(t);
195 dirty = true;
196 }
197
198 #[cfg(
199 any(
201 target_os = "android",
202 target_os = "dragonfly",
203 target_os = "freebsd",
204 target_os = "fuchsia",
205 target_os = "illumos",
206 target_os = "ios",
207 target_os = "visionos",
208 target_os = "linux",
209 target_os = "macos",
210 target_os = "netbsd",
211 target_os = "tvos",
212 target_os = "watchos",
213 )
214 )]
215 if let Some(r) = keepalive_retries {
216 keepalive = keepalive.with_retries(r);
217 dirty = true;
218 }
219
220 let _ = keepalive_retries;
222 let _ = keepalive_interval;
223
224 dirty.then_some(keepalive)
225}
226
227#[cfg(test)]
228mod tests {
229 use crate::transport::server::TcpIncoming;
230 #[tokio::test]
231 async fn one_tcpincoming_at_a_time() {
232 let addr = "127.0.0.1:1322".parse().unwrap();
233 {
234 let _t1 = TcpIncoming::bind(addr).unwrap();
235 let _t2 = TcpIncoming::bind(addr).unwrap_err();
236 }
237 let _t3 = TcpIncoming::bind(addr).unwrap();
238 }
239}