1use super::compression::{
2 compress, CompressionEncoding, CompressionSettings, SingleMessageCompressionOverride,
3};
4use super::{BufferSettings, EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE};
5use crate::Status;
6use bytes::{BufMut, Bytes, BytesMut};
7use http::HeaderMap;
8use http_body::{Body, Frame};
9use pin_project::pin_project;
10use std::{
11 pin::Pin,
12 task::{ready, Context, Poll},
13};
14use tokio_stream::{adapters::Fuse, Stream, StreamExt};
15
16#[pin_project(project = EncodedBytesProj)]
22#[derive(Debug)]
23struct EncodedBytes<T, U> {
24 #[pin]
25 source: Fuse<U>,
26 encoder: T,
27 compression_encoding: Option<CompressionEncoding>,
28 max_message_size: Option<usize>,
29 buf: BytesMut,
30 uncompression_buf: BytesMut,
31 error: Option<Status>,
32}
33
34impl<T: Encoder, U: Stream> EncodedBytes<T, U> {
35 fn new(
36 encoder: T,
37 source: U,
38 compression_encoding: Option<CompressionEncoding>,
39 compression_override: SingleMessageCompressionOverride,
40 max_message_size: Option<usize>,
41 ) -> Self {
42 let buffer_settings = encoder.buffer_settings();
43 let buf = BytesMut::with_capacity(buffer_settings.buffer_size);
44
45 let compression_encoding =
46 if compression_override == SingleMessageCompressionOverride::Disable {
47 None
48 } else {
49 compression_encoding
50 };
51
52 let uncompression_buf = if compression_encoding.is_some() {
53 BytesMut::with_capacity(buffer_settings.buffer_size)
54 } else {
55 BytesMut::new()
56 };
57
58 Self {
59 source: source.fuse(),
60 encoder,
61 compression_encoding,
62 max_message_size,
63 buf,
64 uncompression_buf,
65 error: None,
66 }
67 }
68}
69
70impl<T, U> Stream for EncodedBytes<T, U>
71where
72 T: Encoder<Error = Status>,
73 U: Stream<Item = Result<T::Item, Status>>,
74{
75 type Item = Result<Bytes, Status>;
76
77 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
78 let EncodedBytesProj {
79 mut source,
80 encoder,
81 compression_encoding,
82 max_message_size,
83 buf,
84 uncompression_buf,
85 error,
86 } = self.project();
87 let buffer_settings = encoder.buffer_settings();
88
89 if let Some(status) = error.take() {
90 return Poll::Ready(Some(Err(status)));
91 }
92
93 loop {
94 match source.as_mut().poll_next(cx) {
95 Poll::Pending if buf.is_empty() => {
96 return Poll::Pending;
97 }
98 Poll::Ready(None) if buf.is_empty() => {
99 return Poll::Ready(None);
100 }
101 Poll::Pending | Poll::Ready(None) => {
102 return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
103 }
104 Poll::Ready(Some(Ok(item))) => {
105 if let Err(status) = encode_item(
106 encoder,
107 buf,
108 uncompression_buf,
109 *compression_encoding,
110 *max_message_size,
111 buffer_settings,
112 item,
113 ) {
114 return Poll::Ready(Some(Err(status)));
115 }
116
117 if buf.len() >= buffer_settings.yield_threshold {
118 return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
119 }
120 }
121 Poll::Ready(Some(Err(status))) => {
122 if buf.is_empty() {
123 return Poll::Ready(Some(Err(status)));
124 }
125 *error = Some(status);
126 return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
127 }
128 }
129 }
130 }
131}
132
133fn encode_item<T>(
134 encoder: &mut T,
135 buf: &mut BytesMut,
136 uncompression_buf: &mut BytesMut,
137 compression_encoding: Option<CompressionEncoding>,
138 max_message_size: Option<usize>,
139 buffer_settings: BufferSettings,
140 item: T::Item,
141) -> Result<(), Status>
142where
143 T: Encoder<Error = Status>,
144{
145 let offset = buf.len();
146
147 buf.reserve(HEADER_SIZE);
148 unsafe {
149 buf.advance_mut(HEADER_SIZE);
150 }
151
152 if let Some(encoding) = compression_encoding {
153 uncompression_buf.clear();
154
155 encoder
156 .encode(item, &mut EncodeBuf::new(uncompression_buf))
157 .map_err(|err| Status::internal(format!("Error encoding: {err}")))?;
158
159 let uncompressed_len = uncompression_buf.len();
160
161 compress(
162 CompressionSettings {
163 encoding,
164 buffer_growth_interval: buffer_settings.buffer_size,
165 },
166 uncompression_buf,
167 buf,
168 uncompressed_len,
169 )
170 .map_err(|err| Status::internal(format!("Error compressing: {err}")))?;
171 } else {
172 encoder
173 .encode(item, &mut EncodeBuf::new(buf))
174 .map_err(|err| Status::internal(format!("Error encoding: {err}")))?;
175 }
176
177 finish_encoding(compression_encoding, max_message_size, &mut buf[offset..])
179}
180
181fn finish_encoding(
182 compression_encoding: Option<CompressionEncoding>,
183 max_message_size: Option<usize>,
184 buf: &mut [u8],
185) -> Result<(), Status> {
186 let len = buf.len() - HEADER_SIZE;
187 let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE);
188 if len > limit {
189 return Err(Status::out_of_range(format!(
190 "Error, encoded message length too large: found {len} bytes, the limit is: {limit} bytes"
191 )));
192 }
193
194 if len > u32::MAX as usize {
195 return Err(Status::resource_exhausted(format!(
196 "Cannot return body with more than 4GB of data but got {len} bytes"
197 )));
198 }
199 {
200 let mut buf = &mut buf[..HEADER_SIZE];
201 buf.put_u8(compression_encoding.is_some() as u8);
202 buf.put_u32(len as u32);
203 }
204
205 Ok(())
206}
207
208#[derive(Debug)]
209enum Role {
210 Client,
211 Server,
212}
213
214#[pin_project]
216#[derive(Debug)]
217pub struct EncodeBody<T, U> {
218 #[pin]
219 inner: EncodedBytes<T, U>,
220 state: EncodeState,
221}
222
223#[derive(Debug)]
224struct EncodeState {
225 error: Option<Status>,
226 role: Role,
227 is_end_stream: bool,
228}
229
230impl<T: Encoder, U: Stream> EncodeBody<T, U> {
231 pub fn new_client(
234 encoder: T,
235 source: U,
236 compression_encoding: Option<CompressionEncoding>,
237 max_message_size: Option<usize>,
238 ) -> Self {
239 Self {
240 inner: EncodedBytes::new(
241 encoder,
242 source,
243 compression_encoding,
244 SingleMessageCompressionOverride::default(),
245 max_message_size,
246 ),
247 state: EncodeState {
248 error: None,
249 role: Role::Client,
250 is_end_stream: false,
251 },
252 }
253 }
254
255 pub fn new_server(
258 encoder: T,
259 source: U,
260 compression_encoding: Option<CompressionEncoding>,
261 compression_override: SingleMessageCompressionOverride,
262 max_message_size: Option<usize>,
263 ) -> Self {
264 Self {
265 inner: EncodedBytes::new(
266 encoder,
267 source,
268 compression_encoding,
269 compression_override,
270 max_message_size,
271 ),
272 state: EncodeState {
273 error: None,
274 role: Role::Server,
275 is_end_stream: false,
276 },
277 }
278 }
279}
280
281impl EncodeState {
282 fn trailers(&mut self) -> Option<Result<HeaderMap, Status>> {
283 match self.role {
284 Role::Client => None,
285 Role::Server => {
286 if self.is_end_stream {
287 return None;
288 }
289
290 self.is_end_stream = true;
291 let status = if let Some(status) = self.error.take() {
292 status
293 } else {
294 Status::ok("")
295 };
296 Some(status.to_header_map())
297 }
298 }
299 }
300}
301
302impl<T, U> Body for EncodeBody<T, U>
303where
304 T: Encoder<Error = Status>,
305 U: Stream<Item = Result<T::Item, Status>>,
306{
307 type Data = Bytes;
308 type Error = Status;
309
310 fn is_end_stream(&self) -> bool {
311 self.state.is_end_stream
312 }
313
314 fn poll_frame(
315 self: Pin<&mut Self>,
316 cx: &mut Context<'_>,
317 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
318 let self_proj = self.project();
319 match ready!(self_proj.inner.poll_next(cx)) {
320 Some(Ok(d)) => Some(Ok(Frame::data(d))).into(),
321 Some(Err(status)) => match self_proj.state.role {
322 Role::Client => Some(Err(status)).into(),
323 Role::Server => {
324 self_proj.state.is_end_stream = true;
325 Some(Ok(Frame::trailers(status.to_header_map()?))).into()
326 }
327 },
328 None => self_proj
329 .state
330 .trailers()
331 .map(|t| t.map(Frame::trailers))
332 .into(),
333 }
334 }
335}