tonic_reflection/server/
mod.rs

1use std::collections::HashMap;
2use std::fmt::{Display, Formatter};
3use std::sync::Arc;
4
5use prost::{DecodeError, Message};
6use prost_types::{
7    DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
8    FileDescriptorSet,
9};
10use tonic::Status;
11
12/// v1 interface for the gRPC Reflection Service server.
13pub mod v1;
14/// Deprecated; access these via `v1` instead.
15pub use v1::{ServerReflection, ServerReflectionServer}; // For backwards compatibility
16/// v1alpha interface for the gRPC Reflection Service server.
17pub mod v1alpha;
18
19/// A builder used to construct a gRPC Reflection Service.
20#[derive(Debug)]
21pub struct Builder<'b> {
22    file_descriptor_sets: Vec<FileDescriptorSet>,
23    encoded_file_descriptor_sets: Vec<&'b [u8]>,
24    include_reflection_service: bool,
25
26    service_names: Vec<String>,
27    use_all_service_names: bool,
28}
29
30impl<'b> Builder<'b> {
31    /// Create a new builder that can configure a gRPC Reflection Service.
32    pub fn configure() -> Self {
33        Builder {
34            file_descriptor_sets: Vec::new(),
35            encoded_file_descriptor_sets: Vec::new(),
36            include_reflection_service: true,
37
38            service_names: Vec::new(),
39            use_all_service_names: true,
40        }
41    }
42
43    /// Registers an instance of `prost_types::FileDescriptorSet` with the gRPC Reflection
44    /// Service builder.
45    pub fn register_file_descriptor_set(mut self, file_descriptor_set: FileDescriptorSet) -> Self {
46        self.file_descriptor_sets.push(file_descriptor_set);
47        self
48    }
49
50    /// Registers a byte slice containing an encoded `prost_types::FileDescriptorSet` with
51    /// the gRPC Reflection Service builder.
52    pub fn register_encoded_file_descriptor_set(
53        mut self,
54        encoded_file_descriptor_set: &'b [u8],
55    ) -> Self {
56        self.encoded_file_descriptor_sets
57            .push(encoded_file_descriptor_set);
58        self
59    }
60
61    /// Serve the gRPC Reflection Service descriptor via the Reflection Service. This is enabled
62    /// by default - set `include` to false to disable.
63    pub fn include_reflection_service(mut self, include: bool) -> Self {
64        self.include_reflection_service = include;
65        self
66    }
67
68    /// Advertise a fully-qualified gRPC service name.
69    ///
70    /// If not called, then all services present in the registered file descriptor sets
71    /// will be advertised.
72    pub fn with_service_name(mut self, name: impl Into<String>) -> Self {
73        self.use_all_service_names = false;
74        self.service_names.push(name.into());
75        self
76    }
77
78    /// Build a v1 gRPC Reflection Service to be served via Tonic.
79    #[deprecated(since = "0.12.2", note = "use `build_v1()` instead")]
80    pub fn build(self) -> Result<v1::ServerReflectionServer<impl v1::ServerReflection>, Error> {
81        self.build_v1()
82    }
83
84    /// Build a v1 gRPC Reflection Service to be served via Tonic.
85    pub fn build_v1(
86        mut self,
87    ) -> Result<v1::ServerReflectionServer<impl v1::ServerReflection>, Error> {
88        if self.include_reflection_service {
89            self = self.register_encoded_file_descriptor_set(crate::pb::v1::FILE_DESCRIPTOR_SET);
90        }
91
92        Ok(v1::ServerReflectionServer::new(
93            v1::ReflectionService::from(ReflectionServiceState::new(
94                self.service_names,
95                self.encoded_file_descriptor_sets,
96                self.file_descriptor_sets,
97                self.use_all_service_names,
98            )?),
99        ))
100    }
101
102    /// Build a v1alpha gRPC Reflection Service to be served via Tonic.
103    pub fn build_v1alpha(
104        mut self,
105    ) -> Result<v1alpha::ServerReflectionServer<impl v1alpha::ServerReflection>, Error> {
106        if self.include_reflection_service {
107            self =
108                self.register_encoded_file_descriptor_set(crate::pb::v1alpha::FILE_DESCRIPTOR_SET);
109        }
110
111        Ok(v1alpha::ServerReflectionServer::new(
112            v1alpha::ReflectionService::from(ReflectionServiceState::new(
113                self.service_names,
114                self.encoded_file_descriptor_sets,
115                self.file_descriptor_sets,
116                self.use_all_service_names,
117            )?),
118        ))
119    }
120}
121
122#[derive(Debug)]
123struct ReflectionServiceState {
124    service_names: Vec<String>,
125    files: HashMap<String, Arc<FileDescriptorProto>>,
126    symbols: HashMap<String, Arc<FileDescriptorProto>>,
127}
128
129impl ReflectionServiceState {
130    fn new(
131        service_names: Vec<String>,
132        encoded_file_descriptor_sets: Vec<&[u8]>,
133        mut file_descriptor_sets: Vec<FileDescriptorSet>,
134        use_all_service_names: bool,
135    ) -> Result<Self, Error> {
136        for encoded in encoded_file_descriptor_sets {
137            file_descriptor_sets.push(FileDescriptorSet::decode(encoded)?);
138        }
139
140        let mut state = ReflectionServiceState {
141            service_names,
142            files: HashMap::new(),
143            symbols: HashMap::new(),
144        };
145
146        for fds in file_descriptor_sets {
147            for fd in fds.file {
148                let name = match fd.name.clone() {
149                    None => {
150                        return Err(Error::InvalidFileDescriptorSet("missing name".to_string()));
151                    }
152                    Some(n) => n,
153                };
154
155                if state.files.contains_key(&name) {
156                    continue;
157                }
158
159                let fd = Arc::new(fd);
160                state.files.insert(name, fd.clone());
161                state.process_file(fd, use_all_service_names)?;
162            }
163        }
164
165        Ok(state)
166    }
167
168    fn process_file(
169        &mut self,
170        fd: Arc<FileDescriptorProto>,
171        use_all_service_names: bool,
172    ) -> Result<(), Error> {
173        let prefix = &fd.package.clone().unwrap_or_default();
174
175        for msg in &fd.message_type {
176            self.process_message(fd.clone(), prefix, msg)?;
177        }
178
179        for en in &fd.enum_type {
180            self.process_enum(fd.clone(), prefix, en)?;
181        }
182
183        for service in &fd.service {
184            let service_name = extract_name(prefix, "service", service.name.as_ref())?;
185            if use_all_service_names {
186                self.service_names.push(service_name.clone());
187            }
188            self.symbols.insert(service_name.clone(), fd.clone());
189
190            for method in &service.method {
191                let method_name = extract_name(&service_name, "method", method.name.as_ref())?;
192                self.symbols.insert(method_name, fd.clone());
193            }
194        }
195
196        Ok(())
197    }
198
199    fn process_message(
200        &mut self,
201        fd: Arc<FileDescriptorProto>,
202        prefix: &str,
203        msg: &DescriptorProto,
204    ) -> Result<(), Error> {
205        let message_name = extract_name(prefix, "message", msg.name.as_ref())?;
206        self.symbols.insert(message_name.clone(), fd.clone());
207
208        for nested in &msg.nested_type {
209            self.process_message(fd.clone(), &message_name, nested)?;
210        }
211
212        for en in &msg.enum_type {
213            self.process_enum(fd.clone(), &message_name, en)?;
214        }
215
216        for field in &msg.field {
217            self.process_field(fd.clone(), &message_name, field)?;
218        }
219
220        for oneof in &msg.oneof_decl {
221            let oneof_name = extract_name(&message_name, "oneof", oneof.name.as_ref())?;
222            self.symbols.insert(oneof_name, fd.clone());
223        }
224
225        Ok(())
226    }
227
228    fn process_enum(
229        &mut self,
230        fd: Arc<FileDescriptorProto>,
231        prefix: &str,
232        en: &EnumDescriptorProto,
233    ) -> Result<(), Error> {
234        let enum_name = extract_name(prefix, "enum", en.name.as_ref())?;
235        self.symbols.insert(enum_name.clone(), fd.clone());
236
237        for value in &en.value {
238            let value_name = extract_name(&enum_name, "enum value", value.name.as_ref())?;
239            self.symbols.insert(value_name, fd.clone());
240        }
241
242        Ok(())
243    }
244
245    fn process_field(
246        &mut self,
247        fd: Arc<FileDescriptorProto>,
248        prefix: &str,
249        field: &FieldDescriptorProto,
250    ) -> Result<(), Error> {
251        let field_name = extract_name(prefix, "field", field.name.as_ref())?;
252        self.symbols.insert(field_name, fd);
253        Ok(())
254    }
255
256    fn list_services(&self) -> &[String] {
257        &self.service_names
258    }
259
260    fn symbol_by_name(&self, symbol: &str) -> Result<Vec<u8>, Status> {
261        match self.symbols.get(symbol) {
262            None => Err(Status::not_found(format!("symbol '{}' not found", symbol))),
263            Some(fd) => {
264                let mut encoded_fd = Vec::new();
265                if fd.clone().encode(&mut encoded_fd).is_err() {
266                    return Err(Status::internal("encoding error"));
267                };
268
269                Ok(encoded_fd)
270            }
271        }
272    }
273
274    fn file_by_filename(&self, filename: &str) -> Result<Vec<u8>, Status> {
275        match self.files.get(filename) {
276            None => Err(Status::not_found(format!("file '{}' not found", filename))),
277            Some(fd) => {
278                let mut encoded_fd = Vec::new();
279                if fd.clone().encode(&mut encoded_fd).is_err() {
280                    return Err(Status::internal("encoding error"));
281                }
282
283                Ok(encoded_fd)
284            }
285        }
286    }
287}
288
289fn extract_name(
290    prefix: &str,
291    name_type: &str,
292    maybe_name: Option<&String>,
293) -> Result<String, Error> {
294    match maybe_name {
295        None => Err(Error::InvalidFileDescriptorSet(format!(
296            "missing {} name",
297            name_type
298        ))),
299        Some(name) => {
300            if prefix.is_empty() {
301                Ok(name.to_string())
302            } else {
303                Ok(format!("{}.{}", prefix, name))
304            }
305        }
306    }
307}
308
309/// Represents an error in the construction of a gRPC Reflection Service.
310#[derive(Debug)]
311pub enum Error {
312    /// An error was encountered decoding a `prost_types::FileDescriptorSet` from a buffer.
313    DecodeError(prost::DecodeError),
314    /// An invalid `prost_types::FileDescriptorProto` was encountered.
315    InvalidFileDescriptorSet(String),
316}
317
318impl From<DecodeError> for Error {
319    fn from(e: DecodeError) -> Self {
320        Error::DecodeError(e)
321    }
322}
323
324impl std::error::Error for Error {}
325
326impl Display for Error {
327    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
328        match self {
329            Error::DecodeError(_) => f.write_str("error decoding FileDescriptorSet from buffer"),
330            Error::InvalidFileDescriptorSet(s) => {
331                write!(f, "invalid FileDescriptorSet - {}", s)
332            }
333        }
334    }
335}