1use std::collections::HashMap;
2use std::fmt::{Display, Formatter};
3use std::sync::Arc;
4
5use prost::{DecodeError, Message};
6use prost_types::{
7 DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
8 FileDescriptorSet,
9};
10use tonic::Status;
11
12pub mod v1;
14pub mod v1alpha;
16
17#[derive(Debug)]
19pub struct Builder<'b> {
20 file_descriptor_sets: Vec<FileDescriptorSet>,
21 encoded_file_descriptor_sets: Vec<&'b [u8]>,
22 include_reflection_service: bool,
23
24 service_names: Vec<String>,
25 use_all_service_names: bool,
26}
27
28impl<'b> Builder<'b> {
29 pub fn configure() -> Self {
31 Builder {
32 file_descriptor_sets: Vec::new(),
33 encoded_file_descriptor_sets: Vec::new(),
34 include_reflection_service: true,
35
36 service_names: Vec::new(),
37 use_all_service_names: true,
38 }
39 }
40
41 pub fn register_file_descriptor_set(mut self, file_descriptor_set: FileDescriptorSet) -> Self {
44 self.file_descriptor_sets.push(file_descriptor_set);
45 self
46 }
47
48 pub fn register_encoded_file_descriptor_set(
51 mut self,
52 encoded_file_descriptor_set: &'b [u8],
53 ) -> Self {
54 self.encoded_file_descriptor_sets
55 .push(encoded_file_descriptor_set);
56 self
57 }
58
59 pub fn include_reflection_service(mut self, include: bool) -> Self {
62 self.include_reflection_service = include;
63 self
64 }
65
66 pub fn with_service_name(mut self, name: impl Into<String>) -> Self {
71 self.use_all_service_names = false;
72 self.service_names.push(name.into());
73 self
74 }
75
76 pub fn build_v1(
78 mut self,
79 ) -> Result<v1::ServerReflectionServer<impl v1::ServerReflection>, Error> {
80 if self.include_reflection_service {
81 self = self.register_encoded_file_descriptor_set(crate::pb::v1::FILE_DESCRIPTOR_SET);
82 }
83
84 Ok(v1::ServerReflectionServer::new(
85 v1::ReflectionService::from(ReflectionServiceState::new(
86 self.service_names,
87 self.encoded_file_descriptor_sets,
88 self.file_descriptor_sets,
89 self.use_all_service_names,
90 )?),
91 ))
92 }
93
94 pub fn build_v1alpha(
96 mut self,
97 ) -> Result<v1alpha::ServerReflectionServer<impl v1alpha::ServerReflection>, Error> {
98 if self.include_reflection_service {
99 self =
100 self.register_encoded_file_descriptor_set(crate::pb::v1alpha::FILE_DESCRIPTOR_SET);
101 }
102
103 Ok(v1alpha::ServerReflectionServer::new(
104 v1alpha::ReflectionService::from(ReflectionServiceState::new(
105 self.service_names,
106 self.encoded_file_descriptor_sets,
107 self.file_descriptor_sets,
108 self.use_all_service_names,
109 )?),
110 ))
111 }
112}
113
114#[derive(Debug)]
115struct ReflectionServiceState {
116 service_names: Vec<String>,
117 files: HashMap<String, Arc<FileDescriptorProto>>,
118 symbols: HashMap<String, Arc<FileDescriptorProto>>,
119}
120
121impl ReflectionServiceState {
122 fn new(
123 service_names: Vec<String>,
124 encoded_file_descriptor_sets: Vec<&[u8]>,
125 mut file_descriptor_sets: Vec<FileDescriptorSet>,
126 use_all_service_names: bool,
127 ) -> Result<Self, Error> {
128 for encoded in encoded_file_descriptor_sets {
129 file_descriptor_sets.push(FileDescriptorSet::decode(encoded)?);
130 }
131
132 let mut state = ReflectionServiceState {
133 service_names,
134 files: HashMap::new(),
135 symbols: HashMap::new(),
136 };
137
138 for fds in file_descriptor_sets {
139 for fd in fds.file {
140 let name = match fd.name.clone() {
141 None => {
142 return Err(Error::InvalidFileDescriptorSet("missing name".to_string()));
143 }
144 Some(n) => n,
145 };
146
147 if state.files.contains_key(&name) {
148 continue;
149 }
150
151 let fd = Arc::new(fd);
152 state.files.insert(name, fd.clone());
153 state.process_file(fd, use_all_service_names)?;
154 }
155 }
156
157 Ok(state)
158 }
159
160 fn process_file(
161 &mut self,
162 fd: Arc<FileDescriptorProto>,
163 use_all_service_names: bool,
164 ) -> Result<(), Error> {
165 let prefix = &fd.package.clone().unwrap_or_default();
166
167 for msg in &fd.message_type {
168 self.process_message(fd.clone(), prefix, msg)?;
169 }
170
171 for en in &fd.enum_type {
172 self.process_enum(fd.clone(), prefix, en)?;
173 }
174
175 for service in &fd.service {
176 let service_name = extract_name(prefix, "service", service.name.as_ref())?;
177 if use_all_service_names {
178 self.service_names.push(service_name.clone());
179 }
180 self.symbols.insert(service_name.clone(), fd.clone());
181
182 for method in &service.method {
183 let method_name = extract_name(&service_name, "method", method.name.as_ref())?;
184 self.symbols.insert(method_name, fd.clone());
185 }
186 }
187
188 Ok(())
189 }
190
191 fn process_message(
192 &mut self,
193 fd: Arc<FileDescriptorProto>,
194 prefix: &str,
195 msg: &DescriptorProto,
196 ) -> Result<(), Error> {
197 let message_name = extract_name(prefix, "message", msg.name.as_ref())?;
198 self.symbols.insert(message_name.clone(), fd.clone());
199
200 for nested in &msg.nested_type {
201 self.process_message(fd.clone(), &message_name, nested)?;
202 }
203
204 for en in &msg.enum_type {
205 self.process_enum(fd.clone(), &message_name, en)?;
206 }
207
208 for field in &msg.field {
209 self.process_field(fd.clone(), &message_name, field)?;
210 }
211
212 for oneof in &msg.oneof_decl {
213 let oneof_name = extract_name(&message_name, "oneof", oneof.name.as_ref())?;
214 self.symbols.insert(oneof_name, fd.clone());
215 }
216
217 Ok(())
218 }
219
220 fn process_enum(
221 &mut self,
222 fd: Arc<FileDescriptorProto>,
223 prefix: &str,
224 en: &EnumDescriptorProto,
225 ) -> Result<(), Error> {
226 let enum_name = extract_name(prefix, "enum", en.name.as_ref())?;
227 self.symbols.insert(enum_name.clone(), fd.clone());
228
229 for value in &en.value {
230 let value_name = extract_name(&enum_name, "enum value", value.name.as_ref())?;
231 self.symbols.insert(value_name, fd.clone());
232 }
233
234 Ok(())
235 }
236
237 fn process_field(
238 &mut self,
239 fd: Arc<FileDescriptorProto>,
240 prefix: &str,
241 field: &FieldDescriptorProto,
242 ) -> Result<(), Error> {
243 let field_name = extract_name(prefix, "field", field.name.as_ref())?;
244 self.symbols.insert(field_name, fd);
245 Ok(())
246 }
247
248 fn list_services(&self) -> &[String] {
249 &self.service_names
250 }
251
252 fn symbol_by_name(&self, symbol: &str) -> Result<Vec<u8>, Status> {
253 match self.symbols.get(symbol) {
254 None => Err(Status::not_found(format!("symbol '{symbol}' not found"))),
255 Some(fd) => {
256 let mut encoded_fd = Vec::new();
257 if fd.clone().encode(&mut encoded_fd).is_err() {
258 return Err(Status::internal("encoding error"));
259 };
260
261 Ok(encoded_fd)
262 }
263 }
264 }
265
266 fn file_by_filename(&self, filename: &str) -> Result<Vec<u8>, Status> {
267 match self.files.get(filename) {
268 None => Err(Status::not_found(format!("file '{filename}' not found"))),
269 Some(fd) => {
270 let mut encoded_fd = Vec::new();
271 if fd.clone().encode(&mut encoded_fd).is_err() {
272 return Err(Status::internal("encoding error"));
273 }
274
275 Ok(encoded_fd)
276 }
277 }
278 }
279}
280
281fn extract_name(
282 prefix: &str,
283 name_type: &str,
284 maybe_name: Option<&String>,
285) -> Result<String, Error> {
286 match maybe_name {
287 None => Err(Error::InvalidFileDescriptorSet(format!(
288 "missing {name_type} name"
289 ))),
290 Some(name) => {
291 if prefix.is_empty() {
292 Ok(name.to_string())
293 } else {
294 Ok(format!("{prefix}.{name}"))
295 }
296 }
297 }
298}
299
300#[derive(Debug)]
302pub enum Error {
303 DecodeError(prost::DecodeError),
305 InvalidFileDescriptorSet(String),
307}
308
309impl From<DecodeError> for Error {
310 fn from(e: DecodeError) -> Self {
311 Error::DecodeError(e)
312 }
313}
314
315impl std::error::Error for Error {}
316
317impl Display for Error {
318 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
319 match self {
320 Error::DecodeError(_) => f.write_str("error decoding FileDescriptorSet from buffer"),
321 Error::InvalidFileDescriptorSet(s) => {
322 write!(f, "invalid FileDescriptorSet - {s}")
323 }
324 }
325 }
326}