tower_http/cors/mod.rs
1//! Middleware which adds headers for [CORS][mdn].
2//!
3//! # Example
4//!
5//! ```
6//! use http::{Request, Response, Method, header};
7//! use http_body_util::Full;
8//! use bytes::Bytes;
9//! use tower::{ServiceBuilder, ServiceExt, Service};
10//! use tower_http::cors::{Any, CorsLayer};
11//! use std::convert::Infallible;
12//!
13//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
14//!     Ok(Response::new(Full::default()))
15//! }
16//!
17//! # #[tokio::main]
18//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
19//! let cors = CorsLayer::new()
20//!     // allow `GET` and `POST` when accessing the resource
21//!     .allow_methods([Method::GET, Method::POST])
22//!     // allow requests from any origin
23//!     .allow_origin(Any);
24//!
25//! let mut service = ServiceBuilder::new()
26//!     .layer(cors)
27//!     .service_fn(handle);
28//!
29//! let request = Request::builder()
30//!     .header(header::ORIGIN, "https://example.com")
31//!     .body(Full::default())
32//!     .unwrap();
33//!
34//! let response = service
35//!     .ready()
36//!     .await?
37//!     .call(request)
38//!     .await?;
39//!
40//! assert_eq!(
41//!     response.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(),
42//!     "*",
43//! );
44//! # Ok(())
45//! # }
46//! ```
47//!
48//! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
49
50#![allow(clippy::enum_variant_names)]
51
52use allow_origin::AllowOriginFuture;
53use bytes::{BufMut, BytesMut};
54use http::{
55    header::{self, HeaderName},
56    HeaderMap, HeaderValue, Method, Request, Response,
57};
58use pin_project_lite::pin_project;
59use std::{
60    array,
61    future::Future,
62    mem,
63    pin::Pin,
64    task::{ready, Context, Poll},
65};
66use tower_layer::Layer;
67use tower_service::Service;
68
69mod allow_credentials;
70mod allow_headers;
71mod allow_methods;
72mod allow_origin;
73mod allow_private_network;
74mod expose_headers;
75mod max_age;
76mod vary;
77
78#[cfg(test)]
79mod tests;
80
81pub use self::{
82    allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods,
83    allow_origin::AllowOrigin, allow_private_network::AllowPrivateNetwork,
84    expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary,
85};
86
87/// Layer that applies the [`Cors`] middleware which adds headers for [CORS][mdn].
88///
89/// See the [module docs](crate::cors) for an example.
90///
91/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
92#[derive(Debug, Clone)]
93#[must_use]
94pub struct CorsLayer {
95    allow_credentials: AllowCredentials,
96    allow_headers: AllowHeaders,
97    allow_methods: AllowMethods,
98    allow_origin: AllowOrigin,
99    allow_private_network: AllowPrivateNetwork,
100    expose_headers: ExposeHeaders,
101    max_age: MaxAge,
102    vary: Vary,
103}
104
105#[allow(clippy::declare_interior_mutable_const)]
106const WILDCARD: HeaderValue = HeaderValue::from_static("*");
107
108impl CorsLayer {
109    /// Create a new `CorsLayer`.
110    ///
111    /// No headers are sent by default. Use the builder methods to customize
112    /// the behavior.
113    ///
114    /// You need to set at least an allowed origin for browsers to make
115    /// successful cross-origin requests to your service.
116    pub fn new() -> Self {
117        Self {
118            allow_credentials: Default::default(),
119            allow_headers: Default::default(),
120            allow_methods: Default::default(),
121            allow_origin: Default::default(),
122            allow_private_network: Default::default(),
123            expose_headers: Default::default(),
124            max_age: Default::default(),
125            vary: Default::default(),
126        }
127    }
128
129    /// A permissive configuration:
130    ///
131    /// - All request headers allowed.
132    /// - All methods allowed.
133    /// - All origins allowed.
134    /// - All headers exposed.
135    pub fn permissive() -> Self {
136        Self::new()
137            .allow_headers(Any)
138            .allow_methods(Any)
139            .allow_origin(Any)
140            .expose_headers(Any)
141    }
142
143    /// A very permissive configuration:
144    ///
145    /// - **Credentials allowed.**
146    /// - The method received in `Access-Control-Request-Method` is sent back
147    ///   as an allowed method.
148    /// - The origin of the preflight request is sent back as an allowed origin.
149    /// - The header names received in `Access-Control-Request-Headers` are sent
150    ///   back as allowed headers.
151    /// - No headers are currently exposed, but this may change in the future.
152    pub fn very_permissive() -> Self {
153        Self::new()
154            .allow_credentials(true)
155            .allow_headers(AllowHeaders::mirror_request())
156            .allow_methods(AllowMethods::mirror_request())
157            .allow_origin(AllowOrigin::mirror_request())
158    }
159
160    /// Set the [`Access-Control-Allow-Credentials`][mdn] header.
161    ///
162    /// ```
163    /// use tower_http::cors::CorsLayer;
164    ///
165    /// let layer = CorsLayer::new().allow_credentials(true);
166    /// ```
167    ///
168    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
169    pub fn allow_credentials<T>(mut self, allow_credentials: T) -> Self
170    where
171        T: Into<AllowCredentials>,
172    {
173        self.allow_credentials = allow_credentials.into();
174        self
175    }
176
177    /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header.
178    ///
179    /// ```
180    /// use tower_http::cors::CorsLayer;
181    /// use http::header::{AUTHORIZATION, ACCEPT};
182    ///
183    /// let layer = CorsLayer::new().allow_headers([AUTHORIZATION, ACCEPT]);
184    /// ```
185    ///
186    /// All headers can be allowed with
187    ///
188    /// ```
189    /// use tower_http::cors::{Any, CorsLayer};
190    ///
191    /// let layer = CorsLayer::new().allow_headers(Any);
192    /// ```
193    ///
194    /// Note that multiple calls to this method will override any previous
195    /// calls.
196    ///
197    /// Also note that `Access-Control-Allow-Headers` is required for requests that have
198    /// `Access-Control-Request-Headers`.
199    ///
200    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
201    pub fn allow_headers<T>(mut self, headers: T) -> Self
202    where
203        T: Into<AllowHeaders>,
204    {
205        self.allow_headers = headers.into();
206        self
207    }
208
209    /// Set the value of the [`Access-Control-Max-Age`][mdn] header.
210    ///
211    /// ```
212    /// use std::time::Duration;
213    /// use tower_http::cors::CorsLayer;
214    ///
215    /// let layer = CorsLayer::new().max_age(Duration::from_secs(60) * 10);
216    /// ```
217    ///
218    /// By default the header will not be set which disables caching and will
219    /// require a preflight call for all requests.
220    ///
221    /// Note that each browser has a maximum internal value that takes
222    /// precedence when the Access-Control-Max-Age is greater. For more details
223    /// see [mdn].
224    ///
225    /// If you need more flexibility, you can use supply a function which can
226    /// dynamically decide the max-age based on the origin and other parts of
227    /// each preflight request:
228    ///
229    /// ```
230    /// # struct MyServerConfig { cors_max_age: Duration }
231    /// use std::time::Duration;
232    ///
233    /// use http::{request::Parts as RequestParts, HeaderValue};
234    /// use tower_http::cors::{CorsLayer, MaxAge};
235    ///
236    /// let layer = CorsLayer::new().max_age(MaxAge::dynamic(
237    ///     |_origin: &HeaderValue, parts: &RequestParts| -> Duration {
238    ///         // Let's say you want to be able to reload your config at
239    ///         // runtime and have another middleware that always inserts
240    ///         // the current config into the request extensions
241    ///         let config = parts.extensions.get::<MyServerConfig>().unwrap();
242    ///         config.cors_max_age
243    ///     },
244    /// ));
245    /// ```
246    ///
247    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
248    pub fn max_age<T>(mut self, max_age: T) -> Self
249    where
250        T: Into<MaxAge>,
251    {
252        self.max_age = max_age.into();
253        self
254    }
255
256    /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header.
257    ///
258    /// ```
259    /// use tower_http::cors::CorsLayer;
260    /// use http::Method;
261    ///
262    /// let layer = CorsLayer::new().allow_methods([Method::GET, Method::POST]);
263    /// ```
264    ///
265    /// All methods can be allowed with
266    ///
267    /// ```
268    /// use tower_http::cors::{Any, CorsLayer};
269    ///
270    /// let layer = CorsLayer::new().allow_methods(Any);
271    /// ```
272    ///
273    /// Note that multiple calls to this method will override any previous
274    /// calls.
275    ///
276    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
277    pub fn allow_methods<T>(mut self, methods: T) -> Self
278    where
279        T: Into<AllowMethods>,
280    {
281        self.allow_methods = methods.into();
282        self
283    }
284
285    /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
286    ///
287    /// ```
288    /// use http::HeaderValue;
289    /// use tower_http::cors::CorsLayer;
290    ///
291    /// let layer = CorsLayer::new().allow_origin(
292    ///     "http://example.com".parse::<HeaderValue>().unwrap(),
293    /// );
294    /// ```
295    ///
296    /// Multiple origins can be allowed with
297    ///
298    /// ```
299    /// use tower_http::cors::CorsLayer;
300    ///
301    /// let origins = [
302    ///     "http://example.com".parse().unwrap(),
303    ///     "http://api.example.com".parse().unwrap(),
304    /// ];
305    ///
306    /// let layer = CorsLayer::new().allow_origin(origins);
307    /// ```
308    ///
309    /// All origins can be allowed with
310    ///
311    /// ```
312    /// use tower_http::cors::{Any, CorsLayer};
313    ///
314    /// let layer = CorsLayer::new().allow_origin(Any);
315    /// ```
316    ///
317    /// You can also use a closure
318    ///
319    /// ```
320    /// use tower_http::cors::{CorsLayer, AllowOrigin};
321    /// use http::{request::Parts as RequestParts, HeaderValue};
322    ///
323    /// let layer = CorsLayer::new().allow_origin(AllowOrigin::predicate(
324    ///     |origin: &HeaderValue, _request_parts: &RequestParts| {
325    ///         origin.as_bytes().ends_with(b".rust-lang.org")
326    ///     },
327    /// ));
328    /// ```
329    ///
330    /// You can also use an async closure:
331    ///
332    /// ```
333    /// # #[derive(Clone)]
334    /// # struct Client;
335    /// # fn get_api_client() -> Client {
336    /// #     Client
337    /// # }
338    /// # impl Client {
339    /// #     async fn fetch_allowed_origins(&self) -> Vec<HeaderValue> {
340    /// #         vec![HeaderValue::from_static("http://example.com")]
341    /// #     }
342    /// #     async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> {
343    /// #         vec![HeaderValue::from_static("http://example.com")]
344    /// #     }
345    /// # }
346    /// use tower_http::cors::{CorsLayer, AllowOrigin};
347    /// use http::{request::Parts as RequestParts, HeaderValue};
348    ///
349    /// let client = get_api_client();
350    ///
351    /// let layer = CorsLayer::new().allow_origin(AllowOrigin::async_predicate(
352    ///     |origin: HeaderValue, _request_parts: &RequestParts| async move {
353    ///         // fetch list of origins that are allowed
354    ///         let origins = client.fetch_allowed_origins().await;
355    ///         origins.contains(&origin)
356    ///     },
357    /// ));
358    ///
359    /// let client = get_api_client();
360    ///
361    /// // if using &RequestParts, make sure all the values are owned
362    /// // before passing into the future
363    /// let layer = CorsLayer::new().allow_origin(AllowOrigin::async_predicate(
364    ///     |origin: HeaderValue, parts: &RequestParts| {
365    ///         let path = parts.uri.path().to_owned();
366    ///
367    ///         async move {
368    ///             // fetch list of origins that are allowed for this path
369    ///             let origins = client.fetch_allowed_origins_for_path(path).await;
370    ///             origins.contains(&origin)
371    ///         }
372    ///     },
373    /// ));
374    /// ```
375    ///
376    /// Note that multiple calls to this method will override any previous
377    /// calls.
378    ///
379    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
380    pub fn allow_origin<T>(mut self, origin: T) -> Self
381    where
382        T: Into<AllowOrigin>,
383    {
384        self.allow_origin = origin.into();
385        self
386    }
387
388    /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
389    ///
390    /// ```
391    /// use tower_http::cors::CorsLayer;
392    /// use http::header::CONTENT_ENCODING;
393    ///
394    /// let layer = CorsLayer::new().expose_headers([CONTENT_ENCODING]);
395    /// ```
396    ///
397    /// All headers can be allowed with
398    ///
399    /// ```
400    /// use tower_http::cors::{Any, CorsLayer};
401    ///
402    /// let layer = CorsLayer::new().expose_headers(Any);
403    /// ```
404    ///
405    /// Note that multiple calls to this method will override any previous
406    /// calls.
407    ///
408    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
409    pub fn expose_headers<T>(mut self, headers: T) -> Self
410    where
411        T: Into<ExposeHeaders>,
412    {
413        self.expose_headers = headers.into();
414        self
415    }
416
417    /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
418    ///
419    /// ```
420    /// use tower_http::cors::CorsLayer;
421    ///
422    /// let layer = CorsLayer::new().allow_private_network(true);
423    /// ```
424    ///
425    /// [wicg]: https://wicg.github.io/private-network-access/
426    pub fn allow_private_network<T>(mut self, allow_private_network: T) -> Self
427    where
428        T: Into<AllowPrivateNetwork>,
429    {
430        self.allow_private_network = allow_private_network.into();
431        self
432    }
433
434    /// Set the value(s) of the [`Vary`][mdn] header.
435    ///
436    /// In contrast to the other headers, this one has a non-empty default of
437    /// [`preflight_request_headers()`].
438    ///
439    /// You only need to set this is you want to remove some of these defaults,
440    /// or if you use a closure for one of the other headers and want to add a
441    /// vary header accordingly.
442    ///
443    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary
444    pub fn vary<T>(mut self, headers: T) -> Self
445    where
446        T: Into<Vary>,
447    {
448        self.vary = headers.into();
449        self
450    }
451}
452
453/// Represents a wildcard value (`*`) used with some CORS headers such as
454/// [`CorsLayer::allow_methods`].
455#[derive(Debug, Clone, Copy)]
456#[must_use]
457pub struct Any;
458
459/// Represents a wildcard value (`*`) used with some CORS headers such as
460/// [`CorsLayer::allow_methods`].
461#[deprecated = "Use Any as a unit struct literal instead"]
462pub fn any() -> Any {
463    Any
464}
465
466fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
467where
468    I: Iterator<Item = HeaderValue>,
469{
470    match iter.next() {
471        Some(fst) => {
472            let mut result = BytesMut::from(fst.as_bytes());
473            for val in iter {
474                result.reserve(val.len() + 1);
475                result.put_u8(b',');
476                result.extend_from_slice(val.as_bytes());
477            }
478
479            Some(HeaderValue::from_maybe_shared(result.freeze()).unwrap())
480        }
481        None => None,
482    }
483}
484
485impl Default for CorsLayer {
486    fn default() -> Self {
487        Self::new()
488    }
489}
490
491impl<S> Layer<S> for CorsLayer {
492    type Service = Cors<S>;
493
494    fn layer(&self, inner: S) -> Self::Service {
495        ensure_usable_cors_rules(self);
496
497        Cors {
498            inner,
499            layer: self.clone(),
500        }
501    }
502}
503
504/// Middleware which adds headers for [CORS][mdn].
505///
506/// See the [module docs](crate::cors) for an example.
507///
508/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
509#[derive(Debug, Clone)]
510#[must_use]
511pub struct Cors<S> {
512    inner: S,
513    layer: CorsLayer,
514}
515
516impl<S> Cors<S> {
517    /// Create a new `Cors`.
518    ///
519    /// See [`CorsLayer::new`] for more details.
520    pub fn new(inner: S) -> Self {
521        Self {
522            inner,
523            layer: CorsLayer::new(),
524        }
525    }
526
527    /// A permissive configuration.
528    ///
529    /// See [`CorsLayer::permissive`] for more details.
530    pub fn permissive(inner: S) -> Self {
531        Self {
532            inner,
533            layer: CorsLayer::permissive(),
534        }
535    }
536
537    /// A very permissive configuration.
538    ///
539    /// See [`CorsLayer::very_permissive`] for more details.
540    pub fn very_permissive(inner: S) -> Self {
541        Self {
542            inner,
543            layer: CorsLayer::very_permissive(),
544        }
545    }
546
547    define_inner_service_accessors!();
548
549    /// Returns a new [`Layer`] that wraps services with a [`Cors`] middleware.
550    ///
551    /// [`Layer`]: tower_layer::Layer
552    pub fn layer() -> CorsLayer {
553        CorsLayer::new()
554    }
555
556    /// Set the [`Access-Control-Allow-Credentials`][mdn] header.
557    ///
558    /// See [`CorsLayer::allow_credentials`] for more details.
559    ///
560    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
561    pub fn allow_credentials<T>(self, allow_credentials: T) -> Self
562    where
563        T: Into<AllowCredentials>,
564    {
565        self.map_layer(|layer| layer.allow_credentials(allow_credentials))
566    }
567
568    /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header.
569    ///
570    /// See [`CorsLayer::allow_headers`] for more details.
571    ///
572    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
573    pub fn allow_headers<T>(self, headers: T) -> Self
574    where
575        T: Into<AllowHeaders>,
576    {
577        self.map_layer(|layer| layer.allow_headers(headers))
578    }
579
580    /// Set the value of the [`Access-Control-Max-Age`][mdn] header.
581    ///
582    /// See [`CorsLayer::max_age`] for more details.
583    ///
584    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
585    pub fn max_age<T>(self, max_age: T) -> Self
586    where
587        T: Into<MaxAge>,
588    {
589        self.map_layer(|layer| layer.max_age(max_age))
590    }
591
592    /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header.
593    ///
594    /// See [`CorsLayer::allow_methods`] for more details.
595    ///
596    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
597    pub fn allow_methods<T>(self, methods: T) -> Self
598    where
599        T: Into<AllowMethods>,
600    {
601        self.map_layer(|layer| layer.allow_methods(methods))
602    }
603
604    /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
605    ///
606    /// See [`CorsLayer::allow_origin`] for more details.
607    ///
608    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
609    pub fn allow_origin<T>(self, origin: T) -> Self
610    where
611        T: Into<AllowOrigin>,
612    {
613        self.map_layer(|layer| layer.allow_origin(origin))
614    }
615
616    /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
617    ///
618    /// See [`CorsLayer::expose_headers`] for more details.
619    ///
620    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
621    pub fn expose_headers<T>(self, headers: T) -> Self
622    where
623        T: Into<ExposeHeaders>,
624    {
625        self.map_layer(|layer| layer.expose_headers(headers))
626    }
627
628    /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
629    ///
630    /// See [`CorsLayer::allow_private_network`] for more details.
631    ///
632    /// [wicg]: https://wicg.github.io/private-network-access/
633    pub fn allow_private_network<T>(self, allow_private_network: T) -> Self
634    where
635        T: Into<AllowPrivateNetwork>,
636    {
637        self.map_layer(|layer| layer.allow_private_network(allow_private_network))
638    }
639
640    fn map_layer<F>(mut self, f: F) -> Self
641    where
642        F: FnOnce(CorsLayer) -> CorsLayer,
643    {
644        self.layer = f(self.layer);
645        self
646    }
647}
648
649impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Cors<S>
650where
651    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
652    ResBody: Default,
653{
654    type Response = S::Response;
655    type Error = S::Error;
656    type Future = ResponseFuture<S::Future>;
657
658    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
659        ensure_usable_cors_rules(&self.layer);
660        self.inner.poll_ready(cx)
661    }
662
663    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
664        let (parts, body) = req.into_parts();
665        let origin = parts.headers.get(&header::ORIGIN);
666
667        let mut headers = HeaderMap::new();
668
669        // These headers are applied to both preflight and subsequent regular CORS requests:
670        // https://fetch.spec.whatwg.org/#http-responses
671
672        headers.extend(self.layer.allow_credentials.to_header(origin, &parts));
673        headers.extend(self.layer.allow_private_network.to_header(origin, &parts));
674        headers.extend(self.layer.vary.to_header());
675
676        let allow_origin_future = self.layer.allow_origin.to_future(origin, &parts);
677
678        // Return results immediately upon preflight request
679        if parts.method == Method::OPTIONS {
680            // These headers are applied only to preflight requests
681            headers.extend(self.layer.allow_methods.to_header(&parts));
682            headers.extend(self.layer.allow_headers.to_header(&parts));
683            headers.extend(self.layer.max_age.to_header(origin, &parts));
684
685            ResponseFuture {
686                inner: Kind::PreflightCall {
687                    allow_origin_future,
688                    headers,
689                },
690            }
691        } else {
692            // This header is applied only to non-preflight requests
693            headers.extend(self.layer.expose_headers.to_header(&parts));
694
695            let req = Request::from_parts(parts, body);
696            ResponseFuture {
697                inner: Kind::CorsCall {
698                    allow_origin_future,
699                    allow_origin_complete: false,
700                    future: self.inner.call(req),
701                    headers,
702                },
703            }
704        }
705    }
706}
707
708pin_project! {
709    /// Response future for [`Cors`].
710    pub struct ResponseFuture<F> {
711        #[pin]
712        inner: Kind<F>,
713    }
714}
715
716pin_project! {
717    #[project = KindProj]
718    enum Kind<F> {
719        CorsCall {
720            #[pin]
721            allow_origin_future: AllowOriginFuture,
722            allow_origin_complete: bool,
723            #[pin]
724            future: F,
725            headers: HeaderMap,
726        },
727        PreflightCall {
728            #[pin]
729            allow_origin_future: AllowOriginFuture,
730            headers: HeaderMap,
731        },
732    }
733}
734
735impl<F, B, E> Future for ResponseFuture<F>
736where
737    F: Future<Output = Result<Response<B>, E>>,
738    B: Default,
739{
740    type Output = Result<Response<B>, E>;
741
742    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
743        match self.project().inner.project() {
744            KindProj::CorsCall {
745                allow_origin_future,
746                allow_origin_complete,
747                future,
748                headers,
749            } => {
750                if !*allow_origin_complete {
751                    headers.extend(ready!(allow_origin_future.poll(cx)));
752                    *allow_origin_complete = true;
753                }
754
755                let mut response: Response<B> = ready!(future.poll(cx))?;
756
757                let response_headers = response.headers_mut();
758
759                // vary header can have multiple values, don't overwrite
760                // previously-set value(s).
761                if let Some(vary) = headers.remove(header::VARY) {
762                    response_headers.append(header::VARY, vary);
763                }
764                // extend will overwrite previous headers of remaining names
765                response_headers.extend(headers.drain());
766
767                Poll::Ready(Ok(response))
768            }
769            KindProj::PreflightCall {
770                allow_origin_future,
771                headers,
772            } => {
773                headers.extend(ready!(allow_origin_future.poll(cx)));
774
775                let mut response = Response::new(B::default());
776                mem::swap(response.headers_mut(), headers);
777
778                Poll::Ready(Ok(response))
779            }
780        }
781    }
782}
783
784fn ensure_usable_cors_rules(layer: &CorsLayer) {
785    if layer.allow_credentials.is_true() {
786        assert!(
787            !layer.allow_headers.is_wildcard(),
788            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
789             with `Access-Control-Allow-Headers: *`"
790        );
791
792        assert!(
793            !layer.allow_methods.is_wildcard(),
794            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
795             with `Access-Control-Allow-Methods: *`"
796        );
797
798        assert!(
799            !layer.allow_origin.is_wildcard(),
800            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
801             with `Access-Control-Allow-Origin: *`"
802        );
803
804        assert!(
805            !layer.expose_headers.is_wildcard(),
806            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
807             with `Access-Control-Expose-Headers: *`"
808        );
809    }
810}
811
812/// Returns an iterator over the three request headers that may be involved in a CORS preflight request.
813///
814/// This is the default set of header names returned in the `vary` header
815pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> {
816    #[allow(deprecated)] // Can be changed when MSRV >= 1.53
817    array::IntoIter::new([
818        header::ORIGIN,
819        header::ACCESS_CONTROL_REQUEST_METHOD,
820        header::ACCESS_CONTROL_REQUEST_HEADERS,
821    ])
822}