http_types/trace/
trace_context.rs

1use rand::Rng;
2use std::fmt;
3
4use crate::headers::{HeaderName, HeaderValue, Headers, TRACEPARENT};
5use crate::Status;
6
7/// Extract and apply [Trace-Context](https://w3c.github.io/trace-context/) headers.
8///
9/// # Specifications
10///
11/// - [Trace-Context (Working Draft)](https://w3c.github.io/trace-context/)
12///
13/// # Examples
14///
15/// ```
16/// # fn main() -> http_types::Result<()> {
17/// #
18/// use http_types::trace::TraceContext;
19///
20/// let mut res = http_types::Response::new(200);
21///
22/// res.insert_header(
23///     "traceparent",
24///     "00-0af7651916cd43dd8448eb211c80319c-00f067aa0ba902b7-01"
25/// );
26///
27/// let context = TraceContext::from_headers(&res)?.unwrap();
28///
29/// let trace_id = u128::from_str_radix("0af7651916cd43dd8448eb211c80319c", 16);
30/// let parent_id = u64::from_str_radix("00f067aa0ba902b7", 16);
31///
32/// assert_eq!(context.trace_id(), trace_id.unwrap());
33/// assert_eq!(context.parent_id(), parent_id.ok());
34/// assert_eq!(context.sampled(), true);
35/// #
36/// # Ok(()) }
37/// ```
38#[derive(Debug)]
39pub struct TraceContext {
40    id: u64,
41    version: u8,
42    trace_id: u128,
43    parent_id: Option<u64>,
44    flags: u8,
45}
46
47impl TraceContext {
48    /// Generate a new TraceContext object without a parent.
49    ///
50    /// By default root TraceContext objects are sampled.
51    /// To mark it unsampled, call `context.set_sampled(false)`.
52    ///
53    /// # Examples
54    /// ```
55    /// use http_types::trace::TraceContext;
56    ///
57    /// let context = TraceContext::new();
58    ///
59    /// assert_eq!(context.parent_id(), None);
60    /// assert_eq!(context.sampled(), true);
61    /// ```
62    pub fn new() -> Self {
63        let mut rng = rand::thread_rng();
64
65        Self {
66            id: rng.gen(),
67            version: 0,
68            trace_id: rng.gen(),
69            parent_id: None,
70            flags: 1,
71        }
72    }
73
74    /// Create and return TraceContext object based on `traceparent` HTTP header.
75    ///
76    /// # Errors
77    ///
78    /// This function may error if the header is malformed. An error with a
79    /// status code of `400: Bad Request` will be generated.
80    ///
81    /// # Examples
82    ///
83    /// ```
84    /// # fn main() -> http_types::Result<()> {
85    /// #
86    /// use http_types::trace::TraceContext;
87    ///
88    /// let mut res = http_types::Response::new(200);
89    /// res.insert_header(
90    ///   "traceparent",
91    ///   "00-0af7651916cd43dd8448eb211c80319c-00f067aa0ba902b7-01"
92    /// );
93    ///
94    /// let context = TraceContext::from_headers(&res)?.unwrap();
95    ///
96    /// let trace_id = u128::from_str_radix("0af7651916cd43dd8448eb211c80319c", 16);
97    /// let parent_id = u64::from_str_radix("00f067aa0ba902b7", 16);
98    ///
99    /// assert_eq!(context.trace_id(), trace_id.unwrap());
100    /// assert_eq!(context.parent_id(), parent_id.ok());
101    /// assert_eq!(context.sampled(), true);
102    /// #
103    /// # Ok(()) }
104    /// ```
105    pub fn from_headers(headers: impl AsRef<Headers>) -> crate::Result<Option<Self>> {
106        let headers = headers.as_ref();
107        let mut rng = rand::thread_rng();
108
109        let traceparent = match headers.get(TRACEPARENT) {
110            Some(header) => header,
111            None => return Ok(None),
112        };
113        let parts: Vec<&str> = traceparent.as_str().split('-').collect();
114
115        Ok(Some(Self {
116            id: rng.gen(),
117            version: u8::from_str_radix(parts[0], 16)?,
118            trace_id: u128::from_str_radix(parts[1], 16).status(400)?,
119            parent_id: Some(u64::from_str_radix(parts[2], 16).status(400)?),
120            flags: u8::from_str_radix(parts[3], 16).status(400)?,
121        }))
122    }
123
124    /// Add the traceparent header to the http headers
125    ///
126    /// # Examples
127    ///
128    /// ```
129    /// # fn main() -> http_types::Result<()> {
130    /// #
131    /// use http_types::trace::TraceContext;
132    /// use http_types::{Request, Response, Url, Method};
133    ///
134    /// let mut req = Request::new(Method::Get, Url::parse("https://example.com").unwrap());
135    /// req.insert_header(
136    ///   "traceparent",
137    ///   "00-0af7651916cd43dd8448eb211c80319c-00f067aa0ba902b7-01"
138    /// );
139    ///
140    /// let parent = TraceContext::from_headers(&req)?.unwrap();
141    ///
142    /// let mut res = Response::new(200);
143    /// parent.apply(&mut res);
144    ///
145    /// let child = TraceContext::from_headers(&res)?.unwrap();
146    ///
147    /// assert_eq!(child.version(), parent.version());
148    /// assert_eq!(child.trace_id(), parent.trace_id());
149    /// assert_eq!(child.parent_id(), Some(parent.id()));
150    /// #
151    /// # Ok(()) }
152    /// ```
153    pub fn apply(&self, mut headers: impl AsMut<Headers>) {
154        let headers = headers.as_mut();
155        headers.insert(TRACEPARENT, self.value());
156    }
157
158    /// Get the `HeaderName`.
159    pub fn name(&self) -> HeaderName {
160        TRACEPARENT
161    }
162
163    /// Get the `HeaderValue`.
164    pub fn value(&self) -> HeaderValue {
165        let output = format!("{}", self);
166        unsafe { HeaderValue::from_bytes_unchecked(output.into()) }
167    }
168
169    /// Generate a child of the current TraceContext and return it.
170    ///
171    /// The child will have a new randomly genrated `id` and its `parent_id` will be set to the
172    /// `id` of this TraceContext.
173    pub fn child(&self) -> Self {
174        let mut rng = rand::thread_rng();
175
176        Self {
177            id: rng.gen(),
178            version: self.version,
179            trace_id: self.trace_id,
180            parent_id: Some(self.id),
181            flags: self.flags,
182        }
183    }
184
185    /// Return the id of the TraceContext.
186    pub fn id(&self) -> u64 {
187        self.id
188    }
189
190    /// Return the version of the TraceContext spec used.
191    ///
192    /// You probably don't need this.
193    pub fn version(&self) -> u8 {
194        self.version
195    }
196
197    /// Return the trace id of the TraceContext.
198    ///
199    /// All children will have the same `trace_id`.
200    pub fn trace_id(&self) -> u128 {
201        self.trace_id
202    }
203
204    /// Return the id of the parent TraceContext.
205    #[inline]
206    pub fn parent_id(&self) -> Option<u64> {
207        self.parent_id
208    }
209
210    /// Returns true if the trace is sampled
211    ///
212    /// # Examples
213    ///
214    /// ```
215    /// # fn main() -> http_types::Result<()> {
216    /// #
217    /// use http_types::trace::TraceContext;
218    /// use http_types::Response;
219    ///
220    /// let mut res = Response::new(200);
221    /// res.insert_header("traceparent", "00-00000000000000000000000000000001-0000000000000002-01");
222    /// let context = TraceContext::from_headers(&res)?.unwrap();
223    /// assert_eq!(context.sampled(), true);
224    /// #
225    /// # Ok(()) }
226    /// ```
227    pub fn sampled(&self) -> bool {
228        (self.flags & 0b00000001) == 1
229    }
230
231    /// Change sampled flag
232    ///
233    /// # Examples
234    ///
235    /// ```
236    /// use http_types::trace::TraceContext;
237    ///
238    /// let mut context = TraceContext::new();
239    /// assert_eq!(context.sampled(), true);
240    /// context.set_sampled(false);
241    /// assert_eq!(context.sampled(), false);
242    /// ```
243    pub fn set_sampled(&mut self, sampled: bool) {
244        let x = sampled as u8;
245        self.flags ^= (x ^ self.flags) & (1 << 0);
246    }
247}
248
249impl fmt::Display for TraceContext {
250    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
251        write!(
252            f,
253            "{:02x}-{:032x}-{:016x}-{:02x}",
254            self.version, self.trace_id, self.id, self.flags
255        )
256    }
257}
258
259#[cfg(test)]
260mod test {
261    use super::*;
262
263    #[test]
264    fn default() -> crate::Result<()> {
265        let mut headers = crate::Headers::new();
266        headers.insert(TRACEPARENT, "00-01-deadbeef-00");
267        let context = TraceContext::from_headers(&mut headers)?.unwrap();
268        assert_eq!(context.version(), 0);
269        assert_eq!(context.trace_id(), 1);
270        assert_eq!(context.parent_id().unwrap(), 3735928559);
271        assert_eq!(context.flags, 0);
272        assert!(!context.sampled());
273        Ok(())
274    }
275
276    #[test]
277    fn no_header() {
278        let context = TraceContext::new();
279        assert_eq!(context.version(), 0);
280        assert_eq!(context.parent_id(), None);
281        assert_eq!(context.flags, 1);
282        assert!(context.sampled());
283    }
284
285    #[test]
286    fn not_sampled() -> crate::Result<()> {
287        let mut headers = crate::Headers::new();
288        headers.insert(TRACEPARENT, "00-01-02-00");
289        let context = TraceContext::from_headers(&mut headers)?.unwrap();
290        assert!(!context.sampled());
291        Ok(())
292    }
293
294    #[test]
295    fn sampled() -> crate::Result<()> {
296        let mut headers = crate::Headers::new();
297        headers.insert(TRACEPARENT, "00-01-02-01");
298        let context = TraceContext::from_headers(&mut headers)?.unwrap();
299        assert!(context.sampled());
300        Ok(())
301    }
302}