eventsource_client/
client.rs

1use base64::prelude::*;
2
3use futures::{ready, Stream};
4use hyper::{
5    body::HttpBody,
6    client::{
7        connect::{Connect, Connection},
8        ResponseFuture,
9    },
10    header::{HeaderMap, HeaderName, HeaderValue},
11    service::Service,
12    Body, Request, Uri,
13};
14use log::{debug, info, trace, warn};
15use pin_project::pin_project;
16use std::{
17    boxed,
18    fmt::{self, Debug, Formatter},
19    future::Future,
20    io::ErrorKind,
21    pin::Pin,
22    str::FromStr,
23    task::{Context, Poll},
24    time::{Duration, Instant},
25};
26
27use tokio::{
28    io::{AsyncRead, AsyncWrite},
29    time::Sleep,
30};
31
32use crate::{
33    config::ReconnectOptions,
34    response::{ErrorBody, Response},
35};
36use crate::{
37    error::{Error, Result},
38    event_parser::ConnectionDetails,
39};
40
41use hyper::client::HttpConnector;
42use hyper_timeout::TimeoutConnector;
43
44use crate::event_parser::EventParser;
45use crate::event_parser::SSE;
46
47use crate::retry::{BackoffRetry, RetryStrategy};
48use std::error::Error as StdError;
49
50#[cfg(feature = "rustls")]
51use hyper_rustls::HttpsConnectorBuilder;
52
53type BoxError = Box<dyn std::error::Error + Send + Sync>;
54
55/// Represents a [`Pin`]'d [`Send`] + [`Sync`] stream, returned by [`Client`]'s stream method.
56pub type BoxStream<T> = Pin<boxed::Box<dyn Stream<Item = T> + Send + Sync>>;
57
58/// Client is the Server-Sent-Events interface.
59/// This trait is sealed and cannot be implemented for types outside this crate.
60pub trait Client: Send + Sync + private::Sealed {
61    fn stream(&self) -> BoxStream<Result<SSE>>;
62}
63
64/*
65 * TODO remove debug output
66 * TODO specify list of stati to not retry (e.g. 204)
67 */
68
69/// Maximum amount of redirects that the client will follow before
70/// giving up, if not overridden via [ClientBuilder::redirect_limit].
71pub const DEFAULT_REDIRECT_LIMIT: u32 = 16;
72
73/// ClientBuilder provides a series of builder methods to easily construct a [`Client`].
74pub struct ClientBuilder {
75    url: Uri,
76    headers: HeaderMap,
77    reconnect_opts: ReconnectOptions,
78    connect_timeout: Option<Duration>,
79    read_timeout: Option<Duration>,
80    write_timeout: Option<Duration>,
81    last_event_id: Option<String>,
82    method: String,
83    body: Option<String>,
84    max_redirects: Option<u32>,
85}
86
87impl ClientBuilder {
88    /// Create a builder for a given URL.
89    pub fn for_url(url: &str) -> Result<ClientBuilder> {
90        let url = url
91            .parse()
92            .map_err(|e| Error::InvalidParameter(Box::new(e)))?;
93
94        let mut header_map = HeaderMap::new();
95        header_map.insert("Accept", HeaderValue::from_static("text/event-stream"));
96        header_map.insert("Cache-Control", HeaderValue::from_static("no-cache"));
97
98        Ok(ClientBuilder {
99            url,
100            headers: header_map,
101            reconnect_opts: ReconnectOptions::default(),
102            connect_timeout: None,
103            read_timeout: None,
104            write_timeout: None,
105            last_event_id: None,
106            method: String::from("GET"),
107            max_redirects: None,
108            body: None,
109        })
110    }
111
112    /// Set the request method used for the initial connection to the SSE endpoint.
113    pub fn method(mut self, method: String) -> ClientBuilder {
114        self.method = method;
115        self
116    }
117
118    /// Set the request body used for the initial connection to the SSE endpoint.
119    pub fn body(mut self, body: String) -> ClientBuilder {
120        self.body = Some(body);
121        self
122    }
123
124    /// Set the last event id for a stream when it is created. If it is set, it will be sent to the
125    /// server in case it can replay missed events.
126    pub fn last_event_id(mut self, last_event_id: String) -> ClientBuilder {
127        self.last_event_id = Some(last_event_id);
128        self
129    }
130
131    /// Set a HTTP header on the SSE request.
132    pub fn header(mut self, name: &str, value: &str) -> Result<ClientBuilder> {
133        let name = HeaderName::from_str(name).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
134
135        let value =
136            HeaderValue::from_str(value).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
137
138        self.headers.insert(name, value);
139        Ok(self)
140    }
141
142    /// Set the Authorization header with the calculated basic authentication value.
143    pub fn basic_auth(self, username: &str, password: &str) -> Result<ClientBuilder> {
144        let auth = format!("{}:{}", username, password);
145        let encoded = BASE64_STANDARD.encode(auth);
146        let value = format!("Basic {}", encoded);
147
148        self.header("Authorization", &value)
149    }
150
151    /// Set a connect timeout for the underlying connection. There is no connect timeout by
152    /// default.
153    pub fn connect_timeout(mut self, connect_timeout: Duration) -> ClientBuilder {
154        self.connect_timeout = Some(connect_timeout);
155        self
156    }
157
158    /// Set a read timeout for the underlying connection. There is no read timeout by default.
159    pub fn read_timeout(mut self, read_timeout: Duration) -> ClientBuilder {
160        self.read_timeout = Some(read_timeout);
161        self
162    }
163
164    /// Set a write timeout for the underlying connection. There is no write timeout by default.
165    pub fn write_timeout(mut self, write_timeout: Duration) -> ClientBuilder {
166        self.write_timeout = Some(write_timeout);
167        self
168    }
169
170    /// Configure the client's reconnect behaviour according to the supplied
171    /// [`ReconnectOptions`].
172    ///
173    /// [`ReconnectOptions`]: struct.ReconnectOptions.html
174    pub fn reconnect(mut self, opts: ReconnectOptions) -> ClientBuilder {
175        self.reconnect_opts = opts;
176        self
177    }
178
179    /// Customize the client's following behavior when served a redirect.
180    /// To disable following redirects, pass `0`.
181    /// By default, the limit is [`DEFAULT_REDIRECT_LIMIT`].
182    pub fn redirect_limit(mut self, limit: u32) -> ClientBuilder {
183        self.max_redirects = Some(limit);
184        self
185    }
186
187    /// Build with a specific client connector.
188    pub fn build_with_conn<C>(self, conn: C) -> impl Client
189    where
190        C: Service<Uri> + Clone + Send + Sync + 'static,
191        C::Response: Connection + AsyncRead + AsyncWrite + Send + Unpin,
192        C::Future: Send + 'static,
193        C::Error: Into<BoxError>,
194    {
195        let mut connector = TimeoutConnector::new(conn);
196        connector.set_connect_timeout(self.connect_timeout);
197        connector.set_read_timeout(self.read_timeout);
198        connector.set_write_timeout(self.write_timeout);
199
200        let client = hyper::Client::builder().build::<_, hyper::Body>(connector);
201
202        ClientImpl {
203            http: client,
204            request_props: RequestProps {
205                url: self.url,
206                headers: self.headers,
207                method: self.method,
208                body: self.body,
209                reconnect_opts: self.reconnect_opts,
210                max_redirects: self.max_redirects.unwrap_or(DEFAULT_REDIRECT_LIMIT),
211            },
212            last_event_id: self.last_event_id,
213        }
214    }
215
216    /// Build with an HTTP client connector.
217    pub fn build_http(self) -> impl Client {
218        self.build_with_conn(HttpConnector::new())
219    }
220
221    #[cfg(feature = "rustls")]
222    /// Build with an HTTPS client connector, using the OS root certificate store.
223    pub fn build(self) -> impl Client {
224        let conn = HttpsConnectorBuilder::new()
225            .with_native_roots()
226            .https_or_http()
227            .enable_http1()
228            .enable_http2()
229            .build();
230
231        self.build_with_conn(conn)
232    }
233
234    /// Build with the given [`hyper::client::Client`].
235    pub fn build_with_http_client<C>(self, http: hyper::Client<C>) -> impl Client
236    where
237        C: Connect + Clone + Send + Sync + 'static,
238    {
239        ClientImpl {
240            http,
241            request_props: RequestProps {
242                url: self.url,
243                headers: self.headers,
244                method: self.method,
245                body: self.body,
246                reconnect_opts: self.reconnect_opts,
247                max_redirects: self.max_redirects.unwrap_or(DEFAULT_REDIRECT_LIMIT),
248            },
249            last_event_id: self.last_event_id,
250        }
251    }
252}
253
254#[derive(Clone)]
255struct RequestProps {
256    url: Uri,
257    headers: HeaderMap,
258    method: String,
259    body: Option<String>,
260    reconnect_opts: ReconnectOptions,
261    max_redirects: u32,
262}
263
264/// A client implementation that connects to a server using the Server-Sent Events protocol
265/// and consumes the event stream indefinitely.
266/// Can be parameterized with different hyper Connectors, such as HTTP or HTTPS.
267struct ClientImpl<C> {
268    http: hyper::Client<C>,
269    request_props: RequestProps,
270    last_event_id: Option<String>,
271}
272
273impl<C> Client for ClientImpl<C>
274where
275    C: Connect + Clone + Send + Sync + 'static,
276{
277    /// Connect to the server and begin consuming the stream. Produces a
278    /// [`Stream`] of [`Event`](crate::Event)s wrapped in [`Result`].
279    ///
280    /// Do not use the stream after it returned an error!
281    ///
282    /// After the first successful connection, the stream will
283    /// reconnect for retryable errors.
284    fn stream(&self) -> BoxStream<Result<SSE>> {
285        Box::pin(ReconnectingRequest::new(
286            self.http.clone(),
287            self.request_props.clone(),
288            self.last_event_id.clone(),
289        ))
290    }
291}
292
293#[allow(clippy::large_enum_variant)] // false positive
294#[pin_project(project = StateProj)]
295enum State {
296    New,
297    Connecting {
298        retry: bool,
299        #[pin]
300        resp: ResponseFuture,
301    },
302    Connected(#[pin] hyper::Body),
303    WaitingToReconnect(#[pin] Sleep),
304    FollowingRedirect(Option<HeaderValue>),
305    StreamClosed,
306}
307
308impl State {
309    fn name(&self) -> &'static str {
310        match self {
311            State::New => "new",
312            State::Connecting { retry: false, .. } => "connecting(no-retry)",
313            State::Connecting { retry: true, .. } => "connecting(retry)",
314            State::Connected(_) => "connected",
315            State::WaitingToReconnect(_) => "waiting-to-reconnect",
316            State::FollowingRedirect(_) => "following-redirect",
317            State::StreamClosed => "closed",
318        }
319    }
320}
321
322impl Debug for State {
323    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
324        write!(f, "{}", self.name())
325    }
326}
327
328#[must_use = "streams do nothing unless polled"]
329#[pin_project]
330pub struct ReconnectingRequest<C> {
331    http: hyper::Client<C>,
332    props: RequestProps,
333    #[pin]
334    state: State,
335    retry_strategy: Box<dyn RetryStrategy + Send + Sync>,
336    current_url: Uri,
337    redirect_count: u32,
338    event_parser: EventParser,
339    last_event_id: Option<String>,
340}
341
342impl<C> ReconnectingRequest<C> {
343    fn new(
344        http: hyper::Client<C>,
345        props: RequestProps,
346        last_event_id: Option<String>,
347    ) -> ReconnectingRequest<C> {
348        let reconnect_delay = props.reconnect_opts.delay;
349        let delay_max = props.reconnect_opts.delay_max;
350        let backoff_factor = props.reconnect_opts.backoff_factor;
351
352        let url = props.url.clone();
353        ReconnectingRequest {
354            props,
355            http,
356            state: State::New,
357            retry_strategy: Box::new(BackoffRetry::new(
358                reconnect_delay,
359                delay_max,
360                backoff_factor,
361                true,
362            )),
363            redirect_count: 0,
364            current_url: url,
365            event_parser: EventParser::new(),
366            last_event_id,
367        }
368    }
369
370    fn send_request(&self) -> Result<ResponseFuture>
371    where
372        C: Connect + Clone + Send + Sync + 'static,
373    {
374        let mut request_builder = Request::builder()
375            .method(self.props.method.as_str())
376            .uri(&self.current_url);
377
378        for (name, value) in &self.props.headers {
379            request_builder = request_builder.header(name, value);
380        }
381
382        if let Some(id) = self.last_event_id.as_ref() {
383            if !id.is_empty() {
384                let id_as_header =
385                    HeaderValue::from_str(id).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
386
387                request_builder = request_builder.header("last-event-id", id_as_header);
388            }
389        }
390
391        let body = match &self.props.body {
392            Some(body) => Body::from(body.to_string()),
393            None => Body::empty(),
394        };
395
396        let request = request_builder
397            .body(body)
398            .map_err(|e| Error::InvalidParameter(Box::new(e)))?;
399
400        Ok(self.http.request(request))
401    }
402
403    fn reset_redirects(self: Pin<&mut Self>) {
404        let url = self.props.url.clone();
405        let this = self.project();
406        *this.current_url = url;
407        *this.redirect_count = 0;
408    }
409
410    fn increment_redirect_counter(self: Pin<&mut Self>) -> bool {
411        if self.redirect_count == self.props.max_redirects {
412            return false;
413        }
414        *self.project().redirect_count += 1;
415        true
416    }
417}
418
419impl<C> Stream for ReconnectingRequest<C>
420where
421    C: Connect + Clone + Send + Sync + 'static,
422{
423    type Item = Result<SSE>;
424
425    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
426        trace!("ReconnectingRequest::poll({:?})", &self.state);
427
428        loop {
429            let this = self.as_mut().project();
430            if let Some(event) = this.event_parser.get_event() {
431                return match event {
432                    SSE::Connected(_) => Poll::Ready(Some(Ok(event))),
433                    SSE::Event(ref evt) => {
434                        this.last_event_id.clone_from(&evt.id);
435
436                        if let Some(retry) = evt.retry {
437                            this.retry_strategy
438                                .change_base_delay(Duration::from_millis(retry));
439                        }
440                        Poll::Ready(Some(Ok(event)))
441                    }
442                    SSE::Comment(_) => Poll::Ready(Some(Ok(event))),
443                };
444            }
445
446            trace!("ReconnectingRequest::poll loop({:?})", &this.state);
447
448            let state = this.state.project();
449            match state {
450                StateProj::StreamClosed => return Poll::Ready(Some(Err(Error::StreamClosed))),
451                // New immediately transitions to Connecting, and exists only
452                // to ensure that we only connect when polled.
453                StateProj::New => {
454                    *self.as_mut().project().event_parser = EventParser::new();
455                    match self.send_request() {
456                        Ok(resp) => {
457                            let retry = self.props.reconnect_opts.retry_initial;
458                            self.as_mut()
459                                .project()
460                                .state
461                                .set(State::Connecting { resp, retry })
462                        }
463                        Err(e) => {
464                            // This error seems to be unrecoverable. So we should just shut down the
465                            // stream.
466                            self.as_mut().project().state.set(State::StreamClosed);
467                            return Poll::Ready(Some(Err(e)));
468                        }
469                    }
470                }
471                StateProj::Connecting { retry, resp } => match ready!(resp.poll(cx)) {
472                    Ok(resp) => {
473                        debug!("HTTP response: {:#?}", resp);
474
475                        if resp.status().is_success() {
476                            self.as_mut().project().retry_strategy.reset(Instant::now());
477                            self.as_mut().reset_redirects();
478
479                            let status = resp.status();
480                            let headers = resp.headers().clone();
481
482                            self.as_mut()
483                                .project()
484                                .state
485                                .set(State::Connected(resp.into_body()));
486
487                            return Poll::Ready(Some(Ok(SSE::Connected(ConnectionDetails::new(
488                                Response::new(status, headers),
489                            )))));
490                        }
491
492                        if resp.status() == 301 || resp.status() == 307 {
493                            debug!("got redirected ({})", resp.status());
494
495                            if self.as_mut().increment_redirect_counter() {
496                                debug!("following redirect {}", self.redirect_count);
497
498                                self.as_mut().project().state.set(State::FollowingRedirect(
499                                    resp.headers().get(hyper::header::LOCATION).cloned(),
500                                ));
501                                continue;
502                            } else {
503                                debug!("redirect limit reached ({})", self.props.max_redirects);
504
505                                self.as_mut().project().state.set(State::StreamClosed);
506                                return Poll::Ready(Some(Err(Error::MaxRedirectLimitReached(
507                                    self.props.max_redirects,
508                                ))));
509                            }
510                        }
511
512                        self.as_mut().reset_redirects();
513                        self.as_mut().project().state.set(State::New);
514
515                        return Poll::Ready(Some(Err(Error::UnexpectedResponse(
516                            Response::new(resp.status(), resp.headers().clone()),
517                            ErrorBody::new(resp.into_body()),
518                        ))));
519                    }
520                    Err(e) => {
521                        // This seems basically impossible. AFAIK we can only get this way if we
522                        // poll after it was already ready
523                        warn!("request returned an error: {}", e);
524                        if !*retry {
525                            self.as_mut().project().state.set(State::New);
526                            return Poll::Ready(Some(Err(Error::HttpStream(Box::new(e)))));
527                        }
528                        let duration = self
529                            .as_mut()
530                            .project()
531                            .retry_strategy
532                            .next_delay(Instant::now());
533                        self.as_mut()
534                            .project()
535                            .state
536                            .set(State::WaitingToReconnect(delay(duration, "retrying")))
537                    }
538                },
539                StateProj::FollowingRedirect(maybe_header) => match uri_from_header(maybe_header) {
540                    Ok(uri) => {
541                        *self.as_mut().project().current_url = uri;
542                        self.as_mut().project().state.set(State::New);
543                    }
544                    Err(e) => {
545                        self.as_mut().project().state.set(State::StreamClosed);
546                        return Poll::Ready(Some(Err(e)));
547                    }
548                },
549                StateProj::Connected(body) => match ready!(body.poll_data(cx)) {
550                    Some(Ok(result)) => {
551                        this.event_parser.process_bytes(result)?;
552                        continue;
553                    }
554                    Some(Err(e)) => {
555                        if self.props.reconnect_opts.reconnect {
556                            let duration = self
557                                .as_mut()
558                                .project()
559                                .retry_strategy
560                                .next_delay(Instant::now());
561                            self.as_mut()
562                                .project()
563                                .state
564                                .set(State::WaitingToReconnect(delay(duration, "reconnecting")));
565                        }
566
567                        if let Some(cause) = e.source() {
568                            if let Some(downcast) = cause.downcast_ref::<std::io::Error>() {
569                                if let std::io::ErrorKind::TimedOut = downcast.kind() {
570                                    return Poll::Ready(Some(Err(Error::TimedOut)));
571                                }
572                            }
573                        } else {
574                            return Poll::Ready(Some(Err(Error::HttpStream(Box::new(e)))));
575                        }
576                    }
577                    None => {
578                        let duration = self
579                            .as_mut()
580                            .project()
581                            .retry_strategy
582                            .next_delay(Instant::now());
583                        self.as_mut()
584                            .project()
585                            .state
586                            .set(State::WaitingToReconnect(delay(duration, "retrying")));
587
588                        if self.event_parser.was_processing() {
589                            return Poll::Ready(Some(Err(Error::UnexpectedEof)));
590                        }
591                        return Poll::Ready(Some(Err(Error::Eof)));
592                    }
593                },
594                StateProj::WaitingToReconnect(delay) => {
595                    ready!(delay.poll(cx));
596                    info!("Reconnecting");
597                    self.as_mut().project().state.set(State::New);
598                }
599            };
600        }
601    }
602}
603
604fn uri_from_header(maybe_header: &Option<HeaderValue>) -> Result<Uri> {
605    let header = maybe_header.as_ref().ok_or_else(|| {
606        Error::MalformedLocationHeader(Box::new(std::io::Error::new(
607            ErrorKind::NotFound,
608            "missing Location header",
609        )))
610    })?;
611
612    let header_string = header
613        .to_str()
614        .map_err(|e| Error::MalformedLocationHeader(Box::new(e)))?;
615
616    header_string
617        .parse::<Uri>()
618        .map_err(|e| Error::MalformedLocationHeader(Box::new(e)))
619}
620
621fn delay(dur: Duration, description: &str) -> Sleep {
622    info!("Waiting {:?} before {}", dur, description);
623    tokio::time::sleep(dur)
624}
625
626mod private {
627    use crate::client::ClientImpl;
628
629    pub trait Sealed {}
630    impl<C> Sealed for ClientImpl<C> {}
631}
632
633#[cfg(test)]
634mod tests {
635    use crate::ClientBuilder;
636    use hyper::http::HeaderValue;
637    use test_case::test_case;
638
639    #[test_case("user", "pass", "dXNlcjpwYXNz")]
640    #[test_case("user1", "password123", "dXNlcjE6cGFzc3dvcmQxMjM=")]
641    #[test_case("user2", "", "dXNlcjI6")]
642    #[test_case("user@name", "pass#word!", "dXNlckBuYW1lOnBhc3Mjd29yZCE=")]
643    #[test_case("user3", "my pass", "dXNlcjM6bXkgcGFzcw==")]
644    #[test_case(
645        "weird@-/:stuff",
646        "goes@-/:here",
647        "d2VpcmRALS86c3R1ZmY6Z29lc0AtLzpoZXJl"
648    )]
649    fn basic_auth_generates_correct_headers(username: &str, password: &str, expected: &str) {
650        let builder = ClientBuilder::for_url("http://example.com")
651            .expect("failed to build client")
652            .basic_auth(username, password)
653            .expect("failed to add authentication");
654
655        let actual = builder.headers.get("Authorization");
656        let expected = HeaderValue::from_str(format!("Basic {}", expected).as_str())
657            .expect("unable to create expected header");
658
659        assert_eq!(Some(&expected), actual);
660    }
661}