tonic_reflection/server/
v1.rs

1use std::sync::Arc;
2
3use tokio::sync::mpsc;
4use tokio_stream::{wrappers::ReceiverStream, StreamExt};
5use tonic::{Request, Response, Status, Streaming};
6
7use super::ReflectionServiceState;
8use crate::pb::v1::server_reflection_request::MessageRequest;
9use crate::pb::v1::server_reflection_response::MessageResponse;
10pub use crate::pb::v1::server_reflection_server::{ServerReflection, ServerReflectionServer};
11use crate::pb::v1::{
12    ExtensionNumberResponse, FileDescriptorResponse, ListServiceResponse, ServerReflectionRequest,
13    ServerReflectionResponse, ServiceResponse,
14};
15
16#[derive(Debug)]
17pub(super) struct ReflectionService {
18    state: Arc<ReflectionServiceState>,
19}
20
21#[tonic::async_trait]
22impl ServerReflection for ReflectionService {
23    type ServerReflectionInfoStream = ReceiverStream<Result<ServerReflectionResponse, Status>>;
24
25    async fn server_reflection_info(
26        &self,
27        req: Request<Streaming<ServerReflectionRequest>>,
28    ) -> Result<Response<Self::ServerReflectionInfoStream>, Status> {
29        let mut req_rx = req.into_inner();
30        let (resp_tx, resp_rx) = mpsc::channel::<Result<ServerReflectionResponse, Status>>(1);
31
32        let state = self.state.clone();
33
34        tokio::spawn(async move {
35            while let Some(req) = req_rx.next().await {
36                let Ok(req) = req else {
37                    return;
38                };
39
40                let resp_msg = match req.message_request.clone() {
41                    None => Err(Status::invalid_argument("invalid MessageRequest")),
42                    Some(msg) => match msg {
43                        MessageRequest::FileByFilename(s) => state.file_by_filename(&s).map(|fd| {
44                            MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
45                                file_descriptor_proto: vec![fd],
46                            })
47                        }),
48                        MessageRequest::FileContainingSymbol(s) => {
49                            state.symbol_by_name(&s).map(|fd| {
50                                MessageResponse::FileDescriptorResponse(FileDescriptorResponse {
51                                    file_descriptor_proto: vec![fd],
52                                })
53                            })
54                        }
55                        MessageRequest::FileContainingExtension(_) => {
56                            Err(Status::not_found("extensions are not supported"))
57                        }
58                        MessageRequest::AllExtensionNumbersOfType(_) => {
59                            // NOTE: Workaround. Some grpc clients (e.g. grpcurl) expect this method not to fail.
60                            // https://github.com/hyperium/tonic/issues/1077
61                            Ok(MessageResponse::AllExtensionNumbersResponse(
62                                ExtensionNumberResponse::default(),
63                            ))
64                        }
65                        MessageRequest::ListServices(_) => {
66                            Ok(MessageResponse::ListServicesResponse(ListServiceResponse {
67                                service: state
68                                    .list_services()
69                                    .iter()
70                                    .map(|s| ServiceResponse { name: s.clone() })
71                                    .collect(),
72                            }))
73                        }
74                    },
75                };
76
77                match resp_msg {
78                    Ok(resp_msg) => {
79                        let resp = ServerReflectionResponse {
80                            valid_host: req.host.clone(),
81                            original_request: Some(req.clone()),
82                            message_response: Some(resp_msg),
83                        };
84                        resp_tx.send(Ok(resp)).await.expect("send");
85                    }
86                    Err(status) => {
87                        resp_tx.send(Err(status)).await.expect("send");
88                        return;
89                    }
90                }
91            }
92        });
93
94        Ok(Response::new(ReceiverStream::new(resp_rx)))
95    }
96}
97
98impl From<ReflectionServiceState> for ReflectionService {
99    fn from(state: ReflectionServiceState) -> Self {
100        Self {
101            state: Arc::new(state),
102        }
103    }
104}