headers/common/
access_control_allow_origin.rs1use std::convert::TryFrom;
2
3use http::HeaderValue;
4
5use super::origin::Origin;
6use crate::util::{IterExt, TryFromValues};
7use crate::Error;
8
9#[derive(Clone, Debug, PartialEq, Eq, Hash)]
38pub struct AccessControlAllowOrigin(OriginOrAny);
39
40derive_header! {
41 AccessControlAllowOrigin(_),
42 name: ACCESS_CONTROL_ALLOW_ORIGIN
43}
44
45#[derive(Clone, Debug, PartialEq, Eq, Hash)]
46enum OriginOrAny {
47 Origin(Origin),
48 Any,
50}
51
52impl AccessControlAllowOrigin {
53 pub const ANY: AccessControlAllowOrigin = AccessControlAllowOrigin(OriginOrAny::Any);
55 pub const NULL: AccessControlAllowOrigin =
57 AccessControlAllowOrigin(OriginOrAny::Origin(Origin::NULL));
58
59 pub fn origin(&self) -> Option<&Origin> {
61 match self.0 {
62 OriginOrAny::Origin(ref origin) => Some(origin),
63 _ => None,
64 }
65 }
66}
67
68impl TryFrom<&str> for AccessControlAllowOrigin {
69 type Error = Error;
70
71 fn try_from(s: &str) -> Result<Self, Error> {
72 let header_value = HeaderValue::from_str(s).map_err(|_| Error::invalid())?;
73 let origin = OriginOrAny::try_from(&header_value)?;
74 Ok(Self(origin))
75 }
76}
77
78impl TryFrom<&HeaderValue> for OriginOrAny {
79 type Error = Error;
80
81 fn try_from(header_value: &HeaderValue) -> Result<Self, Error> {
82 Origin::try_from_value(header_value)
83 .map(OriginOrAny::Origin)
84 .ok_or_else(Error::invalid)
85 }
86}
87
88impl TryFromValues for OriginOrAny {
89 fn try_from_values<'i, I>(values: &mut I) -> Result<Self, Error>
90 where
91 I: Iterator<Item = &'i HeaderValue>,
92 {
93 values
94 .just_one()
95 .and_then(|value| {
96 if value == "*" {
97 return Some(OriginOrAny::Any);
98 }
99
100 Origin::try_from_value(value).map(OriginOrAny::Origin)
101 })
102 .ok_or_else(Error::invalid)
103 }
104}
105
106impl<'a> From<&'a OriginOrAny> for HeaderValue {
107 fn from(origin: &'a OriginOrAny) -> HeaderValue {
108 match origin {
109 OriginOrAny::Origin(ref origin) => origin.to_value(),
110 OriginOrAny::Any => HeaderValue::from_static("*"),
111 }
112 }
113}
114
115#[cfg(test)]
116mod tests {
117
118 use super::super::{test_decode, test_encode};
119 use super::*;
120
121 #[test]
122 fn origin() {
123 let s = "http://web-platform.test:8000";
124
125 let allow_origin = test_decode::<AccessControlAllowOrigin>(&[s]).unwrap();
126 {
127 let origin = allow_origin.origin().unwrap();
128 assert_eq!(origin.scheme(), "http");
129 assert_eq!(origin.hostname(), "web-platform.test");
130 assert_eq!(origin.port(), Some(8000));
131 }
132
133 let headers = test_encode(allow_origin);
134 assert_eq!(headers["access-control-allow-origin"], s);
135 }
136
137 #[test]
138 fn try_from_origin() {
139 let s = "http://web-platform.test:8000";
140
141 let allow_origin = AccessControlAllowOrigin::try_from(s).unwrap();
142 {
143 let origin = allow_origin.origin().unwrap();
144 assert_eq!(origin.scheme(), "http");
145 assert_eq!(origin.hostname(), "web-platform.test");
146 assert_eq!(origin.port(), Some(8000));
147 }
148
149 let headers = test_encode(allow_origin);
150 assert_eq!(headers["access-control-allow-origin"], s);
151 }
152
153 #[test]
154 fn any() {
155 let allow_origin = test_decode::<AccessControlAllowOrigin>(&["*"]).unwrap();
156 assert_eq!(allow_origin, AccessControlAllowOrigin::ANY);
157
158 let headers = test_encode(allow_origin);
159 assert_eq!(headers["access-control-allow-origin"], "*");
160 }
161
162 #[test]
163 fn null() {
164 let allow_origin = test_decode::<AccessControlAllowOrigin>(&["null"]).unwrap();
165 assert_eq!(allow_origin, AccessControlAllowOrigin::NULL);
166
167 let headers = test_encode(allow_origin);
168 assert_eq!(headers["access-control-allow-origin"], "null");
169 }
170}