tonic_reflection/server/
v1.rs

1use std::{fmt, sync::Arc};
2
3use tokio::sync::mpsc;
4use tokio_stream::{Stream, StreamExt};
5use tonic::{Request, Response, Status, Streaming};
6
7use super::ReflectionServiceState;
8use crate::pb::v1::server_reflection_request::MessageRequest;
9use crate::pb::v1::server_reflection_response::MessageResponse;
10pub use crate::pb::v1::server_reflection_server::{ServerReflection, ServerReflectionServer};
11use crate::pb::v1::{
12    ExtensionNumberResponse, FileDescriptorResponse, ListServiceResponse, ServerReflectionRequest,
13    ServerReflectionResponse, ServiceResponse,
14};
15
16/// An implementation for `ServerReflection`.
17#[derive(Debug)]
18pub struct ReflectionService {
19    state: Arc<ReflectionServiceState>,
20}
21
22#[tonic::async_trait]
23impl ServerReflection for ReflectionService {
24    type ServerReflectionInfoStream = ServerReflectionInfoStream;
25
26    async fn server_reflection_info(
27        &self,
28        req: Request<Streaming<ServerReflectionRequest>>,
29    ) -> Result<Response<Self::ServerReflectionInfoStream>, Status> {
30        let mut req_rx = req.into_inner();
31        let (resp_tx, resp_rx) = mpsc::channel::<Result<ServerReflectionResponse, Status>>(1);
32
33        let state = self.state.clone();
34
35        tokio::spawn(async move {
36            while let Some(req) = req_rx.next().await {
37                let Ok(req) = req else {
38                    return;
39                };
40
41                let resp_msg = match req.message_request.clone() {
42                    None => Err(Status::invalid_argument("invalid MessageRequest")),
43                    Some(msg) => match msg {
44                        MessageRequest::FileByFilename(s) => state.file_by_filename(&s).map(|fd| {
45                            MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
46                                file_descriptor_proto: vec![fd],
47                            })
48                        }),
49                        MessageRequest::FileContainingSymbol(s) => {
50                            state.symbol_by_name(&s).map(|fd| {
51                                MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
52                                    file_descriptor_proto: vec![fd],
53                                })
54                            })
55                        }
56                        MessageRequest::FileContainingExtension(_) => {
57                            Err(Status::not_found("extensions are not supported"))
58                        }
59                        MessageRequest::AllExtensionNumbersOfType(_) => {
60                            // NOTE: Workaround. Some grpc clients (e.g. grpcurl) expect this method not to fail.
61                            // https://github.com/hyperium/tonic/issues/1077
62                            Ok(MessageResponse::AllExtensionNumbersResponse(
63                                ExtensionNumberResponse::default(),
64                            ))
65                        }
66                        MessageRequest::ListServices(_) => {
67                            Ok(MessageResponse::ListServicesResponse(ListServiceResponse {
68                                service: state
69                                    .list_services()
70                                    .iter()
71                                    .map(|s| ServiceResponse { name: s.clone() })
72                                    .collect(),
73                            }))
74                        }
75                    },
76                };
77
78                match resp_msg {
79                    Ok(resp_msg) => {
80                        let resp = ServerReflectionResponse {
81                            valid_host: req.host.clone(),
82                            original_request: Some(req.clone()),
83                            message_response: Some(resp_msg),
84                        };
85                        resp_tx.send(Ok(resp)).await.expect("send");
86                    }
87                    Err(status) => {
88                        resp_tx.send(Err(status)).await.expect("send");
89                        return;
90                    }
91                }
92            }
93        });
94
95        Ok(Response::new(ServerReflectionInfoStream::new(resp_rx)))
96    }
97}
98
99impl From<ReflectionServiceState> for ReflectionService {
100    fn from(state: ReflectionServiceState) -> Self {
101        Self {
102            state: Arc::new(state),
103        }
104    }
105}
106
107/// A response stream.
108pub struct ServerReflectionInfoStream {
109    inner: tokio_stream::wrappers::ReceiverStream<Result<ServerReflectionResponse, Status>>,
110}
111
112impl ServerReflectionInfoStream {
113    fn new(resp_rx: mpsc::Receiver<Result<ServerReflectionResponse, Status>>) -> Self {
114        let inner = tokio_stream::wrappers::ReceiverStream::new(resp_rx);
115        Self { inner }
116    }
117}
118
119impl Stream for ServerReflectionInfoStream {
120    type Item = Result<ServerReflectionResponse, Status>;
121
122    fn poll_next(
123        mut self: std::pin::Pin<&mut Self>,
124        cx: &mut std::task::Context<'_>,
125    ) -> std::task::Poll<Option<Self::Item>> {
126        std::pin::Pin::new(&mut self.inner).poll_next(cx)
127    }
128
129    fn size_hint(&self) -> (usize, Option<usize>) {
130        self.inner.size_hint()
131    }
132}
133
134impl fmt::Debug for ServerReflectionInfoStream {
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        f.debug_tuple("ServerReflectionInfoStream").finish()
137    }
138}