1use crate::codec::EncodeBody;
2use crate::codec::{CompressionEncoding, EnabledCompressionEncodings};
3use crate::metadata::GRPC_CONTENT_TYPE;
4use crate::{
5 body::Body,
6 client::GrpcService,
7 codec::{Codec, Decoder, Streaming},
8 request::SanitizeHeaders,
9 Code, Request, Response, Status,
10};
11use http::{
12 header::{HeaderValue, CONTENT_TYPE, TE},
13 uri::{PathAndQuery, Uri},
14};
15use http_body::Body as HttpBody;
16use std::{fmt, future, pin::pin};
17use tokio_stream::{Stream, StreamExt};
18
19pub struct Grpc<T> {
33 inner: T,
34 config: GrpcConfig,
35}
36
37struct GrpcConfig {
38 origin: Uri,
39 accept_compression_encodings: EnabledCompressionEncodings,
41 send_compression_encodings: Option<CompressionEncoding>,
43 max_decoding_message_size: Option<usize>,
45 max_encoding_message_size: Option<usize>,
47}
48
49impl<T> Grpc<T> {
50 pub fn new(inner: T) -> Self {
52 Self::with_origin(inner, Uri::default())
53 }
54
55 pub fn with_origin(inner: T, origin: Uri) -> Self {
60 Self {
61 inner,
62 config: GrpcConfig {
63 origin,
64 send_compression_encodings: None,
65 accept_compression_encodings: EnabledCompressionEncodings::default(),
66 max_decoding_message_size: None,
67 max_encoding_message_size: None,
68 },
69 }
70 }
71
72 pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
99 self.config.send_compression_encodings = Some(encoding);
100 self
101 }
102
103 pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
130 self.config.accept_compression_encodings.enable(encoding);
131 self
132 }
133
134 pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
160 self.config.max_decoding_message_size = Some(limit);
161 self
162 }
163
164 pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
190 self.config.max_encoding_message_size = Some(limit);
191 self
192 }
193
194 pub async fn ready(&mut self) -> Result<(), T::Error>
200 where
201 T: GrpcService<Body>,
202 {
203 future::poll_fn(|cx| self.inner.poll_ready(cx)).await
204 }
205
206 pub async fn unary<M1, M2, C>(
208 &mut self,
209 request: Request<M1>,
210 path: PathAndQuery,
211 codec: C,
212 ) -> Result<Response<M2>, Status>
213 where
214 T: GrpcService<Body>,
215 T::ResponseBody: HttpBody + Send + 'static,
216 <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,
217 C: Codec<Encode = M1, Decode = M2>,
218 M1: Send + Sync + 'static,
219 M2: Send + Sync + 'static,
220 {
221 let request = request.map(|m| tokio_stream::once(m));
222 self.client_streaming(request, path, codec).await
223 }
224
225 pub async fn client_streaming<S, M1, M2, C>(
227 &mut self,
228 request: Request<S>,
229 path: PathAndQuery,
230 codec: C,
231 ) -> Result<Response<M2>, Status>
232 where
233 T: GrpcService<Body>,
234 T::ResponseBody: HttpBody + Send + 'static,
235 <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,
236 S: Stream<Item = M1> + Send + 'static,
237 C: Codec<Encode = M1, Decode = M2>,
238 M1: Send + Sync + 'static,
239 M2: Send + Sync + 'static,
240 {
241 let (mut parts, body, extensions) =
242 self.streaming(request, path, codec).await?.into_parts();
243
244 let mut body = pin!(body);
245
246 let message = body
247 .try_next()
248 .await
249 .map_err(|mut status| {
250 status.metadata_mut().merge(parts.clone());
251 status
252 })?
253 .ok_or_else(|| Status::internal("Missing response message."))?;
254
255 if let Some(trailers) = body.trailers().await? {
256 parts.merge(trailers);
257 }
258
259 Ok(Response::from_parts(parts, message, extensions))
260 }
261
262 pub async fn server_streaming<M1, M2, C>(
264 &mut self,
265 request: Request<M1>,
266 path: PathAndQuery,
267 codec: C,
268 ) -> Result<Response<Streaming<M2>>, Status>
269 where
270 T: GrpcService<Body>,
271 T::ResponseBody: HttpBody + Send + 'static,
272 <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,
273 C: Codec<Encode = M1, Decode = M2>,
274 M1: Send + Sync + 'static,
275 M2: Send + Sync + 'static,
276 {
277 let request = request.map(|m| tokio_stream::once(m));
278 self.streaming(request, path, codec).await
279 }
280
281 pub async fn streaming<S, M1, M2, C>(
283 &mut self,
284 request: Request<S>,
285 path: PathAndQuery,
286 mut codec: C,
287 ) -> Result<Response<Streaming<M2>>, Status>
288 where
289 T: GrpcService<Body>,
290 T::ResponseBody: HttpBody + Send + 'static,
291 <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,
292 S: Stream<Item = M1> + Send + 'static,
293 C: Codec<Encode = M1, Decode = M2>,
294 M1: Send + Sync + 'static,
295 M2: Send + Sync + 'static,
296 {
297 let request = request
298 .map(|s| {
299 EncodeBody::new_client(
300 codec.encoder(),
301 s.map(Ok),
302 self.config.send_compression_encodings,
303 self.config.max_encoding_message_size,
304 )
305 })
306 .map(Body::new);
307
308 let request = self.config.prepare_request(request, path);
309
310 let response = self
311 .inner
312 .call(request)
313 .await
314 .map_err(Status::from_error_generic)?;
315
316 let decoder = codec.decoder();
317
318 self.create_response(decoder, response)
319 }
320
321 fn create_response<M2>(
324 &self,
325 decoder: impl Decoder<Item = M2, Error = Status> + Send + 'static,
326 response: http::Response<T::ResponseBody>,
327 ) -> Result<Response<Streaming<M2>>, Status>
328 where
329 T: GrpcService<Body>,
330 T::ResponseBody: HttpBody + Send + 'static,
331 <T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,
332 {
333 let encoding = CompressionEncoding::from_encoding_header(
334 response.headers(),
335 self.config.accept_compression_encodings,
336 )?;
337
338 let status_code = response.status();
339 let trailers_only_status = Status::from_header_map(response.headers());
340
341 let expect_additional_trailers = if let Some(status) = trailers_only_status {
344 if status.code() != Code::Ok {
345 return Err(status);
346 }
347
348 false
349 } else {
350 true
351 };
352
353 let response = response.map(|body| {
354 if expect_additional_trailers {
355 Streaming::new_response(
356 decoder,
357 body,
358 status_code,
359 encoding,
360 self.config.max_decoding_message_size,
361 )
362 } else {
363 Streaming::new_empty(decoder, body)
364 }
365 });
366
367 Ok(Response::from_http(response))
368 }
369}
370
371impl GrpcConfig {
372 fn prepare_request(&self, request: Request<Body>, path: PathAndQuery) -> http::Request<Body> {
373 let mut parts = self.origin.clone().into_parts();
374
375 match &parts.path_and_query {
376 Some(pnq) if pnq != "/" => {
377 parts.path_and_query = Some(
378 format!("{}{}", pnq.path(), path)
379 .parse()
380 .expect("must form valid path_and_query"),
381 )
382 }
383 _ => {
384 parts.path_and_query = Some(path);
385 }
386 }
387
388 let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri");
389
390 let mut request = request.into_http(
391 uri,
392 http::Method::POST,
393 http::Version::HTTP_2,
394 SanitizeHeaders::Yes,
395 );
396
397 request
399 .headers_mut()
400 .insert(TE, HeaderValue::from_static("trailers"));
401
402 request
404 .headers_mut()
405 .insert(CONTENT_TYPE, GRPC_CONTENT_TYPE);
406
407 #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
408 if let Some(encoding) = self.send_compression_encodings {
409 request.headers_mut().insert(
410 crate::codec::compression::ENCODING_HEADER,
411 encoding.into_header_value(),
412 );
413 }
414
415 if let Some(header_value) = self
416 .accept_compression_encodings
417 .into_accept_encoding_header_value()
418 {
419 request.headers_mut().insert(
420 crate::codec::compression::ACCEPT_ENCODING_HEADER,
421 header_value,
422 );
423 }
424
425 request
426 }
427}
428
429impl<T: Clone> Clone for Grpc<T> {
430 fn clone(&self) -> Self {
431 Self {
432 inner: self.inner.clone(),
433 config: GrpcConfig {
434 origin: self.config.origin.clone(),
435 send_compression_encodings: self.config.send_compression_encodings,
436 accept_compression_encodings: self.config.accept_compression_encodings,
437 max_encoding_message_size: self.config.max_encoding_message_size,
438 max_decoding_message_size: self.config.max_decoding_message_size,
439 },
440 }
441 }
442}
443
444impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
445 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
446 f.debug_struct("Grpc")
447 .field("inner", &self.inner)
448 .field("origin", &self.config.origin)
449 .field(
450 "compression_encoding",
451 &self.config.send_compression_encodings,
452 )
453 .field(
454 "accept_compression_encodings",
455 &self.config.accept_compression_encodings,
456 )
457 .field(
458 "max_decoding_message_size",
459 &self.config.max_decoding_message_size,
460 )
461 .field(
462 "max_encoding_message_size",
463 &self.config.max_encoding_message_size,
464 )
465 .finish()
466 }
467}