axum/extract/
ws.rs

1//! Handle WebSocket connections.
2//!
3//! # Example
4//!
5//! ```
6//! use axum::{
7//!     extract::ws::{WebSocketUpgrade, WebSocket},
8//!     routing::get,
9//!     response::{IntoResponse, Response},
10//!     Router,
11//! };
12//!
13//! let app = Router::new().route("/ws", get(handler));
14//!
15//! async fn handler(ws: WebSocketUpgrade) -> Response {
16//!     ws.on_upgrade(handle_socket)
17//! }
18//!
19//! async fn handle_socket(mut socket: WebSocket) {
20//!     while let Some(msg) = socket.recv().await {
21//!         let msg = if let Ok(msg) = msg {
22//!             msg
23//!         } else {
24//!             // client disconnected
25//!             return;
26//!         };
27//!
28//!         if socket.send(msg).await.is_err() {
29//!             // client disconnected
30//!             return;
31//!         }
32//!     }
33//! }
34//! # let _: Router = app;
35//! ```
36//!
37//! # Passing data and/or state to an `on_upgrade` callback
38//!
39//! ```
40//! use axum::{
41//!     extract::{ws::{WebSocketUpgrade, WebSocket}, State},
42//!     response::Response,
43//!     routing::get,
44//!     Router,
45//! };
46//!
47//! #[derive(Clone)]
48//! struct AppState {
49//!     // ...
50//! }
51//!
52//! async fn handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
53//!     ws.on_upgrade(|socket| handle_socket(socket, state))
54//! }
55//!
56//! async fn handle_socket(socket: WebSocket, state: AppState) {
57//!     // ...
58//! }
59//!
60//! let app = Router::new()
61//!     .route("/ws", get(handler))
62//!     .with_state(AppState { /* ... */ });
63//! # let _: Router = app;
64//! ```
65//!
66//! # Read and write concurrently
67//!
68//! If you need to read and write concurrently from a [`WebSocket`] you can use
69//! [`StreamExt::split`]:
70//!
71//! ```rust,no_run
72//! use axum::{Error, extract::ws::{WebSocket, Message}};
73//! use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};
74//!
75//! async fn handle_socket(mut socket: WebSocket) {
76//!     let (mut sender, mut receiver) = socket.split();
77//!
78//!     tokio::spawn(write(sender));
79//!     tokio::spawn(read(receiver));
80//! }
81//!
82//! async fn read(receiver: SplitStream<WebSocket>) {
83//!     // ...
84//! }
85//!
86//! async fn write(sender: SplitSink<WebSocket, Message>) {
87//!     // ...
88//! }
89//! ```
90//!
91//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
92
93use self::rejection::*;
94use super::FromRequestParts;
95use crate::{body::Bytes, response::Response, Error};
96use async_trait::async_trait;
97use axum_core::body::Body;
98use futures_util::{
99    sink::{Sink, SinkExt},
100    stream::{Stream, StreamExt},
101};
102use http::{
103    header::{self, HeaderMap, HeaderName, HeaderValue},
104    request::Parts,
105    Method, StatusCode,
106};
107use hyper_util::rt::TokioIo;
108use sha1::{Digest, Sha1};
109use std::{
110    borrow::Cow,
111    future::Future,
112    pin::Pin,
113    task::{Context, Poll},
114};
115use tokio_tungstenite::{
116    tungstenite::{
117        self as ts,
118        protocol::{self, WebSocketConfig},
119    },
120    WebSocketStream,
121};
122
123/// Extractor for establishing WebSocket connections.
124///
125/// Note: This extractor requires the request method to be `GET` so it should
126/// always be used with [`get`](crate::routing::get). Requests with other methods will be
127/// rejected.
128///
129/// See the [module docs](self) for an example.
130#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
131pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
132    config: WebSocketConfig,
133    /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
134    protocol: Option<HeaderValue>,
135    sec_websocket_key: HeaderValue,
136    on_upgrade: hyper::upgrade::OnUpgrade,
137    on_failed_upgrade: F,
138    sec_websocket_protocol: Option<HeaderValue>,
139}
140
141impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        f.debug_struct("WebSocketUpgrade")
144            .field("config", &self.config)
145            .field("protocol", &self.protocol)
146            .field("sec_websocket_key", &self.sec_websocket_key)
147            .field("sec_websocket_protocol", &self.sec_websocket_protocol)
148            .finish_non_exhaustive()
149    }
150}
151
152impl<F> WebSocketUpgrade<F> {
153    /// The target minimum size of the write buffer to reach before writing the data
154    /// to the underlying stream.
155    ///
156    /// The default value is 128 KiB.
157    ///
158    /// If set to `0` each message will be eagerly written to the underlying stream.
159    /// It is often more optimal to allow them to buffer a little, hence the default value.
160    ///
161    /// Note: [`flush`](SinkExt::flush) will always fully write the buffer regardless.
162    pub fn write_buffer_size(mut self, size: usize) -> Self {
163        self.config.write_buffer_size = size;
164        self
165    }
166
167    /// The max size of the write buffer in bytes. Setting this can provide backpressure
168    /// in the case the write buffer is filling up due to write errors.
169    ///
170    /// The default value is unlimited.
171    ///
172    /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size)
173    /// when writes to the underlying stream are failing. So the **write buffer can not
174    /// fill up if you are not observing write errors even if not flushing**.
175    ///
176    /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size)
177    /// and probably a little more depending on error handling strategy.
178    pub fn max_write_buffer_size(mut self, max: usize) -> Self {
179        self.config.max_write_buffer_size = max;
180        self
181    }
182
183    /// Set the maximum message size (defaults to 64 megabytes)
184    pub fn max_message_size(mut self, max: usize) -> Self {
185        self.config.max_message_size = Some(max);
186        self
187    }
188
189    /// Set the maximum frame size (defaults to 16 megabytes)
190    pub fn max_frame_size(mut self, max: usize) -> Self {
191        self.config.max_frame_size = Some(max);
192        self
193    }
194
195    /// Allow server to accept unmasked frames (defaults to false)
196    pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
197        self.config.accept_unmasked_frames = accept;
198        self
199    }
200
201    /// Set the known protocols.
202    ///
203    /// If the protocol name specified by `Sec-WebSocket-Protocol` header
204    /// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
205    /// return the protocol name.
206    ///
207    /// The protocols should be listed in decreasing order of preference: if the client offers
208    /// multiple protocols that the server could support, the server will pick the first one in
209    /// this list.
210    ///
211    /// # Examples
212    ///
213    /// ```
214    /// use axum::{
215    ///     extract::ws::{WebSocketUpgrade, WebSocket},
216    ///     routing::get,
217    ///     response::{IntoResponse, Response},
218    ///     Router,
219    /// };
220    ///
221    /// let app = Router::new().route("/ws", get(handler));
222    ///
223    /// async fn handler(ws: WebSocketUpgrade) -> Response {
224    ///     ws.protocols(["graphql-ws", "graphql-transport-ws"])
225    ///         .on_upgrade(|socket| async {
226    ///             // ...
227    ///         })
228    /// }
229    /// # let _: Router = app;
230    /// ```
231    pub fn protocols<I>(mut self, protocols: I) -> Self
232    where
233        I: IntoIterator,
234        I::Item: Into<Cow<'static, str>>,
235    {
236        if let Some(req_protocols) = self
237            .sec_websocket_protocol
238            .as_ref()
239            .and_then(|p| p.to_str().ok())
240        {
241            self.protocol = protocols
242                .into_iter()
243                // FIXME: This will often allocate a new `String` and so is less efficient than it
244                // could be. But that can't be fixed without breaking changes to the public API.
245                .map(Into::into)
246                .find(|protocol| {
247                    req_protocols
248                        .split(',')
249                        .any(|req_protocol| req_protocol.trim() == protocol)
250                })
251                .map(|protocol| match protocol {
252                    Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
253                    Cow::Borrowed(s) => HeaderValue::from_static(s),
254                });
255        }
256
257        self
258    }
259
260    /// Provide a callback to call if upgrading the connection fails.
261    ///
262    /// The connection upgrade is performed in a background task. If that fails this callback
263    /// will be called.
264    ///
265    /// By default any errors will be silently ignored.
266    ///
267    /// # Example
268    ///
269    /// ```
270    /// use axum::{
271    ///     extract::{WebSocketUpgrade},
272    ///     response::Response,
273    /// };
274    ///
275    /// async fn handler(ws: WebSocketUpgrade) -> Response {
276    ///     ws.on_failed_upgrade(|error| {
277    ///         report_error(error);
278    ///     })
279    ///     .on_upgrade(|socket| async { /* ... */ })
280    /// }
281    /// #
282    /// # fn report_error(_: axum::Error) {}
283    /// ```
284    pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
285    where
286        C: OnFailedUpgrade,
287    {
288        WebSocketUpgrade {
289            config: self.config,
290            protocol: self.protocol,
291            sec_websocket_key: self.sec_websocket_key,
292            on_upgrade: self.on_upgrade,
293            on_failed_upgrade: callback,
294            sec_websocket_protocol: self.sec_websocket_protocol,
295        }
296    }
297
298    /// Finalize upgrading the connection and call the provided callback with
299    /// the stream.
300    #[must_use = "to set up the WebSocket connection, this response must be returned"]
301    pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
302    where
303        C: FnOnce(WebSocket) -> Fut + Send + 'static,
304        Fut: Future<Output = ()> + Send + 'static,
305        F: OnFailedUpgrade,
306    {
307        let on_upgrade = self.on_upgrade;
308        let config = self.config;
309        let on_failed_upgrade = self.on_failed_upgrade;
310
311        let protocol = self.protocol.clone();
312
313        tokio::spawn(async move {
314            let upgraded = match on_upgrade.await {
315                Ok(upgraded) => upgraded,
316                Err(err) => {
317                    on_failed_upgrade.call(Error::new(err));
318                    return;
319                }
320            };
321            let upgraded = TokioIo::new(upgraded);
322
323            let socket =
324                WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
325                    .await;
326            let socket = WebSocket {
327                inner: socket,
328                protocol,
329            };
330            callback(socket).await;
331        });
332
333        #[allow(clippy::declare_interior_mutable_const)]
334        const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
335        #[allow(clippy::declare_interior_mutable_const)]
336        const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
337
338        let mut builder = Response::builder()
339            .status(StatusCode::SWITCHING_PROTOCOLS)
340            .header(header::CONNECTION, UPGRADE)
341            .header(header::UPGRADE, WEBSOCKET)
342            .header(
343                header::SEC_WEBSOCKET_ACCEPT,
344                sign(self.sec_websocket_key.as_bytes()),
345            );
346
347        if let Some(protocol) = self.protocol {
348            builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
349        }
350
351        builder.body(Body::empty()).unwrap()
352    }
353}
354
355/// What to do when a connection upgrade fails.
356///
357/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
358pub trait OnFailedUpgrade: Send + 'static {
359    /// Call the callback.
360    fn call(self, error: Error);
361}
362
363impl<F> OnFailedUpgrade for F
364where
365    F: FnOnce(Error) + Send + 'static,
366{
367    fn call(self, error: Error) {
368        self(error)
369    }
370}
371
372/// The default `OnFailedUpgrade` used by `WebSocketUpgrade`.
373///
374/// It simply ignores the error.
375#[non_exhaustive]
376#[derive(Debug)]
377pub struct DefaultOnFailedUpgrade;
378
379impl OnFailedUpgrade for DefaultOnFailedUpgrade {
380    #[inline]
381    fn call(self, _error: Error) {}
382}
383
384#[async_trait]
385impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpgrade>
386where
387    S: Send + Sync,
388{
389    type Rejection = WebSocketUpgradeRejection;
390
391    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
392        if parts.method != Method::GET {
393            return Err(MethodNotGet.into());
394        }
395
396        if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
397            return Err(InvalidConnectionHeader.into());
398        }
399
400        if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
401            return Err(InvalidUpgradeHeader.into());
402        }
403
404        if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
405            return Err(InvalidWebSocketVersionHeader.into());
406        }
407
408        let sec_websocket_key = parts
409            .headers
410            .get(header::SEC_WEBSOCKET_KEY)
411            .ok_or(WebSocketKeyHeaderMissing)?
412            .clone();
413
414        let on_upgrade = parts
415            .extensions
416            .remove::<hyper::upgrade::OnUpgrade>()
417            .ok_or(ConnectionNotUpgradable)?;
418
419        let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
420
421        Ok(Self {
422            config: Default::default(),
423            protocol: None,
424            sec_websocket_key,
425            on_upgrade,
426            sec_websocket_protocol,
427            on_failed_upgrade: DefaultOnFailedUpgrade,
428        })
429    }
430}
431
432fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
433    if let Some(header) = headers.get(&key) {
434        header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
435    } else {
436        false
437    }
438}
439
440fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
441    let header = if let Some(header) = headers.get(&key) {
442        header
443    } else {
444        return false;
445    };
446
447    if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
448        header.to_ascii_lowercase().contains(value)
449    } else {
450        false
451    }
452}
453
454/// A stream of WebSocket messages.
455///
456/// See [the module level documentation](self) for more details.
457#[derive(Debug)]
458pub struct WebSocket {
459    inner: WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
460    protocol: Option<HeaderValue>,
461}
462
463impl WebSocket {
464    /// Receive another message.
465    ///
466    /// Returns `None` if the stream has closed.
467    pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
468        self.next().await
469    }
470
471    /// Send a message.
472    pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
473        self.inner
474            .send(msg.into_tungstenite())
475            .await
476            .map_err(Error::new)
477    }
478
479    /// Gracefully close this WebSocket.
480    pub async fn close(mut self) -> Result<(), Error> {
481        self.inner.close(None).await.map_err(Error::new)
482    }
483
484    /// Return the selected WebSocket subprotocol, if one has been chosen.
485    pub fn protocol(&self) -> Option<&HeaderValue> {
486        self.protocol.as_ref()
487    }
488}
489
490impl Stream for WebSocket {
491    type Item = Result<Message, Error>;
492
493    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
494        loop {
495            match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
496                Some(Ok(msg)) => {
497                    if let Some(msg) = Message::from_tungstenite(msg) {
498                        return Poll::Ready(Some(Ok(msg)));
499                    }
500                }
501                Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))),
502                None => return Poll::Ready(None),
503            }
504        }
505    }
506}
507
508impl Sink<Message> for WebSocket {
509    type Error = Error;
510
511    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
512        Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
513    }
514
515    fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
516        Pin::new(&mut self.inner)
517            .start_send(item.into_tungstenite())
518            .map_err(Error::new)
519    }
520
521    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
522        Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
523    }
524
525    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
526        Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
527    }
528}
529
530/// Status code used to indicate why an endpoint is closing the WebSocket connection.
531pub type CloseCode = u16;
532
533/// A struct representing the close command.
534#[derive(Debug, Clone, Eq, PartialEq)]
535pub struct CloseFrame<'t> {
536    /// The reason as a code.
537    pub code: CloseCode,
538    /// The reason as text string.
539    pub reason: Cow<'t, str>,
540}
541
542/// A WebSocket message.
543//
544// This code comes from https://github.com/snapview/tungstenite-rs/blob/master/src/protocol/message.rs and is under following license:
545// Copyright (c) 2017 Alexey Galakhov
546// Copyright (c) 2016 Jason Housley
547//
548// Permission is hereby granted, free of charge, to any person obtaining a copy
549// of this software and associated documentation files (the "Software"), to deal
550// in the Software without restriction, including without limitation the rights
551// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
552// copies of the Software, and to permit persons to whom the Software is
553// furnished to do so, subject to the following conditions:
554//
555// The above copyright notice and this permission notice shall be included in
556// all copies or substantial portions of the Software.
557//
558// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
559// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
560// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
561// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
562// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
563// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
564// THE SOFTWARE.
565#[derive(Debug, Eq, PartialEq, Clone)]
566pub enum Message {
567    /// A text WebSocket message
568    Text(String),
569    /// A binary WebSocket message
570    Binary(Vec<u8>),
571    /// A ping message with the specified payload
572    ///
573    /// The payload here must have a length less than 125 bytes.
574    ///
575    /// Ping messages will be automatically responded to by the server, so you do not have to worry
576    /// about dealing with them yourself.
577    Ping(Vec<u8>),
578    /// A pong message with the specified payload
579    ///
580    /// The payload here must have a length less than 125 bytes.
581    ///
582    /// Pong messages will be automatically sent to the client if a ping message is received, so
583    /// you do not have to worry about constructing them yourself unless you want to implement a
584    /// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
585    Pong(Vec<u8>),
586    /// A close message with the optional close frame.
587    Close(Option<CloseFrame<'static>>),
588}
589
590impl Message {
591    fn into_tungstenite(self) -> ts::Message {
592        match self {
593            Self::Text(text) => ts::Message::Text(text),
594            Self::Binary(binary) => ts::Message::Binary(binary),
595            Self::Ping(ping) => ts::Message::Ping(ping),
596            Self::Pong(pong) => ts::Message::Pong(pong),
597            Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
598                code: ts::protocol::frame::coding::CloseCode::from(close.code),
599                reason: close.reason,
600            })),
601            Self::Close(None) => ts::Message::Close(None),
602        }
603    }
604
605    fn from_tungstenite(message: ts::Message) -> Option<Self> {
606        match message {
607            ts::Message::Text(text) => Some(Self::Text(text)),
608            ts::Message::Binary(binary) => Some(Self::Binary(binary)),
609            ts::Message::Ping(ping) => Some(Self::Ping(ping)),
610            ts::Message::Pong(pong) => Some(Self::Pong(pong)),
611            ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
612                code: close.code.into(),
613                reason: close.reason,
614            }))),
615            ts::Message::Close(None) => Some(Self::Close(None)),
616            // we can ignore `Frame` frames as recommended by the tungstenite maintainers
617            // https://github.com/snapview/tungstenite-rs/issues/268
618            ts::Message::Frame(_) => None,
619        }
620    }
621
622    /// Consume the WebSocket and return it as binary data.
623    pub fn into_data(self) -> Vec<u8> {
624        match self {
625            Self::Text(string) => string.into_bytes(),
626            Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
627            Self::Close(None) => Vec::new(),
628            Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
629        }
630    }
631
632    /// Attempt to consume the WebSocket message and convert it to a String.
633    pub fn into_text(self) -> Result<String, Error> {
634        match self {
635            Self::Text(string) => Ok(string),
636            Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data)
637                .map_err(|err| err.utf8_error())
638                .map_err(Error::new)?),
639            Self::Close(None) => Ok(String::new()),
640            Self::Close(Some(frame)) => Ok(frame.reason.into_owned()),
641        }
642    }
643
644    /// Attempt to get a &str from the WebSocket message,
645    /// this will try to convert binary data to utf8.
646    pub fn to_text(&self) -> Result<&str, Error> {
647        match *self {
648            Self::Text(ref string) => Ok(string),
649            Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
650                Ok(std::str::from_utf8(data).map_err(Error::new)?)
651            }
652            Self::Close(None) => Ok(""),
653            Self::Close(Some(ref frame)) => Ok(&frame.reason),
654        }
655    }
656}
657
658impl From<String> for Message {
659    fn from(string: String) -> Self {
660        Message::Text(string)
661    }
662}
663
664impl<'s> From<&'s str> for Message {
665    fn from(string: &'s str) -> Self {
666        Message::Text(string.into())
667    }
668}
669
670impl<'b> From<&'b [u8]> for Message {
671    fn from(data: &'b [u8]) -> Self {
672        Message::Binary(data.into())
673    }
674}
675
676impl From<Vec<u8>> for Message {
677    fn from(data: Vec<u8>) -> Self {
678        Message::Binary(data)
679    }
680}
681
682impl From<Message> for Vec<u8> {
683    fn from(msg: Message) -> Self {
684        msg.into_data()
685    }
686}
687
688fn sign(key: &[u8]) -> HeaderValue {
689    use base64::engine::Engine as _;
690
691    let mut sha1 = Sha1::default();
692    sha1.update(key);
693    sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
694    let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
695    HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
696}
697
698pub mod rejection {
699    //! WebSocket specific rejections.
700
701    use axum_core::__composite_rejection as composite_rejection;
702    use axum_core::__define_rejection as define_rejection;
703
704    define_rejection! {
705        #[status = METHOD_NOT_ALLOWED]
706        #[body = "Request method must be `GET`"]
707        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
708        pub struct MethodNotGet;
709    }
710
711    define_rejection! {
712        #[status = BAD_REQUEST]
713        #[body = "Connection header did not include 'upgrade'"]
714        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
715        pub struct InvalidConnectionHeader;
716    }
717
718    define_rejection! {
719        #[status = BAD_REQUEST]
720        #[body = "`Upgrade` header did not include 'websocket'"]
721        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
722        pub struct InvalidUpgradeHeader;
723    }
724
725    define_rejection! {
726        #[status = BAD_REQUEST]
727        #[body = "`Sec-WebSocket-Version` header did not include '13'"]
728        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
729        pub struct InvalidWebSocketVersionHeader;
730    }
731
732    define_rejection! {
733        #[status = BAD_REQUEST]
734        #[body = "`Sec-WebSocket-Key` header missing"]
735        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
736        pub struct WebSocketKeyHeaderMissing;
737    }
738
739    define_rejection! {
740        #[status = UPGRADE_REQUIRED]
741        #[body = "WebSocket request couldn't be upgraded since no upgrade state was present"]
742        /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
743        ///
744        /// This rejection is returned if the connection cannot be upgraded for example if the
745        /// request is HTTP/1.0.
746        ///
747        /// See [MDN] for more details about connection upgrades.
748        ///
749        /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Upgrade
750        pub struct ConnectionNotUpgradable;
751    }
752
753    composite_rejection! {
754        /// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade).
755        ///
756        /// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade)
757        /// extractor can fail.
758        pub enum WebSocketUpgradeRejection {
759            MethodNotGet,
760            InvalidConnectionHeader,
761            InvalidUpgradeHeader,
762            InvalidWebSocketVersionHeader,
763            WebSocketKeyHeaderMissing,
764            ConnectionNotUpgradable,
765        }
766    }
767}
768
769pub mod close_code {
770    //! Constants for [`CloseCode`]s.
771    //!
772    //! [`CloseCode`]: super::CloseCode
773
774    /// Indicates a normal closure, meaning that the purpose for which the connection was
775    /// established has been fulfilled.
776    pub const NORMAL: u16 = 1000;
777
778    /// Indicates that an endpoint is "going away", such as a server going down or a browser having
779    /// navigated away from a page.
780    pub const AWAY: u16 = 1001;
781
782    /// Indicates that an endpoint is terminating the connection due to a protocol error.
783    pub const PROTOCOL: u16 = 1002;
784
785    /// Indicates that an endpoint is terminating the connection because it has received a type of
786    /// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if
787    /// it receives a binary message).
788    pub const UNSUPPORTED: u16 = 1003;
789
790    /// Indicates that no status code was included in a closing frame.
791    pub const STATUS: u16 = 1005;
792
793    /// Indicates an abnormal closure.
794    pub const ABNORMAL: u16 = 1006;
795
796    /// Indicates that an endpoint is terminating the connection because it has received data
797    /// within a message that was not consistent with the type of the message (e.g., non-UTF-8
798    /// RFC3629 data within a text message).
799    pub const INVALID: u16 = 1007;
800
801    /// Indicates that an endpoint is terminating the connection because it has received a message
802    /// that violates its policy. This is a generic status code that can be returned when there is
803    /// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to
804    /// hide specific details about the policy.
805    pub const POLICY: u16 = 1008;
806
807    /// Indicates that an endpoint is terminating the connection because it has received a message
808    /// that is too big for it to process.
809    pub const SIZE: u16 = 1009;
810
811    /// Indicates that an endpoint (client) is terminating the connection because it has expected
812    /// the server to negotiate one or more extension, but the server didn't return them in the
813    /// response message of the WebSocket handshake. The list of extensions that are needed should
814    /// be given as the reason for closing. Note that this status code is not used by the server,
815    /// because it can fail the WebSocket handshake instead.
816    pub const EXTENSION: u16 = 1010;
817
818    /// Indicates that a server is terminating the connection because it encountered an unexpected
819    /// condition that prevented it from fulfilling the request.
820    pub const ERROR: u16 = 1011;
821
822    /// Indicates that the server is restarting.
823    pub const RESTART: u16 = 1012;
824
825    /// Indicates that the server is overloaded and the client should either connect to a different
826    /// IP (when multiple targets exist), or reconnect to the same IP when a user has performed an
827    /// action.
828    pub const AGAIN: u16 = 1013;
829}
830
831#[cfg(test)]
832mod tests {
833    use std::future::ready;
834
835    use super::*;
836    use crate::{routing::get, test_helpers::spawn_service, Router};
837    use http::{Request, Version};
838    use tokio_tungstenite::tungstenite;
839    use tower::ServiceExt;
840
841    #[crate::test]
842    async fn rejects_http_1_0_requests() {
843        let svc = get(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
844            let rejection = ws.unwrap_err();
845            assert!(matches!(
846                rejection,
847                WebSocketUpgradeRejection::ConnectionNotUpgradable(_)
848            ));
849            std::future::ready(())
850        });
851
852        let req = Request::builder()
853            .version(Version::HTTP_10)
854            .method(Method::GET)
855            .header("upgrade", "websocket")
856            .header("connection", "Upgrade")
857            .header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==")
858            .header("sec-websocket-version", "13")
859            .body(Body::empty())
860            .unwrap();
861
862        let res = svc.oneshot(req).await.unwrap();
863
864        assert_eq!(res.status(), StatusCode::OK);
865    }
866
867    #[allow(dead_code)]
868    fn default_on_failed_upgrade() {
869        async fn handler(ws: WebSocketUpgrade) -> Response {
870            ws.on_upgrade(|_| async {})
871        }
872        let _: Router = Router::new().route("/", get(handler));
873    }
874
875    #[allow(dead_code)]
876    fn on_failed_upgrade() {
877        async fn handler(ws: WebSocketUpgrade) -> Response {
878            ws.on_failed_upgrade(|_error: Error| println!("oops!"))
879                .on_upgrade(|_| async {})
880        }
881        let _: Router = Router::new().route("/", get(handler));
882    }
883
884    #[crate::test]
885    async fn integration_test() {
886        let app = Router::new().route(
887            "/echo",
888            get(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
889        );
890
891        async fn handle_socket(mut socket: WebSocket) {
892            while let Some(Ok(msg)) = socket.recv().await {
893                match msg {
894                    Message::Text(_) | Message::Binary(_) | Message::Close(_) => {
895                        if socket.send(msg).await.is_err() {
896                            break;
897                        }
898                    }
899                    Message::Ping(_) | Message::Pong(_) => {
900                        // tungstenite will respond to pings automatically
901                    }
902                }
903            }
904        }
905
906        let addr = spawn_service(app);
907        let (mut socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
908            .await
909            .unwrap();
910
911        let input = tungstenite::Message::Text("foobar".to_owned());
912        socket.send(input.clone()).await.unwrap();
913        let output = socket.next().await.unwrap().unwrap();
914        assert_eq!(input, output);
915
916        socket
917            .send(tungstenite::Message::Ping("ping".to_owned().into_bytes()))
918            .await
919            .unwrap();
920        let output = socket.next().await.unwrap().unwrap();
921        assert_eq!(
922            output,
923            tungstenite::Message::Pong("ping".to_owned().into_bytes())
924        );
925    }
926}