axum/extract/ws.rs
1//! Handle WebSocket connections.
2//!
3//! # Example
4//!
5//! ```
6//! use axum::{
7//! extract::ws::{WebSocketUpgrade, WebSocket},
8//! routing::get,
9//! response::{IntoResponse, Response},
10//! Router,
11//! };
12//!
13//! let app = Router::new().route("/ws", get(handler));
14//!
15//! async fn handler(ws: WebSocketUpgrade) -> Response {
16//! ws.on_upgrade(handle_socket)
17//! }
18//!
19//! async fn handle_socket(mut socket: WebSocket) {
20//! while let Some(msg) = socket.recv().await {
21//! let msg = if let Ok(msg) = msg {
22//! msg
23//! } else {
24//! // client disconnected
25//! return;
26//! };
27//!
28//! if socket.send(msg).await.is_err() {
29//! // client disconnected
30//! return;
31//! }
32//! }
33//! }
34//! # let _: Router = app;
35//! ```
36//!
37//! # Passing data and/or state to an `on_upgrade` callback
38//!
39//! ```
40//! use axum::{
41//! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
42//! response::Response,
43//! routing::get,
44//! Router,
45//! };
46//!
47//! #[derive(Clone)]
48//! struct AppState {
49//! // ...
50//! }
51//!
52//! async fn handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
53//! ws.on_upgrade(|socket| handle_socket(socket, state))
54//! }
55//!
56//! async fn handle_socket(socket: WebSocket, state: AppState) {
57//! // ...
58//! }
59//!
60//! let app = Router::new()
61//! .route("/ws", get(handler))
62//! .with_state(AppState { /* ... */ });
63//! # let _: Router = app;
64//! ```
65//!
66//! # Read and write concurrently
67//!
68//! If you need to read and write concurrently from a [`WebSocket`] you can use
69//! [`StreamExt::split`]:
70//!
71//! ```rust,no_run
72//! use axum::{Error, extract::ws::{WebSocket, Message}};
73//! use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};
74//!
75//! async fn handle_socket(mut socket: WebSocket) {
76//! let (mut sender, mut receiver) = socket.split();
77//!
78//! tokio::spawn(write(sender));
79//! tokio::spawn(read(receiver));
80//! }
81//!
82//! async fn read(receiver: SplitStream<WebSocket>) {
83//! // ...
84//! }
85//!
86//! async fn write(sender: SplitSink<WebSocket, Message>) {
87//! // ...
88//! }
89//! ```
90//!
91//! [`StreamExt::split`]: https://docs.rs/futures/0.3.17/futures/stream/trait.StreamExt.html#method.split
92
93use self::rejection::*;
94use super::FromRequestParts;
95use crate::{body::Bytes, response::Response, Error};
96use async_trait::async_trait;
97use axum_core::body::Body;
98use futures_util::{
99 sink::{Sink, SinkExt},
100 stream::{Stream, StreamExt},
101};
102use http::{
103 header::{self, HeaderMap, HeaderName, HeaderValue},
104 request::Parts,
105 Method, StatusCode,
106};
107use hyper_util::rt::TokioIo;
108use sha1::{Digest, Sha1};
109use std::{
110 borrow::Cow,
111 future::Future,
112 pin::Pin,
113 task::{Context, Poll},
114};
115use tokio_tungstenite::{
116 tungstenite::{
117 self as ts,
118 protocol::{self, WebSocketConfig},
119 },
120 WebSocketStream,
121};
122
123/// Extractor for establishing WebSocket connections.
124///
125/// Note: This extractor requires the request method to be `GET` so it should
126/// always be used with [`get`](crate::routing::get). Requests with other methods will be
127/// rejected.
128///
129/// See the [module docs](self) for an example.
130#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
131pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
132 config: WebSocketConfig,
133 /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
134 protocol: Option<HeaderValue>,
135 sec_websocket_key: HeaderValue,
136 on_upgrade: hyper::upgrade::OnUpgrade,
137 on_failed_upgrade: F,
138 sec_websocket_protocol: Option<HeaderValue>,
139}
140
141impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 f.debug_struct("WebSocketUpgrade")
144 .field("config", &self.config)
145 .field("protocol", &self.protocol)
146 .field("sec_websocket_key", &self.sec_websocket_key)
147 .field("sec_websocket_protocol", &self.sec_websocket_protocol)
148 .finish_non_exhaustive()
149 }
150}
151
152impl<F> WebSocketUpgrade<F> {
153 /// The target minimum size of the write buffer to reach before writing the data
154 /// to the underlying stream.
155 ///
156 /// The default value is 128 KiB.
157 ///
158 /// If set to `0` each message will be eagerly written to the underlying stream.
159 /// It is often more optimal to allow them to buffer a little, hence the default value.
160 ///
161 /// Note: [`flush`](SinkExt::flush) will always fully write the buffer regardless.
162 pub fn write_buffer_size(mut self, size: usize) -> Self {
163 self.config.write_buffer_size = size;
164 self
165 }
166
167 /// The max size of the write buffer in bytes. Setting this can provide backpressure
168 /// in the case the write buffer is filling up due to write errors.
169 ///
170 /// The default value is unlimited.
171 ///
172 /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size)
173 /// when writes to the underlying stream are failing. So the **write buffer can not
174 /// fill up if you are not observing write errors even if not flushing**.
175 ///
176 /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size)
177 /// and probably a little more depending on error handling strategy.
178 pub fn max_write_buffer_size(mut self, max: usize) -> Self {
179 self.config.max_write_buffer_size = max;
180 self
181 }
182
183 /// Set the maximum message size (defaults to 64 megabytes)
184 pub fn max_message_size(mut self, max: usize) -> Self {
185 self.config.max_message_size = Some(max);
186 self
187 }
188
189 /// Set the maximum frame size (defaults to 16 megabytes)
190 pub fn max_frame_size(mut self, max: usize) -> Self {
191 self.config.max_frame_size = Some(max);
192 self
193 }
194
195 /// Allow server to accept unmasked frames (defaults to false)
196 pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
197 self.config.accept_unmasked_frames = accept;
198 self
199 }
200
201 /// Set the known protocols.
202 ///
203 /// If the protocol name specified by `Sec-WebSocket-Protocol` header
204 /// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and
205 /// return the protocol name.
206 ///
207 /// The protocols should be listed in decreasing order of preference: if the client offers
208 /// multiple protocols that the server could support, the server will pick the first one in
209 /// this list.
210 ///
211 /// # Examples
212 ///
213 /// ```
214 /// use axum::{
215 /// extract::ws::{WebSocketUpgrade, WebSocket},
216 /// routing::get,
217 /// response::{IntoResponse, Response},
218 /// Router,
219 /// };
220 ///
221 /// let app = Router::new().route("/ws", get(handler));
222 ///
223 /// async fn handler(ws: WebSocketUpgrade) -> Response {
224 /// ws.protocols(["graphql-ws", "graphql-transport-ws"])
225 /// .on_upgrade(|socket| async {
226 /// // ...
227 /// })
228 /// }
229 /// # let _: Router = app;
230 /// ```
231 pub fn protocols<I>(mut self, protocols: I) -> Self
232 where
233 I: IntoIterator,
234 I::Item: Into<Cow<'static, str>>,
235 {
236 if let Some(req_protocols) = self
237 .sec_websocket_protocol
238 .as_ref()
239 .and_then(|p| p.to_str().ok())
240 {
241 self.protocol = protocols
242 .into_iter()
243 // FIXME: This will often allocate a new `String` and so is less efficient than it
244 // could be. But that can't be fixed without breaking changes to the public API.
245 .map(Into::into)
246 .find(|protocol| {
247 req_protocols
248 .split(',')
249 .any(|req_protocol| req_protocol.trim() == protocol)
250 })
251 .map(|protocol| match protocol {
252 Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
253 Cow::Borrowed(s) => HeaderValue::from_static(s),
254 });
255 }
256
257 self
258 }
259
260 /// Provide a callback to call if upgrading the connection fails.
261 ///
262 /// The connection upgrade is performed in a background task. If that fails this callback
263 /// will be called.
264 ///
265 /// By default any errors will be silently ignored.
266 ///
267 /// # Example
268 ///
269 /// ```
270 /// use axum::{
271 /// extract::{WebSocketUpgrade},
272 /// response::Response,
273 /// };
274 ///
275 /// async fn handler(ws: WebSocketUpgrade) -> Response {
276 /// ws.on_failed_upgrade(|error| {
277 /// report_error(error);
278 /// })
279 /// .on_upgrade(|socket| async { /* ... */ })
280 /// }
281 /// #
282 /// # fn report_error(_: axum::Error) {}
283 /// ```
284 pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
285 where
286 C: OnFailedUpgrade,
287 {
288 WebSocketUpgrade {
289 config: self.config,
290 protocol: self.protocol,
291 sec_websocket_key: self.sec_websocket_key,
292 on_upgrade: self.on_upgrade,
293 on_failed_upgrade: callback,
294 sec_websocket_protocol: self.sec_websocket_protocol,
295 }
296 }
297
298 /// Finalize upgrading the connection and call the provided callback with
299 /// the stream.
300 #[must_use = "to set up the WebSocket connection, this response must be returned"]
301 pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
302 where
303 C: FnOnce(WebSocket) -> Fut + Send + 'static,
304 Fut: Future<Output = ()> + Send + 'static,
305 F: OnFailedUpgrade,
306 {
307 let on_upgrade = self.on_upgrade;
308 let config = self.config;
309 let on_failed_upgrade = self.on_failed_upgrade;
310
311 let protocol = self.protocol.clone();
312
313 tokio::spawn(async move {
314 let upgraded = match on_upgrade.await {
315 Ok(upgraded) => upgraded,
316 Err(err) => {
317 on_failed_upgrade.call(Error::new(err));
318 return;
319 }
320 };
321 let upgraded = TokioIo::new(upgraded);
322
323 let socket =
324 WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
325 .await;
326 let socket = WebSocket {
327 inner: socket,
328 protocol,
329 };
330 callback(socket).await;
331 });
332
333 #[allow(clippy::declare_interior_mutable_const)]
334 const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
335 #[allow(clippy::declare_interior_mutable_const)]
336 const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
337
338 let mut builder = Response::builder()
339 .status(StatusCode::SWITCHING_PROTOCOLS)
340 .header(header::CONNECTION, UPGRADE)
341 .header(header::UPGRADE, WEBSOCKET)
342 .header(
343 header::SEC_WEBSOCKET_ACCEPT,
344 sign(self.sec_websocket_key.as_bytes()),
345 );
346
347 if let Some(protocol) = self.protocol {
348 builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
349 }
350
351 builder.body(Body::empty()).unwrap()
352 }
353}
354
355/// What to do when a connection upgrade fails.
356///
357/// See [`WebSocketUpgrade::on_failed_upgrade`] for more details.
358pub trait OnFailedUpgrade: Send + 'static {
359 /// Call the callback.
360 fn call(self, error: Error);
361}
362
363impl<F> OnFailedUpgrade for F
364where
365 F: FnOnce(Error) + Send + 'static,
366{
367 fn call(self, error: Error) {
368 self(error)
369 }
370}
371
372/// The default `OnFailedUpgrade` used by `WebSocketUpgrade`.
373///
374/// It simply ignores the error.
375#[non_exhaustive]
376#[derive(Debug)]
377pub struct DefaultOnFailedUpgrade;
378
379impl OnFailedUpgrade for DefaultOnFailedUpgrade {
380 #[inline]
381 fn call(self, _error: Error) {}
382}
383
384#[async_trait]
385impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpgrade>
386where
387 S: Send + Sync,
388{
389 type Rejection = WebSocketUpgradeRejection;
390
391 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
392 if parts.method != Method::GET {
393 return Err(MethodNotGet.into());
394 }
395
396 if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
397 return Err(InvalidConnectionHeader.into());
398 }
399
400 if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
401 return Err(InvalidUpgradeHeader.into());
402 }
403
404 if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
405 return Err(InvalidWebSocketVersionHeader.into());
406 }
407
408 let sec_websocket_key = parts
409 .headers
410 .get(header::SEC_WEBSOCKET_KEY)
411 .ok_or(WebSocketKeyHeaderMissing)?
412 .clone();
413
414 let on_upgrade = parts
415 .extensions
416 .remove::<hyper::upgrade::OnUpgrade>()
417 .ok_or(ConnectionNotUpgradable)?;
418
419 let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
420
421 Ok(Self {
422 config: Default::default(),
423 protocol: None,
424 sec_websocket_key,
425 on_upgrade,
426 sec_websocket_protocol,
427 on_failed_upgrade: DefaultOnFailedUpgrade,
428 })
429 }
430}
431
432fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
433 if let Some(header) = headers.get(&key) {
434 header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
435 } else {
436 false
437 }
438}
439
440fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
441 let header = if let Some(header) = headers.get(&key) {
442 header
443 } else {
444 return false;
445 };
446
447 if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
448 header.to_ascii_lowercase().contains(value)
449 } else {
450 false
451 }
452}
453
454/// A stream of WebSocket messages.
455///
456/// See [the module level documentation](self) for more details.
457#[derive(Debug)]
458pub struct WebSocket {
459 inner: WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
460 protocol: Option<HeaderValue>,
461}
462
463impl WebSocket {
464 /// Receive another message.
465 ///
466 /// Returns `None` if the stream has closed.
467 pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
468 self.next().await
469 }
470
471 /// Send a message.
472 pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
473 self.inner
474 .send(msg.into_tungstenite())
475 .await
476 .map_err(Error::new)
477 }
478
479 /// Gracefully close this WebSocket.
480 pub async fn close(mut self) -> Result<(), Error> {
481 self.inner.close(None).await.map_err(Error::new)
482 }
483
484 /// Return the selected WebSocket subprotocol, if one has been chosen.
485 pub fn protocol(&self) -> Option<&HeaderValue> {
486 self.protocol.as_ref()
487 }
488}
489
490impl Stream for WebSocket {
491 type Item = Result<Message, Error>;
492
493 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
494 loop {
495 match futures_util::ready!(self.inner.poll_next_unpin(cx)) {
496 Some(Ok(msg)) => {
497 if let Some(msg) = Message::from_tungstenite(msg) {
498 return Poll::Ready(Some(Ok(msg)));
499 }
500 }
501 Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))),
502 None => return Poll::Ready(None),
503 }
504 }
505 }
506}
507
508impl Sink<Message> for WebSocket {
509 type Error = Error;
510
511 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
512 Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
513 }
514
515 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
516 Pin::new(&mut self.inner)
517 .start_send(item.into_tungstenite())
518 .map_err(Error::new)
519 }
520
521 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
522 Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
523 }
524
525 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
526 Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
527 }
528}
529
530/// Status code used to indicate why an endpoint is closing the WebSocket connection.
531pub type CloseCode = u16;
532
533/// A struct representing the close command.
534#[derive(Debug, Clone, Eq, PartialEq)]
535pub struct CloseFrame<'t> {
536 /// The reason as a code.
537 pub code: CloseCode,
538 /// The reason as text string.
539 pub reason: Cow<'t, str>,
540}
541
542/// A WebSocket message.
543//
544// This code comes from https://github.com/snapview/tungstenite-rs/blob/master/src/protocol/message.rs and is under following license:
545// Copyright (c) 2017 Alexey Galakhov
546// Copyright (c) 2016 Jason Housley
547//
548// Permission is hereby granted, free of charge, to any person obtaining a copy
549// of this software and associated documentation files (the "Software"), to deal
550// in the Software without restriction, including without limitation the rights
551// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
552// copies of the Software, and to permit persons to whom the Software is
553// furnished to do so, subject to the following conditions:
554//
555// The above copyright notice and this permission notice shall be included in
556// all copies or substantial portions of the Software.
557//
558// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
559// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
560// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
561// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
562// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
563// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
564// THE SOFTWARE.
565#[derive(Debug, Eq, PartialEq, Clone)]
566pub enum Message {
567 /// A text WebSocket message
568 Text(String),
569 /// A binary WebSocket message
570 Binary(Vec<u8>),
571 /// A ping message with the specified payload
572 ///
573 /// The payload here must have a length less than 125 bytes.
574 ///
575 /// Ping messages will be automatically responded to by the server, so you do not have to worry
576 /// about dealing with them yourself.
577 Ping(Vec<u8>),
578 /// A pong message with the specified payload
579 ///
580 /// The payload here must have a length less than 125 bytes.
581 ///
582 /// Pong messages will be automatically sent to the client if a ping message is received, so
583 /// you do not have to worry about constructing them yourself unless you want to implement a
584 /// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
585 Pong(Vec<u8>),
586 /// A close message with the optional close frame.
587 Close(Option<CloseFrame<'static>>),
588}
589
590impl Message {
591 fn into_tungstenite(self) -> ts::Message {
592 match self {
593 Self::Text(text) => ts::Message::Text(text),
594 Self::Binary(binary) => ts::Message::Binary(binary),
595 Self::Ping(ping) => ts::Message::Ping(ping),
596 Self::Pong(pong) => ts::Message::Pong(pong),
597 Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
598 code: ts::protocol::frame::coding::CloseCode::from(close.code),
599 reason: close.reason,
600 })),
601 Self::Close(None) => ts::Message::Close(None),
602 }
603 }
604
605 fn from_tungstenite(message: ts::Message) -> Option<Self> {
606 match message {
607 ts::Message::Text(text) => Some(Self::Text(text)),
608 ts::Message::Binary(binary) => Some(Self::Binary(binary)),
609 ts::Message::Ping(ping) => Some(Self::Ping(ping)),
610 ts::Message::Pong(pong) => Some(Self::Pong(pong)),
611 ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
612 code: close.code.into(),
613 reason: close.reason,
614 }))),
615 ts::Message::Close(None) => Some(Self::Close(None)),
616 // we can ignore `Frame` frames as recommended by the tungstenite maintainers
617 // https://github.com/snapview/tungstenite-rs/issues/268
618 ts::Message::Frame(_) => None,
619 }
620 }
621
622 /// Consume the WebSocket and return it as binary data.
623 pub fn into_data(self) -> Vec<u8> {
624 match self {
625 Self::Text(string) => string.into_bytes(),
626 Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
627 Self::Close(None) => Vec::new(),
628 Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
629 }
630 }
631
632 /// Attempt to consume the WebSocket message and convert it to a String.
633 pub fn into_text(self) -> Result<String, Error> {
634 match self {
635 Self::Text(string) => Ok(string),
636 Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data)
637 .map_err(|err| err.utf8_error())
638 .map_err(Error::new)?),
639 Self::Close(None) => Ok(String::new()),
640 Self::Close(Some(frame)) => Ok(frame.reason.into_owned()),
641 }
642 }
643
644 /// Attempt to get a &str from the WebSocket message,
645 /// this will try to convert binary data to utf8.
646 pub fn to_text(&self) -> Result<&str, Error> {
647 match *self {
648 Self::Text(ref string) => Ok(string),
649 Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
650 Ok(std::str::from_utf8(data).map_err(Error::new)?)
651 }
652 Self::Close(None) => Ok(""),
653 Self::Close(Some(ref frame)) => Ok(&frame.reason),
654 }
655 }
656}
657
658impl From<String> for Message {
659 fn from(string: String) -> Self {
660 Message::Text(string)
661 }
662}
663
664impl<'s> From<&'s str> for Message {
665 fn from(string: &'s str) -> Self {
666 Message::Text(string.into())
667 }
668}
669
670impl<'b> From<&'b [u8]> for Message {
671 fn from(data: &'b [u8]) -> Self {
672 Message::Binary(data.into())
673 }
674}
675
676impl From<Vec<u8>> for Message {
677 fn from(data: Vec<u8>) -> Self {
678 Message::Binary(data)
679 }
680}
681
682impl From<Message> for Vec<u8> {
683 fn from(msg: Message) -> Self {
684 msg.into_data()
685 }
686}
687
688fn sign(key: &[u8]) -> HeaderValue {
689 use base64::engine::Engine as _;
690
691 let mut sha1 = Sha1::default();
692 sha1.update(key);
693 sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
694 let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
695 HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
696}
697
698pub mod rejection {
699 //! WebSocket specific rejections.
700
701 use axum_core::__composite_rejection as composite_rejection;
702 use axum_core::__define_rejection as define_rejection;
703
704 define_rejection! {
705 #[status = METHOD_NOT_ALLOWED]
706 #[body = "Request method must be `GET`"]
707 /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
708 pub struct MethodNotGet;
709 }
710
711 define_rejection! {
712 #[status = BAD_REQUEST]
713 #[body = "Connection header did not include 'upgrade'"]
714 /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
715 pub struct InvalidConnectionHeader;
716 }
717
718 define_rejection! {
719 #[status = BAD_REQUEST]
720 #[body = "`Upgrade` header did not include 'websocket'"]
721 /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
722 pub struct InvalidUpgradeHeader;
723 }
724
725 define_rejection! {
726 #[status = BAD_REQUEST]
727 #[body = "`Sec-WebSocket-Version` header did not include '13'"]
728 /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
729 pub struct InvalidWebSocketVersionHeader;
730 }
731
732 define_rejection! {
733 #[status = BAD_REQUEST]
734 #[body = "`Sec-WebSocket-Key` header missing"]
735 /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
736 pub struct WebSocketKeyHeaderMissing;
737 }
738
739 define_rejection! {
740 #[status = UPGRADE_REQUIRED]
741 #[body = "WebSocket request couldn't be upgraded since no upgrade state was present"]
742 /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
743 ///
744 /// This rejection is returned if the connection cannot be upgraded for example if the
745 /// request is HTTP/1.0.
746 ///
747 /// See [MDN] for more details about connection upgrades.
748 ///
749 /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Upgrade
750 pub struct ConnectionNotUpgradable;
751 }
752
753 composite_rejection! {
754 /// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade).
755 ///
756 /// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade)
757 /// extractor can fail.
758 pub enum WebSocketUpgradeRejection {
759 MethodNotGet,
760 InvalidConnectionHeader,
761 InvalidUpgradeHeader,
762 InvalidWebSocketVersionHeader,
763 WebSocketKeyHeaderMissing,
764 ConnectionNotUpgradable,
765 }
766 }
767}
768
769pub mod close_code {
770 //! Constants for [`CloseCode`]s.
771 //!
772 //! [`CloseCode`]: super::CloseCode
773
774 /// Indicates a normal closure, meaning that the purpose for which the connection was
775 /// established has been fulfilled.
776 pub const NORMAL: u16 = 1000;
777
778 /// Indicates that an endpoint is "going away", such as a server going down or a browser having
779 /// navigated away from a page.
780 pub const AWAY: u16 = 1001;
781
782 /// Indicates that an endpoint is terminating the connection due to a protocol error.
783 pub const PROTOCOL: u16 = 1002;
784
785 /// Indicates that an endpoint is terminating the connection because it has received a type of
786 /// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if
787 /// it receives a binary message).
788 pub const UNSUPPORTED: u16 = 1003;
789
790 /// Indicates that no status code was included in a closing frame.
791 pub const STATUS: u16 = 1005;
792
793 /// Indicates an abnormal closure.
794 pub const ABNORMAL: u16 = 1006;
795
796 /// Indicates that an endpoint is terminating the connection because it has received data
797 /// within a message that was not consistent with the type of the message (e.g., non-UTF-8
798 /// RFC3629 data within a text message).
799 pub const INVALID: u16 = 1007;
800
801 /// Indicates that an endpoint is terminating the connection because it has received a message
802 /// that violates its policy. This is a generic status code that can be returned when there is
803 /// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to
804 /// hide specific details about the policy.
805 pub const POLICY: u16 = 1008;
806
807 /// Indicates that an endpoint is terminating the connection because it has received a message
808 /// that is too big for it to process.
809 pub const SIZE: u16 = 1009;
810
811 /// Indicates that an endpoint (client) is terminating the connection because it has expected
812 /// the server to negotiate one or more extension, but the server didn't return them in the
813 /// response message of the WebSocket handshake. The list of extensions that are needed should
814 /// be given as the reason for closing. Note that this status code is not used by the server,
815 /// because it can fail the WebSocket handshake instead.
816 pub const EXTENSION: u16 = 1010;
817
818 /// Indicates that a server is terminating the connection because it encountered an unexpected
819 /// condition that prevented it from fulfilling the request.
820 pub const ERROR: u16 = 1011;
821
822 /// Indicates that the server is restarting.
823 pub const RESTART: u16 = 1012;
824
825 /// Indicates that the server is overloaded and the client should either connect to a different
826 /// IP (when multiple targets exist), or reconnect to the same IP when a user has performed an
827 /// action.
828 pub const AGAIN: u16 = 1013;
829}
830
831#[cfg(test)]
832mod tests {
833 use std::future::ready;
834
835 use super::*;
836 use crate::{routing::get, test_helpers::spawn_service, Router};
837 use http::{Request, Version};
838 use tokio_tungstenite::tungstenite;
839 use tower::ServiceExt;
840
841 #[crate::test]
842 async fn rejects_http_1_0_requests() {
843 let svc = get(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
844 let rejection = ws.unwrap_err();
845 assert!(matches!(
846 rejection,
847 WebSocketUpgradeRejection::ConnectionNotUpgradable(_)
848 ));
849 std::future::ready(())
850 });
851
852 let req = Request::builder()
853 .version(Version::HTTP_10)
854 .method(Method::GET)
855 .header("upgrade", "websocket")
856 .header("connection", "Upgrade")
857 .header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==")
858 .header("sec-websocket-version", "13")
859 .body(Body::empty())
860 .unwrap();
861
862 let res = svc.oneshot(req).await.unwrap();
863
864 assert_eq!(res.status(), StatusCode::OK);
865 }
866
867 #[allow(dead_code)]
868 fn default_on_failed_upgrade() {
869 async fn handler(ws: WebSocketUpgrade) -> Response {
870 ws.on_upgrade(|_| async {})
871 }
872 let _: Router = Router::new().route("/", get(handler));
873 }
874
875 #[allow(dead_code)]
876 fn on_failed_upgrade() {
877 async fn handler(ws: WebSocketUpgrade) -> Response {
878 ws.on_failed_upgrade(|_error: Error| println!("oops!"))
879 .on_upgrade(|_| async {})
880 }
881 let _: Router = Router::new().route("/", get(handler));
882 }
883
884 #[crate::test]
885 async fn integration_test() {
886 let app = Router::new().route(
887 "/echo",
888 get(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
889 );
890
891 async fn handle_socket(mut socket: WebSocket) {
892 while let Some(Ok(msg)) = socket.recv().await {
893 match msg {
894 Message::Text(_) | Message::Binary(_) | Message::Close(_) => {
895 if socket.send(msg).await.is_err() {
896 break;
897 }
898 }
899 Message::Ping(_) | Message::Pong(_) => {
900 // tungstenite will respond to pings automatically
901 }
902 }
903 }
904 }
905
906 let addr = spawn_service(app);
907 let (mut socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
908 .await
909 .unwrap();
910
911 let input = tungstenite::Message::Text("foobar".to_owned());
912 socket.send(input.clone()).await.unwrap();
913 let output = socket.next().await.unwrap().unwrap();
914 assert_eq!(input, output);
915
916 socket
917 .send(tungstenite::Message::Ping("ping".to_owned().into_bytes()))
918 .await
919 .unwrap();
920 let output = socket.next().await.unwrap().unwrap();
921 assert_eq!(
922 output,
923 tungstenite::Message::Pong("ping".to_owned().into_bytes())
924 );
925 }
926}