1use crate::{InsertError, MatchError, Params};
2
3use std::cell::UnsafeCell;
4use std::cmp::min;
5use std::mem;
6
7#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Copy)]
9enum NodeType {
10 Root,
12 Param,
14 CatchAll,
16 Static,
18}
19
20pub struct Node<T> {
24 priority: u32,
25 wild_child: bool,
26 indices: Vec<u8>,
27 node_type: NodeType,
28 value: Option<UnsafeCell<T>>,
30 pub(crate) prefix: Vec<u8>,
31 pub(crate) children: Vec<Self>,
32}
33
34unsafe impl<T: Send> Send for Node<T> {}
36unsafe impl<T: Sync> Sync for Node<T> {}
37
38impl<T> Node<T> {
39 pub fn insert(&mut self, route: impl Into<String>, val: T) -> Result<(), InsertError> {
40 let route = route.into().into_bytes();
41 let mut prefix = route.as_ref();
42
43 self.priority += 1;
44
45 if self.prefix.is_empty() && self.children.is_empty() {
47 self.insert_child(prefix, &route, val)?;
48 self.node_type = NodeType::Root;
49 return Ok(());
50 }
51
52 let mut current = self;
53
54 'walk: loop {
55 let mut i = 0;
61 let max = min(prefix.len(), current.prefix.len());
62
63 while i < max && prefix[i] == current.prefix[i] {
64 i += 1;
65 }
66
67 if i < current.prefix.len() {
69 let mut child = Self {
70 prefix: current.prefix[i..].to_owned(),
71 wild_child: current.wild_child,
72 indices: current.indices.clone(),
73 value: current.value.take(),
74 priority: current.priority - 1,
75 ..Self::default()
76 };
77
78 mem::swap(&mut current.children, &mut child.children);
79
80 current.children = vec![child];
81 current.indices = current.prefix[i..=i].to_owned();
82 current.prefix = prefix[..i].to_owned();
83 current.wild_child = false;
84 }
85
86 if prefix.len() > i {
88 prefix = &prefix[i..];
89
90 let first = prefix[0];
91
92 if current.node_type == NodeType::Param
94 && first == b'/'
95 && current.children.len() == 1
96 {
97 current = &mut current.children[0];
98 current.priority += 1;
99
100 continue 'walk;
101 }
102
103 for mut i in 0..current.indices.len() {
105 if first == current.indices[i] {
106 i = current.update_child_priority(i);
107 current = &mut current.children[i];
108 continue 'walk;
109 }
110 }
111
112 if first != b':' && first != b'*' && current.node_type != NodeType::CatchAll {
113 current.indices.push(first);
114 let mut child = current.add_child(Self::default());
115 child = current.update_child_priority(child);
116 current = &mut current.children[child];
117 } else if current.wild_child {
118 current = current.children.last_mut().unwrap();
120 current.priority += 1;
121
122 if prefix.len() >= current.prefix.len()
124 && current.prefix == prefix[..current.prefix.len()]
125 && current.node_type != NodeType::CatchAll
127 && (current.prefix.len() >= prefix.len()
129 || prefix[current.prefix.len()] == b'/')
130 {
131 continue 'walk;
132 }
133
134 return Err(InsertError::conflict(&route, prefix, current));
135 }
136
137 return current.insert_child(prefix, &route, val);
138 }
139
140 if current.value.is_some() {
142 return Err(InsertError::conflict(&route, prefix, current));
143 }
144
145 current.value = Some(UnsafeCell::new(val));
146
147 return Ok(());
148 }
149 }
150
151 fn add_child(&mut self, child: Node<T>) -> usize {
153 let len = self.children.len();
154
155 if self.wild_child && len > 0 {
156 self.children.insert(len - 1, child);
157 len - 1
158 } else {
159 self.children.push(child);
160 len
161 }
162 }
163
164 fn update_child_priority(&mut self, pos: usize) -> usize {
167 self.children[pos].priority += 1;
168 let priority = self.children[pos].priority;
169
170 let mut new_pos = pos;
172 while new_pos > 0 && self.children[new_pos - 1].priority < priority {
173 self.children.swap(new_pos - 1, new_pos);
175 new_pos -= 1;
176 }
177
178 if new_pos != pos {
180 self.indices = [
181 &self.indices[..new_pos], &self.indices[pos..=pos], &self.indices[new_pos..pos], &self.indices[pos + 1..],
185 ]
186 .concat();
187 }
188
189 new_pos
190 }
191
192 fn insert_child(&mut self, mut prefix: &[u8], route: &[u8], val: T) -> Result<(), InsertError> {
193 let mut current = self;
194
195 loop {
196 let (wildcard, wildcard_index) = match find_wildcard(prefix) {
198 (Some((w, i)), true) => (w, i),
199 (Some(..), false) => return Err(InsertError::TooManyParams),
201 (None, _) => {
203 current.value = Some(UnsafeCell::new(val));
204 current.prefix = prefix.to_owned();
205 return Ok(());
206 }
207 };
208
209 if wildcard.len() < 2 {
211 return Err(InsertError::UnnamedParam);
212 }
213
214 if wildcard[0] == b':' {
216 if wildcard_index > 0 {
218 current.prefix = prefix[..wildcard_index].to_owned();
219 prefix = &prefix[wildcard_index..];
220 }
221
222 let child = Self {
223 node_type: NodeType::Param,
224 prefix: wildcard.to_owned(),
225 ..Self::default()
226 };
227
228 let child = current.add_child(child);
229 current.wild_child = true;
230 current = &mut current.children[child];
231 current.priority += 1;
232
233 if wildcard.len() < prefix.len() {
236 prefix = &prefix[wildcard.len()..];
237 let child = Self {
238 priority: 1,
239 ..Self::default()
240 };
241
242 let child = current.add_child(child);
243 current = &mut current.children[child];
244 continue;
245 }
246
247 current.value = Some(UnsafeCell::new(val));
249 return Ok(());
250 }
251
252 assert_eq!(wildcard[0], b'*');
254
255 if wildcard_index + wildcard.len() != prefix.len() {
257 return Err(InsertError::InvalidCatchAll);
258 }
259
260 if let Some(i) = wildcard_index.checked_sub(1) {
261 if prefix[i] != b'/' {
263 return Err(InsertError::InvalidCatchAll);
264 }
265 }
266
267 if prefix == route && route[0] != b'/' {
269 return Err(InsertError::InvalidCatchAll);
270 }
271
272 if wildcard_index > 0 {
273 current.prefix = prefix[..wildcard_index].to_owned();
274 prefix = &prefix[wildcard_index..];
275 }
276
277 let child = Self {
278 prefix: prefix.to_owned(),
279 node_type: NodeType::CatchAll,
280 value: Some(UnsafeCell::new(val)),
281 priority: 1,
282 ..Self::default()
283 };
284
285 current.add_child(child);
286 current.wild_child = true;
287
288 return Ok(());
289 }
290 }
291}
292
293struct Skipped<'n, 'p, T> {
294 path: &'p [u8],
295 node: &'n Node<T>,
296 params: usize,
297}
298
299#[rustfmt::skip]
300macro_rules! backtracker {
301 ($skipped_nodes:ident, $path:ident, $current:ident, $params:ident, $backtracking:ident, $walk:lifetime) => {
302 macro_rules! try_backtrack {
303 () => {
304 while let Some(skipped) = $skipped_nodes.pop() {
307 if skipped.path.ends_with($path) {
308 $path = skipped.path;
309 $current = &skipped.node;
310 $params.truncate(skipped.params);
311 $backtracking = true;
312 continue $walk;
313 }
314 }
315 };
316 }
317 };
318}
319
320impl<T> Node<T> {
321 pub fn at<'n, 'p>(
325 &'n self,
326 full_path: &'p [u8],
327 ) -> Result<(&'n UnsafeCell<T>, Params<'n, 'p>), MatchError> {
328 let mut current = self;
329 let mut path = full_path;
330 let mut backtracking = false;
331 let mut params = Params::new();
332 let mut skipped_nodes = Vec::new();
333
334 'walk: loop {
335 backtracker!(skipped_nodes, path, current, params, backtracking, 'walk);
336
337 if path.len() > current.prefix.len() {
339 let (prefix, rest) = path.split_at(current.prefix.len());
340
341 if prefix == current.prefix {
343 let first = rest[0];
344 let consumed = path;
345 path = rest;
346
347 if !backtracking {
350 if let Some(i) = current.indices.iter().position(|&c| c == first) {
351 if current.wild_child {
354 skipped_nodes.push(Skipped {
355 path: consumed,
356 node: current,
357 params: params.len(),
358 });
359 }
360
361 if path == b"/"
363 && current.children[i].prefix != b"/"
364 && current.value.is_some()
365 {
366 return Err(MatchError::ExtraTrailingSlash);
367 }
368
369 current = ¤t.children[i];
371 continue 'walk;
372 }
373 }
374
375 if !current.wild_child {
378 if path == b"/" && current.value.is_some() {
380 return Err(MatchError::ExtraTrailingSlash);
381 }
382
383 if path != b"/" {
385 try_backtrack!();
386 }
387
388 return Err(MatchError::NotFound);
390 }
391
392 current = current.children.last().unwrap();
394
395 match current.node_type {
396 NodeType::Param => {
397 match path.iter().position(|&c| c == b'/') {
399 Some(i) => {
400 let (param, rest) = path.split_at(i);
401
402 if let [child] = current.children.as_slice() {
403 if rest == b"/"
405 && child.prefix != b"/"
406 && current.value.is_some()
407 {
408 return Err(MatchError::ExtraTrailingSlash);
409 }
410
411 params.push(¤t.prefix[1..], param);
413
414 path = rest;
416 current = child;
417 backtracking = false;
418 continue 'walk;
419 }
420
421 if path.len() == i + 1 {
424 return Err(MatchError::ExtraTrailingSlash);
425 }
426
427 return Err(MatchError::NotFound);
428 }
429 None => {
431 params.push(¤t.prefix[1..], path);
433
434 if let Some(ref value) = current.value {
436 return Ok((value, params));
437 }
438
439 if let [child] = current.children.as_slice() {
441 current = child;
442
443 if (current.prefix == b"/" && current.value.is_some())
444 || (current.prefix.is_empty()
445 && current.indices == b"/")
446 {
447 return Err(MatchError::MissingTrailingSlash);
448 }
449
450 if path != b"/" {
452 try_backtrack!();
453 }
454 }
455
456 return Err(MatchError::NotFound);
458 }
459 }
460 }
461 NodeType::CatchAll => {
462 return match current.value {
465 Some(ref value) => {
466 params.push(¤t.prefix[1..], path);
467 Ok((value, params))
468 }
469 None => Err(MatchError::NotFound),
470 };
471 }
472 _ => unreachable!(),
473 }
474 }
475 }
476
477 if path == current.prefix {
479 if let Some(ref value) = current.value {
480 return Ok((value, params));
481 }
482
483 if path != b"/" {
485 try_backtrack!();
486 }
487
488 if path == b"/" && current.wild_child && current.node_type != NodeType::Root {
490 return Err(MatchError::unsure(full_path));
491 }
492
493 if !backtracking {
494 if let Some(i) = current.indices.iter().position(|&c| c == b'/') {
496 current = ¤t.children[i];
497
498 if current.prefix.len() == 1 && current.value.is_some() {
499 return Err(MatchError::MissingTrailingSlash);
500 }
501 }
502 }
503
504 return Err(MatchError::NotFound);
505 }
506
507 if current.prefix.split_last() == Some((&b'/', path)) && current.value.is_some() {
509 return Err(MatchError::MissingTrailingSlash);
510 }
511
512 if path != b"/" {
514 try_backtrack!();
515 }
516
517 return Err(MatchError::NotFound);
518 }
519 }
520
521 #[cfg(feature = "__test_helpers")]
522 pub fn check_priorities(&self) -> Result<u32, (u32, u32)> {
523 let mut priority: u32 = 0;
524 for child in &self.children {
525 priority += child.check_priorities()?;
526 }
527
528 if self.value.is_some() {
529 priority += 1;
530 }
531
532 if self.priority != priority {
533 return Err((self.priority, priority));
534 }
535
536 Ok(priority)
537 }
538}
539
540fn find_wildcard(path: &[u8]) -> (Option<(&[u8], usize)>, bool) {
542 for (start, &c) in path.iter().enumerate() {
543 if c != b':' && c != b'*' {
545 continue;
546 };
547
548 let mut valid = true;
550
551 for (end, &c) in path[start + 1..].iter().enumerate() {
552 match c {
553 b'/' => return (Some((&path[start..start + 1 + end], start)), valid),
554 b':' | b'*' => valid = false,
555 _ => (),
556 };
557 }
558
559 return (Some((&path[start..], start)), valid);
560 }
561
562 (None, false)
563}
564
565impl<T> Clone for Node<T>
566where
567 T: Clone,
568{
569 fn clone(&self) -> Self {
570 let value = match self.value {
571 Some(ref value) => {
572 let value = unsafe { &*value.get() };
574 Some(UnsafeCell::new(value.clone()))
575 }
576 None => None,
577 };
578
579 Self {
580 value,
581 prefix: self.prefix.clone(),
582 wild_child: self.wild_child,
583 node_type: self.node_type,
584 indices: self.indices.clone(),
585 children: self.children.clone(),
586 priority: self.priority,
587 }
588 }
589}
590
591impl<T> Default for Node<T> {
592 fn default() -> Self {
593 Self {
594 prefix: Vec::new(),
595 wild_child: false,
596 node_type: NodeType::Static,
597 indices: Vec::new(),
598 children: Vec::new(),
599 value: None,
600 priority: 0,
601 }
602 }
603}
604
605#[cfg(test)]
607const _: () = {
608 use std::fmt::{self, Debug, Formatter};
609
610 impl<T: Debug> Debug for Node<T> {
611 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
612 let value = unsafe { self.value.as_ref().map(|x| &*x.get()) };
614
615 let indices = self
616 .indices
617 .iter()
618 .map(|&x| char::from_u32(x as _))
619 .collect::<Vec<_>>();
620
621 let mut fmt = f.debug_struct("Node");
622 fmt.field("value", &value);
623 fmt.field("prefix", &std::str::from_utf8(&self.prefix));
624 fmt.field("node_type", &self.node_type);
625 fmt.field("children", &self.children);
626 fmt.field("indices", &indices);
627 fmt.finish()
628 }
629 }
630};