tokio_tungstenite/
handshake.rs

1#[cfg(feature = "handshake")]
2use crate::compat::SetWaker;
3use crate::{compat::AllowStd, WebSocketStream};
4use log::*;
5use std::{
6    future::Future,
7    io::{Read, Write},
8    pin::Pin,
9    task::{Context, Poll},
10};
11use tokio::io::{AsyncRead, AsyncWrite};
12use tungstenite::WebSocket;
13#[cfg(feature = "handshake")]
14use tungstenite::{
15    handshake::{
16        client::Response, server::Callback, HandshakeError as Error, HandshakeRole,
17        MidHandshake as WsHandshake,
18    },
19    ClientHandshake, ServerHandshake,
20};
21
22pub(crate) async fn without_handshake<F, S>(stream: S, f: F) -> WebSocketStream<S>
23where
24    F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
25    S: AsyncRead + AsyncWrite + Unpin,
26{
27    let start = SkippedHandshakeFuture(Some(SkippedHandshakeFutureInner { f, stream }));
28
29    let ws = start.await;
30
31    WebSocketStream::new(ws)
32}
33
34struct SkippedHandshakeFuture<F, S>(Option<SkippedHandshakeFutureInner<F, S>>);
35struct SkippedHandshakeFutureInner<F, S> {
36    f: F,
37    stream: S,
38}
39
40impl<F, S> Future for SkippedHandshakeFuture<F, S>
41where
42    F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
43    S: Unpin,
44    AllowStd<S>: Read + Write,
45{
46    type Output = WebSocket<AllowStd<S>>;
47
48    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
49        let inner = self.get_mut().0.take().expect("future polled after completion");
50        trace!("Setting context when skipping handshake");
51        let stream = AllowStd::new(inner.stream, ctx.waker());
52
53        Poll::Ready((inner.f)(stream))
54    }
55}
56
57#[cfg(feature = "handshake")]
58struct MidHandshake<Role: HandshakeRole>(Option<WsHandshake<Role>>);
59
60#[cfg(feature = "handshake")]
61enum StartedHandshake<Role: HandshakeRole> {
62    Done(Role::FinalResult),
63    Mid(WsHandshake<Role>),
64}
65
66#[cfg(feature = "handshake")]
67struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>);
68#[cfg(feature = "handshake")]
69struct StartedHandshakeFutureInner<F, S> {
70    f: F,
71    stream: S,
72}
73
74#[cfg(feature = "handshake")]
75async fn handshake<Role, F, S>(stream: S, f: F) -> Result<Role::FinalResult, Error<Role>>
76where
77    Role: HandshakeRole + Unpin,
78    Role::InternalStream: SetWaker + Unpin,
79    F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
80    S: AsyncRead + AsyncWrite + Unpin,
81{
82    let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream }));
83
84    match start.await? {
85        StartedHandshake::Done(r) => Ok(r),
86        StartedHandshake::Mid(s) => {
87            let res: Result<Role::FinalResult, Error<Role>> = MidHandshake::<Role>(Some(s)).await;
88            res
89        }
90    }
91}
92
93#[cfg(feature = "handshake")]
94pub(crate) async fn client_handshake<F, S>(
95    stream: S,
96    f: F,
97) -> Result<(WebSocketStream<S>, Response), Error<ClientHandshake<AllowStd<S>>>>
98where
99    F: FnOnce(
100            AllowStd<S>,
101        ) -> Result<
102            <ClientHandshake<AllowStd<S>> as HandshakeRole>::FinalResult,
103            Error<ClientHandshake<AllowStd<S>>>,
104        > + Unpin,
105    S: AsyncRead + AsyncWrite + Unpin,
106{
107    let result = handshake(stream, f).await?;
108    let (s, r) = result;
109    Ok((WebSocketStream::new(s), r))
110}
111
112#[cfg(feature = "handshake")]
113pub(crate) async fn server_handshake<C, F, S>(
114    stream: S,
115    f: F,
116) -> Result<WebSocketStream<S>, Error<ServerHandshake<AllowStd<S>, C>>>
117where
118    C: Callback + Unpin,
119    F: FnOnce(
120            AllowStd<S>,
121        ) -> Result<
122            <ServerHandshake<AllowStd<S>, C> as HandshakeRole>::FinalResult,
123            Error<ServerHandshake<AllowStd<S>, C>>,
124        > + Unpin,
125    S: AsyncRead + AsyncWrite + Unpin,
126{
127    let s: WebSocket<AllowStd<S>> = handshake(stream, f).await?;
128    Ok(WebSocketStream::new(s))
129}
130
131#[cfg(feature = "handshake")]
132impl<Role, F, S> Future for StartedHandshakeFuture<F, S>
133where
134    Role: HandshakeRole,
135    Role::InternalStream: SetWaker + Unpin,
136    F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
137    S: Unpin,
138    AllowStd<S>: Read + Write,
139{
140    type Output = Result<StartedHandshake<Role>, Error<Role>>;
141
142    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
143        let inner = self.0.take().expect("future polled after completion");
144        trace!("Setting ctx when starting handshake");
145        let stream = AllowStd::new(inner.stream, ctx.waker());
146
147        match (inner.f)(stream) {
148            Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))),
149            Err(Error::Interrupted(mid)) => Poll::Ready(Ok(StartedHandshake::Mid(mid))),
150            Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
151        }
152    }
153}
154
155#[cfg(feature = "handshake")]
156impl<Role> Future for MidHandshake<Role>
157where
158    Role: HandshakeRole + Unpin,
159    Role::InternalStream: SetWaker + Unpin,
160{
161    type Output = Result<Role::FinalResult, Error<Role>>;
162
163    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164        let mut s = self.as_mut().0.take().expect("future polled after completion");
165
166        let machine = s.get_mut();
167        trace!("Setting context in handshake");
168        machine.get_mut().set_waker(cx.waker());
169
170        match s.handshake() {
171            Ok(stream) => Poll::Ready(Ok(stream)),
172            Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
173            Err(Error::Interrupted(mid)) => {
174                self.0 = Some(mid);
175                Poll::Pending
176            }
177        }
178    }
179}