1use crate::{
29 body::{Bytes, HttpBody},
30 BoxError,
31};
32use axum_core::{
33 body::Body,
34 response::{IntoResponse, Response},
35};
36use bytes::{BufMut, BytesMut};
37use futures_util::{
38 ready,
39 stream::{Stream, TryStream},
40};
41use http_body::Frame;
42use pin_project_lite::pin_project;
43use std::{
44 fmt,
45 future::Future,
46 pin::Pin,
47 task::{Context, Poll},
48 time::Duration,
49};
50use sync_wrapper::SyncWrapper;
51use tokio::time::Sleep;
52
53#[derive(Clone)]
55#[must_use]
56pub struct Sse<S> {
57 stream: S,
58 keep_alive: Option<KeepAlive>,
59}
60
61impl<S> Sse<S> {
62 pub fn new(stream: S) -> Self
67 where
68 S: TryStream<Ok = Event> + Send + 'static,
69 S::Error: Into<BoxError>,
70 {
71 Sse {
72 stream,
73 keep_alive: None,
74 }
75 }
76
77 pub fn keep_alive(mut self, keep_alive: KeepAlive) -> Self {
81 self.keep_alive = Some(keep_alive);
82 self
83 }
84}
85
86impl<S> fmt::Debug for Sse<S> {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 f.debug_struct("Sse")
89 .field("stream", &format_args!("{}", std::any::type_name::<S>()))
90 .field("keep_alive", &self.keep_alive)
91 .finish()
92 }
93}
94
95impl<S, E> IntoResponse for Sse<S>
96where
97 S: Stream<Item = Result<Event, E>> + Send + 'static,
98 E: Into<BoxError>,
99{
100 fn into_response(self) -> Response {
101 (
102 [
103 (http::header::CONTENT_TYPE, mime::TEXT_EVENT_STREAM.as_ref()),
104 (http::header::CACHE_CONTROL, "no-cache"),
105 ],
106 Body::new(SseBody {
107 event_stream: SyncWrapper::new(self.stream),
108 keep_alive: self.keep_alive.map(KeepAliveStream::new),
109 }),
110 )
111 .into_response()
112 }
113}
114
115pin_project! {
116 struct SseBody<S> {
117 #[pin]
118 event_stream: SyncWrapper<S>,
119 #[pin]
120 keep_alive: Option<KeepAliveStream>,
121 }
122}
123
124impl<S, E> HttpBody for SseBody<S>
125where
126 S: Stream<Item = Result<Event, E>>,
127{
128 type Data = Bytes;
129 type Error = E;
130
131 fn poll_frame(
132 self: Pin<&mut Self>,
133 cx: &mut Context<'_>,
134 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
135 let this = self.project();
136
137 match this.event_stream.get_pin_mut().poll_next(cx) {
138 Poll::Pending => {
139 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
140 keep_alive.poll_event(cx).map(|e| Some(Ok(Frame::data(e))))
141 } else {
142 Poll::Pending
143 }
144 }
145 Poll::Ready(Some(Ok(event))) => {
146 if let Some(keep_alive) = this.keep_alive.as_pin_mut() {
147 keep_alive.reset();
148 }
149 Poll::Ready(Some(Ok(Frame::data(event.finalize()))))
150 }
151 Poll::Ready(Some(Err(error))) => Poll::Ready(Some(Err(error))),
152 Poll::Ready(None) => Poll::Ready(None),
153 }
154 }
155}
156
157#[derive(Debug, Default, Clone)]
159#[must_use]
160pub struct Event {
161 buffer: BytesMut,
162 flags: EventFlags,
163}
164
165impl Event {
166 pub fn data<T>(mut self, data: T) -> Event
181 where
182 T: AsRef<str>,
183 {
184 if self.flags.contains(EventFlags::HAS_DATA) {
185 panic!("Called `EventBuilder::data` multiple times");
186 }
187
188 for line in memchr_split(b'\n', data.as_ref().as_bytes()) {
189 self.field("data", line);
190 }
191
192 self.flags.insert(EventFlags::HAS_DATA);
193
194 self
195 }
196
197 #[cfg(feature = "json")]
207 pub fn json_data<T>(mut self, data: T) -> Result<Event, axum_core::Error>
208 where
209 T: serde::Serialize,
210 {
211 if self.flags.contains(EventFlags::HAS_DATA) {
212 panic!("Called `EventBuilder::json_data` multiple times");
213 }
214
215 self.buffer.extend_from_slice(b"data: ");
216 serde_json::to_writer((&mut self.buffer).writer(), &data).map_err(axum_core::Error::new)?;
217 self.buffer.put_u8(b'\n');
218
219 self.flags.insert(EventFlags::HAS_DATA);
220
221 Ok(self)
222 }
223
224 pub fn comment<T>(mut self, comment: T) -> Event
235 where
236 T: AsRef<str>,
237 {
238 self.field("", comment.as_ref());
239 self
240 }
241
242 pub fn event<T>(mut self, event: T) -> Event
257 where
258 T: AsRef<str>,
259 {
260 if self.flags.contains(EventFlags::HAS_EVENT) {
261 panic!("Called `EventBuilder::event` multiple times");
262 }
263 self.flags.insert(EventFlags::HAS_EVENT);
264
265 self.field("event", event.as_ref());
266
267 self
268 }
269
270 pub fn retry(mut self, duration: Duration) -> Event {
280 if self.flags.contains(EventFlags::HAS_RETRY) {
281 panic!("Called `EventBuilder::retry` multiple times");
282 }
283 self.flags.insert(EventFlags::HAS_RETRY);
284
285 self.buffer.extend_from_slice(b"retry:");
286
287 let secs = duration.as_secs();
288 let millis = duration.subsec_millis();
289
290 if secs > 0 {
291 self.buffer
293 .extend_from_slice(itoa::Buffer::new().format(secs).as_bytes());
294
295 if millis < 10 {
297 self.buffer.extend_from_slice(b"00");
298 } else if millis < 100 {
299 self.buffer.extend_from_slice(b"0");
300 }
301 }
302
303 self.buffer
305 .extend_from_slice(itoa::Buffer::new().format(millis).as_bytes());
306
307 self.buffer.put_u8(b'\n');
308
309 self
310 }
311
312 pub fn id<T>(mut self, id: T) -> Event
325 where
326 T: AsRef<str>,
327 {
328 if self.flags.contains(EventFlags::HAS_ID) {
329 panic!("Called `EventBuilder::id` multiple times");
330 }
331 self.flags.insert(EventFlags::HAS_ID);
332
333 let id = id.as_ref().as_bytes();
334 assert_eq!(
335 memchr::memchr(b'\0', id),
336 None,
337 "Event ID cannot contain null characters",
338 );
339
340 self.field("id", id);
341 self
342 }
343
344 fn field(&mut self, name: &str, value: impl AsRef<[u8]>) {
345 let value = value.as_ref();
346 assert_eq!(
347 memchr::memchr2(b'\r', b'\n', value),
348 None,
349 "SSE field value cannot contain newlines or carriage returns",
350 );
351 self.buffer.extend_from_slice(name.as_bytes());
352 self.buffer.put_u8(b':');
353 self.buffer.put_u8(b' ');
354 self.buffer.extend_from_slice(value);
355 self.buffer.put_u8(b'\n');
356 }
357
358 fn finalize(mut self) -> Bytes {
359 self.buffer.put_u8(b'\n');
360 self.buffer.freeze()
361 }
362}
363
364#[derive(Default, Debug, Copy, Clone, PartialEq)]
365struct EventFlags(u8);
366
367impl EventFlags {
368 const HAS_DATA: Self = Self::from_bits(0b0001);
369 const HAS_EVENT: Self = Self::from_bits(0b0010);
370 const HAS_RETRY: Self = Self::from_bits(0b0100);
371 const HAS_ID: Self = Self::from_bits(0b1000);
372
373 const fn bits(&self) -> u8 {
374 self.0
375 }
376
377 const fn from_bits(bits: u8) -> Self {
378 Self(bits)
379 }
380
381 const fn contains(&self, other: Self) -> bool {
382 self.bits() & other.bits() == other.bits()
383 }
384
385 fn insert(&mut self, other: Self) {
386 *self = Self::from_bits(self.bits() | other.bits());
387 }
388}
389
390#[derive(Debug, Clone)]
393#[must_use]
394pub struct KeepAlive {
395 event: Bytes,
396 max_interval: Duration,
397}
398
399impl KeepAlive {
400 pub fn new() -> Self {
402 Self {
403 event: Bytes::from_static(b":\n\n"),
404 max_interval: Duration::from_secs(15),
405 }
406 }
407
408 pub fn interval(mut self, time: Duration) -> Self {
412 self.max_interval = time;
413 self
414 }
415
416 pub fn text<I>(self, text: I) -> Self
425 where
426 I: AsRef<str>,
427 {
428 self.event(Event::default().comment(text))
429 }
430
431 pub fn event(mut self, event: Event) -> Self {
440 self.event = event.finalize();
441 self
442 }
443}
444
445impl Default for KeepAlive {
446 fn default() -> Self {
447 Self::new()
448 }
449}
450
451pin_project! {
452 #[derive(Debug)]
453 struct KeepAliveStream {
454 keep_alive: KeepAlive,
455 #[pin]
456 alive_timer: Sleep,
457 }
458}
459
460impl KeepAliveStream {
461 fn new(keep_alive: KeepAlive) -> Self {
462 Self {
463 alive_timer: tokio::time::sleep(keep_alive.max_interval),
464 keep_alive,
465 }
466 }
467
468 fn reset(self: Pin<&mut Self>) {
469 let this = self.project();
470 this.alive_timer
471 .reset(tokio::time::Instant::now() + this.keep_alive.max_interval);
472 }
473
474 fn poll_event(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Bytes> {
475 let this = self.as_mut().project();
476
477 ready!(this.alive_timer.poll(cx));
478
479 let event = this.keep_alive.event.clone();
480
481 self.reset();
482
483 Poll::Ready(event)
484 }
485}
486
487fn memchr_split(needle: u8, haystack: &[u8]) -> MemchrSplit<'_> {
488 MemchrSplit {
489 needle,
490 haystack: Some(haystack),
491 }
492}
493
494struct MemchrSplit<'a> {
495 needle: u8,
496 haystack: Option<&'a [u8]>,
497}
498
499impl<'a> Iterator for MemchrSplit<'a> {
500 type Item = &'a [u8];
501 fn next(&mut self) -> Option<Self::Item> {
502 let haystack = self.haystack?;
503 if let Some(pos) = memchr::memchr(self.needle, haystack) {
504 let (front, back) = haystack.split_at(pos);
505 self.haystack = Some(&back[1..]);
506 Some(front)
507 } else {
508 self.haystack.take()
509 }
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516 use crate::{routing::get, test_helpers::*, Router};
517 use futures_util::stream;
518 use std::{collections::HashMap, convert::Infallible};
519 use tokio_stream::StreamExt as _;
520
521 #[test]
522 fn leading_space_is_not_stripped() {
523 let no_leading_space = Event::default().data("\tfoobar");
524 assert_eq!(&*no_leading_space.finalize(), b"data: \tfoobar\n\n");
525
526 let leading_space = Event::default().data(" foobar");
527 assert_eq!(&*leading_space.finalize(), b"data: foobar\n\n");
528 }
529
530 #[crate::test]
531 async fn basic() {
532 let app = Router::new().route(
533 "/",
534 get(|| async {
535 let stream = stream::iter(vec![
536 Event::default().data("one").comment("this is a comment"),
537 Event::default()
538 .json_data(serde_json::json!({ "foo": "bar" }))
539 .unwrap(),
540 Event::default()
541 .event("three")
542 .retry(Duration::from_secs(30))
543 .id("unique-id"),
544 ])
545 .map(Ok::<_, Infallible>);
546 Sse::new(stream)
547 }),
548 );
549
550 let client = TestClient::new(app);
551 let mut stream = client.get("/").await;
552
553 assert_eq!(stream.headers()["content-type"], "text/event-stream");
554 assert_eq!(stream.headers()["cache-control"], "no-cache");
555
556 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
557 assert_eq!(event_fields.get("data").unwrap(), "one");
558 assert_eq!(event_fields.get("comment").unwrap(), "this is a comment");
559
560 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
561 assert_eq!(event_fields.get("data").unwrap(), "{\"foo\":\"bar\"}");
562 assert!(!event_fields.contains_key("comment"));
563
564 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
565 assert_eq!(event_fields.get("event").unwrap(), "three");
566 assert_eq!(event_fields.get("retry").unwrap(), "30000");
567 assert_eq!(event_fields.get("id").unwrap(), "unique-id");
568 assert!(!event_fields.contains_key("comment"));
569
570 assert!(stream.chunk_text().await.is_none());
571 }
572
573 #[tokio::test(start_paused = true)]
574 async fn keep_alive() {
575 const DELAY: Duration = Duration::from_secs(5);
576
577 let app = Router::new().route(
578 "/",
579 get(|| async {
580 let stream = stream::repeat_with(|| Event::default().data("msg"))
581 .map(Ok::<_, Infallible>)
582 .throttle(DELAY);
583
584 Sse::new(stream).keep_alive(
585 KeepAlive::new()
586 .interval(Duration::from_secs(1))
587 .text("keep-alive-text"),
588 )
589 }),
590 );
591
592 let client = TestClient::new(app);
593 let mut stream = client.get("/").await;
594
595 for _ in 0..5 {
596 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
598 assert_eq!(event_fields.get("data").unwrap(), "msg");
599
600 for _ in 0..4 {
602 tokio::time::sleep(Duration::from_secs(1)).await;
603 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
604 assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
605 }
606 }
607 }
608
609 #[tokio::test(start_paused = true)]
610 async fn keep_alive_ends_when_the_stream_ends() {
611 const DELAY: Duration = Duration::from_secs(5);
612
613 let app = Router::new().route(
614 "/",
615 get(|| async {
616 let stream = stream::repeat_with(|| Event::default().data("msg"))
617 .map(Ok::<_, Infallible>)
618 .throttle(DELAY)
619 .take(2);
620
621 Sse::new(stream).keep_alive(
622 KeepAlive::new()
623 .interval(Duration::from_secs(1))
624 .text("keep-alive-text"),
625 )
626 }),
627 );
628
629 let client = TestClient::new(app);
630 let mut stream = client.get("/").await;
631
632 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
634 assert_eq!(event_fields.get("data").unwrap(), "msg");
635
636 for _ in 0..4 {
638 tokio::time::sleep(Duration::from_secs(1)).await;
639 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
640 assert_eq!(event_fields.get("comment").unwrap(), "keep-alive-text");
641 }
642
643 let event_fields = parse_event(&stream.chunk_text().await.unwrap());
645 assert_eq!(event_fields.get("data").unwrap(), "msg");
646
647 assert!(stream.chunk_text().await.is_none());
649 }
650
651 fn parse_event(payload: &str) -> HashMap<String, String> {
652 let mut fields = HashMap::new();
653
654 let mut lines = payload.lines().peekable();
655 while let Some(line) = lines.next() {
656 if line.is_empty() {
657 assert!(lines.next().is_none());
658 break;
659 }
660
661 let (mut key, value) = line.split_once(':').unwrap();
662 let value = value.trim();
663 if key.is_empty() {
664 key = "comment";
665 }
666 fields.insert(key.to_owned(), value.to_owned());
667 }
668
669 fields
670 }
671
672 #[test]
673 fn memchr_splitting() {
674 assert_eq!(
675 memchr_split(2, &[]).collect::<Vec<_>>(),
676 [&[]] as [&[u8]; 1]
677 );
678 assert_eq!(
679 memchr_split(2, &[2]).collect::<Vec<_>>(),
680 [&[], &[]] as [&[u8]; 2]
681 );
682 assert_eq!(
683 memchr_split(2, &[1]).collect::<Vec<_>>(),
684 [&[1]] as [&[u8]; 1]
685 );
686 assert_eq!(
687 memchr_split(2, &[1, 2]).collect::<Vec<_>>(),
688 [&[1], &[]] as [&[u8]; 2]
689 );
690 assert_eq!(
691 memchr_split(2, &[2, 1]).collect::<Vec<_>>(),
692 [&[], &[1]] as [&[u8]; 2]
693 );
694 assert_eq!(
695 memchr_split(2, &[1, 2, 2, 1]).collect::<Vec<_>>(),
696 [&[1], &[], &[1]] as [&[u8]; 3]
697 );
698 }
699}