headers/common/
access_control_allow_origin.rs

1use std::convert::TryFrom;
2
3use super::origin::Origin;
4use util::{IterExt, TryFromValues};
5use HeaderValue;
6
7/// The `Access-Control-Allow-Origin` response header,
8/// part of [CORS](http://www.w3.org/TR/cors/#access-control-allow-origin-response-header)
9///
10/// The `Access-Control-Allow-Origin` header indicates whether a resource
11/// can be shared based by returning the value of the Origin request header,
12/// `*`, or `null` in the response.
13///
14/// ## ABNF
15///
16/// ```text
17/// Access-Control-Allow-Origin = "Access-Control-Allow-Origin" ":" origin-list-or-null | "*"
18/// ```
19///
20/// ## Example values
21/// * `null`
22/// * `*`
23/// * `http://google.com/`
24///
25/// # Examples
26///
27/// ```
28/// # extern crate headers;
29/// use headers::AccessControlAllowOrigin;
30/// use std::convert::TryFrom;
31///
32/// let any_origin = AccessControlAllowOrigin::ANY;
33/// let null_origin = AccessControlAllowOrigin::NULL;
34/// let origin = AccessControlAllowOrigin::try_from("http://web-platform.test:8000");
35/// ```
36#[derive(Clone, Debug, PartialEq, Eq, Hash)]
37pub struct AccessControlAllowOrigin(OriginOrAny);
38
39derive_header! {
40    AccessControlAllowOrigin(_),
41    name: ACCESS_CONTROL_ALLOW_ORIGIN
42}
43
44#[derive(Clone, Debug, PartialEq, Eq, Hash)]
45enum OriginOrAny {
46    Origin(Origin),
47    /// Allow all origins
48    Any,
49}
50
51impl AccessControlAllowOrigin {
52    /// `Access-Control-Allow-Origin: *`
53    pub const ANY: AccessControlAllowOrigin = AccessControlAllowOrigin(OriginOrAny::Any);
54    /// `Access-Control-Allow-Origin: null`
55    pub const NULL: AccessControlAllowOrigin =
56        AccessControlAllowOrigin(OriginOrAny::Origin(Origin::NULL));
57
58    /// Returns the origin if there's one specified.
59    pub fn origin(&self) -> Option<&Origin> {
60        match self.0 {
61            OriginOrAny::Origin(ref origin) => Some(origin),
62            _ => None,
63        }
64    }
65}
66
67impl TryFrom<&str> for AccessControlAllowOrigin {
68    type Error = ::Error;
69
70    fn try_from(s: &str) -> Result<Self, ::Error> {
71        let header_value = HeaderValue::from_str(s).map_err(|_| ::Error::invalid())?;
72        let origin = OriginOrAny::try_from(&header_value)?;
73        Ok(Self(origin))
74    }
75}
76
77impl TryFrom<&HeaderValue> for OriginOrAny {
78    type Error = ::Error;
79
80    fn try_from(header_value: &HeaderValue) -> Result<Self, ::Error> {
81        Origin::try_from_value(header_value)
82            .map(OriginOrAny::Origin)
83            .ok_or_else(::Error::invalid)
84    }
85}
86
87impl TryFromValues for OriginOrAny {
88    fn try_from_values<'i, I>(values: &mut I) -> Result<Self, ::Error>
89    where
90        I: Iterator<Item = &'i HeaderValue>,
91    {
92        values
93            .just_one()
94            .and_then(|value| {
95                if value == "*" {
96                    return Some(OriginOrAny::Any);
97                }
98
99                Origin::try_from_value(value).map(OriginOrAny::Origin)
100            })
101            .ok_or_else(::Error::invalid)
102    }
103}
104
105impl<'a> From<&'a OriginOrAny> for HeaderValue {
106    fn from(origin: &'a OriginOrAny) -> HeaderValue {
107        match origin {
108            OriginOrAny::Origin(ref origin) => origin.into_value(),
109            OriginOrAny::Any => HeaderValue::from_static("*"),
110        }
111    }
112}
113
114#[cfg(test)]
115mod tests {
116
117    use super::super::{test_decode, test_encode};
118    use super::*;
119
120    #[test]
121    fn origin() {
122        let s = "http://web-platform.test:8000";
123
124        let allow_origin = test_decode::<AccessControlAllowOrigin>(&[s]).unwrap();
125        {
126            let origin = allow_origin.origin().unwrap();
127            assert_eq!(origin.scheme(), "http");
128            assert_eq!(origin.hostname(), "web-platform.test");
129            assert_eq!(origin.port(), Some(8000));
130        }
131
132        let headers = test_encode(allow_origin);
133        assert_eq!(headers["access-control-allow-origin"], s);
134    }
135
136    #[test]
137    fn try_from_origin() {
138        let s = "http://web-platform.test:8000";
139
140        let allow_origin = AccessControlAllowOrigin::try_from(s).unwrap();
141        {
142            let origin = allow_origin.origin().unwrap();
143            assert_eq!(origin.scheme(), "http");
144            assert_eq!(origin.hostname(), "web-platform.test");
145            assert_eq!(origin.port(), Some(8000));
146        }
147
148        let headers = test_encode(allow_origin);
149        assert_eq!(headers["access-control-allow-origin"], s);
150    }
151
152    #[test]
153    fn any() {
154        let allow_origin = test_decode::<AccessControlAllowOrigin>(&["*"]).unwrap();
155        assert_eq!(allow_origin, AccessControlAllowOrigin::ANY);
156
157        let headers = test_encode(allow_origin);
158        assert_eq!(headers["access-control-allow-origin"], "*");
159    }
160
161    #[test]
162    fn null() {
163        let allow_origin = test_decode::<AccessControlAllowOrigin>(&["null"]).unwrap();
164        assert_eq!(allow_origin, AccessControlAllowOrigin::NULL);
165
166        let headers = test_encode(allow_origin);
167        assert_eq!(headers["access-control-allow-origin"], "null");
168    }
169}