protobuf/reflect/
message.rs

1use std::collections::HashMap;
2use std::marker;
3
4use crate::descriptor::DescriptorProto;
5use crate::descriptor::FileDescriptorProto;
6use crate::descriptorx::find_message_by_rust_name;
7use crate::reflect::acc::FieldAccessor;
8use crate::reflect::find_message_or_enum::find_message_or_enum;
9use crate::reflect::find_message_or_enum::MessageOrEnum;
10use crate::reflect::FieldDescriptor;
11use crate::Message;
12
13trait MessageFactory: Send + Sync + 'static {
14    fn new_instance(&self) -> Box<dyn Message>;
15}
16
17struct MessageFactoryImpl<M>(marker::PhantomData<M>);
18
19impl<M> MessageFactory for MessageFactoryImpl<M>
20where
21    M: 'static + Message + Default + Clone + PartialEq,
22{
23    fn new_instance(&self) -> Box<dyn Message> {
24        let m: M = Default::default();
25        Box::new(m)
26    }
27}
28
29/// Dynamic message type
30pub struct MessageDescriptor {
31    full_name: String,
32    proto: &'static DescriptorProto,
33    factory: &'static dyn MessageFactory,
34    fields: Vec<FieldDescriptor>,
35
36    index_by_name: HashMap<String, usize>,
37    index_by_name_or_json_name: HashMap<String, usize>,
38    index_by_number: HashMap<u32, usize>,
39}
40
41impl MessageDescriptor {
42    /// Get underlying `DescriptorProto` object.
43    pub fn get_proto(&self) -> &DescriptorProto {
44        self.proto
45    }
46
47    /// Get a message descriptor for given message type
48    pub fn for_type<M: Message>() -> &'static MessageDescriptor {
49        M::descriptor_static()
50    }
51
52    fn compute_full_name(package: &str, path_to_package: &str, proto: &DescriptorProto) -> String {
53        let mut full_name = package.to_owned();
54        if path_to_package.len() != 0 {
55            if full_name.len() != 0 {
56                full_name.push('.');
57            }
58            full_name.push_str(path_to_package);
59        }
60        if full_name.len() != 0 {
61            full_name.push('.');
62        }
63        full_name.push_str(proto.get_name());
64        full_name
65    }
66
67    // Non-generic part of `new` is a separate function
68    // to reduce code bloat from multiple instantiations.
69    fn new_non_generic_by_rust_name(
70        rust_name: &'static str,
71        fields: Vec<FieldAccessor>,
72        file: &'static FileDescriptorProto,
73        factory: &'static dyn MessageFactory,
74    ) -> MessageDescriptor {
75        let proto = find_message_by_rust_name(file, rust_name);
76
77        let mut field_proto_by_name = HashMap::new();
78        for field_proto in proto.message.get_field() {
79            field_proto_by_name.insert(field_proto.get_name(), field_proto);
80        }
81
82        let mut index_by_name = HashMap::new();
83        let mut index_by_name_or_json_name = HashMap::new();
84        let mut index_by_number = HashMap::new();
85
86        let mut full_name = file.get_package().to_string();
87        if full_name.len() > 0 {
88            full_name.push('.');
89        }
90        full_name.push_str(proto.message.get_name());
91
92        let fields: Vec<_> = fields
93            .into_iter()
94            .map(|f| {
95                let proto = *field_proto_by_name.get(&f.name).unwrap();
96                FieldDescriptor::new(f, proto)
97            })
98            .collect();
99        for (i, f) in fields.iter().enumerate() {
100            assert!(index_by_number
101                .insert(f.proto().get_number() as u32, i)
102                .is_none());
103            assert!(index_by_name
104                .insert(f.proto().get_name().to_owned(), i)
105                .is_none());
106            assert!(index_by_name_or_json_name
107                .insert(f.proto().get_name().to_owned(), i)
108                .is_none());
109
110            let json_name = f.json_name().to_owned();
111
112            if json_name != f.proto().get_name() {
113                assert!(index_by_name_or_json_name.insert(json_name, i).is_none());
114            }
115        }
116        MessageDescriptor {
117            full_name,
118            proto: proto.message,
119            factory,
120            fields,
121            index_by_name,
122            index_by_name_or_json_name,
123            index_by_number,
124        }
125    }
126
127    // Non-generic part of `new` is a separate function
128    // to reduce code bloat from multiple instantiations.
129    fn new_non_generic_by_pb_name(
130        protobuf_name_to_package: &'static str,
131        fields: Vec<FieldAccessor>,
132        file_descriptor_proto: &'static FileDescriptorProto,
133        factory: &'static dyn MessageFactory,
134    ) -> MessageDescriptor {
135        let (path_to_package, proto) =
136            match find_message_or_enum(file_descriptor_proto, protobuf_name_to_package) {
137                (path_to_package, MessageOrEnum::Message(m)) => (path_to_package, m),
138                (_, MessageOrEnum::Enum(_)) => panic!("not a message"),
139            };
140
141        let mut field_proto_by_name = HashMap::new();
142        for field_proto in proto.get_field() {
143            field_proto_by_name.insert(field_proto.get_name(), field_proto);
144        }
145
146        let mut index_by_name = HashMap::new();
147        let mut index_by_name_or_json_name = HashMap::new();
148        let mut index_by_number = HashMap::new();
149
150        let full_name = MessageDescriptor::compute_full_name(
151            file_descriptor_proto.get_package(),
152            &path_to_package,
153            &proto,
154        );
155        let fields: Vec<_> = fields
156            .into_iter()
157            .map(|f| {
158                let proto = *field_proto_by_name.get(&f.name).unwrap();
159                FieldDescriptor::new(f, proto)
160            })
161            .collect();
162
163        for (i, f) in fields.iter().enumerate() {
164            assert!(index_by_number
165                .insert(f.proto().get_number() as u32, i)
166                .is_none());
167            assert!(index_by_name
168                .insert(f.proto().get_name().to_owned(), i)
169                .is_none());
170            assert!(index_by_name_or_json_name
171                .insert(f.proto().get_name().to_owned(), i)
172                .is_none());
173
174            let json_name = f.json_name().to_owned();
175
176            if json_name != f.proto().get_name() {
177                assert!(index_by_name_or_json_name.insert(json_name, i).is_none());
178            }
179        }
180        MessageDescriptor {
181            full_name,
182            proto,
183            factory,
184            fields,
185            index_by_name,
186            index_by_name_or_json_name,
187            index_by_number,
188        }
189    }
190
191    /// Construct a new message descriptor.
192    ///
193    /// This operation is called from generated code and rarely
194    /// need to be called directly.
195    #[doc(hidden)]
196    #[deprecated(
197        since = "2.12",
198        note = "Please regenerate .rs files from .proto files to use newer APIs"
199    )]
200    pub fn new<M: 'static + Message + Default + Clone + PartialEq>(
201        rust_name: &'static str,
202        fields: Vec<FieldAccessor>,
203        file: &'static FileDescriptorProto,
204    ) -> MessageDescriptor {
205        let factory = &MessageFactoryImpl(marker::PhantomData::<M>);
206        MessageDescriptor::new_non_generic_by_rust_name(rust_name, fields, file, factory)
207    }
208
209    /// Construct a new message descriptor.
210    ///
211    /// This operation is called from generated code and rarely
212    /// need to be called directly.
213    #[doc(hidden)]
214    pub fn new_pb_name<M: 'static + Message + Default + Clone + PartialEq>(
215        protobuf_name_to_package: &'static str,
216        fields: Vec<FieldAccessor>,
217        file_descriptor_proto: &'static FileDescriptorProto,
218    ) -> MessageDescriptor {
219        let factory = &MessageFactoryImpl(marker::PhantomData::<M>);
220        MessageDescriptor::new_non_generic_by_pb_name(
221            protobuf_name_to_package,
222            fields,
223            file_descriptor_proto,
224            factory,
225        )
226    }
227
228    /// New empty message
229    pub fn new_instance(&self) -> Box<dyn Message> {
230        self.factory.new_instance()
231    }
232
233    /// Message name as given in `.proto` file
234    pub fn name(&self) -> &'static str {
235        self.proto.get_name()
236    }
237
238    /// Fully qualified protobuf message name
239    pub fn full_name(&self) -> &str {
240        &self.full_name[..]
241    }
242
243    /// Message field descriptors.
244    pub fn fields(&self) -> &[FieldDescriptor] {
245        &self.fields
246    }
247
248    /// Find message field by protobuf field name
249    ///
250    /// Note: protobuf field name might be different for Rust field name.
251    pub fn get_field_by_name<'a>(&'a self, name: &str) -> Option<&'a FieldDescriptor> {
252        let &index = self.index_by_name.get(name)?;
253        Some(&self.fields[index])
254    }
255
256    /// Find message field by field name or field JSON name
257    pub fn get_field_by_name_or_json_name<'a>(&'a self, name: &str) -> Option<&'a FieldDescriptor> {
258        let &index = self.index_by_name_or_json_name.get(name)?;
259        Some(&self.fields[index])
260    }
261
262    /// Find message field by field name
263    pub fn get_field_by_number(&self, number: u32) -> Option<&FieldDescriptor> {
264        let &index = self.index_by_number.get(&number)?;
265        Some(&self.fields[index])
266    }
267
268    /// Find field by name
269    // TODO: deprecate
270    pub fn field_by_name<'a>(&'a self, name: &str) -> &'a FieldDescriptor {
271        // TODO: clone is weird
272        let &index = self.index_by_name.get(&name.to_string()).unwrap();
273        &self.fields[index]
274    }
275
276    /// Find field by number
277    // TODO: deprecate
278    pub fn field_by_number<'a>(&'a self, number: u32) -> &'a FieldDescriptor {
279        let &index = self.index_by_number.get(&number).unwrap();
280        &self.fields[index]
281    }
282}