axum/routing/
strip_prefix.rs

1use http::{Request, Uri};
2use std::{
3    sync::Arc,
4    task::{Context, Poll},
5};
6use tower::Layer;
7use tower_layer::layer_fn;
8use tower_service::Service;
9
10#[derive(Clone)]
11pub(super) struct StripPrefix<S> {
12    inner: S,
13    prefix: Arc<str>,
14}
15
16impl<S> StripPrefix<S> {
17    pub(super) fn layer(prefix: &str) -> impl Layer<S, Service = Self> + Clone {
18        let prefix = Arc::from(prefix);
19        layer_fn(move |inner| Self {
20            inner,
21            prefix: Arc::clone(&prefix),
22        })
23    }
24}
25
26impl<S, B> Service<Request<B>> for StripPrefix<S>
27where
28    S: Service<Request<B>>,
29{
30    type Response = S::Response;
31    type Error = S::Error;
32    type Future = S::Future;
33
34    #[inline]
35    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
36        self.inner.poll_ready(cx)
37    }
38
39    fn call(&mut self, mut req: Request<B>) -> Self::Future {
40        if let Some(new_uri) = strip_prefix(req.uri(), &self.prefix) {
41            *req.uri_mut() = new_uri;
42        }
43        self.inner.call(req)
44    }
45}
46
47fn strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri> {
48    let path_and_query = uri.path_and_query()?;
49
50    // Check whether the prefix matches the path and if so how long the matching prefix is.
51    //
52    // For example:
53    //
54    // prefix = /api
55    // path   = /api/users
56    //          ^^^^ this much is matched and the length is 4. Thus if we chop off the first 4
57    //          characters we get the remainder
58    //
59    // prefix = /api/{version}
60    // path   = /api/v0/users
61    //          ^^^^^^^ this much is matched and the length is 7.
62    let mut matching_prefix_length = Some(0);
63    for item in zip_longest(segments(path_and_query.path()), segments(prefix)) {
64        // count the `/`
65        *matching_prefix_length.as_mut().unwrap() += 1;
66
67        match item {
68            Item::Both(path_segment, prefix_segment) => {
69                if is_capture(prefix_segment) || path_segment == prefix_segment {
70                    // the prefix segment is either a param, which matches anything, or
71                    // it actually matches the path segment
72                    *matching_prefix_length.as_mut().unwrap() += path_segment.len();
73                } else if prefix_segment.is_empty() {
74                    // the prefix ended in a `/` so we got a match.
75                    //
76                    // For example:
77                    //
78                    // prefix = /foo/
79                    // path   = /foo/bar
80                    //
81                    // The prefix matches and the new path should be `/bar`
82                    break;
83                } else {
84                    // the prefix segment didn't match so there is no match
85                    matching_prefix_length = None;
86                    break;
87                }
88            }
89            // the path had more segments than the prefix but we got a match.
90            //
91            // For example:
92            //
93            // prefix = /foo
94            // path   = /foo/bar
95            Item::First(_) => {
96                break;
97            }
98            // the prefix had more segments than the path so there is no match
99            Item::Second(_) => {
100                matching_prefix_length = None;
101                break;
102            }
103        }
104    }
105
106    // if the prefix matches it will always do so up until a `/`, it cannot match only
107    // part of a segment. Therefore this will always be at a char boundary and `split_at` won't
108    // panic
109    let after_prefix = uri.path().split_at(matching_prefix_length?).1;
110
111    let new_path_and_query = match (after_prefix.starts_with('/'), path_and_query.query()) {
112        (true, None) => after_prefix.parse().unwrap(),
113        (true, Some(query)) => format!("{after_prefix}?{query}").parse().unwrap(),
114        (false, None) => format!("/{after_prefix}").parse().unwrap(),
115        (false, Some(query)) => format!("/{after_prefix}?{query}").parse().unwrap(),
116    };
117
118    let mut parts = uri.clone().into_parts();
119    parts.path_and_query = Some(new_path_and_query);
120
121    Some(Uri::from_parts(parts).unwrap())
122}
123
124fn segments(s: &str) -> impl Iterator<Item = &str> {
125    assert!(
126        s.starts_with('/'),
127        "path didn't start with '/'. axum should have caught this higher up."
128    );
129
130    s.split('/')
131        // skip one because paths always start with `/` so `/a/b` would become ["", "a", "b"]
132        // otherwise
133        .skip(1)
134}
135
136fn zip_longest<I, I2>(a: I, b: I2) -> impl Iterator<Item = Item<I::Item>>
137where
138    I: Iterator,
139    I2: Iterator<Item = I::Item>,
140{
141    let a = a.map(Some).chain(std::iter::repeat_with(|| None));
142    let b = b.map(Some).chain(std::iter::repeat_with(|| None));
143    a.zip(b).map_while(|(a, b)| match (a, b) {
144        (Some(a), Some(b)) => Some(Item::Both(a, b)),
145        (Some(a), None) => Some(Item::First(a)),
146        (None, Some(b)) => Some(Item::Second(b)),
147        (None, None) => None,
148    })
149}
150
151fn is_capture(segment: &str) -> bool {
152    segment.starts_with('{')
153        && segment.ends_with('}')
154        && !segment.starts_with("{{")
155        && !segment.ends_with("}}")
156        && !segment.starts_with("{*")
157}
158
159#[derive(Debug)]
160enum Item<T> {
161    Both(T, T),
162    First(T),
163    Second(T),
164}
165
166#[cfg(test)]
167mod tests {
168    #[allow(unused_imports)]
169    use super::*;
170    use quickcheck::Arbitrary;
171    use quickcheck_macros::quickcheck;
172
173    macro_rules! test {
174        (
175            $name:ident,
176            uri = $uri:literal,
177            prefix = $prefix:literal,
178            expected = $expected:expr,
179        ) => {
180            #[test]
181            fn $name() {
182                let uri = $uri.parse().unwrap();
183                let new_uri = strip_prefix(&uri, $prefix).map(|uri| uri.to_string());
184                assert_eq!(new_uri.as_deref(), $expected);
185            }
186        };
187    }
188
189    test!(empty, uri = "/", prefix = "/", expected = Some("/"),);
190
191    test!(
192        single_segment,
193        uri = "/a",
194        prefix = "/a",
195        expected = Some("/"),
196    );
197
198    test!(
199        single_segment_root_uri,
200        uri = "/",
201        prefix = "/a",
202        expected = None,
203    );
204
205    // the prefix is empty, so removing it should have no effect
206    test!(
207        single_segment_root_prefix,
208        uri = "/a",
209        prefix = "/",
210        expected = Some("/a"),
211    );
212
213    test!(
214        single_segment_no_match,
215        uri = "/a",
216        prefix = "/b",
217        expected = None,
218    );
219
220    test!(
221        single_segment_trailing_slash,
222        uri = "/a/",
223        prefix = "/a/",
224        expected = Some("/"),
225    );
226
227    test!(
228        single_segment_trailing_slash_2,
229        uri = "/a",
230        prefix = "/a/",
231        expected = None,
232    );
233
234    test!(
235        single_segment_trailing_slash_3,
236        uri = "/a/",
237        prefix = "/a",
238        expected = Some("/"),
239    );
240
241    test!(
242        multi_segment,
243        uri = "/a/b",
244        prefix = "/a",
245        expected = Some("/b"),
246    );
247
248    test!(
249        multi_segment_2,
250        uri = "/b/a",
251        prefix = "/a",
252        expected = None,
253    );
254
255    test!(
256        multi_segment_3,
257        uri = "/a",
258        prefix = "/a/b",
259        expected = None,
260    );
261
262    test!(
263        multi_segment_4,
264        uri = "/a/b",
265        prefix = "/b",
266        expected = None,
267    );
268
269    test!(
270        multi_segment_trailing_slash,
271        uri = "/a/b/",
272        prefix = "/a/b/",
273        expected = Some("/"),
274    );
275
276    test!(
277        multi_segment_trailing_slash_2,
278        uri = "/a/b",
279        prefix = "/a/b/",
280        expected = None,
281    );
282
283    test!(
284        multi_segment_trailing_slash_3,
285        uri = "/a/b/",
286        prefix = "/a/b",
287        expected = Some("/"),
288    );
289
290    test!(
291        param_0,
292        uri = "/",
293        prefix = "/{param}",
294        expected = Some("/"),
295    );
296
297    test!(
298        param_1,
299        uri = "/a",
300        prefix = "/{param}",
301        expected = Some("/"),
302    );
303
304    test!(
305        param_2,
306        uri = "/a/b",
307        prefix = "/{param}",
308        expected = Some("/b"),
309    );
310
311    test!(
312        param_3,
313        uri = "/b/a",
314        prefix = "/{param}",
315        expected = Some("/a"),
316    );
317
318    test!(
319        param_4,
320        uri = "/a/b",
321        prefix = "/a/{param}",
322        expected = Some("/"),
323    );
324
325    test!(
326        param_5,
327        uri = "/b/a",
328        prefix = "/a/{param}",
329        expected = None,
330    );
331
332    test!(
333        param_6,
334        uri = "/a/b",
335        prefix = "/{param}/a",
336        expected = None,
337    );
338
339    test!(
340        param_7,
341        uri = "/b/a",
342        prefix = "/{param}/a",
343        expected = Some("/"),
344    );
345
346    test!(
347        param_8,
348        uri = "/a/b/c",
349        prefix = "/a/{param}/c",
350        expected = Some("/"),
351    );
352
353    test!(
354        param_9,
355        uri = "/c/b/a",
356        prefix = "/a/{param}/c",
357        expected = None,
358    );
359
360    test!(
361        param_10,
362        uri = "/a/",
363        prefix = "/{param}",
364        expected = Some("/"),
365    );
366
367    test!(param_11, uri = "/a", prefix = "/{param}/", expected = None,);
368
369    test!(
370        param_12,
371        uri = "/a/",
372        prefix = "/{param}/",
373        expected = Some("/"),
374    );
375
376    test!(
377        param_13,
378        uri = "/a/a",
379        prefix = "/a/",
380        expected = Some("/a"),
381    );
382
383    #[quickcheck]
384    fn does_not_panic(uri_and_prefix: UriAndPrefix) -> bool {
385        let UriAndPrefix { uri, prefix } = uri_and_prefix;
386        strip_prefix(&uri, &prefix);
387        true
388    }
389
390    #[derive(Clone, Debug)]
391    struct UriAndPrefix {
392        uri: Uri,
393        prefix: String,
394    }
395
396    impl Arbitrary for UriAndPrefix {
397        fn arbitrary(g: &mut quickcheck::Gen) -> Self {
398            let mut uri = String::new();
399            let mut prefix = String::new();
400
401            let size = u8_between(1, 20, g);
402
403            for _ in 0..size {
404                let segment = ascii_alphanumeric(g);
405
406                uri.push('/');
407                uri.push_str(&segment);
408
409                prefix.push('/');
410
411                let make_matching_segment = bool::arbitrary(g);
412                let make_capture = bool::arbitrary(g);
413
414                match (make_matching_segment, make_capture) {
415                    (_, true) => {
416                        prefix.push_str(":a");
417                    }
418                    (true, false) => {
419                        prefix.push_str(&segment);
420                    }
421                    (false, false) => {
422                        prefix.push_str(&ascii_alphanumeric(g));
423                    }
424                }
425            }
426
427            if bool::arbitrary(g) {
428                uri.push('/');
429            }
430
431            if bool::arbitrary(g) {
432                prefix.push('/');
433            }
434
435            Self {
436                uri: uri.parse().unwrap(),
437                prefix,
438            }
439        }
440    }
441
442    fn ascii_alphanumeric(g: &mut quickcheck::Gen) -> String {
443        #[derive(Clone)]
444        struct AsciiAlphanumeric(String);
445
446        impl Arbitrary for AsciiAlphanumeric {
447            fn arbitrary(g: &mut quickcheck::Gen) -> Self {
448                let mut out = String::new();
449
450                let size = u8_between(1, 20, g) as usize;
451
452                while out.len() < size {
453                    let c = char::arbitrary(g);
454                    if c.is_ascii_alphanumeric() {
455                        out.push(c);
456                    }
457                }
458                Self(out)
459            }
460        }
461
462        let out = AsciiAlphanumeric::arbitrary(g).0;
463        assert!(!out.is_empty());
464        out
465    }
466
467    fn u8_between(lower: u8, upper: u8, g: &mut quickcheck::Gen) -> u8 {
468        loop {
469            let size = u8::arbitrary(g);
470            if size > lower && size <= upper {
471                break size;
472            }
473        }
474    }
475}