launchdarkly_server_sdk/
feature_requester.rs

1use crate::reqwest::is_http_error_recoverable;
2use futures::future::BoxFuture;
3use hyper::Body;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use super::stores::store_types::AllData;
8use launchdarkly_server_sdk_evaluation::{Flag, Segment};
9
10#[derive(Debug, PartialEq, Eq)]
11pub enum FeatureRequesterError {
12    Temporary,
13    Permanent,
14}
15
16#[derive(Clone, Debug)]
17struct CachedEntry(AllData<Flag, Segment>, String);
18
19pub trait FeatureRequester: Send {
20    fn get_all(&mut self) -> BoxFuture<Result<AllData<Flag, Segment>, FeatureRequesterError>>;
21}
22
23pub struct HyperFeatureRequester<C> {
24    http: Arc<hyper::Client<C>>,
25    url: hyper::Uri,
26    sdk_key: String,
27    cache: Option<CachedEntry>,
28    default_headers: HashMap<&'static str, String>,
29}
30
31impl<C> HyperFeatureRequester<C> {
32    pub fn new(
33        http: hyper::Client<C>,
34        url: hyper::Uri,
35        sdk_key: String,
36        default_headers: HashMap<&'static str, String>,
37    ) -> Self {
38        Self {
39            http: Arc::new(http),
40            url,
41            sdk_key,
42            cache: None,
43            default_headers,
44        }
45    }
46}
47
48impl<C> FeatureRequester for HyperFeatureRequester<C>
49where
50    C: hyper::client::connect::Connect + Clone + Send + Sync + 'static,
51{
52    fn get_all(&mut self) -> BoxFuture<Result<AllData<Flag, Segment>, FeatureRequesterError>> {
53        Box::pin(async {
54            let uri = self.url.clone();
55            let key = self.sdk_key.clone();
56
57            let http = self.http.clone();
58            let cache = self.cache.clone();
59
60            let mut request_builder = hyper::http::Request::builder()
61                .uri(uri)
62                .method("GET")
63                .header("Content-Type", "application/json")
64                .header("Authorization", key)
65                .header("User-Agent", &*crate::USER_AGENT);
66
67            for default_header in &self.default_headers {
68                request_builder =
69                    request_builder.header(*default_header.0, default_header.1.as_str());
70            }
71
72            if let Some(cache) = &self.cache {
73                request_builder = request_builder.header("If-None-Match", cache.1.clone());
74            }
75
76            let result = http
77                .request(request_builder.body(Body::empty()).unwrap())
78                .await;
79
80            let response = match result {
81                Ok(response) => response,
82                Err(e) => {
83                    // It appears this type of error will not be an HTTP error.
84                    // It will be a closed connection, aborted write, timeout, etc.
85                    error!("An error occurred while retrieving flag information {}", e,);
86                    return Err(FeatureRequesterError::Temporary);
87                }
88            };
89
90            if response.status() == hyper::StatusCode::NOT_MODIFIED && cache.is_some() {
91                if let Some(entry) = cache {
92                    return Ok(entry.0);
93                }
94            }
95
96            let etag: String = response
97                .headers()
98                .get("etag")
99                .unwrap_or(&crate::EMPTY_HEADER)
100                .to_str()
101                .map_or_else(|_| "".into(), |s| s.into());
102
103            if response.status().is_success() {
104                let bytes = hyper::body::to_bytes(response.into_body())
105                    .await
106                    .map_err(|e| {
107                        error!(
108                            "An error occurred while reading the polling response body: {}",
109                            e
110                        );
111                        FeatureRequesterError::Temporary
112                    })?;
113                let json = serde_json::from_slice::<AllData<Flag, Segment>>(bytes.as_ref());
114
115                return match json {
116                    Ok(all_data) => {
117                        if !etag.is_empty() {
118                            debug!("Caching data for future use with etag: {}", etag);
119                            self.cache = Some(CachedEntry(all_data.clone(), etag));
120                        }
121                        Ok(all_data)
122                    }
123                    Err(e) => {
124                        error!("An error occurred while parsing the json response: {}", e);
125                        Err(FeatureRequesterError::Temporary)
126                    }
127                };
128            }
129
130            error!(
131                "An error occurred while retrieving flag information. (status: {})",
132                response.status().as_str()
133            );
134
135            if !is_http_error_recoverable(response.status().as_u16()) {
136                return Err(FeatureRequesterError::Permanent);
137            }
138
139            Err(FeatureRequesterError::Temporary)
140        })
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use std::str::FromStr;
148    use test_case::test_case;
149
150    #[tokio::test]
151    async fn updates_etag_as_appropriate() {
152        let mut server = mockito::Server::new_async().await;
153        server
154            .mock("GET", "/")
155            .with_status(200)
156            .with_header("etag", "INITIAL-TAG")
157            .with_body(r#"{"flags": {}, "segments": {}}"#)
158            .expect(1)
159            .create_async()
160            .await;
161        server
162            .mock("GET", "/")
163            .with_status(304)
164            .match_header("If-None-Match", "INITIAL-TAG")
165            .expect(1)
166            .create_async()
167            .await;
168        server
169            .mock("GET", "/")
170            .with_status(200)
171            .match_header("If-None-Match", "INITIAL-TAG")
172            .with_header("etag", "UPDATED-TAG")
173            .with_body(r#"{"flags": {}, "segments": {}}"#)
174            .create_async()
175            .await;
176
177        let mut requester = build_feature_requester(server.url());
178        let result = requester.get_all().await;
179
180        assert!(result.is_ok());
181        if let Some(cache) = &requester.cache {
182            assert_eq!("INITIAL-TAG", cache.1);
183        }
184
185        let result = requester.get_all().await;
186        assert!(result.is_ok());
187        if let Some(cache) = &requester.cache {
188            assert_eq!("INITIAL-TAG", cache.1);
189        }
190
191        let result = requester.get_all().await;
192        assert!(result.is_ok());
193        if let Some(cache) = &requester.cache {
194            assert_eq!("UPDATED-TAG", cache.1);
195        }
196    }
197
198    #[tokio::test]
199    async fn can_process_large_body() {
200        let payload = std::fs::read("test-data/large-polling-payload.json")
201            .expect("Failed to read polling payload file");
202        let payload =
203            String::from_utf8(payload).expect("Invalid UTF-8 characters in polling payload");
204
205        let mut server = mockito::Server::new_async().await;
206        server
207            .mock("GET", "/")
208            .with_status(200)
209            .with_body(payload)
210            .expect(1)
211            .create_async()
212            .await;
213
214        let mut requester = build_feature_requester(server.url());
215        let result = requester.get_all().await;
216
217        assert!(result.is_ok());
218    }
219
220    #[test_case(400, FeatureRequesterError::Temporary)]
221    #[test_case(401, FeatureRequesterError::Permanent)]
222    #[test_case(408, FeatureRequesterError::Temporary)]
223    #[test_case(409, FeatureRequesterError::Permanent)]
224    #[test_case(429, FeatureRequesterError::Temporary)]
225    #[test_case(430, FeatureRequesterError::Permanent)]
226    #[test_case(500, FeatureRequesterError::Temporary)]
227    #[tokio::test]
228    async fn correctly_determines_unrecoverable_errors(
229        status: usize,
230        error: FeatureRequesterError,
231    ) {
232        let mut server = mockito::Server::new_async().await;
233        server
234            .mock("GET", "/")
235            .with_status(status)
236            .create_async()
237            .await;
238
239        let mut requester = build_feature_requester(server.url());
240        let result = requester.get_all().await;
241
242        if let Err(err) = result {
243            assert_eq!(err, error);
244        } else {
245            panic!("get_all returned the wrong response");
246        }
247    }
248
249    fn build_feature_requester(url: String) -> HyperFeatureRequester<hyper::client::HttpConnector> {
250        let http = hyper::Client::builder().build(hyper::client::HttpConnector::new());
251        let url = hyper::Uri::from_str(&url).expect("Failed parsing the mock server url");
252
253        HyperFeatureRequester::new(
254            http,
255            url,
256            "sdk-key".to_string(),
257            HashMap::<&str, String>::new(),
258        )
259    }
260}