1use crate::trace::{TraceError, TraceResult};
2use std::collections::VecDeque;
3use std::fmt;
4use std::hash::Hash;
5use std::num::ParseIntError;
6use std::ops::{BitAnd, BitOr, Not};
7use std::str::FromStr;
8use thiserror::Error;
9
10#[derive(Clone, Debug, Default, PartialEq, Eq, Copy, Hash)]
20pub struct TraceFlags(u8);
21
22impl TraceFlags {
23 pub const NOT_SAMPLED: TraceFlags = TraceFlags(0x00);
30
31 pub const SAMPLED: TraceFlags = TraceFlags(0x01);
38
39 pub const fn new(flags: u8) -> Self {
41 TraceFlags(flags)
42 }
43
44 pub fn is_sampled(&self) -> bool {
46 (*self & TraceFlags::SAMPLED) == TraceFlags::SAMPLED
47 }
48
49 pub fn with_sampled(&self, sampled: bool) -> Self {
51 if sampled {
52 *self | TraceFlags::SAMPLED
53 } else {
54 *self & !TraceFlags::SAMPLED
55 }
56 }
57
58 pub fn to_u8(self) -> u8 {
60 self.0
61 }
62}
63
64impl BitAnd for TraceFlags {
65 type Output = Self;
66
67 fn bitand(self, rhs: Self) -> Self::Output {
68 Self(self.0 & rhs.0)
69 }
70}
71
72impl BitOr for TraceFlags {
73 type Output = Self;
74
75 fn bitor(self, rhs: Self) -> Self::Output {
76 Self(self.0 | rhs.0)
77 }
78}
79
80impl Not for TraceFlags {
81 type Output = Self;
82
83 fn not(self) -> Self::Output {
84 Self(!self.0)
85 }
86}
87
88impl fmt::LowerHex for TraceFlags {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 fmt::LowerHex::fmt(&self.0, f)
91 }
92}
93
94#[derive(Clone, PartialEq, Eq, Copy, Hash)]
98pub struct TraceId(u128);
99
100impl TraceId {
101 pub const INVALID: TraceId = TraceId(0);
103
104 pub const fn from_bytes(bytes: [u8; 16]) -> Self {
106 TraceId(u128::from_be_bytes(bytes))
107 }
108
109 pub const fn to_bytes(self) -> [u8; 16] {
111 self.0.to_be_bytes()
112 }
113
114 pub fn from_hex(hex: &str) -> Result<Self, ParseIntError> {
127 u128::from_str_radix(hex, 16).map(TraceId)
128 }
129}
130
131impl From<u128> for TraceId {
132 fn from(value: u128) -> Self {
133 TraceId(value)
134 }
135}
136
137impl fmt::Debug for TraceId {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 f.write_fmt(format_args!("{:032x}", self.0))
140 }
141}
142
143impl fmt::Display for TraceId {
144 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145 f.write_fmt(format_args!("{:032x}", self.0))
146 }
147}
148
149impl fmt::LowerHex for TraceId {
150 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151 fmt::LowerHex::fmt(&self.0, f)
152 }
153}
154
155#[derive(Clone, PartialEq, Eq, Copy, Hash)]
159pub struct SpanId(u64);
160
161impl SpanId {
162 pub const INVALID: SpanId = SpanId(0);
164
165 pub const fn from_bytes(bytes: [u8; 8]) -> Self {
167 SpanId(u64::from_be_bytes(bytes))
168 }
169
170 pub const fn to_bytes(self) -> [u8; 8] {
172 self.0.to_be_bytes()
173 }
174
175 pub fn from_hex(hex: &str) -> Result<Self, ParseIntError> {
188 u64::from_str_radix(hex, 16).map(SpanId)
189 }
190}
191
192impl From<u64> for SpanId {
193 fn from(value: u64) -> Self {
194 SpanId(value)
195 }
196}
197
198impl fmt::Debug for SpanId {
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 f.write_fmt(format_args!("{:016x}", self.0))
201 }
202}
203
204impl fmt::Display for SpanId {
205 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 f.write_fmt(format_args!("{:016x}", self.0))
207 }
208}
209
210impl fmt::LowerHex for SpanId {
211 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
212 fmt::LowerHex::fmt(&self.0, f)
213 }
214}
215
216#[derive(Clone, Debug, Default, Eq, PartialEq, Hash)]
224pub struct TraceState(Option<VecDeque<(String, String)>>);
225
226impl TraceState {
227 pub const NONE: TraceState = TraceState(None);
229
230 fn valid_key(key: &str) -> bool {
234 if key.len() > 256 {
235 return false;
236 }
237
238 let allowed_special = |b: u8| (b == b'_' || b == b'-' || b == b'*' || b == b'/');
239 let mut vendor_start = None;
240 for (i, &b) in key.as_bytes().iter().enumerate() {
241 if !(b.is_ascii_lowercase() || b.is_ascii_digit() || allowed_special(b) || b == b'@') {
242 return false;
243 }
244
245 if i == 0 && (!b.is_ascii_lowercase() && !b.is_ascii_digit()) {
246 return false;
247 } else if b == b'@' {
248 if vendor_start.is_some() || i + 14 < key.len() {
249 return false;
250 }
251 vendor_start = Some(i);
252 } else if let Some(start) = vendor_start {
253 if i == start + 1 && !(b.is_ascii_lowercase() || b.is_ascii_digit()) {
254 return false;
255 }
256 }
257 }
258
259 true
260 }
261
262 fn valid_value(value: &str) -> bool {
266 if value.len() > 256 {
267 return false;
268 }
269
270 !(value.contains(',') || value.contains('='))
271 }
272
273 pub fn from_key_value<T, K, V>(trace_state: T) -> TraceResult<Self>
287 where
288 T: IntoIterator<Item = (K, V)>,
289 K: ToString,
290 V: ToString,
291 {
292 let ordered_data = trace_state
293 .into_iter()
294 .map(|(key, value)| {
295 let (key, value) = (key.to_string(), value.to_string());
296 if !TraceState::valid_key(key.as_str()) {
297 return Err(TraceStateError::Key(key));
298 }
299 if !TraceState::valid_value(value.as_str()) {
300 return Err(TraceStateError::Value(value));
301 }
302
303 Ok((key, value))
304 })
305 .collect::<Result<VecDeque<_>, TraceStateError>>()?;
306
307 if ordered_data.is_empty() {
308 Ok(TraceState(None))
309 } else {
310 Ok(TraceState(Some(ordered_data)))
311 }
312 }
313
314 pub fn get(&self, key: &str) -> Option<&str> {
316 self.0.as_ref().and_then(|kvs| {
317 kvs.iter().find_map(|item| {
318 if item.0.as_str() == key {
319 Some(item.1.as_str())
320 } else {
321 None
322 }
323 })
324 })
325 }
326
327 pub fn insert<K, V>(&self, key: K, value: V) -> TraceResult<TraceState>
334 where
335 K: Into<String>,
336 V: Into<String>,
337 {
338 let (key, value) = (key.into(), value.into());
339 if !TraceState::valid_key(key.as_str()) {
340 return Err(TraceStateError::Key(key).into());
341 }
342 if !TraceState::valid_value(value.as_str()) {
343 return Err(TraceStateError::Value(value).into());
344 }
345
346 let mut trace_state = self.delete_from_deque(key.clone());
347 let kvs = trace_state.0.get_or_insert(VecDeque::with_capacity(1));
348
349 kvs.push_front((key, value));
350
351 Ok(trace_state)
352 }
353
354 pub fn delete<K: Into<String>>(&self, key: K) -> TraceResult<TraceState> {
362 let key = key.into();
363 if !TraceState::valid_key(key.as_str()) {
364 return Err(TraceStateError::Key(key).into());
365 }
366
367 Ok(self.delete_from_deque(key))
368 }
369
370 fn delete_from_deque(&self, key: String) -> TraceState {
372 let mut owned = self.clone();
373 if let Some(kvs) = owned.0.as_mut() {
374 if let Some(index) = kvs.iter().position(|x| *x.0 == *key) {
375 kvs.remove(index);
376 }
377 }
378 owned
379 }
380
381 pub fn header(&self) -> String {
384 self.header_delimited("=", ",")
385 }
386
387 pub fn header_delimited(&self, entry_delimiter: &str, list_delimiter: &str) -> String {
389 self.0
390 .as_ref()
391 .map(|kvs| {
392 kvs.iter()
393 .map(|(key, value)| format!("{}{}{}", key, entry_delimiter, value))
394 .collect::<Vec<String>>()
395 .join(list_delimiter)
396 })
397 .unwrap_or_default()
398 }
399}
400
401impl FromStr for TraceState {
402 type Err = TraceError;
403
404 fn from_str(s: &str) -> Result<Self, Self::Err> {
405 let list_members: Vec<&str> = s.split_terminator(',').collect();
406 let mut key_value_pairs: Vec<(String, String)> = Vec::with_capacity(list_members.len());
407
408 for list_member in list_members {
409 match list_member.find('=') {
410 None => return Err(TraceStateError::List(list_member.to_string()).into()),
411 Some(separator_index) => {
412 let (key, value) = list_member.split_at(separator_index);
413 key_value_pairs
414 .push((key.to_string(), value.trim_start_matches('=').to_string()));
415 }
416 }
417 }
418
419 TraceState::from_key_value(key_value_pairs)
420 }
421}
422
423#[derive(Error, Debug)]
425#[non_exhaustive]
426enum TraceStateError {
427 #[error("{0} is not a valid key in TraceState, see https://www.w3.org/TR/trace-context/#key for more details")]
431 Key(String),
432
433 #[error("{0} is not a valid value in TraceState, see https://www.w3.org/TR/trace-context/#value for more details")]
437 Value(String),
438
439 #[error("{0} is not a valid list member in TraceState, see https://www.w3.org/TR/trace-context/#list for more details")]
443 List(String),
444}
445
446impl From<TraceStateError> for TraceError {
447 fn from(err: TraceStateError) -> Self {
448 TraceError::Other(Box::new(err))
449 }
450}
451
452#[derive(Clone, Debug, PartialEq, Hash, Eq)]
462pub struct SpanContext {
463 trace_id: TraceId,
464 span_id: SpanId,
465 trace_flags: TraceFlags,
466 is_remote: bool,
467 trace_state: TraceState,
468}
469
470impl SpanContext {
471 pub const NONE: SpanContext = SpanContext {
473 trace_id: TraceId::INVALID,
474 span_id: SpanId::INVALID,
475 trace_flags: TraceFlags::NOT_SAMPLED,
476 is_remote: false,
477 trace_state: TraceState::NONE,
478 };
479
480 pub fn empty_context() -> Self {
482 SpanContext::NONE
483 }
484
485 pub fn new(
487 trace_id: TraceId,
488 span_id: SpanId,
489 trace_flags: TraceFlags,
490 is_remote: bool,
491 trace_state: TraceState,
492 ) -> Self {
493 SpanContext {
494 trace_id,
495 span_id,
496 trace_flags,
497 is_remote,
498 trace_state,
499 }
500 }
501
502 pub fn trace_id(&self) -> TraceId {
504 self.trace_id
505 }
506
507 pub fn span_id(&self) -> SpanId {
509 self.span_id
510 }
511
512 pub fn trace_flags(&self) -> TraceFlags {
517 self.trace_flags
518 }
519
520 pub fn is_valid(&self) -> bool {
523 self.trace_id != TraceId::INVALID && self.span_id != SpanId::INVALID
524 }
525
526 pub fn is_remote(&self) -> bool {
528 self.is_remote
529 }
530
531 pub fn is_sampled(&self) -> bool {
535 self.trace_flags.is_sampled()
536 }
537
538 pub fn trace_state(&self) -> &TraceState {
540 &self.trace_state
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547
548 #[rustfmt::skip]
549 fn trace_id_test_data() -> Vec<(TraceId, &'static str, [u8; 16])> {
550 vec![
551 (TraceId(0), "00000000000000000000000000000000", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
552 (TraceId(42), "0000000000000000000000000000002a", [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 42]),
553 (TraceId(126642714606581564793456114182061442190), "5f467fe7bf42676c05e20ba4a90e448e", [95, 70, 127, 231, 191, 66, 103, 108, 5, 226, 11, 164, 169, 14, 68, 142])
554 ]
555 }
556
557 #[rustfmt::skip]
558 fn span_id_test_data() -> Vec<(SpanId, &'static str, [u8; 8])> {
559 vec![
560 (SpanId(0), "0000000000000000", [0, 0, 0, 0, 0, 0, 0, 0]),
561 (SpanId(42), "000000000000002a", [0, 0, 0, 0, 0, 0, 0, 42]),
562 (SpanId(5508496025762705295), "4c721bf33e3caf8f", [76, 114, 27, 243, 62, 60, 175, 143])
563 ]
564 }
565
566 #[rustfmt::skip]
567 fn trace_state_test_data() -> Vec<(TraceState, &'static str, &'static str)> {
568 vec![
569 (TraceState::from_key_value(vec![("foo", "bar")]).unwrap(), "foo=bar", "foo"),
570 (TraceState::from_key_value(vec![("foo", ""), ("apple", "banana")]).unwrap(), "foo=,apple=banana", "apple"),
571 (TraceState::from_key_value(vec![("foo", "bar"), ("apple", "banana")]).unwrap(), "foo=bar,apple=banana", "apple"),
572 ]
573 }
574
575 #[test]
576 fn test_trace_id() {
577 for test_case in trace_id_test_data() {
578 assert_eq!(format!("{}", test_case.0), test_case.1);
579 assert_eq!(format!("{:032x}", test_case.0), test_case.1);
580 assert_eq!(test_case.0.to_bytes(), test_case.2);
581
582 assert_eq!(test_case.0, TraceId::from_hex(test_case.1).unwrap());
583 assert_eq!(test_case.0, TraceId::from_bytes(test_case.2));
584 }
585 }
586
587 #[test]
588 fn test_span_id() {
589 for test_case in span_id_test_data() {
590 assert_eq!(format!("{}", test_case.0), test_case.1);
591 assert_eq!(format!("{:016x}", test_case.0), test_case.1);
592 assert_eq!(test_case.0.to_bytes(), test_case.2);
593
594 assert_eq!(test_case.0, SpanId::from_hex(test_case.1).unwrap());
595 assert_eq!(test_case.0, SpanId::from_bytes(test_case.2));
596 }
597 }
598
599 #[test]
600 fn test_trace_state() {
601 for test_case in trace_state_test_data() {
602 assert_eq!(test_case.0.clone().header(), test_case.1);
603
604 let new_key = format!("{}-{}", test_case.0.get(test_case.2).unwrap(), "test");
605
606 let updated_trace_state = test_case.0.insert(test_case.2, new_key.clone());
607 assert!(updated_trace_state.is_ok());
608 let updated_trace_state = updated_trace_state.unwrap();
609
610 let updated = format!("{}={}", test_case.2, new_key);
611
612 let index = updated_trace_state.clone().header().find(&updated);
613
614 assert!(index.is_some());
615 assert_eq!(index.unwrap(), 0);
616
617 let deleted_trace_state = updated_trace_state.delete(test_case.2.to_string());
618 assert!(deleted_trace_state.is_ok());
619
620 let deleted_trace_state = deleted_trace_state.unwrap();
621
622 assert!(deleted_trace_state.get(test_case.2).is_none());
623 }
624 }
625
626 #[test]
627 fn test_trace_state_key() {
628 let test_data: Vec<(&'static str, bool)> = vec![
629 ("123", true),
630 ("bar", true),
631 ("foo@bar", true),
632 ("foo@0123456789abcdef", false),
633 ("foo@012345678", true),
634 ("FOO@BAR", false),
635 ("你好", false),
636 ];
637
638 for (key, expected) in test_data {
639 assert_eq!(TraceState::valid_key(key), expected, "test key: {:?}", key);
640 }
641 }
642
643 #[test]
644 fn test_trace_state_insert() {
645 let trace_state = TraceState::from_key_value(vec![("foo", "bar")]).unwrap();
646 let inserted_trace_state = trace_state.insert("testkey", "testvalue").unwrap();
647 assert!(trace_state.get("testkey").is_none()); assert_eq!(inserted_trace_state.get("testkey").unwrap(), "testvalue"); }
650}