1use std::collections::HashMap;
2use std::fmt::{Display, Formatter};
3use std::sync::Arc;
4
5use prost::{DecodeError, Message};
6use prost_types::{
7 DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
8 FileDescriptorSet,
9};
10use tonic::Status;
11
12pub mod v1;
14pub use v1::{ServerReflection, ServerReflectionServer}; pub mod v1alpha;
18
19#[derive(Debug)]
21pub struct Builder<'b> {
22 file_descriptor_sets: Vec<FileDescriptorSet>,
23 encoded_file_descriptor_sets: Vec<&'b [u8]>,
24 include_reflection_service: bool,
25
26 service_names: Vec<String>,
27 use_all_service_names: bool,
28}
29
30impl<'b> Builder<'b> {
31 pub fn configure() -> Self {
33 Builder {
34 file_descriptor_sets: Vec::new(),
35 encoded_file_descriptor_sets: Vec::new(),
36 include_reflection_service: true,
37
38 service_names: Vec::new(),
39 use_all_service_names: true,
40 }
41 }
42
43 pub fn register_file_descriptor_set(mut self, file_descriptor_set: FileDescriptorSet) -> Self {
46 self.file_descriptor_sets.push(file_descriptor_set);
47 self
48 }
49
50 pub fn register_encoded_file_descriptor_set(
53 mut self,
54 encoded_file_descriptor_set: &'b [u8],
55 ) -> Self {
56 self.encoded_file_descriptor_sets
57 .push(encoded_file_descriptor_set);
58 self
59 }
60
61 pub fn include_reflection_service(mut self, include: bool) -> Self {
64 self.include_reflection_service = include;
65 self
66 }
67
68 pub fn with_service_name(mut self, name: impl Into<String>) -> Self {
73 self.use_all_service_names = false;
74 self.service_names.push(name.into());
75 self
76 }
77
78 #[deprecated(since = "0.12.2", note = "use `build_v1()` instead")]
80 pub fn build(self) -> Result<v1::ServerReflectionServer<impl v1::ServerReflection>, Error> {
81 self.build_v1()
82 }
83
84 pub fn build_v1(
86 mut self,
87 ) -> Result<v1::ServerReflectionServer<impl v1::ServerReflection>, Error> {
88 if self.include_reflection_service {
89 self = self.register_encoded_file_descriptor_set(crate::pb::v1::FILE_DESCRIPTOR_SET);
90 }
91
92 Ok(v1::ServerReflectionServer::new(
93 v1::ReflectionService::from(ReflectionServiceState::new(
94 self.service_names,
95 self.encoded_file_descriptor_sets,
96 self.file_descriptor_sets,
97 self.use_all_service_names,
98 )?),
99 ))
100 }
101
102 pub fn build_v1alpha(
104 mut self,
105 ) -> Result<v1alpha::ServerReflectionServer<impl v1alpha::ServerReflection>, Error> {
106 if self.include_reflection_service {
107 self =
108 self.register_encoded_file_descriptor_set(crate::pb::v1alpha::FILE_DESCRIPTOR_SET);
109 }
110
111 Ok(v1alpha::ServerReflectionServer::new(
112 v1alpha::ReflectionService::from(ReflectionServiceState::new(
113 self.service_names,
114 self.encoded_file_descriptor_sets,
115 self.file_descriptor_sets,
116 self.use_all_service_names,
117 )?),
118 ))
119 }
120}
121
122#[derive(Debug)]
123struct ReflectionServiceState {
124 service_names: Vec<String>,
125 files: HashMap<String, Arc<FileDescriptorProto>>,
126 symbols: HashMap<String, Arc<FileDescriptorProto>>,
127}
128
129impl ReflectionServiceState {
130 fn new(
131 service_names: Vec<String>,
132 encoded_file_descriptor_sets: Vec<&[u8]>,
133 mut file_descriptor_sets: Vec<FileDescriptorSet>,
134 use_all_service_names: bool,
135 ) -> Result<Self, Error> {
136 for encoded in encoded_file_descriptor_sets {
137 file_descriptor_sets.push(FileDescriptorSet::decode(encoded)?);
138 }
139
140 let mut state = ReflectionServiceState {
141 service_names,
142 files: HashMap::new(),
143 symbols: HashMap::new(),
144 };
145
146 for fds in file_descriptor_sets {
147 for fd in fds.file {
148 let name = match fd.name.clone() {
149 None => {
150 return Err(Error::InvalidFileDescriptorSet("missing name".to_string()));
151 }
152 Some(n) => n,
153 };
154
155 if state.files.contains_key(&name) {
156 continue;
157 }
158
159 let fd = Arc::new(fd);
160 state.files.insert(name, fd.clone());
161 state.process_file(fd, use_all_service_names)?;
162 }
163 }
164
165 Ok(state)
166 }
167
168 fn process_file(
169 &mut self,
170 fd: Arc<FileDescriptorProto>,
171 use_all_service_names: bool,
172 ) -> Result<(), Error> {
173 let prefix = &fd.package.clone().unwrap_or_default();
174
175 for msg in &fd.message_type {
176 self.process_message(fd.clone(), prefix, msg)?;
177 }
178
179 for en in &fd.enum_type {
180 self.process_enum(fd.clone(), prefix, en)?;
181 }
182
183 for service in &fd.service {
184 let service_name = extract_name(prefix, "service", service.name.as_ref())?;
185 if use_all_service_names {
186 self.service_names.push(service_name.clone());
187 }
188 self.symbols.insert(service_name.clone(), fd.clone());
189
190 for method in &service.method {
191 let method_name = extract_name(&service_name, "method", method.name.as_ref())?;
192 self.symbols.insert(method_name, fd.clone());
193 }
194 }
195
196 Ok(())
197 }
198
199 fn process_message(
200 &mut self,
201 fd: Arc<FileDescriptorProto>,
202 prefix: &str,
203 msg: &DescriptorProto,
204 ) -> Result<(), Error> {
205 let message_name = extract_name(prefix, "message", msg.name.as_ref())?;
206 self.symbols.insert(message_name.clone(), fd.clone());
207
208 for nested in &msg.nested_type {
209 self.process_message(fd.clone(), &message_name, nested)?;
210 }
211
212 for en in &msg.enum_type {
213 self.process_enum(fd.clone(), &message_name, en)?;
214 }
215
216 for field in &msg.field {
217 self.process_field(fd.clone(), &message_name, field)?;
218 }
219
220 for oneof in &msg.oneof_decl {
221 let oneof_name = extract_name(&message_name, "oneof", oneof.name.as_ref())?;
222 self.symbols.insert(oneof_name, fd.clone());
223 }
224
225 Ok(())
226 }
227
228 fn process_enum(
229 &mut self,
230 fd: Arc<FileDescriptorProto>,
231 prefix: &str,
232 en: &EnumDescriptorProto,
233 ) -> Result<(), Error> {
234 let enum_name = extract_name(prefix, "enum", en.name.as_ref())?;
235 self.symbols.insert(enum_name.clone(), fd.clone());
236
237 for value in &en.value {
238 let value_name = extract_name(&enum_name, "enum value", value.name.as_ref())?;
239 self.symbols.insert(value_name, fd.clone());
240 }
241
242 Ok(())
243 }
244
245 fn process_field(
246 &mut self,
247 fd: Arc<FileDescriptorProto>,
248 prefix: &str,
249 field: &FieldDescriptorProto,
250 ) -> Result<(), Error> {
251 let field_name = extract_name(prefix, "field", field.name.as_ref())?;
252 self.symbols.insert(field_name, fd);
253 Ok(())
254 }
255
256 fn list_services(&self) -> &[String] {
257 &self.service_names
258 }
259
260 fn symbol_by_name(&self, symbol: &str) -> Result<Vec<u8>, Status> {
261 match self.symbols.get(symbol) {
262 None => Err(Status::not_found(format!("symbol '{}' not found", symbol))),
263 Some(fd) => {
264 let mut encoded_fd = Vec::new();
265 if fd.clone().encode(&mut encoded_fd).is_err() {
266 return Err(Status::internal("encoding error"));
267 };
268
269 Ok(encoded_fd)
270 }
271 }
272 }
273
274 fn file_by_filename(&self, filename: &str) -> Result<Vec<u8>, Status> {
275 match self.files.get(filename) {
276 None => Err(Status::not_found(format!("file '{}' not found", filename))),
277 Some(fd) => {
278 let mut encoded_fd = Vec::new();
279 if fd.clone().encode(&mut encoded_fd).is_err() {
280 return Err(Status::internal("encoding error"));
281 }
282
283 Ok(encoded_fd)
284 }
285 }
286 }
287}
288
289fn extract_name(
290 prefix: &str,
291 name_type: &str,
292 maybe_name: Option<&String>,
293) -> Result<String, Error> {
294 match maybe_name {
295 None => Err(Error::InvalidFileDescriptorSet(format!(
296 "missing {} name",
297 name_type
298 ))),
299 Some(name) => {
300 if prefix.is_empty() {
301 Ok(name.to_string())
302 } else {
303 Ok(format!("{}.{}", prefix, name))
304 }
305 }
306 }
307}
308
309#[derive(Debug)]
311pub enum Error {
312 DecodeError(prost::DecodeError),
314 InvalidFileDescriptorSet(String),
316}
317
318impl From<DecodeError> for Error {
319 fn from(e: DecodeError) -> Self {
320 Error::DecodeError(e)
321 }
322}
323
324impl std::error::Error for Error {}
325
326impl Display for Error {
327 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
328 match self {
329 Error::DecodeError(_) => f.write_str("error decoding FileDescriptorSet from buffer"),
330 Error::InvalidFileDescriptorSet(s) => {
331 write!(f, "invalid FileDescriptorSet - {}", s)
332 }
333 }
334 }
335}