reqwest/
redirect.rs
1use std::error::Error as StdError;
8use std::fmt;
9
10use crate::header::{HeaderMap, AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, WWW_AUTHENTICATE};
11use hyper::StatusCode;
12
13use crate::Url;
14
15pub struct Policy {
25 inner: PolicyKind,
26}
27
28#[derive(Debug)]
31pub struct Attempt<'a> {
32 status: StatusCode,
33 next: &'a Url,
34 previous: &'a [Url],
35}
36
37#[derive(Debug)]
39pub struct Action {
40 inner: ActionKind,
41}
42
43impl Policy {
44 pub fn limited(max: usize) -> Self {
48 Self {
49 inner: PolicyKind::Limit(max),
50 }
51 }
52
53 pub fn none() -> Self {
55 Self {
56 inner: PolicyKind::None,
57 }
58 }
59
60 pub fn custom<T>(policy: T) -> Self
99 where
100 T: Fn(Attempt) -> Action + Send + Sync + 'static,
101 {
102 Self {
103 inner: PolicyKind::Custom(Box::new(policy)),
104 }
105 }
106
107 pub fn redirect(&self, attempt: Attempt) -> Action {
128 match self.inner {
129 PolicyKind::Custom(ref custom) => custom(attempt),
130 PolicyKind::Limit(max) => {
131 if attempt.previous.len() >= max {
132 attempt.error(TooManyRedirects)
133 } else {
134 attempt.follow()
135 }
136 }
137 PolicyKind::None => attempt.stop(),
138 }
139 }
140
141 pub(crate) fn check(&self, status: StatusCode, next: &Url, previous: &[Url]) -> ActionKind {
142 self.redirect(Attempt {
143 status,
144 next,
145 previous,
146 })
147 .inner
148 }
149
150 pub(crate) fn is_default(&self) -> bool {
151 matches!(self.inner, PolicyKind::Limit(10))
152 }
153}
154
155impl Default for Policy {
156 fn default() -> Policy {
157 Policy::limited(10)
159 }
160}
161
162impl<'a> Attempt<'a> {
163 pub fn status(&self) -> StatusCode {
165 self.status
166 }
167
168 pub fn url(&self) -> &Url {
170 self.next
171 }
172
173 pub fn previous(&self) -> &[Url] {
175 self.previous
176 }
177 pub fn follow(self) -> Action {
179 Action {
180 inner: ActionKind::Follow,
181 }
182 }
183
184 pub fn stop(self) -> Action {
188 Action {
189 inner: ActionKind::Stop,
190 }
191 }
192
193 pub fn error<E: Into<Box<dyn StdError + Send + Sync>>>(self, error: E) -> Action {
197 Action {
198 inner: ActionKind::Error(error.into()),
199 }
200 }
201}
202
203enum PolicyKind {
204 Custom(Box<dyn Fn(Attempt) -> Action + Send + Sync + 'static>),
205 Limit(usize),
206 None,
207}
208
209impl fmt::Debug for Policy {
210 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
211 f.debug_tuple("Policy").field(&self.inner).finish()
212 }
213}
214
215impl fmt::Debug for PolicyKind {
216 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
217 match *self {
218 PolicyKind::Custom(..) => f.pad("Custom"),
219 PolicyKind::Limit(max) => f.debug_tuple("Limit").field(&max).finish(),
220 PolicyKind::None => f.pad("None"),
221 }
222 }
223}
224
225#[derive(Debug)]
228pub(crate) enum ActionKind {
229 Follow,
230 Stop,
231 Error(Box<dyn StdError + Send + Sync>),
232}
233
234pub(crate) fn remove_sensitive_headers(headers: &mut HeaderMap, next: &Url, previous: &[Url]) {
235 if let Some(previous) = previous.last() {
236 let cross_host = next.host_str() != previous.host_str()
237 || next.port_or_known_default() != previous.port_or_known_default();
238 if cross_host {
239 headers.remove(AUTHORIZATION);
240 headers.remove(COOKIE);
241 headers.remove("cookie2");
242 headers.remove(PROXY_AUTHORIZATION);
243 headers.remove(WWW_AUTHENTICATE);
244 }
245 }
246}
247
248#[derive(Debug)]
249struct TooManyRedirects;
250
251impl fmt::Display for TooManyRedirects {
252 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253 f.write_str("too many redirects")
254 }
255}
256
257impl StdError for TooManyRedirects {}
258
259#[test]
260fn test_redirect_policy_limit() {
261 let policy = Policy::default();
262 let next = Url::parse("http://x.y/z").unwrap();
263 let mut previous = (0..9)
264 .map(|i| Url::parse(&format!("http://a.b/c/{i}")).unwrap())
265 .collect::<Vec<_>>();
266
267 match policy.check(StatusCode::FOUND, &next, &previous) {
268 ActionKind::Follow => (),
269 other => panic!("unexpected {other:?}"),
270 }
271
272 previous.push(Url::parse("http://a.b.d/e/33").unwrap());
273
274 match policy.check(StatusCode::FOUND, &next, &previous) {
275 ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
276 other => panic!("unexpected {other:?}"),
277 }
278}
279
280#[test]
281fn test_redirect_policy_limit_to_0() {
282 let policy = Policy::limited(0);
283 let next = Url::parse("http://x.y/z").unwrap();
284 let previous = vec![Url::parse("http://a.b/c").unwrap()];
285
286 match policy.check(StatusCode::FOUND, &next, &previous) {
287 ActionKind::Error(err) if err.is::<TooManyRedirects>() => (),
288 other => panic!("unexpected {other:?}"),
289 }
290}
291
292#[test]
293fn test_redirect_policy_custom() {
294 let policy = Policy::custom(|attempt| {
295 if attempt.url().host_str() == Some("foo") {
296 attempt.stop()
297 } else {
298 attempt.follow()
299 }
300 });
301
302 let next = Url::parse("http://bar/baz").unwrap();
303 match policy.check(StatusCode::FOUND, &next, &[]) {
304 ActionKind::Follow => (),
305 other => panic!("unexpected {other:?}"),
306 }
307
308 let next = Url::parse("http://foo/baz").unwrap();
309 match policy.check(StatusCode::FOUND, &next, &[]) {
310 ActionKind::Stop => (),
311 other => panic!("unexpected {other:?}"),
312 }
313}
314
315#[test]
316fn test_remove_sensitive_headers() {
317 use hyper::header::{HeaderValue, ACCEPT, AUTHORIZATION, COOKIE};
318
319 let mut headers = HeaderMap::new();
320 headers.insert(ACCEPT, HeaderValue::from_static("*/*"));
321 headers.insert(AUTHORIZATION, HeaderValue::from_static("let me in"));
322 headers.insert(COOKIE, HeaderValue::from_static("foo=bar"));
323
324 let next = Url::parse("http://initial-domain.com/path").unwrap();
325 let mut prev = vec![Url::parse("http://initial-domain.com/new_path").unwrap()];
326 let mut filtered_headers = headers.clone();
327
328 remove_sensitive_headers(&mut headers, &next, &prev);
329 assert_eq!(headers, filtered_headers);
330
331 prev.push(Url::parse("http://new-domain.com/path").unwrap());
332 filtered_headers.remove(AUTHORIZATION);
333 filtered_headers.remove(COOKIE);
334
335 remove_sensitive_headers(&mut headers, &next, &prev);
336 assert_eq!(headers, filtered_headers);
337}