1use crate::{SpanId, TraceFlags, TraceId};
2use std::collections::VecDeque;
3use std::hash::Hash;
4use std::str::FromStr;
5use thiserror::Error;
6
7#[derive(Clone, Debug, Default, Eq, PartialEq, Hash)]
15pub struct TraceState(Option<VecDeque<(String, String)>>);
16
17impl TraceState {
18 pub const NONE: TraceState = TraceState(None);
20
21 fn valid_key(key: &str) -> bool {
25 if key.len() > 256 {
26 return false;
27 }
28
29 let allowed_special = |b: u8| (b == b'_' || b == b'-' || b == b'*' || b == b'/');
30 let mut vendor_start = None;
31 for (i, &b) in key.as_bytes().iter().enumerate() {
32 if !(b.is_ascii_lowercase() || b.is_ascii_digit() || allowed_special(b) || b == b'@') {
33 return false;
34 }
35
36 if i == 0 && (!b.is_ascii_lowercase() && !b.is_ascii_digit()) {
37 return false;
38 } else if b == b'@' {
39 if vendor_start.is_some() || i + 14 < key.len() {
40 return false;
41 }
42 vendor_start = Some(i);
43 } else if let Some(start) = vendor_start {
44 if i == start + 1 && !(b.is_ascii_lowercase() || b.is_ascii_digit()) {
45 return false;
46 }
47 }
48 }
49
50 true
51 }
52
53 fn valid_value(value: &str) -> bool {
57 if value.len() > 256 {
58 return false;
59 }
60
61 !(value.contains(',') || value.contains('='))
62 }
63
64 pub fn from_key_value<T, K, V>(trace_state: T) -> TraceStateResult<Self>
78 where
79 T: IntoIterator<Item = (K, V)>,
80 K: ToString,
81 V: ToString,
82 {
83 let ordered_data = trace_state
84 .into_iter()
85 .map(|(key, value)| {
86 let (key, value) = (key.to_string(), value.to_string());
87 if !TraceState::valid_key(key.as_str()) {
88 return Err(TraceStateError::Key(key));
89 }
90 if !TraceState::valid_value(value.as_str()) {
91 return Err(TraceStateError::Value(value));
92 }
93
94 Ok((key, value))
95 })
96 .collect::<Result<VecDeque<_>, TraceStateError>>()?;
97
98 if ordered_data.is_empty() {
99 Ok(TraceState(None))
100 } else {
101 Ok(TraceState(Some(ordered_data)))
102 }
103 }
104
105 pub fn get(&self, key: &str) -> Option<&str> {
107 self.0.as_ref().and_then(|kvs| {
108 kvs.iter().find_map(|item| {
109 if item.0.as_str() == key {
110 Some(item.1.as_str())
111 } else {
112 None
113 }
114 })
115 })
116 }
117
118 pub fn insert<K, V>(&self, key: K, value: V) -> TraceStateResult<TraceState>
125 where
126 K: Into<String>,
127 V: Into<String>,
128 {
129 let (key, value) = (key.into(), value.into());
130 if !TraceState::valid_key(key.as_str()) {
131 return Err(TraceStateError::Key(key));
132 }
133 if !TraceState::valid_value(value.as_str()) {
134 return Err(TraceStateError::Value(value));
135 }
136
137 let mut trace_state = self.delete_from_deque(&key);
138 let kvs = trace_state.0.get_or_insert(VecDeque::with_capacity(1));
139
140 kvs.push_front((key, value));
141
142 Ok(trace_state)
143 }
144
145 pub fn delete<K: Into<String>>(&self, key: K) -> TraceStateResult<TraceState> {
153 let key = key.into();
154 if !TraceState::valid_key(key.as_str()) {
155 return Err(TraceStateError::Key(key));
156 }
157
158 Ok(self.delete_from_deque(&key))
159 }
160
161 fn delete_from_deque(&self, key: &str) -> TraceState {
163 let mut owned = self.clone();
164 if let Some(kvs) = owned.0.as_mut() {
165 if let Some(index) = kvs.iter().position(|x| x.0 == key) {
166 kvs.remove(index);
167 }
168 }
169 owned
170 }
171
172 pub fn header(&self) -> String {
175 self.header_delimited("=", ",")
176 }
177
178 pub fn header_delimited(&self, entry_delimiter: &str, list_delimiter: &str) -> String {
180 self.0
181 .as_ref()
182 .map(|kvs| {
183 kvs.iter()
184 .map(|(key, value)| format!("{}{}{}", key, entry_delimiter, value))
185 .collect::<Vec<String>>()
186 .join(list_delimiter)
187 })
188 .unwrap_or_default()
189 }
190}
191
192impl FromStr for TraceState {
193 type Err = TraceStateError;
194
195 fn from_str(s: &str) -> Result<Self, Self::Err> {
196 let list_members: Vec<&str> = s.split_terminator(',').collect();
197 let mut key_value_pairs: Vec<(String, String)> = Vec::with_capacity(list_members.len());
198
199 for list_member in list_members {
200 match list_member.find('=') {
201 None => return Err(TraceStateError::List(list_member.to_string())),
202 Some(separator_index) => {
203 let (key, value) = list_member.split_at(separator_index);
204 key_value_pairs
205 .push((key.to_string(), value.trim_start_matches('=').to_string()));
206 }
207 }
208 }
209
210 TraceState::from_key_value(key_value_pairs)
211 }
212}
213
214type TraceStateResult<T> = Result<T, TraceStateError>;
216
217#[derive(Error, Debug)]
219#[non_exhaustive]
220pub enum TraceStateError {
221 #[error("{0} is not a valid key in TraceState, see https://www.w3.org/TR/trace-context/#key for more details")]
225 Key(String),
226
227 #[error("{0} is not a valid value in TraceState, see https://www.w3.org/TR/trace-context/#value for more details")]
231 Value(String),
232
233 #[error("{0} is not a valid list member in TraceState, see https://www.w3.org/TR/trace-context/#list for more details")]
237 List(String),
238}
239
240#[derive(Clone, Debug, PartialEq, Hash, Eq)]
250pub struct SpanContext {
251 trace_id: TraceId,
252 span_id: SpanId,
253 trace_flags: TraceFlags,
254 is_remote: bool,
255 trace_state: TraceState,
256}
257
258impl SpanContext {
259 pub const NONE: SpanContext = SpanContext {
261 trace_id: TraceId::INVALID,
262 span_id: SpanId::INVALID,
263 trace_flags: TraceFlags::NOT_SAMPLED,
264 is_remote: false,
265 trace_state: TraceState::NONE,
266 };
267
268 pub fn empty_context() -> Self {
270 SpanContext::NONE
271 }
272
273 pub fn new(
275 trace_id: TraceId,
276 span_id: SpanId,
277 trace_flags: TraceFlags,
278 is_remote: bool,
279 trace_state: TraceState,
280 ) -> Self {
281 SpanContext {
282 trace_id,
283 span_id,
284 trace_flags,
285 is_remote,
286 trace_state,
287 }
288 }
289
290 pub fn trace_id(&self) -> TraceId {
292 self.trace_id
293 }
294
295 pub fn span_id(&self) -> SpanId {
297 self.span_id
298 }
299
300 pub fn trace_flags(&self) -> TraceFlags {
305 self.trace_flags
306 }
307
308 pub fn is_valid(&self) -> bool {
311 self.trace_id != TraceId::INVALID && self.span_id != SpanId::INVALID
312 }
313
314 pub fn is_remote(&self) -> bool {
316 self.is_remote
317 }
318
319 pub fn is_sampled(&self) -> bool {
323 self.trace_flags.is_sampled()
324 }
325
326 pub fn trace_state(&self) -> &TraceState {
328 &self.trace_state
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use crate::{trace::TraceContextExt, Context};
336
337 #[rustfmt::skip]
338 fn trace_state_test_data() -> Vec<(TraceState, &'static str, &'static str)> {
339 vec![
340 (TraceState::from_key_value(vec![("foo", "bar")]).unwrap(), "foo=bar", "foo"),
341 (TraceState::from_key_value(vec![("foo", ""), ("apple", "banana")]).unwrap(), "foo=,apple=banana", "apple"),
342 (TraceState::from_key_value(vec![("foo", "bar"), ("apple", "banana")]).unwrap(), "foo=bar,apple=banana", "apple"),
343 ]
344 }
345
346 #[test]
347 fn test_trace_state() {
348 for test_case in trace_state_test_data() {
349 assert_eq!(test_case.0.clone().header(), test_case.1);
350
351 let new_key = format!("{}-{}", test_case.0.get(test_case.2).unwrap(), "test");
352
353 let updated_trace_state = test_case.0.insert(test_case.2, new_key.clone());
354 assert!(updated_trace_state.is_ok());
355 let updated_trace_state = updated_trace_state.unwrap();
356
357 let updated = format!("{}={}", test_case.2, new_key);
358
359 let index = updated_trace_state.clone().header().find(&updated);
360
361 assert!(index.is_some());
362 assert_eq!(index.unwrap(), 0);
363
364 let deleted_trace_state = updated_trace_state.delete(test_case.2.to_string());
365 assert!(deleted_trace_state.is_ok());
366
367 let deleted_trace_state = deleted_trace_state.unwrap();
368
369 assert!(deleted_trace_state.get(test_case.2).is_none());
370 }
371 }
372
373 #[test]
374 fn test_trace_state_key() {
375 let test_data: Vec<(&'static str, bool)> = vec![
376 ("123", true),
377 ("bar", true),
378 ("foo@bar", true),
379 ("foo@0123456789abcdef", false),
380 ("foo@012345678", true),
381 ("FOO@BAR", false),
382 ("你好", false),
383 ];
384
385 for (key, expected) in test_data {
386 assert_eq!(TraceState::valid_key(key), expected, "test key: {:?}", key);
387 }
388 }
389
390 #[test]
391 fn test_trace_state_insert() {
392 let trace_state = TraceState::from_key_value(vec![("foo", "bar")]).unwrap();
393 let inserted_trace_state = trace_state.insert("testkey", "testvalue").unwrap();
394 assert!(trace_state.get("testkey").is_none()); assert_eq!(inserted_trace_state.get("testkey").unwrap(), "testvalue"); }
397
398 #[test]
399 fn test_context_span_debug() {
400 let cx = Context::current();
401 assert_eq!(
402 format!("{:?}", cx),
403 "Context { span: \"None\", entries count: 0, suppress_telemetry: false }"
404 );
405 let cx = Context::current().with_remote_span_context(SpanContext::NONE);
406 assert_eq!(
407 format!("{:?}", cx),
408 "Context { \
409 span: SpanContext { \
410 trace_id: 00000000000000000000000000000000, \
411 span_id: 0000000000000000, \
412 trace_flags: TraceFlags(0), \
413 is_remote: false, \
414 trace_state: TraceState(None) \
415 }, \
416 entries count: 1, suppress_telemetry: false \
417 }"
418 );
419 }
420}