headers/common/
access_control_allow_origin.rs

1use std::convert::TryFrom;
2
3use http::HeaderValue;
4
5use super::origin::Origin;
6use crate::util::{IterExt, TryFromValues};
7use crate::Error;
8
9/// The `Access-Control-Allow-Origin` response header,
10/// part of [CORS](http://www.w3.org/TR/cors/#access-control-allow-origin-response-header)
11///
12/// The `Access-Control-Allow-Origin` header indicates whether a resource
13/// can be shared based by returning the value of the Origin request header,
14/// `*`, or `null` in the response.
15///
16/// ## ABNF
17///
18/// ```text
19/// Access-Control-Allow-Origin = "Access-Control-Allow-Origin" ":" origin-list-or-null | "*"
20/// ```
21///
22/// ## Example values
23/// * `null`
24/// * `*`
25/// * `http://google.com/`
26///
27/// # Examples
28///
29/// ```
30/// use headers::AccessControlAllowOrigin;
31/// use std::convert::TryFrom;
32///
33/// let any_origin = AccessControlAllowOrigin::ANY;
34/// let null_origin = AccessControlAllowOrigin::NULL;
35/// let origin = AccessControlAllowOrigin::try_from("http://web-platform.test:8000");
36/// ```
37#[derive(Clone, Debug, PartialEq, Eq, Hash)]
38pub struct AccessControlAllowOrigin(OriginOrAny);
39
40derive_header! {
41    AccessControlAllowOrigin(_),
42    name: ACCESS_CONTROL_ALLOW_ORIGIN
43}
44
45#[derive(Clone, Debug, PartialEq, Eq, Hash)]
46enum OriginOrAny {
47    Origin(Origin),
48    /// Allow all origins
49    Any,
50}
51
52impl AccessControlAllowOrigin {
53    /// `Access-Control-Allow-Origin: *`
54    pub const ANY: AccessControlAllowOrigin = AccessControlAllowOrigin(OriginOrAny::Any);
55    /// `Access-Control-Allow-Origin: null`
56    pub const NULL: AccessControlAllowOrigin =
57        AccessControlAllowOrigin(OriginOrAny::Origin(Origin::NULL));
58
59    /// Returns the origin if there's one specified.
60    pub fn origin(&self) -> Option<&Origin> {
61        match self.0 {
62            OriginOrAny::Origin(ref origin) => Some(origin),
63            _ => None,
64        }
65    }
66}
67
68impl TryFrom<&str> for AccessControlAllowOrigin {
69    type Error = Error;
70
71    fn try_from(s: &str) -> Result<Self, Error> {
72        let header_value = HeaderValue::from_str(s).map_err(|_| Error::invalid())?;
73        let origin = OriginOrAny::try_from(&header_value)?;
74        Ok(Self(origin))
75    }
76}
77
78impl TryFrom<&HeaderValue> for OriginOrAny {
79    type Error = Error;
80
81    fn try_from(header_value: &HeaderValue) -> Result<Self, Error> {
82        Origin::try_from_value(header_value)
83            .map(OriginOrAny::Origin)
84            .ok_or_else(Error::invalid)
85    }
86}
87
88impl TryFromValues for OriginOrAny {
89    fn try_from_values<'i, I>(values: &mut I) -> Result<Self, Error>
90    where
91        I: Iterator<Item = &'i HeaderValue>,
92    {
93        values
94            .just_one()
95            .and_then(|value| {
96                if value == "*" {
97                    return Some(OriginOrAny::Any);
98                }
99
100                Origin::try_from_value(value).map(OriginOrAny::Origin)
101            })
102            .ok_or_else(Error::invalid)
103    }
104}
105
106impl<'a> From<&'a OriginOrAny> for HeaderValue {
107    fn from(origin: &'a OriginOrAny) -> HeaderValue {
108        match origin {
109            OriginOrAny::Origin(ref origin) => origin.to_value(),
110            OriginOrAny::Any => HeaderValue::from_static("*"),
111        }
112    }
113}
114
115#[cfg(test)]
116mod tests {
117
118    use super::super::{test_decode, test_encode};
119    use super::*;
120
121    #[test]
122    fn origin() {
123        let s = "http://web-platform.test:8000";
124
125        let allow_origin = test_decode::<AccessControlAllowOrigin>(&[s]).unwrap();
126        {
127            let origin = allow_origin.origin().unwrap();
128            assert_eq!(origin.scheme(), "http");
129            assert_eq!(origin.hostname(), "web-platform.test");
130            assert_eq!(origin.port(), Some(8000));
131        }
132
133        let headers = test_encode(allow_origin);
134        assert_eq!(headers["access-control-allow-origin"], s);
135    }
136
137    #[test]
138    fn try_from_origin() {
139        let s = "http://web-platform.test:8000";
140
141        let allow_origin = AccessControlAllowOrigin::try_from(s).unwrap();
142        {
143            let origin = allow_origin.origin().unwrap();
144            assert_eq!(origin.scheme(), "http");
145            assert_eq!(origin.hostname(), "web-platform.test");
146            assert_eq!(origin.port(), Some(8000));
147        }
148
149        let headers = test_encode(allow_origin);
150        assert_eq!(headers["access-control-allow-origin"], s);
151    }
152
153    #[test]
154    fn any() {
155        let allow_origin = test_decode::<AccessControlAllowOrigin>(&["*"]).unwrap();
156        assert_eq!(allow_origin, AccessControlAllowOrigin::ANY);
157
158        let headers = test_encode(allow_origin);
159        assert_eq!(headers["access-control-allow-origin"], "*");
160    }
161
162    #[test]
163    fn null() {
164        let allow_origin = test_decode::<AccessControlAllowOrigin>(&["null"]).unwrap();
165        assert_eq!(allow_origin, AccessControlAllowOrigin::NULL);
166
167        let headers = test_encode(allow_origin);
168        assert_eq!(headers["access-control-allow-origin"], "null");
169    }
170}