matchit/
tree.rs

1use crate::{InsertError, MatchError, Params};
2
3use std::cell::UnsafeCell;
4use std::cmp::min;
5use std::mem;
6
7/// The types of nodes the tree can hold
8#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)]
9enum NodeType {
10    /// The root path
11    Root,
12    /// A route parameter, ex: `/:id`.
13    Param,
14    /// A catchall parameter, ex: `/*file`
15    CatchAll,
16    /// Anything else
17    Static,
18}
19
20/// A radix tree used for URL path matching.
21///
22/// See [the crate documentation](crate) for details.
23pub struct Node<T> {
24    priority: u32,
25    wild_child: bool,
26    indices: Vec<u8>,
27    node_type: NodeType,
28    // see `at_inner` for why an unsafe cell is needed.
29    value: Option<UnsafeCell<T>>,
30    pub(crate) prefix: Vec<u8>,
31    pub(crate) children: Vec<Self>,
32}
33
34// SAFETY: we expose `value` per rust's usual borrowing rules, so we can just delegate these traits
35unsafe impl<T: Send> Send for Node<T> {}
36unsafe impl<T: Sync> Sync for Node<T> {}
37
38impl<T> Node<T> {
39    pub fn insert(&mut self, route: impl Into<String>, val: T) -> Result<(), InsertError> {
40        let route = route.into().into_bytes();
41        let mut prefix = route.as_ref();
42
43        self.priority += 1;
44
45        // empty tree
46        if self.prefix.is_empty() && self.children.is_empty() {
47            self.insert_child(prefix, &route, val)?;
48            self.node_type = NodeType::Root;
49            return Ok(());
50        }
51
52        let mut current = self;
53
54        'walk: loop {
55            // find the longest common prefix
56            //
57            // this also implies that the common prefix contains
58            // no ':' or '*', since the existing key can't contain
59            // those chars
60            let mut i = 0;
61            let max = min(prefix.len(), current.prefix.len());
62
63            while i < max && prefix[i] == current.prefix[i] {
64                i += 1;
65            }
66
67            // split edge
68            if i < current.prefix.len() {
69                let mut child = Self {
70                    prefix: current.prefix[i..].to_owned(),
71                    wild_child: current.wild_child,
72                    indices: current.indices.clone(),
73                    value: current.value.take(),
74                    priority: current.priority - 1,
75                    ..Self::default()
76                };
77
78                mem::swap(&mut current.children, &mut child.children);
79
80                current.children = vec![child];
81                current.indices = current.prefix[i..=i].to_owned();
82                current.prefix = prefix[..i].to_owned();
83                current.wild_child = false;
84            }
85
86            // make new node a child of this node
87            if prefix.len() > i {
88                prefix = &prefix[i..];
89
90                let first = prefix[0];
91
92                // `/` after param
93                if current.node_type == NodeType::Param
94                    && first == b'/'
95                    && current.children.len() == 1
96                {
97                    current = &mut current.children[0];
98                    current.priority += 1;
99
100                    continue 'walk;
101                }
102
103                // check if a child with the next path byte exists
104                for mut i in 0..current.indices.len() {
105                    if first == current.indices[i] {
106                        i = current.update_child_priority(i);
107                        current = &mut current.children[i];
108                        continue 'walk;
109                    }
110                }
111
112                if first != b':' && first != b'*' && current.node_type != NodeType::CatchAll {
113                    current.indices.push(first);
114                    let mut child = current.add_child(Self::default());
115                    child = current.update_child_priority(child);
116                    current = &mut current.children[child];
117                } else if current.wild_child {
118                    // inserting a wildcard node, check if it conflicts with the existing wildcard
119                    current = current.children.last_mut().unwrap();
120                    current.priority += 1;
121
122                    // check if the wildcard matches
123                    if prefix.len() >= current.prefix.len()
124                        && current.prefix == prefix[..current.prefix.len()]
125                        // adding a child to a catchall Node is not possible
126                        && current.node_type != NodeType::CatchAll
127                        // check for longer wildcard, e.g. :name and :names
128                        && (current.prefix.len() >= prefix.len()
129                            || prefix[current.prefix.len()] == b'/')
130                    {
131                        continue 'walk;
132                    }
133
134                    return Err(InsertError::conflict(&route, prefix, current));
135                }
136
137                return current.insert_child(prefix, &route, val);
138            }
139
140            // otherwise add value to current node
141            if current.value.is_some() {
142                return Err(InsertError::conflict(&route, prefix, current));
143            }
144
145            current.value = Some(UnsafeCell::new(val));
146
147            return Ok(());
148        }
149    }
150
151    // add a child node, keeping wildcards at the end
152    fn add_child(&mut self, child: Node<T>) -> usize {
153        let len = self.children.len();
154
155        if self.wild_child && len > 0 {
156            self.children.insert(len - 1, child);
157            len - 1
158        } else {
159            self.children.push(child);
160            len
161        }
162    }
163
164    // increments priority of the given child and reorders if necessary
165    // returns the new position (index) of the child
166    fn update_child_priority(&mut self, pos: usize) -> usize {
167        self.children[pos].priority += 1;
168        let priority = self.children[pos].priority;
169
170        // adjust position (move to front)
171        let mut new_pos = pos;
172        while new_pos > 0 && self.children[new_pos - 1].priority < priority {
173            // swap node positions
174            self.children.swap(new_pos - 1, new_pos);
175            new_pos -= 1;
176        }
177
178        // build new index list
179        if new_pos != pos {
180            self.indices = [
181                &self.indices[..new_pos],    // unchanged prefix, might be empty
182                &self.indices[pos..=pos],    // the index char we move
183                &self.indices[new_pos..pos], // rest without char at 'pos'
184                &self.indices[pos + 1..],
185            ]
186            .concat();
187        }
188
189        new_pos
190    }
191
192    fn insert_child(&mut self, mut prefix: &[u8], route: &[u8], val: T) -> Result<(), InsertError> {
193        let mut current = self;
194
195        loop {
196            // search for a wildcard segment
197            let (wildcard, wildcard_index) = match find_wildcard(prefix) {
198                (Some((w, i)), true) => (w, i),
199                // the wildcard name contains invalid characters (':' or '*')
200                (Some(..), false) => return Err(InsertError::TooManyParams),
201                // no wildcard, simply use the current node
202                (None, _) => {
203                    current.value = Some(UnsafeCell::new(val));
204                    current.prefix = prefix.to_owned();
205                    return Ok(());
206                }
207            };
208
209            // check if the wildcard has a name
210            if wildcard.len() < 2 {
211                return Err(InsertError::UnnamedParam);
212            }
213
214            // route parameter
215            if wildcard[0] == b':' {
216                // insert prefix before the current wildcard
217                if wildcard_index > 0 {
218                    current.prefix = prefix[..wildcard_index].to_owned();
219                    prefix = &prefix[wildcard_index..];
220                }
221
222                let child = Self {
223                    node_type: NodeType::Param,
224                    prefix: wildcard.to_owned(),
225                    ..Self::default()
226                };
227
228                let child = current.add_child(child);
229                current.wild_child = true;
230                current = &mut current.children[child];
231                current.priority += 1;
232
233                // if the route doesn't end with the wildcard, then there
234                // will be another non-wildcard subroute starting with '/'
235                if wildcard.len() < prefix.len() {
236                    prefix = &prefix[wildcard.len()..];
237                    let child = Self {
238                        priority: 1,
239                        ..Self::default()
240                    };
241
242                    let child = current.add_child(child);
243                    current = &mut current.children[child];
244                    continue;
245                }
246
247                // otherwise we're done. Insert the value in the new leaf
248                current.value = Some(UnsafeCell::new(val));
249                return Ok(());
250            }
251
252            // catch all route
253            assert_eq!(wildcard[0], b'*');
254
255            // "/foo/*catchall/bar"
256            if wildcard_index + wildcard.len() != prefix.len() {
257                return Err(InsertError::InvalidCatchAll);
258            }
259
260            if let Some(i) = wildcard_index.checked_sub(1) {
261                // "/foo/bar*catchall"
262                if prefix[i] != b'/' {
263                    return Err(InsertError::InvalidCatchAll);
264                }
265            }
266
267            // "*catchall"
268            if prefix == route && route[0] != b'/' {
269                return Err(InsertError::InvalidCatchAll);
270            }
271
272            if wildcard_index > 0 {
273                current.prefix = prefix[..wildcard_index].to_owned();
274                prefix = &prefix[wildcard_index..];
275            }
276
277            let child = Self {
278                prefix: prefix.to_owned(),
279                node_type: NodeType::CatchAll,
280                value: Some(UnsafeCell::new(val)),
281                priority: 1,
282                ..Self::default()
283            };
284
285            current.add_child(child);
286            current.wild_child = true;
287
288            return Ok(());
289        }
290    }
291}
292
293struct Skipped<'n, 'p, T> {
294    path: &'p [u8],
295    node: &'n Node<T>,
296    params: usize,
297}
298
299#[rustfmt::skip]
300macro_rules! backtracker {
301    ($skipped_nodes:ident, $path:ident, $current:ident, $params:ident, $backtracking:ident, $walk:lifetime) => {
302        macro_rules! try_backtrack {
303            () => {
304                // try backtracking to any matching wildcard nodes we skipped while traversing
305                // the tree
306                while let Some(skipped) = $skipped_nodes.pop() {
307                    if skipped.path.ends_with($path) {
308                        $path = skipped.path;
309                        $current = &skipped.node;
310                        $params.truncate(skipped.params);
311                        $backtracking = true;
312                        continue $walk;
313                    }
314                }
315            };
316        }
317    };
318}
319
320impl<T> Node<T> {
321    // It's a bit sad that we have to introduce unsafe here but rust doesn't really have a way
322    // to abstract over mutability, so UnsafeCell lets us avoid having to duplicate logic between
323    // `at` and `at_mut`.
324    pub fn at<'n, 'p>(
325        &'n self,
326        full_path: &'p [u8],
327    ) -> Result<(&'n UnsafeCell<T>, Params<'n, 'p>), MatchError> {
328        let mut current = self;
329        let mut path = full_path;
330        let mut backtracking = false;
331        let mut params = Params::new();
332        let mut skipped_nodes = Vec::new();
333
334        'walk: loop {
335            backtracker!(skipped_nodes, path, current, params, backtracking, 'walk);
336
337            // the path is longer than this node's prefix - we are expecting a child node
338            if path.len() > current.prefix.len() {
339                let (prefix, rest) = path.split_at(current.prefix.len());
340
341                // prefix matches
342                if prefix == current.prefix {
343                    let first = rest[0];
344                    let consumed = path;
345                    path = rest;
346
347                    // try searching for a matching static child unless we are currently
348                    // backtracking, which would mean we already traversed them
349                    if !backtracking {
350                        if let Some(i) = current.indices.iter().position(|&c| c == first) {
351                            // keep track of wildcard routes we skipped to backtrack to later if
352                            // we don't find a math
353                            if current.wild_child {
354                                skipped_nodes.push(Skipped {
355                                    path: consumed,
356                                    node: current,
357                                    params: params.len(),
358                                });
359                            }
360
361                            // child won't match because of an extra trailing slash
362                            if path == b"/"
363                                && current.children[i].prefix != b"/"
364                                && current.value.is_some()
365                            {
366                                return Err(MatchError::ExtraTrailingSlash);
367                            }
368
369                            // continue with the child node
370                            current = &current.children[i];
371                            continue 'walk;
372                        }
373                    }
374
375                    // we didn't find a match and there are no children with wildcards,
376                    // there is no match
377                    if !current.wild_child {
378                        // extra trailing slash
379                        if path == b"/" && current.value.is_some() {
380                            return Err(MatchError::ExtraTrailingSlash);
381                        }
382
383                        // try backtracking
384                        if path != b"/" {
385                            try_backtrack!();
386                        }
387
388                        // nothing found
389                        return Err(MatchError::NotFound);
390                    }
391
392                    // handle the wildcard child, which is always at the end of the list
393                    current = current.children.last().unwrap();
394
395                    match current.node_type {
396                        NodeType::Param => {
397                            // check if there are more segments in the path other than this parameter
398                            match path.iter().position(|&c| c == b'/') {
399                                Some(i) => {
400                                    let (param, rest) = path.split_at(i);
401
402                                    if let [child] = current.children.as_slice() {
403                                        // child won't match because of an extra trailing slash
404                                        if rest == b"/"
405                                            && child.prefix != b"/"
406                                            && current.value.is_some()
407                                        {
408                                            return Err(MatchError::ExtraTrailingSlash);
409                                        }
410
411                                        // store the parameter value
412                                        params.push(&current.prefix[1..], param);
413
414                                        // continue with the child node
415                                        path = rest;
416                                        current = child;
417                                        backtracking = false;
418                                        continue 'walk;
419                                    }
420
421                                    // this node has no children yet the path has more segments...
422                                    // either the path has an extra trailing slash or there is no match
423                                    if path.len() == i + 1 {
424                                        return Err(MatchError::ExtraTrailingSlash);
425                                    }
426
427                                    return Err(MatchError::NotFound);
428                                }
429                                // this is the last path segment
430                                None => {
431                                    // store the parameter value
432                                    params.push(&current.prefix[1..], path);
433
434                                    // found the matching value
435                                    if let Some(ref value) = current.value {
436                                        return Ok((value, params));
437                                    }
438
439                                    // check the child node in case the path is missing a trailing slash
440                                    if let [child] = current.children.as_slice() {
441                                        current = child;
442
443                                        if (current.prefix == b"/" && current.value.is_some())
444                                            || (current.prefix.is_empty()
445                                                && current.indices == b"/")
446                                        {
447                                            return Err(MatchError::MissingTrailingSlash);
448                                        }
449
450                                        // no match, try backtracking
451                                        if path != b"/" {
452                                            try_backtrack!();
453                                        }
454                                    }
455
456                                    // this node doesn't have the value, no match
457                                    return Err(MatchError::NotFound);
458                                }
459                            }
460                        }
461                        NodeType::CatchAll => {
462                            // catch all segments are only allowed at the end of the route,
463                            // either this node has the value or there is no match
464                            return match current.value {
465                                Some(ref value) => {
466                                    params.push(&current.prefix[1..], path);
467                                    Ok((value, params))
468                                }
469                                None => Err(MatchError::NotFound),
470                            };
471                        }
472                        _ => unreachable!(),
473                    }
474                }
475            }
476
477            // this is it, we should have reached the node containing the value
478            if path == current.prefix {
479                if let Some(ref value) = current.value {
480                    return Ok((value, params));
481                }
482
483                // nope, try backtracking
484                if path != b"/" {
485                    try_backtrack!();
486                }
487
488                // TODO: does this always means there is an extra trailing slash?
489                if path == b"/" && current.wild_child && current.node_type != NodeType::Root {
490                    return Err(MatchError::unsure(full_path));
491                }
492
493                if !backtracking {
494                    // check if the path is missing a trailing slash
495                    if let Some(i) = current.indices.iter().position(|&c| c == b'/') {
496                        current = &current.children[i];
497
498                        if current.prefix.len() == 1 && current.value.is_some() {
499                            return Err(MatchError::MissingTrailingSlash);
500                        }
501                    }
502                }
503
504                return Err(MatchError::NotFound);
505            }
506
507            // nothing matches, check for a missing trailing slash
508            if current.prefix.split_last() == Some((&b'/', path)) && current.value.is_some() {
509                return Err(MatchError::MissingTrailingSlash);
510            }
511
512            // last chance, try backtracking
513            if path != b"/" {
514                try_backtrack!();
515            }
516
517            return Err(MatchError::NotFound);
518        }
519    }
520
521    #[cfg(feature = "__test_helpers")]
522    pub fn check_priorities(&self) -> Result<u32, (u32, u32)> {
523        let mut priority: u32 = 0;
524        for child in &self.children {
525            priority += child.check_priorities()?;
526        }
527
528        if self.value.is_some() {
529            priority += 1;
530        }
531
532        if self.priority != priority {
533            return Err((self.priority, priority));
534        }
535
536        Ok(priority)
537    }
538}
539
540// Searches for a wildcard segment and checks the path for invalid characters.
541fn find_wildcard(path: &[u8]) -> (Option<(&[u8], usize)>, bool) {
542    for (start, &c) in path.iter().enumerate() {
543        // a wildcard starts with ':' (param) or '*' (catch-all)
544        if c != b':' && c != b'*' {
545            continue;
546        };
547
548        // find end and check for invalid characters
549        let mut valid = true;
550
551        for (end, &c) in path[start + 1..].iter().enumerate() {
552            match c {
553                b'/' => return (Some((&path[start..start + 1 + end], start)), valid),
554                b':' | b'*' => valid = false,
555                _ => (),
556            };
557        }
558
559        return (Some((&path[start..], start)), valid);
560    }
561
562    (None, false)
563}
564
565impl<T> Clone for Node<T>
566where
567    T: Clone,
568{
569    fn clone(&self) -> Self {
570        let value = match self.value {
571            Some(ref value) => {
572                // safety: we only expose &mut T through &mut self
573                let value = unsafe { &*value.get() };
574                Some(UnsafeCell::new(value.clone()))
575            }
576            None => None,
577        };
578
579        Self {
580            value,
581            prefix: self.prefix.clone(),
582            wild_child: self.wild_child,
583            node_type: self.node_type,
584            indices: self.indices.clone(),
585            children: self.children.clone(),
586            priority: self.priority,
587        }
588    }
589}
590
591impl<T> Default for Node<T> {
592    fn default() -> Self {
593        Self {
594            prefix: Vec::new(),
595            wild_child: false,
596            node_type: NodeType::Static,
597            indices: Vec::new(),
598            children: Vec::new(),
599            value: None,
600            priority: 0,
601        }
602    }
603}
604
605// visualize the tree structure when debugging
606#[cfg(test)]
607const _: () = {
608    use std::fmt::{self, Debug, Formatter};
609
610    impl<T: Debug> Debug for Node<T> {
611        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
612            // safety: we only expose &mut T through &mut self
613            let value = unsafe { self.value.as_ref().map(|x| &*x.get()) };
614
615            let indices = self
616                .indices
617                .iter()
618                .map(|&x| char::from_u32(x as _))
619                .collect::<Vec<_>>();
620
621            let mut fmt = f.debug_struct("Node");
622            fmt.field("value", &value);
623            fmt.field("prefix", &std::str::from_utf8(&self.prefix));
624            fmt.field("node_type", &self.node_type);
625            fmt.field("children", &self.children);
626            fmt.field("indices", &indices);
627            fmt.finish()
628        }
629    }
630};