axum/response/
sse.rs

1//! Server-Sent Events (SSE) responses.
2//!
3//! # Example
4//!
5//! ```
6//! use axum::{
7//!     Router,
8//!     routing::get,
9//!     response::sse::{Event, KeepAlive, Sse},
10//! };
11//! use std::{time::Duration, convert::Infallible};
12//! use tokio_stream::StreamExt as _ ;
13//! use futures_util::stream::{self, Stream};
14//!
15//! let app = Router::new().route("/sse", get(sse_handler));
16//!
17//! async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
18//!     // A `Stream` that repeats an event every second
19//!     let stream = stream::repeat_with(|| Event::default().data("hi!"))
20//!         .map(Ok)
21//!         .throttle(Duration::from_secs(1));
22//!
23//!     Sse::new(stream).keep_alive(KeepAlive::default())
24//! }
25//! # let _: Router = app;
26//! ```
27
28use crate::{
29    body::{Bytes, HttpBody},
30    BoxError,
31};
32use axum_core::{
33    body::Body,
34    response::{IntoResponse, Response},
35};
36use bytes::{BufMut, BytesMut};
37use futures_util::{
38    ready,
39    stream::{Stream, TryStream},
40};
41use http_body::Frame;
42use pin_project_lite::pin_project;
43use std::{
44    fmt,
45    future::Future,
46    pin::Pin,
47    task::{Context, Poll},
48    time::Duration,
49};
50use sync_wrapper::SyncWrapper;
51use tokio::time::Sleep;
52
53/// An SSE response
54#[derive(Clone)]
55#[must_use]
56pub struct Sse<S> {
57    stream: S,
58    keep_alive: Option<KeepAlive>,
59}
60
61impl<S> Sse<S> {
62    /// Create a new [`Sse`] response that will respond with the given stream of
63    /// [`Event`]s.
64    ///
65    /// See the [module docs](self) for more details.
66    pub fn new(stream: S) -> Self
67    where
68        S: TryStream<Ok = Event> + Send + 'static,
69        S::Error: Into<BoxError>,
70    {
71        Sse {
72            stream,
73            keep_alive: None,
74        }
75    }
76
77    /// Configure the interval between keep-alive messages.
78    ///
79    /// Defaults to no keep-alive messages.
80    pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
81        self.keep_alive = Some(keep_alive);
82        self
83    }
84}
85
86impl<S> fmt::Debug for Sse<S> {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        f.debug_struct("Sse")
89            .field("stream", &format_args!("{}", std::any::type_name::<S>()))
90            .field("keep_alive", &self.keep_alive)
91            .finish()
92    }
93}
94
95impl<S, E> IntoResponse for Sse<S>
96where
97    S: Stream<Item = Result<Event, E>> + Send + 'static,
98    E: Into<BoxError>,
99{
100    fn into_response(self) -> Response {
101        (
102            [
103                (http::header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.as_ref()),
104                (http::header::CACHE_CONTROL, "no-cache"),
105            ],
106            Body::new(SseBody {
107                event_stream: SyncWrapper::new(self.stream),
108                keep_alive: self.keep_alive.map(KeepAliveStream::new),
109            }),
110        )
111            .into_response()
112    }
113}
114
115pin_project! {
116    struct SseBody<S> {
117        #[pin]
118        event_stream: SyncWrapper<S>,
119        #[pin]
120        keep_alive: Option<KeepAliveStream>,
121    }
122}
123
124impl<S, E> HttpBody for SseBody<S>
125where
126    S: Stream<Item = Result<Event, E>>,
127{
128    type Data = Bytes;
129    type Error = E;
130
131    fn poll_frame(
132        self: Pin<&mut Self>,
133        cx: &mut Context<'_>,
134    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
135        let this = self.project();
136
137        match this.event_stream.get_pin_mut().poll_next(cx) {
138            Poll::Pending => {
139                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
140                    keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
141                } else {
142                    Poll::Pending
143                }
144            }
145            Poll::Ready(Some(Ok(event))) => {
146                if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
147                    keep_alive.reset();
148                }
149                Poll::Ready(Some(Ok(Frame::data(event.finalize()))))
150            }
151            Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
152            Poll::Ready(None) => Poll::Ready(None),
153        }
154    }
155}
156
157/// Server-sent event
158#[derive(Debug, Default, Clone)]
159#[must_use]
160pub struct Event {
161    buffer: BytesMut,
162    flags: EventFlags,
163}
164
165impl Event {
166    /// Set the event's data data field(s) (`data: <content>`)
167    ///
168    /// Newlines in `data` will automatically be broken across `data: ` fields.
169    ///
170    /// This corresponds to [`MessageEvent`'s data field].
171    ///
172    /// Note that events with an empty data field will be ignored by the browser.
173    ///
174    /// # Panics
175    ///
176    /// - Panics if `data` contains any carriage returns, as they cannot be transmitted over SSE.
177    /// - Panics if `data` or `json_data` have already been called.
178    ///
179    /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
180    pub fn data<T>(mut self, data: T) -> Event
181    where
182        T: AsRef<str>,
183    {
184        if self.flags.contains(EventFlags::HAS_DATA) {
185            panic!("Called `EventBuilder::data` multiple times");
186        }
187
188        for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
189            self.field("data", line);
190        }
191
192        self.flags.insert(EventFlags::HAS_DATA);
193
194        self
195    }
196
197    /// Set the event's data field to a value serialized as unformatted JSON (`data: <content>`).
198    ///
199    /// This corresponds to [`MessageEvent`'s data field].
200    ///
201    /// # Panics
202    ///
203    /// Panics if `data` or `json_data` have already been called.
204    ///
205    /// [`MessageEvent`'s data field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/data
206    #[cfg(feature = "json")]
207    pub fn json_data<T>(mut self, data: T) -> Result<Event, axum_core::Error>
208    where
209        T: serde::Serialize,
210    {
211        if self.flags.contains(EventFlags::HAS_DATA) {
212            panic!("Called `EventBuilder::json_data` multiple times");
213        }
214
215        self.buffer.extend_from_slice(b"data: ");
216        serde_json::to_writer((&mut self.buffer).writer(), &data).map_err(axum_core::Error::new)?;
217        self.buffer.put_u8(b'\n');
218
219        self.flags.insert(EventFlags::HAS_DATA);
220
221        Ok(self)
222    }
223
224    /// Set the event's comment field (`:<comment-text>`).
225    ///
226    /// This field will be ignored by most SSE clients.
227    ///
228    /// Unlike other functions, this function can be called multiple times to add many comments.
229    ///
230    /// # Panics
231    ///
232    /// Panics if `comment` contains any newlines or carriage returns, as they are not allowed in
233    /// comments.
234    pub fn comment<T>(mut self, comment: T) -> Event
235    where
236        T: AsRef<str>,
237    {
238        self.field("", comment.as_ref());
239        self
240    }
241
242    /// Set the event's name field (`event:<event-name>`).
243    ///
244    /// This corresponds to the `type` parameter given when calling `addEventListener` on an
245    /// [`EventSource`]. For example, `.event("update")` should correspond to
246    /// `.addEventListener("update", ...)`. If no event type is given, browsers will fire a
247    /// [`message` event] instead.
248    ///
249    /// [`EventSource`]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource
250    /// [`message` event]: https://developer.mozilla.org/en-US/docs/Web/API/EventSource/message_event
251    ///
252    /// # Panics
253    ///
254    /// - Panics if `event` contains any newlines or carriage returns.
255    /// - Panics if this function has already been called on this event.
256    pub fn event<T>(mut self, event: T) -> Event
257    where
258        T: AsRef<str>,
259    {
260        if self.flags.contains(EventFlags::HAS_EVENT) {
261            panic!("Called `EventBuilder::event` multiple times");
262        }
263        self.flags.insert(EventFlags::HAS_EVENT);
264
265        self.field("event", event.as_ref());
266
267        self
268    }
269
270    /// Set the event's retry timeout field (`retry:<timeout>`).
271    ///
272    /// This sets how long clients will wait before reconnecting if they are disconnected from the
273    /// SSE endpoint. Note that this is just a hint: clients are free to wait for longer if they
274    /// wish, such as if they implement exponential backoff.
275    ///
276    /// # Panics
277    ///
278    /// Panics if this function has already been called on this event.
279    pub fn retry(mut self, duration: Duration) -> Event {
280        if self.flags.contains(EventFlags::HAS_RETRY) {
281            panic!("Called `EventBuilder::retry` multiple times");
282        }
283        self.flags.insert(EventFlags::HAS_RETRY);
284
285        self.buffer.extend_from_slice(b"retry:");
286
287        let secs = duration.as_secs();
288        let millis = duration.subsec_millis();
289
290        if secs > 0 {
291            // format seconds
292            self.buffer
293                .extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
294
295            // pad milliseconds
296            if millis < 10 {
297                self.buffer.extend_from_slice(b"00");
298            } else if millis < 100 {
299                self.buffer.extend_from_slice(b"0");
300            }
301        }
302
303        // format milliseconds
304        self.buffer
305            .extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
306
307        self.buffer.put_u8(b'\n');
308
309        self
310    }
311
312    /// Set the event's identifier field (`id:<identifier>`).
313    ///
314    /// This corresponds to [`MessageEvent`'s `lastEventId` field]. If no ID is in the event itself,
315    /// the browser will set that field to the last known message ID, starting with the empty
316    /// string.
317    ///
318    /// [`MessageEvent`'s `lastEventId` field]: https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent/lastEventId
319    ///
320    /// # Panics
321    ///
322    /// - Panics if `id` contains any newlines, carriage returns or null characters.
323    /// - Panics if this function has already been called on this event.
324    pub fn id<T>(mut self, id: T) -> Event
325    where
326        T: AsRef<str>,
327    {
328        if self.flags.contains(EventFlags::HAS_ID) {
329            panic!("Called `EventBuilder::id` multiple times");
330        }
331        self.flags.insert(EventFlags::HAS_ID);
332
333        let id = id.as_ref().as_bytes();
334        assert_eq!(
335            memchr::memchr(b'\0', id),
336            None,
337            "Event ID cannot contain null characters",
338        );
339
340        self.field("id", id);
341        self
342    }
343
344    fn field(&mut self, name: &str, value: impl AsRef<[u8]>) {
345        let value = value.as_ref();
346        assert_eq!(
347            memchr::memchr2(b'\r', b'\n', value),
348            None,
349            "SSE field value cannot contain newlines or carriage returns",
350        );
351        self.buffer.extend_from_slice(name.as_bytes());
352        self.buffer.put_u8(b':');
353        self.buffer.put_u8(b' ');
354        self.buffer.extend_from_slice(value);
355        self.buffer.put_u8(b'\n');
356    }
357
358    fn finalize(mut self) -> Bytes {
359        self.buffer.put_u8(b'\n');
360        self.buffer.freeze()
361    }
362}
363
364#[derive(Default, Debug, Copy, Clone, PartialEq)]
365struct EventFlags(u8);
366
367impl EventFlags {
368    const HAS_DATA: Self = Self::from_bits(0b0001);
369    const HAS_EVENT: Self = Self::from_bits(0b0010);
370    const HAS_RETRY: Self = Self::from_bits(0b0100);
371    const HAS_ID: Self = Self::from_bits(0b1000);
372
373    const fn bits(&self) -> u8 {
374        self.0
375    }
376
377    const fn from_bits(bits: u8) -> Self {
378        Self(bits)
379    }
380
381    const fn contains(&self, other: Self) -> bool {
382        self.bits() & other.bits() == other.bits()
383    }
384
385    fn insert(&mut self, other: Self) {
386        *self = Self::from_bits(self.bits() | other.bits());
387    }
388}
389
390/// Configure the interval between keep-alive messages, the content
391/// of each message, and the associated stream.
392#[derive(Debug, Clone)]
393#[must_use]
394pub struct KeepAlive {
395    event: Bytes,
396    max_interval: Duration,
397}
398
399impl KeepAlive {
400    /// Create a new `KeepAlive`.
401    pub fn new() -> Self {
402        Self {
403            event: Bytes::from_static(b":\n\n"),
404            max_interval: Duration::from_secs(15),
405        }
406    }
407
408    /// Customize the interval between keep-alive messages.
409    ///
410    /// Default is 15 seconds.
411    pub fn interval(mut self, time: Duration) -> Self {
412        self.max_interval = time;
413        self
414    }
415
416    /// Customize the text of the keep-alive message.
417    ///
418    /// Default is an empty comment.
419    ///
420    /// # Panics
421    ///
422    /// Panics if `text` contains any newline or carriage returns, as they are not allowed in SSE
423    /// comments.
424    pub fn text<I>(self, text: I) -> Self
425    where
426        I: AsRef<str>,
427    {
428        self.event(Event::default().comment(text))
429    }
430
431    /// Customize the event of the keep-alive message.
432    ///
433    /// Default is an empty comment.
434    ///
435    /// # Panics
436    ///
437    /// Panics if `event` contains any newline or carriage returns, as they are not allowed in SSE
438    /// comments.
439    pub fn event(mut self, event: Event) -> Self {
440        self.event = event.finalize();
441        self
442    }
443}
444
445impl Default for KeepAlive {
446    fn default() -> Self {
447        Self::new()
448    }
449}
450
451pin_project! {
452    #[derive(Debug)]
453    struct KeepAliveStream {
454        keep_alive: KeepAlive,
455        #[pin]
456        alive_timer: Sleep,
457    }
458}
459
460impl KeepAliveStream {
461    fn new(keep_alive: KeepAlive) -> Self {
462        Self {
463            alive_timer: tokio::time::sleep(keep_alive.max_interval),
464            keep_alive,
465        }
466    }
467
468    fn reset(self: Pin<&mut Self>) {
469        let this = self.project();
470        this.alive_timer
471            .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
472    }
473
474    fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
475        let this = self.as_mut().project();
476
477        ready!(this.alive_timer.poll(cx));
478
479        let event = this.keep_alive.event.clone();
480
481        self.reset();
482
483        Poll::Ready(event)
484    }
485}
486
487fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> {
488    MemchrSplit {
489        needle,
490        haystack: Some(haystack),
491    }
492}
493
494struct MemchrSplit<'a> {
495    needle: u8,
496    haystack: Option<&'a [u8]>,
497}
498
499impl<'a> Iterator for MemchrSplit<'a> {
500    type Item = &'a [u8];
501    fn next(&mut self) -> Option<Self::Item> {
502        let haystack = self.haystack?;
503        if let Some(pos) = memchr::memchr(self.needle, haystack) {
504            let (front, back) = haystack.split_at(pos);
505            self.haystack = Some(&back[1..]);
506            Some(front)
507        } else {
508            self.haystack.take()
509        }
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516    use crate::{routing::get, test_helpers::*, Router};
517    use futures_util::stream;
518    use std::{collections::HashMap, convert::Infallible};
519    use tokio_stream::StreamExt as _;
520
521    #[test]
522    fn leading_space_is_not_stripped() {
523        let no_leading_space = Event::default().data("\tfoobar");
524        assert_eq!(&*no_leading_space.finalize(), b"data: \tfoobar\n\n");
525
526        let leading_space = Event::default().data(" foobar");
527        assert_eq!(&*leading_space.finalize(), b"data:  foobar\n\n");
528    }
529
530    #[crate::test]
531    async fn basic() {
532        let app = Router::new().route(
533            "/",
534            get(|| async {
535                let stream = stream::iter(vec![
536                    Event::default().data("one").comment("this is a comment"),
537                    Event::default()
538                        .json_data(serde_json::json!({ "foo": "bar" }))
539                        .unwrap(),
540                    Event::default()
541                        .event("three")
542                        .retry(Duration::from_secs(30))
543                        .id("unique-id"),
544                ])
545                .map(Ok::<_, Infallible>);
546                Sse::new(stream)
547            }),
548        );
549
550        let client = TestClient::new(app);
551        let mut stream = client.get("/").await;
552
553        assert_eq!(stream.headers()["content-type"], "text/event-stream");
554        assert_eq!(stream.headers()["cache-control"], "no-cache");
555
556        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
557        assert_eq!(event_fields.get("data").unwrap(), "one");
558        assert_eq!(event_fields.get("comment").unwrap(), "this is a comment");
559
560        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
561        assert_eq!(event_fields.get("data").unwrap(), "{\"foo\":\"bar\"}");
562        assert!(!event_fields.contains_key("comment"));
563
564        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
565        assert_eq!(event_fields.get("event").unwrap(), "three");
566        assert_eq!(event_fields.get("retry").unwrap(), "30000");
567        assert_eq!(event_fields.get("id").unwrap(), "unique-id");
568        assert!(!event_fields.contains_key("comment"));
569
570        assert!(stream.chunk_text().await.is_none());
571    }
572
573    #[tokio::test(start_paused = true)]
574    async fn keep_alive() {
575        const DELAY: Duration = Duration::from_secs(5);
576
577        let app = Router::new().route(
578            "/",
579            get(|| async {
580                let stream = stream::repeat_with(|| Event::default().data("msg"))
581                    .map(Ok::<_, Infallible>)
582                    .throttle(DELAY);
583
584                Sse::new(stream).keep_alive(
585                    KeepAlive::new()
586                        .interval(Duration::from_secs(1))
587                        .text("keep-alive-text"),
588                )
589            }),
590        );
591
592        let client = TestClient::new(app);
593        let mut stream = client.get("/").await;
594
595        for _ in 0..5 {
596            // first message should be an event
597            let event_fields = parse_event(&stream.chunk_text().await.unwrap());
598            assert_eq!(event_fields.get("data").unwrap(), "msg");
599
600            // then 4 seconds of keep-alive messages
601            for _ in 0..4 {
602                tokio::time::sleep(Duration::from_secs(1)).await;
603                let event_fields = parse_event(&stream.chunk_text().await.unwrap());
604                assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
605            }
606        }
607    }
608
609    #[tokio::test(start_paused = true)]
610    async fn keep_alive_ends_when_the_stream_ends() {
611        const DELAY: Duration = Duration::from_secs(5);
612
613        let app = Router::new().route(
614            "/",
615            get(|| async {
616                let stream = stream::repeat_with(|| Event::default().data("msg"))
617                    .map(Ok::<_, Infallible>)
618                    .throttle(DELAY)
619                    .take(2);
620
621                Sse::new(stream).keep_alive(
622                    KeepAlive::new()
623                        .interval(Duration::from_secs(1))
624                        .text("keep-alive-text"),
625                )
626            }),
627        );
628
629        let client = TestClient::new(app);
630        let mut stream = client.get("/").await;
631
632        // first message should be an event
633        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
634        assert_eq!(event_fields.get("data").unwrap(), "msg");
635
636        // then 4 seconds of keep-alive messages
637        for _ in 0..4 {
638            tokio::time::sleep(Duration::from_secs(1)).await;
639            let event_fields = parse_event(&stream.chunk_text().await.unwrap());
640            assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
641        }
642
643        // then the last event
644        let event_fields = parse_event(&stream.chunk_text().await.unwrap());
645        assert_eq!(event_fields.get("data").unwrap(), "msg");
646
647        // then no more events or keep-alive messages
648        assert!(stream.chunk_text().await.is_none());
649    }
650
651    fn parse_event(payload: &str) -> HashMap<String, String> {
652        let mut fields = HashMap::new();
653
654        let mut lines = payload.lines().peekable();
655        while let Some(line) = lines.next() {
656            if line.is_empty() {
657                assert!(lines.next().is_none());
658                break;
659            }
660
661            let (mut key, value) = line.split_once(':').unwrap();
662            let value = value.trim();
663            if key.is_empty() {
664                key = "comment";
665            }
666            fields.insert(key.to_owned(), value.to_owned());
667        }
668
669        fields
670    }
671
672    #[test]
673    fn memchr_splitting() {
674        assert_eq!(
675            memchr_split(2, &[]).collect::<Vec<_>>(),
676            [&[]] as [&[u8]; 1]
677        );
678        assert_eq!(
679            memchr_split(2, &[2]).collect::<Vec<_>>(),
680            [&[], &[]] as [&[u8]; 2]
681        );
682        assert_eq!(
683            memchr_split(2, &[1]).collect::<Vec<_>>(),
684            [&[1]] as [&[u8]; 1]
685        );
686        assert_eq!(
687            memchr_split(2, &[1, 2]).collect::<Vec<_>>(),
688            [&[1], &[]] as [&[u8]; 2]
689        );
690        assert_eq!(
691            memchr_split(2, &[2, 1]).collect::<Vec<_>>(),
692            [&[], &[1]] as [&[u8]; 2]
693        );
694        assert_eq!(
695            memchr_split(2, &[1, 2, 2, 1]).collect::<Vec<_>>(),
696            [&[1], &[], &[1]] as [&[u8]; 3]
697        );
698    }
699}