tonic_reflection/server/
mod.rs

1use std::collections::HashMap;
2use std::fmt::{Display, Formatter};
3use std::sync::Arc;
4
5use prost::{DecodeError, Message};
6use prost_types::{
7    DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
8    FileDescriptorSet,
9};
10use tonic::Status;
11
12/// v1 interface for the gRPC Reflection Service server.
13pub mod v1;
14/// v1alpha interface for the gRPC Reflection Service server.
15pub mod v1alpha;
16
17/// A builder used to construct a gRPC Reflection Service.
18#[derive(Debug)]
19pub struct Builder<'b> {
20    file_descriptor_sets: Vec<FileDescriptorSet>,
21    encoded_file_descriptor_sets: Vec<&'b [u8]>,
22    include_reflection_service: bool,
23
24    service_names: Vec<String>,
25    use_all_service_names: bool,
26}
27
28impl<'b> Builder<'b> {
29    /// Create a new builder that can configure a gRPC Reflection Service.
30    pub fn configure() -> Self {
31        Builder {
32            file_descriptor_sets: Vec::new(),
33            encoded_file_descriptor_sets: Vec::new(),
34            include_reflection_service: true,
35
36            service_names: Vec::new(),
37            use_all_service_names: true,
38        }
39    }
40
41    /// Registers an instance of `prost_types::FileDescriptorSet` with the gRPC Reflection
42    /// Service builder.
43    pub fn register_file_descriptor_set(mut self, file_descriptor_set: FileDescriptorSet) -> Self {
44        self.file_descriptor_sets.push(file_descriptor_set);
45        self
46    }
47
48    /// Registers a byte slice containing an encoded `prost_types::FileDescriptorSet` with
49    /// the gRPC Reflection Service builder.
50    pub fn register_encoded_file_descriptor_set(
51        mut self,
52        encoded_file_descriptor_set: &'b [u8],
53    ) -> Self {
54        self.encoded_file_descriptor_sets
55            .push(encoded_file_descriptor_set);
56        self
57    }
58
59    /// Serve the gRPC Reflection Service descriptor via the Reflection Service. This is enabled
60    /// by default - set `include` to false to disable.
61    pub fn include_reflection_service(mut self, include: bool) -> Self {
62        self.include_reflection_service = include;
63        self
64    }
65
66    /// Advertise a fully-qualified gRPC service name.
67    ///
68    /// If not called, then all services present in the registered file descriptor sets
69    /// will be advertised.
70    pub fn with_service_name(mut self, name: impl Into<String>) -> Self {
71        self.use_all_service_names = false;
72        self.service_names.push(name.into());
73        self
74    }
75
76    /// Build a v1 gRPC Reflection Service to be served via Tonic.
77    pub fn build_v1(
78        mut self,
79    ) -> Result<v1::ServerReflectionServer<impl v1::ServerReflection>, Error> {
80        if self.include_reflection_service {
81            self = self.register_encoded_file_descriptor_set(crate::pb::v1::FILE_DESCRIPTOR_SET);
82        }
83
84        Ok(v1::ServerReflectionServer::new(
85            v1::ReflectionService::from(ReflectionServiceState::new(
86                self.service_names,
87                self.encoded_file_descriptor_sets,
88                self.file_descriptor_sets,
89                self.use_all_service_names,
90            )?),
91        ))
92    }
93
94    /// Build a v1alpha gRPC Reflection Service to be served via Tonic.
95    pub fn build_v1alpha(
96        mut self,
97    ) -> Result<v1alpha::ServerReflectionServer<impl v1alpha::ServerReflection>, Error> {
98        if self.include_reflection_service {
99            self =
100                self.register_encoded_file_descriptor_set(crate::pb::v1alpha::FILE_DESCRIPTOR_SET);
101        }
102
103        Ok(v1alpha::ServerReflectionServer::new(
104            v1alpha::ReflectionService::from(ReflectionServiceState::new(
105                self.service_names,
106                self.encoded_file_descriptor_sets,
107                self.file_descriptor_sets,
108                self.use_all_service_names,
109            )?),
110        ))
111    }
112}
113
114#[derive(Debug)]
115struct ReflectionServiceState {
116    service_names: Vec<String>,
117    files: HashMap<String, Arc<FileDescriptorProto>>,
118    symbols: HashMap<String, Arc<FileDescriptorProto>>,
119}
120
121impl ReflectionServiceState {
122    fn new(
123        service_names: Vec<String>,
124        encoded_file_descriptor_sets: Vec<&[u8]>,
125        mut file_descriptor_sets: Vec<FileDescriptorSet>,
126        use_all_service_names: bool,
127    ) -> Result<Self, Error> {
128        for encoded in encoded_file_descriptor_sets {
129            file_descriptor_sets.push(FileDescriptorSet::decode(encoded)?);
130        }
131
132        let mut state = ReflectionServiceState {
133            service_names,
134            files: HashMap::new(),
135            symbols: HashMap::new(),
136        };
137
138        for fds in file_descriptor_sets {
139            for fd in fds.file {
140                let name = match fd.name.clone() {
141                    None => {
142                        return Err(Error::InvalidFileDescriptorSet("missing name".to_string()));
143                    }
144                    Some(n) => n,
145                };
146
147                if state.files.contains_key(&name) {
148                    continue;
149                }
150
151                let fd = Arc::new(fd);
152                state.files.insert(name, fd.clone());
153                state.process_file(fd, use_all_service_names)?;
154            }
155        }
156
157        Ok(state)
158    }
159
160    fn process_file(
161        &mut self,
162        fd: Arc<FileDescriptorProto>,
163        use_all_service_names: bool,
164    ) -> Result<(), Error> {
165        let prefix = &fd.package.clone().unwrap_or_default();
166
167        for msg in &fd.message_type {
168            self.process_message(fd.clone(), prefix, msg)?;
169        }
170
171        for en in &fd.enum_type {
172            self.process_enum(fd.clone(), prefix, en)?;
173        }
174
175        for service in &fd.service {
176            let service_name = extract_name(prefix, "service", service.name.as_ref())?;
177            if use_all_service_names {
178                self.service_names.push(service_name.clone());
179            }
180            self.symbols.insert(service_name.clone(), fd.clone());
181
182            for method in &service.method {
183                let method_name = extract_name(&service_name, "method", method.name.as_ref())?;
184                self.symbols.insert(method_name, fd.clone());
185            }
186        }
187
188        Ok(())
189    }
190
191    fn process_message(
192        &mut self,
193        fd: Arc<FileDescriptorProto>,
194        prefix: &str,
195        msg: &DescriptorProto,
196    ) -> Result<(), Error> {
197        let message_name = extract_name(prefix, "message", msg.name.as_ref())?;
198        self.symbols.insert(message_name.clone(), fd.clone());
199
200        for nested in &msg.nested_type {
201            self.process_message(fd.clone(), &message_name, nested)?;
202        }
203
204        for en in &msg.enum_type {
205            self.process_enum(fd.clone(), &message_name, en)?;
206        }
207
208        for field in &msg.field {
209            self.process_field(fd.clone(), &message_name, field)?;
210        }
211
212        for oneof in &msg.oneof_decl {
213            let oneof_name = extract_name(&message_name, "oneof", oneof.name.as_ref())?;
214            self.symbols.insert(oneof_name, fd.clone());
215        }
216
217        Ok(())
218    }
219
220    fn process_enum(
221        &mut self,
222        fd: Arc<FileDescriptorProto>,
223        prefix: &str,
224        en: &EnumDescriptorProto,
225    ) -> Result<(), Error> {
226        let enum_name = extract_name(prefix, "enum", en.name.as_ref())?;
227        self.symbols.insert(enum_name.clone(), fd.clone());
228
229        for value in &en.value {
230            let value_name = extract_name(&enum_name, "enum value", value.name.as_ref())?;
231            self.symbols.insert(value_name, fd.clone());
232        }
233
234        Ok(())
235    }
236
237    fn process_field(
238        &mut self,
239        fd: Arc<FileDescriptorProto>,
240        prefix: &str,
241        field: &FieldDescriptorProto,
242    ) -> Result<(), Error> {
243        let field_name = extract_name(prefix, "field", field.name.as_ref())?;
244        self.symbols.insert(field_name, fd);
245        Ok(())
246    }
247
248    fn list_services(&self) -> &[String] {
249        &self.service_names
250    }
251
252    fn symbol_by_name(&self, symbol: &str) -> Result<Vec<u8>, Status> {
253        match self.symbols.get(symbol) {
254            None => Err(Status::not_found(format!("symbol '{symbol}' not found"))),
255            Some(fd) => {
256                let mut encoded_fd = Vec::new();
257                if fd.clone().encode(&mut encoded_fd).is_err() {
258                    return Err(Status::internal("encoding error"));
259                };
260
261                Ok(encoded_fd)
262            }
263        }
264    }
265
266    fn file_by_filename(&self, filename: &str) -> Result<Vec<u8>, Status> {
267        match self.files.get(filename) {
268            None => Err(Status::not_found(format!("file '{filename}' not found"))),
269            Some(fd) => {
270                let mut encoded_fd = Vec::new();
271                if fd.clone().encode(&mut encoded_fd).is_err() {
272                    return Err(Status::internal("encoding error"));
273                }
274
275                Ok(encoded_fd)
276            }
277        }
278    }
279}
280
281fn extract_name(
282    prefix: &str,
283    name_type: &str,
284    maybe_name: Option<&String>,
285) -> Result<String, Error> {
286    match maybe_name {
287        None => Err(Error::InvalidFileDescriptorSet(format!(
288            "missing {name_type} name"
289        ))),
290        Some(name) => {
291            if prefix.is_empty() {
292                Ok(name.to_string())
293            } else {
294                Ok(format!("{prefix}.{name}"))
295            }
296        }
297    }
298}
299
300/// Represents an error in the construction of a gRPC Reflection Service.
301#[derive(Debug)]
302pub enum Error {
303    /// An error was encountered decoding a `prost_types::FileDescriptorSet` from a buffer.
304    DecodeError(prost::DecodeError),
305    /// An invalid `prost_types::FileDescriptorProto` was encountered.
306    InvalidFileDescriptorSet(String),
307}
308
309impl From<DecodeError> for Error {
310    fn from(e: DecodeError) -> Self {
311        Error::DecodeError(e)
312    }
313}
314
315impl std::error::Error for Error {}
316
317impl Display for Error {
318    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
319        match self {
320            Error::DecodeError(_) => f.write_str("error decoding FileDescriptorSet from buffer"),
321            Error::InvalidFileDescriptorSet(s) => {
322                write!(f, "invalid FileDescriptorSet - {s}")
323            }
324        }
325    }
326}