1#[cfg(feature = "handshake")]
2use crate::compat::SetWaker;
3use crate::{compat::AllowStd, WebSocketStream};
4use log::*;
5use std::{
6 future::Future,
7 io::{Read, Write},
8 pin::Pin,
9 task::{Context, Poll},
10};
11use tokio::io::{AsyncRead, AsyncWrite};
12use tungstenite::WebSocket;
13#[cfg(feature = "handshake")]
14use tungstenite::{
15 handshake::{
16 client::Response, server::Callback, HandshakeError as Error, HandshakeRole,
17 MidHandshake as WsHandshake,
18 },
19 ClientHandshake, ServerHandshake,
20};
21
22pub(crate) async fn without_handshake<F, S>(stream: S, f: F) -> WebSocketStream<S>
23where
24 F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
25 S: AsyncRead + AsyncWrite + Unpin,
26{
27 let start = SkippedHandshakeFuture(Some(SkippedHandshakeFutureInner { f, stream }));
28
29 let ws = start.await;
30
31 WebSocketStream::new(ws)
32}
33
34struct SkippedHandshakeFuture<F, S>(Option<SkippedHandshakeFutureInner<F, S>>);
35struct SkippedHandshakeFutureInner<F, S> {
36 f: F,
37 stream: S,
38}
39
40impl<F, S> Future for SkippedHandshakeFuture<F, S>
41where
42 F: FnOnce(AllowStd<S>) -> WebSocket<AllowStd<S>> + Unpin,
43 S: Unpin,
44 AllowStd<S>: Read + Write,
45{
46 type Output = WebSocket<AllowStd<S>>;
47
48 fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
49 let inner = self.get_mut().0.take().expect("future polled after completion");
50 trace!("Setting context when skipping handshake");
51 let stream = AllowStd::new(inner.stream, ctx.waker());
52
53 Poll::Ready((inner.f)(stream))
54 }
55}
56
57#[cfg(feature = "handshake")]
58struct MidHandshake<Role: HandshakeRole>(Option<WsHandshake<Role>>);
59
60#[cfg(feature = "handshake")]
61enum StartedHandshake<Role: HandshakeRole> {
62 Done(Role::FinalResult),
63 Mid(WsHandshake<Role>),
64}
65
66#[cfg(feature = "handshake")]
67struct StartedHandshakeFuture<F, S>(Option<StartedHandshakeFutureInner<F, S>>);
68#[cfg(feature = "handshake")]
69struct StartedHandshakeFutureInner<F, S> {
70 f: F,
71 stream: S,
72}
73
74#[cfg(feature = "handshake")]
75async fn handshake<Role, F, S>(stream: S, f: F) -> Result<Role::FinalResult, Error<Role>>
76where
77 Role: HandshakeRole + Unpin,
78 Role::InternalStream: SetWaker + Unpin,
79 F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
80 S: AsyncRead + AsyncWrite + Unpin,
81{
82 let start = StartedHandshakeFuture(Some(StartedHandshakeFutureInner { f, stream }));
83
84 match start.await? {
85 StartedHandshake::Done(r) => Ok(r),
86 StartedHandshake::Mid(s) => {
87 let res: Result<Role::FinalResult, Error<Role>> = MidHandshake::<Role>(Some(s)).await;
88 res
89 }
90 }
91}
92
93#[cfg(feature = "handshake")]
94pub(crate) async fn client_handshake<F, S>(
95 stream: S,
96 f: F,
97) -> Result<(WebSocketStream<S>, Response), Error<ClientHandshake<AllowStd<S>>>>
98where
99 F: FnOnce(
100 AllowStd<S>,
101 ) -> Result<
102 <ClientHandshake<AllowStd<S>> as HandshakeRole>::FinalResult,
103 Error<ClientHandshake<AllowStd<S>>>,
104 > + Unpin,
105 S: AsyncRead + AsyncWrite + Unpin,
106{
107 let result = handshake(stream, f).await?;
108 let (s, r) = result;
109 Ok((WebSocketStream::new(s), r))
110}
111
112#[cfg(feature = "handshake")]
113pub(crate) async fn server_handshake<C, F, S>(
114 stream: S,
115 f: F,
116) -> Result<WebSocketStream<S>, Error<ServerHandshake<AllowStd<S>, C>>>
117where
118 C: Callback + Unpin,
119 F: FnOnce(
120 AllowStd<S>,
121 ) -> Result<
122 <ServerHandshake<AllowStd<S>, C> as HandshakeRole>::FinalResult,
123 Error<ServerHandshake<AllowStd<S>, C>>,
124 > + Unpin,
125 S: AsyncRead + AsyncWrite + Unpin,
126{
127 let s: WebSocket<AllowStd<S>> = handshake(stream, f).await?;
128 Ok(WebSocketStream::new(s))
129}
130
131#[cfg(feature = "handshake")]
132impl<Role, F, S> Future for StartedHandshakeFuture<F, S>
133where
134 Role: HandshakeRole,
135 Role::InternalStream: SetWaker + Unpin,
136 F: FnOnce(AllowStd<S>) -> Result<Role::FinalResult, Error<Role>> + Unpin,
137 S: Unpin,
138 AllowStd<S>: Read + Write,
139{
140 type Output = Result<StartedHandshake<Role>, Error<Role>>;
141
142 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
143 let inner = self.0.take().expect("future polled after completion");
144 trace!("Setting ctx when starting handshake");
145 let stream = AllowStd::new(inner.stream, ctx.waker());
146
147 match (inner.f)(stream) {
148 Ok(r) => Poll::Ready(Ok(StartedHandshake::Done(r))),
149 Err(Error::Interrupted(mid)) => Poll::Ready(Ok(StartedHandshake::Mid(mid))),
150 Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
151 }
152 }
153}
154
155#[cfg(feature = "handshake")]
156impl<Role> Future for MidHandshake<Role>
157where
158 Role: HandshakeRole + Unpin,
159 Role::InternalStream: SetWaker + Unpin,
160{
161 type Output = Result<Role::FinalResult, Error<Role>>;
162
163 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
164 let mut s = self.as_mut().0.take().expect("future polled after completion");
165
166 let machine = s.get_mut();
167 trace!("Setting context in handshake");
168 machine.get_mut().set_waker(cx.waker());
169
170 match s.handshake() {
171 Ok(stream) => Poll::Ready(Ok(stream)),
172 Err(Error::Failure(e)) => Poll::Ready(Err(Error::Failure(e))),
173 Err(Error::Interrupted(mid)) => {
174 self.0 = Some(mid);
175 Poll::Pending
176 }
177 }
178 }
179}