1use super::compression::{decompress, CompressionEncoding, CompressionSettings};
2use super::{BufferSettings, DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE};
3use crate::{body::Body, metadata::MetadataMap, Code, Status};
4use bytes::{Buf, BufMut, BytesMut};
5use http::{HeaderMap, StatusCode};
6use http_body::Body as HttpBody;
7use http_body_util::BodyExt;
8use std::{
9 fmt, future,
10 pin::Pin,
11 task::ready,
12 task::{Context, Poll},
13};
14use sync_wrapper::SyncWrapper;
15use tokio_stream::Stream;
16use tracing::{debug, trace};
17
18pub struct Streaming<T> {
23 decoder: SyncWrapper<Box<dyn Decoder<Item = T, Error = Status> + Send + 'static>>,
24 inner: StreamingInner,
25}
26
27struct StreamingInner {
28 body: SyncWrapper<Body>,
29 state: State,
30 direction: Direction,
31 buf: BytesMut,
32 trailers: Option<HeaderMap>,
33 decompress_buf: BytesMut,
34 encoding: Option<CompressionEncoding>,
35 max_message_size: Option<usize>,
36}
37
38impl<T> Unpin for Streaming<T> {}
39
40#[derive(Debug, Clone)]
41enum State {
42 ReadHeader,
43 ReadBody {
44 compression: Option<CompressionEncoding>,
45 len: usize,
46 },
47 Error(Option<Status>),
48}
49
50#[derive(Debug, PartialEq, Eq)]
51enum Direction {
52 Request,
53 Response(StatusCode),
54 EmptyResponse,
55}
56
57impl<T> Streaming<T> {
58 pub fn new_response<B, D>(
61 decoder: D,
62 body: B,
63 status_code: StatusCode,
64 encoding: Option<CompressionEncoding>,
65 max_message_size: Option<usize>,
66 ) -> Self
67 where
68 B: HttpBody + Send + 'static,
69 B::Error: Into<crate::BoxError>,
70 D: Decoder<Item = T, Error = Status> + Send + 'static,
71 {
72 Self::new(
73 decoder,
74 body,
75 Direction::Response(status_code),
76 encoding,
77 max_message_size,
78 )
79 }
80
81 pub fn new_empty<B, D>(decoder: D, body: B) -> Self
83 where
84 B: HttpBody + Send + 'static,
85 B::Error: Into<crate::BoxError>,
86 D: Decoder<Item = T, Error = Status> + Send + 'static,
87 {
88 Self::new(decoder, body, Direction::EmptyResponse, None, None)
89 }
90
91 pub fn new_request<B, D>(
94 decoder: D,
95 body: B,
96 encoding: Option<CompressionEncoding>,
97 max_message_size: Option<usize>,
98 ) -> Self
99 where
100 B: HttpBody + Send + 'static,
101 B::Error: Into<crate::BoxError>,
102 D: Decoder<Item = T, Error = Status> + Send + 'static,
103 {
104 Self::new(
105 decoder,
106 body,
107 Direction::Request,
108 encoding,
109 max_message_size,
110 )
111 }
112
113 fn new<B, D>(
114 decoder: D,
115 body: B,
116 direction: Direction,
117 encoding: Option<CompressionEncoding>,
118 max_message_size: Option<usize>,
119 ) -> Self
120 where
121 B: HttpBody + Send + 'static,
122 B::Error: Into<crate::BoxError>,
123 D: Decoder<Item = T, Error = Status> + Send + 'static,
124 {
125 let buffer_size = decoder.buffer_settings().buffer_size;
126 Self {
127 decoder: SyncWrapper::new(Box::new(decoder)),
128 inner: StreamingInner {
129 body: SyncWrapper::new(Body::new(
130 body.map_frame(|frame| {
131 frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()))
132 })
133 .map_err(|err| Status::map_error(err.into())),
134 )),
135 state: State::ReadHeader,
136 direction,
137 buf: BytesMut::with_capacity(buffer_size),
138 trailers: None,
139 decompress_buf: BytesMut::new(),
140 encoding,
141 max_message_size,
142 },
143 }
144 }
145}
146
147impl StreamingInner {
148 fn decode_chunk(
149 &mut self,
150 buffer_settings: BufferSettings,
151 ) -> Result<Option<DecodeBuf<'_>>, Status> {
152 if let State::ReadHeader = self.state {
153 if self.buf.remaining() < HEADER_SIZE {
154 return Ok(None);
155 }
156
157 let compression_encoding = match self.buf.get_u8() {
158 0 => None,
159 1 => {
160 {
161 if self.encoding.is_some() {
162 self.encoding
163 } else {
164 return Err(Status::internal( "protocol error: received message with compressed-flag but no grpc-encoding was specified"));
169 }
170 }
171 }
172 f => {
173 trace!("unexpected compression flag");
174 let message = if let Direction::Response(status) = self.direction {
175 format!(
176 "protocol error: received message with invalid compression flag: {f} (valid flags are 0 and 1) while receiving response with status: {status}"
177 )
178 } else {
179 format!("protocol error: received message with invalid compression flag: {f} (valid flags are 0 and 1), while sending request")
180 };
181 return Err(Status::internal(message));
182 }
183 };
184
185 let len = self.buf.get_u32() as usize;
186 let limit = self
187 .max_message_size
188 .unwrap_or(DEFAULT_MAX_RECV_MESSAGE_SIZE);
189 if len > limit {
190 return Err(Status::out_of_range(
191 format!(
192 "Error, decoded message length too large: found {len} bytes, the limit is: {limit} bytes"
193 ),
194 ));
195 }
196
197 self.buf.reserve(len);
198
199 self.state = State::ReadBody {
200 compression: compression_encoding,
201 len,
202 }
203 }
204
205 if let State::ReadBody { len, compression } = self.state {
206 if self.buf.remaining() < len || self.buf.len() < len {
209 return Ok(None);
210 }
211
212 let decode_buf = if let Some(encoding) = compression {
213 self.decompress_buf.clear();
214
215 if let Err(err) = decompress(
216 CompressionSettings {
217 encoding,
218 buffer_growth_interval: buffer_settings.buffer_size,
219 },
220 &mut self.buf,
221 &mut self.decompress_buf,
222 len,
223 ) {
224 let message = if let Direction::Response(status) = self.direction {
225 format!(
226 "Error decompressing: {err}, while receiving response with status: {status}"
227 )
228 } else {
229 format!("Error decompressing: {err}, while sending request")
230 };
231 return Err(Status::internal(message));
232 }
233 let decompressed_len = self.decompress_buf.len();
234 DecodeBuf::new(&mut self.decompress_buf, decompressed_len)
235 } else {
236 DecodeBuf::new(&mut self.buf, len)
237 };
238
239 return Ok(Some(decode_buf));
240 }
241
242 Ok(None)
243 }
244
245 fn poll_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<()>, Status>> {
247 let frame = match ready!(Pin::new(self.body.get_mut()).poll_frame(cx)) {
248 Some(Ok(frame)) => frame,
249 Some(Err(status)) => {
250 if self.direction == Direction::Request && status.code() == Code::Cancelled {
251 return Poll::Ready(Ok(None));
252 }
253
254 let _ = std::mem::replace(&mut self.state, State::Error(Some(status.clone())));
255 debug!("decoder inner stream error: {:?}", status);
256 return Poll::Ready(Err(status));
257 }
258 None => {
259 return Poll::Ready(if self.buf.has_remaining() {
261 trace!("unexpected EOF decoding stream, state: {:?}", self.state);
262 Err(Status::internal("Unexpected EOF decoding stream."))
263 } else {
264 Ok(None)
265 });
266 }
267 };
268
269 Poll::Ready(if frame.is_data() {
270 self.buf.put(frame.into_data().unwrap());
271 Ok(Some(()))
272 } else if frame.is_trailers() {
273 if let Some(trailers) = &mut self.trailers {
274 trailers.extend(frame.into_trailers().unwrap());
275 } else {
276 self.trailers = Some(frame.into_trailers().unwrap());
277 }
278
279 Ok(None)
280 } else {
281 panic!("unexpected frame: {frame:?}");
282 })
283 }
284
285 fn response(&mut self) -> Result<(), Status> {
286 if let Direction::Response(status) = self.direction {
287 if let Err(Some(e)) = crate::status::infer_grpc_status(self.trailers.as_ref(), status) {
288 self.trailers.take();
291 return Err(e);
292 }
293 }
294 Ok(())
295 }
296}
297
298impl<T> Streaming<T> {
299 pub async fn message(&mut self) -> Result<Option<T>, Status> {
328 match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await {
329 Some(Ok(m)) => Ok(Some(m)),
330 Some(Err(e)) => Err(e),
331 None => Ok(None),
332 }
333 }
334
335 pub async fn trailers(&mut self) -> Result<Option<MetadataMap>, Status> {
351 if let Some(trailers) = self.inner.trailers.take() {
354 return Ok(Some(MetadataMap::from_headers(trailers)));
355 }
356
357 while self.message().await?.is_some() {}
359
360 if let Some(trailers) = self.inner.trailers.take() {
363 return Ok(Some(MetadataMap::from_headers(trailers)));
364 }
365
366 Ok(None)
368 }
369
370 fn decode_chunk(&mut self) -> Result<Option<T>, Status> {
371 match self
372 .inner
373 .decode_chunk(self.decoder.get_mut().buffer_settings())?
374 {
375 Some(mut decode_buf) => match self.decoder.get_mut().decode(&mut decode_buf)? {
376 Some(msg) => {
377 self.inner.state = State::ReadHeader;
378 Ok(Some(msg))
379 }
380 None => Ok(None),
381 },
382 None => Ok(None),
383 }
384 }
385}
386
387impl<T> Stream for Streaming<T> {
388 type Item = Result<T, Status>;
389
390 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
391 loop {
392 if let State::Error(status) = &mut self.inner.state {
396 return Poll::Ready(status.take().map(Err));
397 }
398
399 if let Some(item) = self.decode_chunk()? {
400 return Poll::Ready(Some(Ok(item)));
401 }
402
403 if ready!(self.inner.poll_frame(cx))?.is_none() {
404 match self.inner.response() {
405 Ok(()) => return Poll::Ready(None),
406 Err(err) => self.inner.state = State::Error(Some(err)),
407 }
408 }
409 }
410 }
411}
412
413impl<T> fmt::Debug for Streaming<T> {
414 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
415 f.debug_struct("Streaming").finish()
416 }
417}
418
419#[cfg(test)]
420static_assertions::assert_impl_all!(Streaming<()>: Send, Sync);