use once_cell::sync::Lazy;
use opentelemetry::{
propagation::{text_map_propagator::FieldIter, Extractor, Injector, TextMapPropagator},
trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId, TraceState},
Context,
};
use std::str::FromStr;
const SUPPORTED_VERSION: u8 = 0;
const MAX_VERSION: u8 = 254;
const TRACEPARENT_HEADER: &str = "traceparent";
const TRACESTATE_HEADER: &str = "tracestate";
static TRACE_CONTEXT_HEADER_FIELDS: Lazy<[String; 2]> =
Lazy::new(|| [TRACEPARENT_HEADER.to_owned(), TRACESTATE_HEADER.to_owned()]);
#[derive(Clone, Debug, Default)]
pub struct TraceContextPropagator {
_private: (),
}
impl TraceContextPropagator {
pub fn new() -> Self {
TraceContextPropagator { _private: () }
}
fn extract_span_context(&self, extractor: &dyn Extractor) -> Result<SpanContext, ()> {
let header_value = extractor.get(TRACEPARENT_HEADER).unwrap_or("").trim();
let parts = header_value.split_terminator('-').collect::<Vec<&str>>();
if parts.len() < 4 {
return Err(());
}
let version = u8::from_str_radix(parts[0], 16).map_err(|_| ())?;
if version > MAX_VERSION || version == 0 && parts.len() != 4 {
return Err(());
}
if parts[1].chars().any(|c| c.is_ascii_uppercase()) {
return Err(());
}
let trace_id = TraceId::from_hex(parts[1]).map_err(|_| ())?;
if parts[2].chars().any(|c| c.is_ascii_uppercase()) {
return Err(());
}
let span_id = SpanId::from_hex(parts[2]).map_err(|_| ())?;
let opts = u8::from_str_radix(parts[3], 16).map_err(|_| ())?;
if version == 0 && opts > 2 {
return Err(());
}
let trace_flags = TraceFlags::new(opts) & TraceFlags::SAMPLED;
let trace_state = match extractor.get(TRACESTATE_HEADER) {
Some(trace_state_str) => {
TraceState::from_str(trace_state_str).unwrap_or_else(|_| TraceState::default())
}
None => TraceState::default(),
};
let span_context = SpanContext::new(trace_id, span_id, trace_flags, true, trace_state);
if !span_context.is_valid() {
return Err(());
}
Ok(span_context)
}
}
impl TextMapPropagator for TraceContextPropagator {
fn inject_context(&self, cx: &Context, injector: &mut dyn Injector) {
let span = cx.span();
let span_context = span.span_context();
if span_context.is_valid() {
let header_value = format!(
"{:02x}-{}-{}-{:02x}",
SUPPORTED_VERSION,
span_context.trace_id(),
span_context.span_id(),
span_context.trace_flags() & TraceFlags::SAMPLED
);
injector.set(TRACEPARENT_HEADER, header_value);
injector.set(TRACESTATE_HEADER, span_context.trace_state().header());
}
}
fn extract_with_context(&self, cx: &Context, extractor: &dyn Extractor) -> Context {
self.extract_span_context(extractor)
.map(|sc| cx.with_remote_span_context(sc))
.unwrap_or_else(|_| cx.clone())
}
fn fields(&self) -> FieldIter<'_> {
FieldIter::new(TRACE_CONTEXT_HEADER_FIELDS.as_ref())
}
}
#[cfg(all(test, feature = "testing", feature = "trace"))]
mod tests {
use super::*;
use crate::testing::trace::TestSpan;
use std::collections::HashMap;
#[rustfmt::skip]
fn extract_data() -> Vec<(&'static str, &'static str, SpanContext)> {
vec![
("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::default(), true, TraceState::from_str("foo=bar").unwrap())),
("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
("02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
("02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-09", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
("02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-08", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::default(), true, TraceState::from_str("foo=bar").unwrap())),
("02-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-09-XYZxsf09", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01-", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
("01-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-09-", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
]
}
#[rustfmt::skip]
fn extract_data_invalid() -> Vec<(&'static str, &'static str)> {
vec![
("0000-00000000000000000000000000000000-0000000000000000-01", "wrong version length"),
("00-ab00000000000000000000000000000000-cd00000000000000-01", "wrong trace ID length"),
("00-ab000000000000000000000000000000-cd0000000000000000-01", "wrong span ID length"),
("00-ab000000000000000000000000000000-cd00000000000000-0100", "wrong trace flag length"),
("qw-00000000000000000000000000000000-0000000000000000-01", "bogus version"),
("00-qw000000000000000000000000000000-cd00000000000000-01", "bogus trace ID"),
("00-ab000000000000000000000000000000-qw00000000000000-01", "bogus span ID"),
("00-ab000000000000000000000000000000-cd00000000000000-qw", "bogus trace flag"),
("A0-00000000000000000000000000000000-0000000000000000-01", "upper case version"),
("00-AB000000000000000000000000000000-cd00000000000000-01", "upper case trace ID"),
("00-ab000000000000000000000000000000-CD00000000000000-01", "upper case span ID"),
("00-ab000000000000000000000000000000-cd00000000000000-A1", "upper case trace flag"),
("00-00000000000000000000000000000000-0000000000000000-01", "zero trace ID and span ID"),
("00-ab000000000000000000000000000000-cd00000000000000-09", "trace-flag unused bits set"),
("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7", "missing options"),
("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-", "empty options"),
]
}
#[rustfmt::skip]
fn inject_data() -> Vec<(&'static str, &'static str, SpanContext)> {
vec![
("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::SAMPLED, true, TraceState::from_str("foo=bar").unwrap())),
("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::default(), true, TraceState::from_str("foo=bar").unwrap())),
("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", "foo=bar", SpanContext::new(TraceId::from_u128(0x4bf9_2f35_77b3_4da6_a3ce_929d_0e0e_4736), SpanId::from_u64(0x00f0_67aa_0ba9_02b7), TraceFlags::new(0xff), true, TraceState::from_str("foo=bar").unwrap())),
("", "", SpanContext::empty_context()),
]
}
#[test]
fn extract_w3c() {
let propagator = TraceContextPropagator::new();
for (trace_parent, trace_state, expected_context) in extract_data() {
let mut extractor = HashMap::new();
extractor.insert(TRACEPARENT_HEADER.to_string(), trace_parent.to_string());
extractor.insert(TRACESTATE_HEADER.to_string(), trace_state.to_string());
assert_eq!(
propagator.extract(&extractor).span().span_context(),
&expected_context
)
}
}
#[test]
fn extract_w3c_tracestate() {
let propagator = TraceContextPropagator::new();
let state = "foo=bar".to_string();
let parent = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00".to_string();
let mut extractor = HashMap::new();
extractor.insert(TRACEPARENT_HEADER.to_string(), parent);
extractor.insert(TRACESTATE_HEADER.to_string(), state.clone());
assert_eq!(
propagator
.extract(&extractor)
.span()
.span_context()
.trace_state()
.header(),
state
)
}
#[test]
fn extract_w3c_reject_invalid() {
let propagator = TraceContextPropagator::new();
for (invalid_header, reason) in extract_data_invalid() {
let mut extractor = HashMap::new();
extractor.insert(TRACEPARENT_HEADER.to_string(), invalid_header.to_string());
assert_eq!(
propagator.extract(&extractor).span().span_context(),
&SpanContext::empty_context(),
"{}",
reason
)
}
}
#[test]
fn inject_w3c() {
let propagator = TraceContextPropagator::new();
for (expected_trace_parent, expected_trace_state, context) in inject_data() {
let mut injector = HashMap::new();
propagator.inject_context(
&Context::current_with_span(TestSpan(context)),
&mut injector,
);
assert_eq!(
Extractor::get(&injector, TRACEPARENT_HEADER).unwrap_or(""),
expected_trace_parent
);
assert_eq!(
Extractor::get(&injector, TRACESTATE_HEADER).unwrap_or(""),
expected_trace_state
);
}
}
#[test]
fn inject_w3c_tracestate() {
let propagator = TraceContextPropagator::new();
let state = "foo=bar";
let mut injector: HashMap<String, String> = HashMap::new();
injector.set(TRACESTATE_HEADER, state.to_string());
Context::map_current(|cx| propagator.inject_context(cx, &mut injector));
assert_eq!(Extractor::get(&injector, TRACESTATE_HEADER), Some(state))
}
}