tower_http/cors/
mod.rs

1//! Middleware which adds headers for [CORS][mdn].
2//!
3//! # Example
4//!
5//! ```
6//! use http::{Request, Response, Method, header};
7//! use http_body_util::Full;
8//! use bytes::Bytes;
9//! use tower::{ServiceBuilder, ServiceExt, Service};
10//! use tower_http::cors::{Any, CorsLayer};
11//! use std::convert::Infallible;
12//!
13//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
14//!     Ok(Response::new(Full::default()))
15//! }
16//!
17//! # #[tokio::main]
18//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
19//! let cors = CorsLayer::new()
20//!     // allow `GET` and `POST` when accessing the resource
21//!     .allow_methods([Method::GET, Method::POST])
22//!     // allow requests from any origin
23//!     .allow_origin(Any);
24//!
25//! let mut service = ServiceBuilder::new()
26//!     .layer(cors)
27//!     .service_fn(handle);
28//!
29//! let request = Request::builder()
30//!     .header(header::ORIGIN, "https://example.com")
31//!     .body(Full::default())
32//!     .unwrap();
33//!
34//! let response = service
35//!     .ready()
36//!     .await?
37//!     .call(request)
38//!     .await?;
39//!
40//! assert_eq!(
41//!     response.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(),
42//!     "*",
43//! );
44//! # Ok(())
45//! # }
46//! ```
47//!
48//! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
49
50#![allow(clippy::enum_variant_names)]
51
52use allow_origin::AllowOriginFuture;
53use bytes::{BufMut, BytesMut};
54use http::{
55    header::{self, HeaderName},
56    HeaderMap, HeaderValue, Method, Request, Response,
57};
58use pin_project_lite::pin_project;
59use std::{
60    future::Future,
61    mem,
62    pin::Pin,
63    task::{ready, Context, Poll},
64};
65use tower_layer::Layer;
66use tower_service::Service;
67
68mod allow_credentials;
69mod allow_headers;
70mod allow_methods;
71mod allow_origin;
72mod allow_private_network;
73mod expose_headers;
74mod max_age;
75mod vary;
76
77#[cfg(test)]
78mod tests;
79
80pub use self::{
81    allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods,
82    allow_origin::AllowOrigin, allow_private_network::AllowPrivateNetwork,
83    expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary,
84};
85
86/// Layer that applies the [`Cors`] middleware which adds headers for [CORS][mdn].
87///
88/// See the [module docs](crate::cors) for an example.
89///
90/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
91#[derive(Debug, Clone)]
92#[must_use]
93pub struct CorsLayer {
94    allow_credentials: AllowCredentials,
95    allow_headers: AllowHeaders,
96    allow_methods: AllowMethods,
97    allow_origin: AllowOrigin,
98    allow_private_network: AllowPrivateNetwork,
99    expose_headers: ExposeHeaders,
100    max_age: MaxAge,
101    vary: Vary,
102}
103
104#[allow(clippy::declare_interior_mutable_const)]
105const WILDCARD: HeaderValue = HeaderValue::from_static("*");
106
107impl CorsLayer {
108    /// Create a new `CorsLayer`.
109    ///
110    /// No headers are sent by default. Use the builder methods to customize
111    /// the behavior.
112    ///
113    /// You need to set at least an allowed origin for browsers to make
114    /// successful cross-origin requests to your service.
115    pub fn new() -> Self {
116        Self {
117            allow_credentials: Default::default(),
118            allow_headers: Default::default(),
119            allow_methods: Default::default(),
120            allow_origin: Default::default(),
121            allow_private_network: Default::default(),
122            expose_headers: Default::default(),
123            max_age: Default::default(),
124            vary: Default::default(),
125        }
126    }
127
128    /// A permissive configuration:
129    ///
130    /// - All request headers allowed.
131    /// - All methods allowed.
132    /// - All origins allowed.
133    /// - All headers exposed.
134    pub fn permissive() -> Self {
135        Self::new()
136            .allow_headers(Any)
137            .allow_methods(Any)
138            .allow_origin(Any)
139            .expose_headers(Any)
140    }
141
142    /// A very permissive configuration:
143    ///
144    /// - **Credentials allowed.**
145    /// - The method received in `Access-Control-Request-Method` is sent back
146    ///   as an allowed method.
147    /// - The origin of the preflight request is sent back as an allowed origin.
148    /// - The header names received in `Access-Control-Request-Headers` are sent
149    ///   back as allowed headers.
150    /// - No headers are currently exposed, but this may change in the future.
151    pub fn very_permissive() -> Self {
152        Self::new()
153            .allow_credentials(true)
154            .allow_headers(AllowHeaders::mirror_request())
155            .allow_methods(AllowMethods::mirror_request())
156            .allow_origin(AllowOrigin::mirror_request())
157    }
158
159    /// Set the [`Access-Control-Allow-Credentials`][mdn] header.
160    ///
161    /// ```
162    /// use tower_http::cors::CorsLayer;
163    ///
164    /// let layer = CorsLayer::new().allow_credentials(true);
165    /// ```
166    ///
167    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
168    pub fn allow_credentials<T>(mut self, allow_credentials: T) -> Self
169    where
170        T: Into<AllowCredentials>,
171    {
172        self.allow_credentials = allow_credentials.into();
173        self
174    }
175
176    /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header.
177    ///
178    /// ```
179    /// use tower_http::cors::CorsLayer;
180    /// use http::header::{AUTHORIZATION, ACCEPT};
181    ///
182    /// let layer = CorsLayer::new().allow_headers([AUTHORIZATION, ACCEPT]);
183    /// ```
184    ///
185    /// All headers can be allowed with
186    ///
187    /// ```
188    /// use tower_http::cors::{Any, CorsLayer};
189    ///
190    /// let layer = CorsLayer::new().allow_headers(Any);
191    /// ```
192    ///
193    /// Note that multiple calls to this method will override any previous
194    /// calls.
195    ///
196    /// Also note that `Access-Control-Allow-Headers` is required for requests that have
197    /// `Access-Control-Request-Headers`.
198    ///
199    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
200    pub fn allow_headers<T>(mut self, headers: T) -> Self
201    where
202        T: Into<AllowHeaders>,
203    {
204        self.allow_headers = headers.into();
205        self
206    }
207
208    /// Set the value of the [`Access-Control-Max-Age`][mdn] header.
209    ///
210    /// ```
211    /// use std::time::Duration;
212    /// use tower_http::cors::CorsLayer;
213    ///
214    /// let layer = CorsLayer::new().max_age(Duration::from_secs(60) * 10);
215    /// ```
216    ///
217    /// By default the header will not be set which disables caching and will
218    /// require a preflight call for all requests.
219    ///
220    /// Note that each browser has a maximum internal value that takes
221    /// precedence when the Access-Control-Max-Age is greater. For more details
222    /// see [mdn].
223    ///
224    /// If you need more flexibility, you can use supply a function which can
225    /// dynamically decide the max-age based on the origin and other parts of
226    /// each preflight request:
227    ///
228    /// ```
229    /// # struct MyServerConfig { cors_max_age: Duration }
230    /// use std::time::Duration;
231    ///
232    /// use http::{request::Parts as RequestParts, HeaderValue};
233    /// use tower_http::cors::{CorsLayer, MaxAge};
234    ///
235    /// let layer = CorsLayer::new().max_age(MaxAge::dynamic(
236    ///     |_origin: &HeaderValue, parts: &RequestParts| -> Duration {
237    ///         // Let's say you want to be able to reload your config at
238    ///         // runtime and have another middleware that always inserts
239    ///         // the current config into the request extensions
240    ///         let config = parts.extensions.get::<MyServerConfig>().unwrap();
241    ///         config.cors_max_age
242    ///     },
243    /// ));
244    /// ```
245    ///
246    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
247    pub fn max_age<T>(mut self, max_age: T) -> Self
248    where
249        T: Into<MaxAge>,
250    {
251        self.max_age = max_age.into();
252        self
253    }
254
255    /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header.
256    ///
257    /// ```
258    /// use tower_http::cors::CorsLayer;
259    /// use http::Method;
260    ///
261    /// let layer = CorsLayer::new().allow_methods([Method::GET, Method::POST]);
262    /// ```
263    ///
264    /// All methods can be allowed with
265    ///
266    /// ```
267    /// use tower_http::cors::{Any, CorsLayer};
268    ///
269    /// let layer = CorsLayer::new().allow_methods(Any);
270    /// ```
271    ///
272    /// Note that multiple calls to this method will override any previous
273    /// calls.
274    ///
275    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
276    pub fn allow_methods<T>(mut self, methods: T) -> Self
277    where
278        T: Into<AllowMethods>,
279    {
280        self.allow_methods = methods.into();
281        self
282    }
283
284    /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
285    ///
286    /// ```
287    /// use http::HeaderValue;
288    /// use tower_http::cors::CorsLayer;
289    ///
290    /// let layer = CorsLayer::new().allow_origin(
291    ///     "http://example.com".parse::<HeaderValue>().unwrap(),
292    /// );
293    /// ```
294    ///
295    /// Multiple origins can be allowed with
296    ///
297    /// ```
298    /// use tower_http::cors::CorsLayer;
299    ///
300    /// let origins = [
301    ///     "http://example.com".parse().unwrap(),
302    ///     "http://api.example.com".parse().unwrap(),
303    /// ];
304    ///
305    /// let layer = CorsLayer::new().allow_origin(origins);
306    /// ```
307    ///
308    /// All origins can be allowed with
309    ///
310    /// ```
311    /// use tower_http::cors::{Any, CorsLayer};
312    ///
313    /// let layer = CorsLayer::new().allow_origin(Any);
314    /// ```
315    ///
316    /// You can also use a closure
317    ///
318    /// ```
319    /// use tower_http::cors::{CorsLayer, AllowOrigin};
320    /// use http::{request::Parts as RequestParts, HeaderValue};
321    ///
322    /// let layer = CorsLayer::new().allow_origin(AllowOrigin::predicate(
323    ///     |origin: &HeaderValue, _request_parts: &RequestParts| {
324    ///         origin.as_bytes().ends_with(b".rust-lang.org")
325    ///     },
326    /// ));
327    /// ```
328    ///
329    /// You can also use an async closure:
330    ///
331    /// ```
332    /// # #[derive(Clone)]
333    /// # struct Client;
334    /// # fn get_api_client() -> Client {
335    /// #     Client
336    /// # }
337    /// # impl Client {
338    /// #     async fn fetch_allowed_origins(&self) -> Vec<HeaderValue> {
339    /// #         vec![HeaderValue::from_static("http://example.com")]
340    /// #     }
341    /// #     async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> {
342    /// #         vec![HeaderValue::from_static("http://example.com")]
343    /// #     }
344    /// # }
345    /// use tower_http::cors::{CorsLayer, AllowOrigin};
346    /// use http::{request::Parts as RequestParts, HeaderValue};
347    ///
348    /// let client = get_api_client();
349    ///
350    /// let layer = CorsLayer::new().allow_origin(AllowOrigin::async_predicate(
351    ///     |origin: HeaderValue, _request_parts: &RequestParts| async move {
352    ///         // fetch list of origins that are allowed
353    ///         let origins = client.fetch_allowed_origins().await;
354    ///         origins.contains(&origin)
355    ///     },
356    /// ));
357    ///
358    /// let client = get_api_client();
359    ///
360    /// // if using &RequestParts, make sure all the values are owned
361    /// // before passing into the future
362    /// let layer = CorsLayer::new().allow_origin(AllowOrigin::async_predicate(
363    ///     |origin: HeaderValue, parts: &RequestParts| {
364    ///         let path = parts.uri.path().to_owned();
365    ///
366    ///         async move {
367    ///             // fetch list of origins that are allowed for this path
368    ///             let origins = client.fetch_allowed_origins_for_path(path).await;
369    ///             origins.contains(&origin)
370    ///         }
371    ///     },
372    /// ));
373    /// ```
374    ///
375    /// Note that multiple calls to this method will override any previous
376    /// calls.
377    ///
378    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
379    pub fn allow_origin<T>(mut self, origin: T) -> Self
380    where
381        T: Into<AllowOrigin>,
382    {
383        self.allow_origin = origin.into();
384        self
385    }
386
387    /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
388    ///
389    /// ```
390    /// use tower_http::cors::CorsLayer;
391    /// use http::header::CONTENT_ENCODING;
392    ///
393    /// let layer = CorsLayer::new().expose_headers([CONTENT_ENCODING]);
394    /// ```
395    ///
396    /// All headers can be allowed with
397    ///
398    /// ```
399    /// use tower_http::cors::{Any, CorsLayer};
400    ///
401    /// let layer = CorsLayer::new().expose_headers(Any);
402    /// ```
403    ///
404    /// Note that multiple calls to this method will override any previous
405    /// calls.
406    ///
407    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
408    pub fn expose_headers<T>(mut self, headers: T) -> Self
409    where
410        T: Into<ExposeHeaders>,
411    {
412        self.expose_headers = headers.into();
413        self
414    }
415
416    /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
417    ///
418    /// ```
419    /// use tower_http::cors::CorsLayer;
420    ///
421    /// let layer = CorsLayer::new().allow_private_network(true);
422    /// ```
423    ///
424    /// [wicg]: https://wicg.github.io/private-network-access/
425    pub fn allow_private_network<T>(mut self, allow_private_network: T) -> Self
426    where
427        T: Into<AllowPrivateNetwork>,
428    {
429        self.allow_private_network = allow_private_network.into();
430        self
431    }
432
433    /// Set the value(s) of the [`Vary`][mdn] header.
434    ///
435    /// In contrast to the other headers, this one has a non-empty default of
436    /// [`preflight_request_headers()`].
437    ///
438    /// You only need to set this is you want to remove some of these defaults,
439    /// or if you use a closure for one of the other headers and want to add a
440    /// vary header accordingly.
441    ///
442    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary
443    pub fn vary<T>(mut self, headers: T) -> Self
444    where
445        T: Into<Vary>,
446    {
447        self.vary = headers.into();
448        self
449    }
450}
451
452/// Represents a wildcard value (`*`) used with some CORS headers such as
453/// [`CorsLayer::allow_methods`].
454#[derive(Debug, Clone, Copy)]
455#[must_use]
456pub struct Any;
457
458/// Represents a wildcard value (`*`) used with some CORS headers such as
459/// [`CorsLayer::allow_methods`].
460#[deprecated = "Use Any as a unit struct literal instead"]
461pub fn any() -> Any {
462    Any
463}
464
465fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
466where
467    I: Iterator<Item = HeaderValue>,
468{
469    match iter.next() {
470        Some(fst) => {
471            let mut result = BytesMut::from(fst.as_bytes());
472            for val in iter {
473                result.reserve(val.len() + 1);
474                result.put_u8(b',');
475                result.extend_from_slice(val.as_bytes());
476            }
477
478            Some(HeaderValue::from_maybe_shared(result.freeze()).unwrap())
479        }
480        None => None,
481    }
482}
483
484impl Default for CorsLayer {
485    fn default() -> Self {
486        Self::new()
487    }
488}
489
490impl<S> Layer<S> for CorsLayer {
491    type Service = Cors<S>;
492
493    fn layer(&self, inner: S) -> Self::Service {
494        ensure_usable_cors_rules(self);
495
496        Cors {
497            inner,
498            layer: self.clone(),
499        }
500    }
501}
502
503/// Middleware which adds headers for [CORS][mdn].
504///
505/// See the [module docs](crate::cors) for an example.
506///
507/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
508#[derive(Debug, Clone)]
509#[must_use]
510pub struct Cors<S> {
511    inner: S,
512    layer: CorsLayer,
513}
514
515impl<S> Cors<S> {
516    /// Create a new `Cors`.
517    ///
518    /// See [`CorsLayer::new`] for more details.
519    pub fn new(inner: S) -> Self {
520        Self {
521            inner,
522            layer: CorsLayer::new(),
523        }
524    }
525
526    /// A permissive configuration.
527    ///
528    /// See [`CorsLayer::permissive`] for more details.
529    pub fn permissive(inner: S) -> Self {
530        Self {
531            inner,
532            layer: CorsLayer::permissive(),
533        }
534    }
535
536    /// A very permissive configuration.
537    ///
538    /// See [`CorsLayer::very_permissive`] for more details.
539    pub fn very_permissive(inner: S) -> Self {
540        Self {
541            inner,
542            layer: CorsLayer::very_permissive(),
543        }
544    }
545
546    define_inner_service_accessors!();
547
548    /// Returns a new [`Layer`] that wraps services with a [`Cors`] middleware.
549    ///
550    /// [`Layer`]: tower_layer::Layer
551    pub fn layer() -> CorsLayer {
552        CorsLayer::new()
553    }
554
555    /// Set the [`Access-Control-Allow-Credentials`][mdn] header.
556    ///
557    /// See [`CorsLayer::allow_credentials`] for more details.
558    ///
559    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
560    pub fn allow_credentials<T>(self, allow_credentials: T) -> Self
561    where
562        T: Into<AllowCredentials>,
563    {
564        self.map_layer(|layer| layer.allow_credentials(allow_credentials))
565    }
566
567    /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header.
568    ///
569    /// See [`CorsLayer::allow_headers`] for more details.
570    ///
571    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
572    pub fn allow_headers<T>(self, headers: T) -> Self
573    where
574        T: Into<AllowHeaders>,
575    {
576        self.map_layer(|layer| layer.allow_headers(headers))
577    }
578
579    /// Set the value of the [`Access-Control-Max-Age`][mdn] header.
580    ///
581    /// See [`CorsLayer::max_age`] for more details.
582    ///
583    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
584    pub fn max_age<T>(self, max_age: T) -> Self
585    where
586        T: Into<MaxAge>,
587    {
588        self.map_layer(|layer| layer.max_age(max_age))
589    }
590
591    /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header.
592    ///
593    /// See [`CorsLayer::allow_methods`] for more details.
594    ///
595    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
596    pub fn allow_methods<T>(self, methods: T) -> Self
597    where
598        T: Into<AllowMethods>,
599    {
600        self.map_layer(|layer| layer.allow_methods(methods))
601    }
602
603    /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
604    ///
605    /// See [`CorsLayer::allow_origin`] for more details.
606    ///
607    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
608    pub fn allow_origin<T>(self, origin: T) -> Self
609    where
610        T: Into<AllowOrigin>,
611    {
612        self.map_layer(|layer| layer.allow_origin(origin))
613    }
614
615    /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
616    ///
617    /// See [`CorsLayer::expose_headers`] for more details.
618    ///
619    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
620    pub fn expose_headers<T>(self, headers: T) -> Self
621    where
622        T: Into<ExposeHeaders>,
623    {
624        self.map_layer(|layer| layer.expose_headers(headers))
625    }
626
627    /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
628    ///
629    /// See [`CorsLayer::allow_private_network`] for more details.
630    ///
631    /// [wicg]: https://wicg.github.io/private-network-access/
632    pub fn allow_private_network<T>(self, allow_private_network: T) -> Self
633    where
634        T: Into<AllowPrivateNetwork>,
635    {
636        self.map_layer(|layer| layer.allow_private_network(allow_private_network))
637    }
638
639    fn map_layer<F>(mut self, f: F) -> Self
640    where
641        F: FnOnce(CorsLayer) -> CorsLayer,
642    {
643        self.layer = f(self.layer);
644        self
645    }
646}
647
648impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Cors<S>
649where
650    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
651    ResBody: Default,
652{
653    type Response = S::Response;
654    type Error = S::Error;
655    type Future = ResponseFuture<S::Future>;
656
657    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
658        ensure_usable_cors_rules(&self.layer);
659        self.inner.poll_ready(cx)
660    }
661
662    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
663        let (parts, body) = req.into_parts();
664        let origin = parts.headers.get(&header::ORIGIN);
665
666        let mut headers = HeaderMap::new();
667
668        // These headers are applied to both preflight and subsequent regular CORS requests:
669        // https://fetch.spec.whatwg.org/#http-responses
670
671        headers.extend(self.layer.allow_credentials.to_header(origin, &parts));
672        headers.extend(self.layer.allow_private_network.to_header(origin, &parts));
673        headers.extend(self.layer.vary.to_header());
674
675        let allow_origin_future = self.layer.allow_origin.to_future(origin, &parts);
676
677        // Return results immediately upon preflight request
678        if parts.method == Method::OPTIONS {
679            // These headers are applied only to preflight requests
680            headers.extend(self.layer.allow_methods.to_header(&parts));
681            headers.extend(self.layer.allow_headers.to_header(&parts));
682            headers.extend(self.layer.max_age.to_header(origin, &parts));
683
684            ResponseFuture {
685                inner: Kind::PreflightCall {
686                    allow_origin_future,
687                    headers,
688                },
689            }
690        } else {
691            // This header is applied only to non-preflight requests
692            headers.extend(self.layer.expose_headers.to_header(&parts));
693
694            let req = Request::from_parts(parts, body);
695            ResponseFuture {
696                inner: Kind::CorsCall {
697                    allow_origin_future,
698                    allow_origin_complete: false,
699                    future: self.inner.call(req),
700                    headers,
701                },
702            }
703        }
704    }
705}
706
707pin_project! {
708    /// Response future for [`Cors`].
709    pub struct ResponseFuture<F> {
710        #[pin]
711        inner: Kind<F>,
712    }
713}
714
715pin_project! {
716    #[project = KindProj]
717    enum Kind<F> {
718        CorsCall {
719            #[pin]
720            allow_origin_future: AllowOriginFuture,
721            allow_origin_complete: bool,
722            #[pin]
723            future: F,
724            headers: HeaderMap,
725        },
726        PreflightCall {
727            #[pin]
728            allow_origin_future: AllowOriginFuture,
729            headers: HeaderMap,
730        },
731    }
732}
733
734impl<F, B, E> Future for ResponseFuture<F>
735where
736    F: Future<Output = Result<Response<B>, E>>,
737    B: Default,
738{
739    type Output = Result<Response<B>, E>;
740
741    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
742        match self.project().inner.project() {
743            KindProj::CorsCall {
744                allow_origin_future,
745                allow_origin_complete,
746                future,
747                headers,
748            } => {
749                if !*allow_origin_complete {
750                    headers.extend(ready!(allow_origin_future.poll(cx)));
751                    *allow_origin_complete = true;
752                }
753
754                let mut response: Response<B> = ready!(future.poll(cx))?;
755
756                let response_headers = response.headers_mut();
757
758                // vary header can have multiple values, don't overwrite
759                // previously-set value(s).
760                if let Some(vary) = headers.remove(header::VARY) {
761                    response_headers.append(header::VARY, vary);
762                }
763                // extend will overwrite previous headers of remaining names
764                response_headers.extend(headers.drain());
765
766                Poll::Ready(Ok(response))
767            }
768            KindProj::PreflightCall {
769                allow_origin_future,
770                headers,
771            } => {
772                headers.extend(ready!(allow_origin_future.poll(cx)));
773
774                let mut response = Response::new(B::default());
775                mem::swap(response.headers_mut(), headers);
776
777                Poll::Ready(Ok(response))
778            }
779        }
780    }
781}
782
783fn ensure_usable_cors_rules(layer: &CorsLayer) {
784    if layer.allow_credentials.is_true() {
785        assert!(
786            !layer.allow_headers.is_wildcard(),
787            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
788             with `Access-Control-Allow-Headers: *`"
789        );
790
791        assert!(
792            !layer.allow_methods.is_wildcard(),
793            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
794             with `Access-Control-Allow-Methods: *`"
795        );
796
797        assert!(
798            !layer.allow_origin.is_wildcard(),
799            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
800             with `Access-Control-Allow-Origin: *`"
801        );
802
803        assert!(
804            !layer.expose_headers.is_wildcard(),
805            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
806             with `Access-Control-Expose-Headers: *`"
807        );
808    }
809}
810
811/// Returns an iterator over the three request headers that may be involved in a CORS preflight request.
812///
813/// This is the default set of header names returned in the `vary` header
814pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> {
815    IntoIterator::into_iter([
816        header::ORIGIN,
817        header::ACCESS_CONTROL_REQUEST_METHOD,
818        header::ACCESS_CONTROL_REQUEST_HEADERS,
819    ])
820}