1use base64::prelude::*;
2
3use futures::{ready, Stream};
4use hyper::{
5 body::HttpBody,
6 client::{
7 connect::{Connect, Connection},
8 ResponseFuture,
9 },
10 header::{HeaderMap, HeaderName, HeaderValue},
11 service::Service,
12 Body, Request, Uri,
13};
14use log::{debug, info, trace, warn};
15use pin_project::pin_project;
16use std::{
17 boxed,
18 fmt::{self, Debug, Formatter},
19 future::Future,
20 io::ErrorKind,
21 pin::Pin,
22 str::FromStr,
23 task::{Context, Poll},
24 time::{Duration, Instant},
25};
26
27use tokio::{
28 io::{AsyncRead, AsyncWrite},
29 time::Sleep,
30};
31
32use crate::{
33 config::ReconnectOptions,
34 response::{ErrorBody, Response},
35};
36use crate::{
37 error::{Error, Result},
38 event_parser::ConnectionDetails,
39};
40
41use hyper::client::HttpConnector;
42use hyper_timeout::TimeoutConnector;
43
44use crate::event_parser::EventParser;
45use crate::event_parser::SSE;
46
47use crate::retry::{BackoffRetry, RetryStrategy};
48use std::error::Error as StdError;
49
50#[cfg(feature = "rustls")]
51use hyper_rustls::HttpsConnectorBuilder;
52
53type BoxError = Box<dyn std::error::Error + Send + Sync>;
54
55pub type BoxStream<T> = Pin<boxed::Box<dyn Stream<Item = T> + Send + Sync>>;
57
58pub trait Client: Send + Sync + private::Sealed {
61 fn stream(&self) -> BoxStream<Result<SSE>>;
62}
63
64pub const DEFAULT_REDIRECT_LIMIT: u32 = 16;
72
73pub struct ClientBuilder {
75 url: Uri,
76 headers: HeaderMap,
77 reconnect_opts: ReconnectOptions,
78 connect_timeout: Option<Duration>,
79 read_timeout: Option<Duration>,
80 write_timeout: Option<Duration>,
81 last_event_id: Option<String>,
82 method: String,
83 body: Option<String>,
84 max_redirects: Option<u32>,
85}
86
87impl ClientBuilder {
88 pub fn for_url(url: &str) -> Result<ClientBuilder> {
90 let url = url
91 .parse()
92 .map_err(|e| Error::InvalidParameter(Box::new(e)))?;
93
94 let mut header_map = HeaderMap::new();
95 header_map.insert("Accept", HeaderValue::from_static("text/event-stream"));
96 header_map.insert("Cache-Control", HeaderValue::from_static("no-cache"));
97
98 Ok(ClientBuilder {
99 url,
100 headers: header_map,
101 reconnect_opts: ReconnectOptions::default(),
102 connect_timeout: None,
103 read_timeout: None,
104 write_timeout: None,
105 last_event_id: None,
106 method: String::from("GET"),
107 max_redirects: None,
108 body: None,
109 })
110 }
111
112 pub fn method(mut self, method: String) -> ClientBuilder {
114 self.method = method;
115 self
116 }
117
118 pub fn body(mut self, body: String) -> ClientBuilder {
120 self.body = Some(body);
121 self
122 }
123
124 pub fn last_event_id(mut self, last_event_id: String) -> ClientBuilder {
127 self.last_event_id = Some(last_event_id);
128 self
129 }
130
131 pub fn header(mut self, name: &str, value: &str) -> Result<ClientBuilder> {
133 let name = HeaderName::from_str(name).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
134
135 let value =
136 HeaderValue::from_str(value).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
137
138 self.headers.insert(name, value);
139 Ok(self)
140 }
141
142 pub fn basic_auth(self, username: &str, password: &str) -> Result<ClientBuilder> {
144 let auth = format!("{}:{}", username, password);
145 let encoded = BASE64_STANDARD.encode(auth);
146 let value = format!("Basic {}", encoded);
147
148 self.header("Authorization", &value)
149 }
150
151 pub fn connect_timeout(mut self, connect_timeout: Duration) -> ClientBuilder {
154 self.connect_timeout = Some(connect_timeout);
155 self
156 }
157
158 pub fn read_timeout(mut self, read_timeout: Duration) -> ClientBuilder {
160 self.read_timeout = Some(read_timeout);
161 self
162 }
163
164 pub fn write_timeout(mut self, write_timeout: Duration) -> ClientBuilder {
166 self.write_timeout = Some(write_timeout);
167 self
168 }
169
170 pub fn reconnect(mut self, opts: ReconnectOptions) -> ClientBuilder {
175 self.reconnect_opts = opts;
176 self
177 }
178
179 pub fn redirect_limit(mut self, limit: u32) -> ClientBuilder {
183 self.max_redirects = Some(limit);
184 self
185 }
186
187 pub fn build_with_conn<C>(self, conn: C) -> impl Client
189 where
190 C: Service<Uri> + Clone + Send + Sync + 'static,
191 C::Response: Connection + AsyncRead + AsyncWrite + Send + Unpin,
192 C::Future: Send + 'static,
193 C::Error: Into<BoxError>,
194 {
195 let mut connector = TimeoutConnector::new(conn);
196 connector.set_connect_timeout(self.connect_timeout);
197 connector.set_read_timeout(self.read_timeout);
198 connector.set_write_timeout(self.write_timeout);
199
200 let client = hyper::Client::builder().build::<_, hyper::Body>(connector);
201
202 ClientImpl {
203 http: client,
204 request_props: RequestProps {
205 url: self.url,
206 headers: self.headers,
207 method: self.method,
208 body: self.body,
209 reconnect_opts: self.reconnect_opts,
210 max_redirects: self.max_redirects.unwrap_or(DEFAULT_REDIRECT_LIMIT),
211 },
212 last_event_id: self.last_event_id,
213 }
214 }
215
216 pub fn build_http(self) -> impl Client {
218 self.build_with_conn(HttpConnector::new())
219 }
220
221 #[cfg(feature = "rustls")]
222 pub fn build(self) -> impl Client {
224 let conn = HttpsConnectorBuilder::new()
225 .with_native_roots()
226 .https_or_http()
227 .enable_http1()
228 .enable_http2()
229 .build();
230
231 self.build_with_conn(conn)
232 }
233
234 pub fn build_with_http_client<C>(self, http: hyper::Client<C>) -> impl Client
236 where
237 C: Connect + Clone + Send + Sync + 'static,
238 {
239 ClientImpl {
240 http,
241 request_props: RequestProps {
242 url: self.url,
243 headers: self.headers,
244 method: self.method,
245 body: self.body,
246 reconnect_opts: self.reconnect_opts,
247 max_redirects: self.max_redirects.unwrap_or(DEFAULT_REDIRECT_LIMIT),
248 },
249 last_event_id: self.last_event_id,
250 }
251 }
252}
253
254#[derive(Clone)]
255struct RequestProps {
256 url: Uri,
257 headers: HeaderMap,
258 method: String,
259 body: Option<String>,
260 reconnect_opts: ReconnectOptions,
261 max_redirects: u32,
262}
263
264struct ClientImpl<C> {
268 http: hyper::Client<C>,
269 request_props: RequestProps,
270 last_event_id: Option<String>,
271}
272
273impl<C> Client for ClientImpl<C>
274where
275 C: Connect + Clone + Send + Sync + 'static,
276{
277 fn stream(&self) -> BoxStream<Result<SSE>> {
285 Box::pin(ReconnectingRequest::new(
286 self.http.clone(),
287 self.request_props.clone(),
288 self.last_event_id.clone(),
289 ))
290 }
291}
292
293#[allow(clippy::large_enum_variant)] #[pin_project(project = StateProj)]
295enum State {
296 New,
297 Connecting {
298 retry: bool,
299 #[pin]
300 resp: ResponseFuture,
301 },
302 Connected(#[pin] hyper::Body),
303 WaitingToReconnect(#[pin] Sleep),
304 FollowingRedirect(Option<HeaderValue>),
305 StreamClosed,
306}
307
308impl State {
309 fn name(&self) -> &'static str {
310 match self {
311 State::New => "new",
312 State::Connecting { retry: false, .. } => "connecting(no-retry)",
313 State::Connecting { retry: true, .. } => "connecting(retry)",
314 State::Connected(_) => "connected",
315 State::WaitingToReconnect(_) => "waiting-to-reconnect",
316 State::FollowingRedirect(_) => "following-redirect",
317 State::StreamClosed => "closed",
318 }
319 }
320}
321
322impl Debug for State {
323 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
324 write!(f, "{}", self.name())
325 }
326}
327
328#[must_use = "streams do nothing unless polled"]
329#[pin_project]
330pub struct ReconnectingRequest<C> {
331 http: hyper::Client<C>,
332 props: RequestProps,
333 #[pin]
334 state: State,
335 retry_strategy: Box<dyn RetryStrategy + Send + Sync>,
336 current_url: Uri,
337 redirect_count: u32,
338 event_parser: EventParser,
339 last_event_id: Option<String>,
340}
341
342impl<C> ReconnectingRequest<C> {
343 fn new(
344 http: hyper::Client<C>,
345 props: RequestProps,
346 last_event_id: Option<String>,
347 ) -> ReconnectingRequest<C> {
348 let reconnect_delay = props.reconnect_opts.delay;
349 let delay_max = props.reconnect_opts.delay_max;
350 let backoff_factor = props.reconnect_opts.backoff_factor;
351
352 let url = props.url.clone();
353 ReconnectingRequest {
354 props,
355 http,
356 state: State::New,
357 retry_strategy: Box::new(BackoffRetry::new(
358 reconnect_delay,
359 delay_max,
360 backoff_factor,
361 true,
362 )),
363 redirect_count: 0,
364 current_url: url,
365 event_parser: EventParser::new(),
366 last_event_id,
367 }
368 }
369
370 fn send_request(&self) -> Result<ResponseFuture>
371 where
372 C: Connect + Clone + Send + Sync + 'static,
373 {
374 let mut request_builder = Request::builder()
375 .method(self.props.method.as_str())
376 .uri(&self.current_url);
377
378 for (name, value) in &self.props.headers {
379 request_builder = request_builder.header(name, value);
380 }
381
382 if let Some(id) = self.last_event_id.as_ref() {
383 if !id.is_empty() {
384 let id_as_header =
385 HeaderValue::from_str(id).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
386
387 request_builder = request_builder.header("last-event-id", id_as_header);
388 }
389 }
390
391 let body = match &self.props.body {
392 Some(body) => Body::from(body.to_string()),
393 None => Body::empty(),
394 };
395
396 let request = request_builder
397 .body(body)
398 .map_err(|e| Error::InvalidParameter(Box::new(e)))?;
399
400 Ok(self.http.request(request))
401 }
402
403 fn reset_redirects(self: Pin<&mut Self>) {
404 let url = self.props.url.clone();
405 let this = self.project();
406 *this.current_url = url;
407 *this.redirect_count = 0;
408 }
409
410 fn increment_redirect_counter(self: Pin<&mut Self>) -> bool {
411 if self.redirect_count == self.props.max_redirects {
412 return false;
413 }
414 *self.project().redirect_count += 1;
415 true
416 }
417}
418
419impl<C> Stream for ReconnectingRequest<C>
420where
421 C: Connect + Clone + Send + Sync + 'static,
422{
423 type Item = Result<SSE>;
424
425 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
426 trace!("ReconnectingRequest::poll({:?})", &self.state);
427
428 loop {
429 let this = self.as_mut().project();
430 if let Some(event) = this.event_parser.get_event() {
431 return match event {
432 SSE::Connected(_) => Poll::Ready(Some(Ok(event))),
433 SSE::Event(ref evt) => {
434 this.last_event_id.clone_from(&evt.id);
435
436 if let Some(retry) = evt.retry {
437 this.retry_strategy
438 .change_base_delay(Duration::from_millis(retry));
439 }
440 Poll::Ready(Some(Ok(event)))
441 }
442 SSE::Comment(_) => Poll::Ready(Some(Ok(event))),
443 };
444 }
445
446 trace!("ReconnectingRequest::poll loop({:?})", &this.state);
447
448 let state = this.state.project();
449 match state {
450 StateProj::StreamClosed => return Poll::Ready(Some(Err(Error::StreamClosed))),
451 StateProj::New => {
454 *self.as_mut().project().event_parser = EventParser::new();
455 match self.send_request() {
456 Ok(resp) => {
457 let retry = self.props.reconnect_opts.retry_initial;
458 self.as_mut()
459 .project()
460 .state
461 .set(State::Connecting { resp, retry })
462 }
463 Err(e) => {
464 self.as_mut().project().state.set(State::StreamClosed);
467 return Poll::Ready(Some(Err(e)));
468 }
469 }
470 }
471 StateProj::Connecting { retry, resp } => match ready!(resp.poll(cx)) {
472 Ok(resp) => {
473 debug!("HTTP response: {:#?}", resp);
474
475 if resp.status().is_success() {
476 self.as_mut().project().retry_strategy.reset(Instant::now());
477 self.as_mut().reset_redirects();
478
479 let status = resp.status();
480 let headers = resp.headers().clone();
481
482 self.as_mut()
483 .project()
484 .state
485 .set(State::Connected(resp.into_body()));
486
487 return Poll::Ready(Some(Ok(SSE::Connected(ConnectionDetails::new(
488 Response::new(status, headers),
489 )))));
490 }
491
492 if resp.status() == 301 || resp.status() == 307 {
493 debug!("got redirected ({})", resp.status());
494
495 if self.as_mut().increment_redirect_counter() {
496 debug!("following redirect {}", self.redirect_count);
497
498 self.as_mut().project().state.set(State::FollowingRedirect(
499 resp.headers().get(hyper::header::LOCATION).cloned(),
500 ));
501 continue;
502 } else {
503 debug!("redirect limit reached ({})", self.props.max_redirects);
504
505 self.as_mut().project().state.set(State::StreamClosed);
506 return Poll::Ready(Some(Err(Error::MaxRedirectLimitReached(
507 self.props.max_redirects,
508 ))));
509 }
510 }
511
512 self.as_mut().reset_redirects();
513 self.as_mut().project().state.set(State::New);
514
515 return Poll::Ready(Some(Err(Error::UnexpectedResponse(
516 Response::new(resp.status(), resp.headers().clone()),
517 ErrorBody::new(resp.into_body()),
518 ))));
519 }
520 Err(e) => {
521 warn!("request returned an error: {}", e);
524 if !*retry {
525 self.as_mut().project().state.set(State::New);
526 return Poll::Ready(Some(Err(Error::HttpStream(Box::new(e)))));
527 }
528 let duration = self
529 .as_mut()
530 .project()
531 .retry_strategy
532 .next_delay(Instant::now());
533 self.as_mut()
534 .project()
535 .state
536 .set(State::WaitingToReconnect(delay(duration, "retrying")))
537 }
538 },
539 StateProj::FollowingRedirect(maybe_header) => match uri_from_header(maybe_header) {
540 Ok(uri) => {
541 *self.as_mut().project().current_url = uri;
542 self.as_mut().project().state.set(State::New);
543 }
544 Err(e) => {
545 self.as_mut().project().state.set(State::StreamClosed);
546 return Poll::Ready(Some(Err(e)));
547 }
548 },
549 StateProj::Connected(body) => match ready!(body.poll_data(cx)) {
550 Some(Ok(result)) => {
551 this.event_parser.process_bytes(result)?;
552 continue;
553 }
554 Some(Err(e)) => {
555 if self.props.reconnect_opts.reconnect {
556 let duration = self
557 .as_mut()
558 .project()
559 .retry_strategy
560 .next_delay(Instant::now());
561 self.as_mut()
562 .project()
563 .state
564 .set(State::WaitingToReconnect(delay(duration, "reconnecting")));
565 }
566
567 if let Some(cause) = e.source() {
568 if let Some(downcast) = cause.downcast_ref::<std::io::Error>() {
569 if let std::io::ErrorKind::TimedOut = downcast.kind() {
570 return Poll::Ready(Some(Err(Error::TimedOut)));
571 }
572 }
573 } else {
574 return Poll::Ready(Some(Err(Error::HttpStream(Box::new(e)))));
575 }
576 }
577 None => {
578 let duration = self
579 .as_mut()
580 .project()
581 .retry_strategy
582 .next_delay(Instant::now());
583 self.as_mut()
584 .project()
585 .state
586 .set(State::WaitingToReconnect(delay(duration, "retrying")));
587
588 if self.event_parser.was_processing() {
589 return Poll::Ready(Some(Err(Error::UnexpectedEof)));
590 }
591 return Poll::Ready(Some(Err(Error::Eof)));
592 }
593 },
594 StateProj::WaitingToReconnect(delay) => {
595 ready!(delay.poll(cx));
596 info!("Reconnecting");
597 self.as_mut().project().state.set(State::New);
598 }
599 };
600 }
601 }
602}
603
604fn uri_from_header(maybe_header: &Option<HeaderValue>) -> Result<Uri> {
605 let header = maybe_header.as_ref().ok_or_else(|| {
606 Error::MalformedLocationHeader(Box::new(std::io::Error::new(
607 ErrorKind::NotFound,
608 "missing Location header",
609 )))
610 })?;
611
612 let header_string = header
613 .to_str()
614 .map_err(|e| Error::MalformedLocationHeader(Box::new(e)))?;
615
616 header_string
617 .parse::<Uri>()
618 .map_err(|e| Error::MalformedLocationHeader(Box::new(e)))
619}
620
621fn delay(dur: Duration, description: &str) -> Sleep {
622 info!("Waiting {:?} before {}", dur, description);
623 tokio::time::sleep(dur)
624}
625
626mod private {
627 use crate::client::ClientImpl;
628
629 pub trait Sealed {}
630 impl<C> Sealed for ClientImpl<C> {}
631}
632
633#[cfg(test)]
634mod tests {
635 use crate::ClientBuilder;
636 use hyper::http::HeaderValue;
637 use test_case::test_case;
638
639 #[test_case("user", "pass", "dXNlcjpwYXNz")]
640 #[test_case("user1", "password123", "dXNlcjE6cGFzc3dvcmQxMjM=")]
641 #[test_case("user2", "", "dXNlcjI6")]
642 #[test_case("user@name", "pass#word!", "dXNlckBuYW1lOnBhc3Mjd29yZCE=")]
643 #[test_case("user3", "my pass", "dXNlcjM6bXkgcGFzcw==")]
644 #[test_case(
645 "weird@-/:stuff",
646 "goes@-/:here",
647 "d2VpcmRALS86c3R1ZmY6Z29lc0AtLzpoZXJl"
648 )]
649 fn basic_auth_generates_correct_headers(username: &str, password: &str, expected: &str) {
650 let builder = ClientBuilder::for_url("http://example.com")
651 .expect("failed to build client")
652 .basic_auth(username, password)
653 .expect("failed to add authentication");
654
655 let actual = builder.headers.get("Authorization");
656 let expected = HeaderValue::from_str(format!("Basic {}", expected).as_str())
657 .expect("unable to create expected header");
658
659 assert_eq!(Some(&expected), actual);
660 }
661}