h2/proto/
ping_pong.rs

1use crate::codec::Codec;
2use crate::frame::Ping;
3use crate::proto::{self, PingPayload};
4
5use atomic_waker::AtomicWaker;
6use bytes::Buf;
7use std::io;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use tokio::io::AsyncWrite;
12
13/// Acknowledges ping requests from the remote.
14#[derive(Debug)]
15pub(crate) struct PingPong {
16    pending_ping: Option<PendingPing>,
17    pending_pong: Option<PingPayload>,
18    user_pings: Option<UserPingsRx>,
19}
20
21#[derive(Debug)]
22pub(crate) struct UserPings(Arc<UserPingsInner>);
23
24#[derive(Debug)]
25struct UserPingsRx(Arc<UserPingsInner>);
26
27#[derive(Debug)]
28struct UserPingsInner {
29    state: AtomicUsize,
30    /// Task to wake up the main `Connection`.
31    ping_task: AtomicWaker,
32    /// Task to wake up `share::PingPong::poll_pong`.
33    pong_task: AtomicWaker,
34}
35
36#[derive(Debug)]
37struct PendingPing {
38    payload: PingPayload,
39    sent: bool,
40}
41
42/// Status returned from `PingPong::recv_ping`.
43#[derive(Debug)]
44pub(crate) enum ReceivedPing {
45    MustAck,
46    Unknown,
47    Shutdown,
48}
49
50/// No user ping pending.
51const USER_STATE_EMPTY: usize = 0;
52/// User has called `send_ping`, but PING hasn't been written yet.
53const USER_STATE_PENDING_PING: usize = 1;
54/// User PING has been written, waiting for PONG.
55const USER_STATE_PENDING_PONG: usize = 2;
56/// We've received user PONG, waiting for user to `poll_pong`.
57const USER_STATE_RECEIVED_PONG: usize = 3;
58/// The connection is closed.
59const USER_STATE_CLOSED: usize = 4;
60
61// ===== impl PingPong =====
62
63impl PingPong {
64    pub(crate) fn new() -> Self {
65        PingPong {
66            pending_ping: None,
67            pending_pong: None,
68            user_pings: None,
69        }
70    }
71
72    /// Can only be called once. If called a second time, returns `None`.
73    pub(crate) fn take_user_pings(&mut self) -> Option<UserPings> {
74        if self.user_pings.is_some() {
75            return None;
76        }
77
78        let user_pings = Arc::new(UserPingsInner {
79            state: AtomicUsize::new(USER_STATE_EMPTY),
80            ping_task: AtomicWaker::new(),
81            pong_task: AtomicWaker::new(),
82        });
83        self.user_pings = Some(UserPingsRx(user_pings.clone()));
84        Some(UserPings(user_pings))
85    }
86
87    pub(crate) fn ping_shutdown(&mut self) {
88        assert!(self.pending_ping.is_none());
89
90        self.pending_ping = Some(PendingPing {
91            payload: Ping::SHUTDOWN,
92            sent: false,
93        });
94    }
95
96    /// Process a ping
97    pub(crate) fn recv_ping(&mut self, ping: Ping) -> ReceivedPing {
98        // The caller should always check that `send_pongs` returns ready before
99        // calling `recv_ping`.
100        assert!(self.pending_pong.is_none());
101
102        if ping.is_ack() {
103            if let Some(pending) = self.pending_ping.take() {
104                if &pending.payload == ping.payload() {
105                    assert_eq!(
106                        &pending.payload,
107                        &Ping::SHUTDOWN,
108                        "pending_ping should be for shutdown",
109                    );
110                    tracing::trace!("recv PING SHUTDOWN ack");
111                    return ReceivedPing::Shutdown;
112                }
113
114                // if not the payload we expected, put it back.
115                self.pending_ping = Some(pending);
116            }
117
118            if let Some(ref users) = self.user_pings {
119                if ping.payload() == &Ping::USER && users.receive_pong() {
120                    tracing::trace!("recv PING USER ack");
121                    return ReceivedPing::Unknown;
122                }
123            }
124
125            // else we were acked a ping we didn't send?
126            // The spec doesn't require us to do anything about this,
127            // so for resiliency, just ignore it for now.
128            tracing::warn!("recv PING ack that we never sent: {:?}", ping);
129            ReceivedPing::Unknown
130        } else {
131            // Save the ping's payload to be sent as an acknowledgement.
132            self.pending_pong = Some(ping.into_payload());
133            ReceivedPing::MustAck
134        }
135    }
136
137    /// Send any pending pongs.
138    pub(crate) fn send_pending_pong<T, B>(
139        &mut self,
140        cx: &mut Context,
141        dst: &mut Codec<T, B>,
142    ) -> Poll<io::Result<()>>
143    where
144        T: AsyncWrite + Unpin,
145        B: Buf,
146    {
147        if let Some(pong) = self.pending_pong.take() {
148            if !dst.poll_ready(cx)?.is_ready() {
149                self.pending_pong = Some(pong);
150                return Poll::Pending;
151            }
152
153            dst.buffer(Ping::pong(pong).into())
154                .expect("invalid pong frame");
155        }
156
157        Poll::Ready(Ok(()))
158    }
159
160    /// Send any pending pings.
161    pub(crate) fn send_pending_ping<T, B>(
162        &mut self,
163        cx: &mut Context,
164        dst: &mut Codec<T, B>,
165    ) -> Poll<io::Result<()>>
166    where
167        T: AsyncWrite + Unpin,
168        B: Buf,
169    {
170        if let Some(ref mut ping) = self.pending_ping {
171            if !ping.sent {
172                if !dst.poll_ready(cx)?.is_ready() {
173                    return Poll::Pending;
174                }
175
176                dst.buffer(Ping::new(ping.payload).into())
177                    .expect("invalid ping frame");
178                ping.sent = true;
179            }
180        } else if let Some(ref users) = self.user_pings {
181            if users.0.state.load(Ordering::Acquire) == USER_STATE_PENDING_PING {
182                if !dst.poll_ready(cx)?.is_ready() {
183                    return Poll::Pending;
184                }
185
186                dst.buffer(Ping::new(Ping::USER).into())
187                    .expect("invalid ping frame");
188                users
189                    .0
190                    .state
191                    .store(USER_STATE_PENDING_PONG, Ordering::Release);
192            } else {
193                users.0.ping_task.register(cx.waker());
194            }
195        }
196
197        Poll::Ready(Ok(()))
198    }
199}
200
201impl ReceivedPing {
202    pub(crate) fn is_shutdown(&self) -> bool {
203        matches!(*self, Self::Shutdown)
204    }
205}
206
207// ===== impl UserPings =====
208
209impl UserPings {
210    pub(crate) fn send_ping(&self) -> Result<(), Option<proto::Error>> {
211        let prev = self
212            .0
213            .state
214            .compare_exchange(
215                USER_STATE_EMPTY,        // current
216                USER_STATE_PENDING_PING, // new
217                Ordering::AcqRel,
218                Ordering::Acquire,
219            )
220            .unwrap_or_else(|v| v);
221
222        match prev {
223            USER_STATE_EMPTY => {
224                self.0.ping_task.wake();
225                Ok(())
226            }
227            USER_STATE_CLOSED => Err(Some(broken_pipe().into())),
228            _ => {
229                // Was already pending, user error!
230                Err(None)
231            }
232        }
233    }
234
235    pub(crate) fn poll_pong(&self, cx: &mut Context) -> Poll<Result<(), proto::Error>> {
236        // Must register before checking state, in case state were to change
237        // before we could register, and then the ping would just be lost.
238        self.0.pong_task.register(cx.waker());
239        let prev = self
240            .0
241            .state
242            .compare_exchange(
243                USER_STATE_RECEIVED_PONG, // current
244                USER_STATE_EMPTY,         // new
245                Ordering::AcqRel,
246                Ordering::Acquire,
247            )
248            .unwrap_or_else(|v| v);
249
250        match prev {
251            USER_STATE_RECEIVED_PONG => Poll::Ready(Ok(())),
252            USER_STATE_CLOSED => Poll::Ready(Err(broken_pipe().into())),
253            _ => Poll::Pending,
254        }
255    }
256}
257
258// ===== impl UserPingsRx =====
259
260impl UserPingsRx {
261    fn receive_pong(&self) -> bool {
262        let prev = self
263            .0
264            .state
265            .compare_exchange(
266                USER_STATE_PENDING_PONG,  // current
267                USER_STATE_RECEIVED_PONG, // new
268                Ordering::AcqRel,
269                Ordering::Acquire,
270            )
271            .unwrap_or_else(|v| v);
272
273        if prev == USER_STATE_PENDING_PONG {
274            self.0.pong_task.wake();
275            true
276        } else {
277            false
278        }
279    }
280}
281
282impl Drop for UserPingsRx {
283    fn drop(&mut self) {
284        self.0.state.store(USER_STATE_CLOSED, Ordering::Release);
285        self.0.pong_task.wake();
286    }
287}
288
289fn broken_pipe() -> io::Error {
290    io::ErrorKind::BrokenPipe.into()
291}