1use http::{Request, Uri};
2use std::{
3 sync::Arc,
4 task::{Context, Poll},
5};
6use tower::Layer;
7use tower_layer::layer_fn;
8use tower_service::Service;
9
10#[derive(Clone)]
11pub(super) struct StripPrefix<S> {
12 inner: S,
13 prefix: Arc<str>,
14}
15
16impl<S> StripPrefix<S> {
17 pub(super) fn layer(prefix: &str) -> impl Layer<S, Service = Self> + Clone {
18 let prefix = Arc::from(prefix);
19 layer_fn(move |inner| Self {
20 inner,
21 prefix: Arc::clone(&prefix),
22 })
23 }
24}
25
26impl<S, B> Service<Request<B>> for StripPrefix<S>
27where
28 S: Service<Request<B>>,
29{
30 type Response = S::Response;
31 type Error = S::Error;
32 type Future = S::Future;
33
34 #[inline]
35 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
36 self.inner.poll_ready(cx)
37 }
38
39 fn call(&mut self, mut req: Request<B>) -> Self::Future {
40 if let Some(new_uri) = strip_prefix(req.uri(), &self.prefix) {
41 *req.uri_mut() = new_uri;
42 }
43 self.inner.call(req)
44 }
45}
46
47fn strip_prefix(uri: &Uri, prefix: &str) -> Option<Uri> {
48 let path_and_query = uri.path_and_query()?;
49
50 let mut matching_prefix_length = Some(0);
63 for item in zip_longest(segments(path_and_query.path()), segments(prefix)) {
64 *matching_prefix_length.as_mut().unwrap() += 1;
66
67 match item {
68 Item::Both(path_segment, prefix_segment) => {
69 if is_capture(prefix_segment) || path_segment == prefix_segment {
70 *matching_prefix_length.as_mut().unwrap() += path_segment.len();
73 } else if prefix_segment.is_empty() {
74 break;
83 } else {
84 matching_prefix_length = None;
86 break;
87 }
88 }
89 Item::First(_) => {
96 break;
97 }
98 Item::Second(_) => {
100 matching_prefix_length = None;
101 break;
102 }
103 }
104 }
105
106 let after_prefix = uri.path().split_at(matching_prefix_length?).1;
110
111 let new_path_and_query = match (after_prefix.starts_with('/'), path_and_query.query()) {
112 (true, None) => after_prefix.parse().unwrap(),
113 (true, Some(query)) => format!("{after_prefix}?{query}").parse().unwrap(),
114 (false, None) => format!("/{after_prefix}").parse().unwrap(),
115 (false, Some(query)) => format!("/{after_prefix}?{query}").parse().unwrap(),
116 };
117
118 let mut parts = uri.clone().into_parts();
119 parts.path_and_query = Some(new_path_and_query);
120
121 Some(Uri::from_parts(parts).unwrap())
122}
123
124fn segments(s: &str) -> impl Iterator<Item = &str> {
125 assert!(
126 s.starts_with('/'),
127 "path didn't start with '/'. axum should have caught this higher up."
128 );
129
130 s.split('/')
131 .skip(1)
134}
135
136fn zip_longest<I, I2>(a: I, b: I2) -> impl Iterator<Item = Item<I::Item>>
137where
138 I: Iterator,
139 I2: Iterator<Item = I::Item>,
140{
141 let a = a.map(Some).chain(std::iter::repeat_with(|| None));
142 let b = b.map(Some).chain(std::iter::repeat_with(|| None));
143 a.zip(b).map_while(|(a, b)| match (a, b) {
144 (Some(a), Some(b)) => Some(Item::Both(a, b)),
145 (Some(a), None) => Some(Item::First(a)),
146 (None, Some(b)) => Some(Item::Second(b)),
147 (None, None) => None,
148 })
149}
150
151fn is_capture(segment: &str) -> bool {
152 segment.starts_with('{')
153 && segment.ends_with('}')
154 && !segment.starts_with("{{")
155 && !segment.ends_with("}}")
156 && !segment.starts_with("{*")
157}
158
159#[derive(Debug)]
160enum Item<T> {
161 Both(T, T),
162 First(T),
163 Second(T),
164}
165
166#[cfg(test)]
167mod tests {
168 #[allow(unused_imports)]
169 use super::*;
170 use quickcheck::Arbitrary;
171 use quickcheck_macros::quickcheck;
172
173 macro_rules! test {
174 (
175 $name:ident,
176 uri = $uri:literal,
177 prefix = $prefix:literal,
178 expected = $expected:expr,
179 ) => {
180 #[test]
181 fn $name() {
182 let uri = $uri.parse().unwrap();
183 let new_uri = strip_prefix(&uri, $prefix).map(|uri| uri.to_string());
184 assert_eq!(new_uri.as_deref(), $expected);
185 }
186 };
187 }
188
189 test!(empty, uri = "/", prefix = "/", expected = Some("/"),);
190
191 test!(
192 single_segment,
193 uri = "/a",
194 prefix = "/a",
195 expected = Some("/"),
196 );
197
198 test!(
199 single_segment_root_uri,
200 uri = "/",
201 prefix = "/a",
202 expected = None,
203 );
204
205 test!(
207 single_segment_root_prefix,
208 uri = "/a",
209 prefix = "/",
210 expected = Some("/a"),
211 );
212
213 test!(
214 single_segment_no_match,
215 uri = "/a",
216 prefix = "/b",
217 expected = None,
218 );
219
220 test!(
221 single_segment_trailing_slash,
222 uri = "/a/",
223 prefix = "/a/",
224 expected = Some("/"),
225 );
226
227 test!(
228 single_segment_trailing_slash_2,
229 uri = "/a",
230 prefix = "/a/",
231 expected = None,
232 );
233
234 test!(
235 single_segment_trailing_slash_3,
236 uri = "/a/",
237 prefix = "/a",
238 expected = Some("/"),
239 );
240
241 test!(
242 multi_segment,
243 uri = "/a/b",
244 prefix = "/a",
245 expected = Some("/b"),
246 );
247
248 test!(
249 multi_segment_2,
250 uri = "/b/a",
251 prefix = "/a",
252 expected = None,
253 );
254
255 test!(
256 multi_segment_3,
257 uri = "/a",
258 prefix = "/a/b",
259 expected = None,
260 );
261
262 test!(
263 multi_segment_4,
264 uri = "/a/b",
265 prefix = "/b",
266 expected = None,
267 );
268
269 test!(
270 multi_segment_trailing_slash,
271 uri = "/a/b/",
272 prefix = "/a/b/",
273 expected = Some("/"),
274 );
275
276 test!(
277 multi_segment_trailing_slash_2,
278 uri = "/a/b",
279 prefix = "/a/b/",
280 expected = None,
281 );
282
283 test!(
284 multi_segment_trailing_slash_3,
285 uri = "/a/b/",
286 prefix = "/a/b",
287 expected = Some("/"),
288 );
289
290 test!(
291 param_0,
292 uri = "/",
293 prefix = "/{param}",
294 expected = Some("/"),
295 );
296
297 test!(
298 param_1,
299 uri = "/a",
300 prefix = "/{param}",
301 expected = Some("/"),
302 );
303
304 test!(
305 param_2,
306 uri = "/a/b",
307 prefix = "/{param}",
308 expected = Some("/b"),
309 );
310
311 test!(
312 param_3,
313 uri = "/b/a",
314 prefix = "/{param}",
315 expected = Some("/a"),
316 );
317
318 test!(
319 param_4,
320 uri = "/a/b",
321 prefix = "/a/{param}",
322 expected = Some("/"),
323 );
324
325 test!(
326 param_5,
327 uri = "/b/a",
328 prefix = "/a/{param}",
329 expected = None,
330 );
331
332 test!(
333 param_6,
334 uri = "/a/b",
335 prefix = "/{param}/a",
336 expected = None,
337 );
338
339 test!(
340 param_7,
341 uri = "/b/a",
342 prefix = "/{param}/a",
343 expected = Some("/"),
344 );
345
346 test!(
347 param_8,
348 uri = "/a/b/c",
349 prefix = "/a/{param}/c",
350 expected = Some("/"),
351 );
352
353 test!(
354 param_9,
355 uri = "/c/b/a",
356 prefix = "/a/{param}/c",
357 expected = None,
358 );
359
360 test!(
361 param_10,
362 uri = "/a/",
363 prefix = "/{param}",
364 expected = Some("/"),
365 );
366
367 test!(param_11, uri = "/a", prefix = "/{param}/", expected = None,);
368
369 test!(
370 param_12,
371 uri = "/a/",
372 prefix = "/{param}/",
373 expected = Some("/"),
374 );
375
376 test!(
377 param_13,
378 uri = "/a/a",
379 prefix = "/a/",
380 expected = Some("/a"),
381 );
382
383 #[quickcheck]
384 fn does_not_panic(uri_and_prefix: UriAndPrefix) -> bool {
385 let UriAndPrefix { uri, prefix } = uri_and_prefix;
386 strip_prefix(&uri, &prefix);
387 true
388 }
389
390 #[derive(Clone, Debug)]
391 struct UriAndPrefix {
392 uri: Uri,
393 prefix: String,
394 }
395
396 impl Arbitrary for UriAndPrefix {
397 fn arbitrary(g: &mut quickcheck::Gen) -> Self {
398 let mut uri = String::new();
399 let mut prefix = String::new();
400
401 let size = u8_between(1, 20, g);
402
403 for _ in 0..size {
404 let segment = ascii_alphanumeric(g);
405
406 uri.push('/');
407 uri.push_str(&segment);
408
409 prefix.push('/');
410
411 let make_matching_segment = bool::arbitrary(g);
412 let make_capture = bool::arbitrary(g);
413
414 match (make_matching_segment, make_capture) {
415 (_, true) => {
416 prefix.push_str(":a");
417 }
418 (true, false) => {
419 prefix.push_str(&segment);
420 }
421 (false, false) => {
422 prefix.push_str(&ascii_alphanumeric(g));
423 }
424 }
425 }
426
427 if bool::arbitrary(g) {
428 uri.push('/');
429 }
430
431 if bool::arbitrary(g) {
432 prefix.push('/');
433 }
434
435 Self {
436 uri: uri.parse().unwrap(),
437 prefix,
438 }
439 }
440 }
441
442 fn ascii_alphanumeric(g: &mut quickcheck::Gen) -> String {
443 #[derive(Clone)]
444 struct AsciiAlphanumeric(String);
445
446 impl Arbitrary for AsciiAlphanumeric {
447 fn arbitrary(g: &mut quickcheck::Gen) -> Self {
448 let mut out = String::new();
449
450 let size = u8_between(1, 20, g) as usize;
451
452 while out.len() < size {
453 let c = char::arbitrary(g);
454 if c.is_ascii_alphanumeric() {
455 out.push(c);
456 }
457 }
458 Self(out)
459 }
460 }
461
462 let out = AsciiAlphanumeric::arbitrary(g).0;
463 assert!(!out.is_empty());
464 out
465 }
466
467 fn u8_between(lower: u8, upper: u8, g: &mut quickcheck::Gen) -> u8 {
468 loop {
469 let size = u8::arbitrary(g);
470 if size > lower && size <= upper {
471 break size;
472 }
473 }
474 }
475}