launchdarkly_server_sdk/
feature_requester.rs
1use crate::reqwest::is_http_error_recoverable;
2use futures::future::BoxFuture;
3use hyper::Body;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use super::stores::store_types::AllData;
8use launchdarkly_server_sdk_evaluation::{Flag, Segment};
9
10#[derive(Debug, PartialEq, Eq)]
11pub enum FeatureRequesterError {
12 Temporary,
13 Permanent,
14}
15
16#[derive(Clone, Debug)]
17struct CachedEntry(AllData<Flag, Segment>, String);
18
19pub trait FeatureRequester: Send {
20 fn get_all(&mut self) -> BoxFuture<Result<AllData<Flag, Segment>, FeatureRequesterError>>;
21}
22
23pub struct HyperFeatureRequester<C> {
24 http: Arc<hyper::Client<C>>,
25 url: hyper::Uri,
26 sdk_key: String,
27 cache: Option<CachedEntry>,
28 default_headers: HashMap<&'static str, String>,
29}
30
31impl<C> HyperFeatureRequester<C> {
32 pub fn new(
33 http: hyper::Client<C>,
34 url: hyper::Uri,
35 sdk_key: String,
36 default_headers: HashMap<&'static str, String>,
37 ) -> Self {
38 Self {
39 http: Arc::new(http),
40 url,
41 sdk_key,
42 cache: None,
43 default_headers,
44 }
45 }
46}
47
48impl<C> FeatureRequester for HyperFeatureRequester<C>
49where
50 C: hyper::client::connect::Connect + Clone + Send + Sync + 'static,
51{
52 fn get_all(&mut self) -> BoxFuture<Result<AllData<Flag, Segment>, FeatureRequesterError>> {
53 Box::pin(async {
54 let uri = self.url.clone();
55 let key = self.sdk_key.clone();
56
57 let http = self.http.clone();
58 let cache = self.cache.clone();
59
60 let mut request_builder = hyper::http::Request::builder()
61 .uri(uri)
62 .method("GET")
63 .header("Content-Type", "application/json")
64 .header("Authorization", key)
65 .header("User-Agent", &*crate::USER_AGENT);
66
67 for default_header in &self.default_headers {
68 request_builder =
69 request_builder.header(*default_header.0, default_header.1.as_str());
70 }
71
72 if let Some(cache) = &self.cache {
73 request_builder = request_builder.header("If-None-Match", cache.1.clone());
74 }
75
76 let result = http
77 .request(request_builder.body(Body::empty()).unwrap())
78 .await;
79
80 let response = match result {
81 Ok(response) => response,
82 Err(e) => {
83 error!("An error occurred while retrieving flag information {}", e,);
86 return Err(FeatureRequesterError::Temporary);
87 }
88 };
89
90 if response.status() == hyper::StatusCode::NOT_MODIFIED && cache.is_some() {
91 if let Some(entry) = cache {
92 return Ok(entry.0);
93 }
94 }
95
96 let etag: String = response
97 .headers()
98 .get("etag")
99 .unwrap_or(&crate::EMPTY_HEADER)
100 .to_str()
101 .map_or_else(|_| "".into(), |s| s.into());
102
103 if response.status().is_success() {
104 let bytes = hyper::body::to_bytes(response.into_body())
105 .await
106 .map_err(|e| {
107 error!(
108 "An error occurred while reading the polling response body: {}",
109 e
110 );
111 FeatureRequesterError::Temporary
112 })?;
113 let json = serde_json::from_slice::<AllData<Flag, Segment>>(bytes.as_ref());
114
115 return match json {
116 Ok(all_data) => {
117 if !etag.is_empty() {
118 debug!("Caching data for future use with etag: {}", etag);
119 self.cache = Some(CachedEntry(all_data.clone(), etag));
120 }
121 Ok(all_data)
122 }
123 Err(e) => {
124 error!("An error occurred while parsing the json response: {}", e);
125 Err(FeatureRequesterError::Temporary)
126 }
127 };
128 }
129
130 error!(
131 "An error occurred while retrieving flag information. (status: {})",
132 response.status().as_str()
133 );
134
135 if !is_http_error_recoverable(response.status().as_u16()) {
136 return Err(FeatureRequesterError::Permanent);
137 }
138
139 Err(FeatureRequesterError::Temporary)
140 })
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use std::str::FromStr;
148 use test_case::test_case;
149
150 #[tokio::test]
151 async fn updates_etag_as_appropriate() {
152 let mut server = mockito::Server::new_async().await;
153 server
154 .mock("GET", "/")
155 .with_status(200)
156 .with_header("etag", "INITIAL-TAG")
157 .with_body(r#"{"flags": {}, "segments": {}}"#)
158 .expect(1)
159 .create_async()
160 .await;
161 server
162 .mock("GET", "/")
163 .with_status(304)
164 .match_header("If-None-Match", "INITIAL-TAG")
165 .expect(1)
166 .create_async()
167 .await;
168 server
169 .mock("GET", "/")
170 .with_status(200)
171 .match_header("If-None-Match", "INITIAL-TAG")
172 .with_header("etag", "UPDATED-TAG")
173 .with_body(r#"{"flags": {}, "segments": {}}"#)
174 .create_async()
175 .await;
176
177 let mut requester = build_feature_requester(server.url());
178 let result = requester.get_all().await;
179
180 assert!(result.is_ok());
181 if let Some(cache) = &requester.cache {
182 assert_eq!("INITIAL-TAG", cache.1);
183 }
184
185 let result = requester.get_all().await;
186 assert!(result.is_ok());
187 if let Some(cache) = &requester.cache {
188 assert_eq!("INITIAL-TAG", cache.1);
189 }
190
191 let result = requester.get_all().await;
192 assert!(result.is_ok());
193 if let Some(cache) = &requester.cache {
194 assert_eq!("UPDATED-TAG", cache.1);
195 }
196 }
197
198 #[tokio::test]
199 async fn can_process_large_body() {
200 let payload = std::fs::read("test-data/large-polling-payload.json")
201 .expect("Failed to read polling payload file");
202 let payload =
203 String::from_utf8(payload).expect("Invalid UTF-8 characters in polling payload");
204
205 let mut server = mockito::Server::new_async().await;
206 server
207 .mock("GET", "/")
208 .with_status(200)
209 .with_body(payload)
210 .expect(1)
211 .create_async()
212 .await;
213
214 let mut requester = build_feature_requester(server.url());
215 let result = requester.get_all().await;
216
217 assert!(result.is_ok());
218 }
219
220 #[test_case(400, FeatureRequesterError::Temporary)]
221 #[test_case(401, FeatureRequesterError::Permanent)]
222 #[test_case(408, FeatureRequesterError::Temporary)]
223 #[test_case(409, FeatureRequesterError::Permanent)]
224 #[test_case(429, FeatureRequesterError::Temporary)]
225 #[test_case(430, FeatureRequesterError::Permanent)]
226 #[test_case(500, FeatureRequesterError::Temporary)]
227 #[tokio::test]
228 async fn correctly_determines_unrecoverable_errors(
229 status: usize,
230 error: FeatureRequesterError,
231 ) {
232 let mut server = mockito::Server::new_async().await;
233 server
234 .mock("GET", "/")
235 .with_status(status)
236 .create_async()
237 .await;
238
239 let mut requester = build_feature_requester(server.url());
240 let result = requester.get_all().await;
241
242 if let Err(err) = result {
243 assert_eq!(err, error);
244 } else {
245 panic!("get_all returned the wrong response");
246 }
247 }
248
249 fn build_feature_requester(url: String) -> HyperFeatureRequester<hyper::client::HttpConnector> {
250 let http = hyper::Client::builder().build(hyper::client::HttpConnector::new());
251 let url = hyper::Uri::from_str(&url).expect("Failed parsing the mock server url");
252
253 HyperFeatureRequester::new(
254 http,
255 url,
256 "sdk-key".to_string(),
257 HashMap::<&str, String>::new(),
258 )
259 }
260}