1use crate::codec::UserError;
2use crate::frame::{Reason, StreamId};
3use crate::{client, server};
4
5use crate::frame::DEFAULT_INITIAL_WINDOW_SIZE;
6use crate::proto::*;
7
8use bytes::Bytes;
9use futures_core::Stream;
10use std::io;
11use std::marker::PhantomData;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use std::time::Duration;
15use tokio::io::AsyncRead;
16
17#[derive(Debug)]
19pub(crate) struct Connection<T, P, B: Buf = Bytes>
20where
21 P: Peer,
22{
23 codec: Codec<T, Prioritized<B>>,
25
26 inner: ConnectionInner<P, B>,
27}
28
29#[derive(Debug)]
32struct ConnectionInner<P, B: Buf = Bytes>
33where
34 P: Peer,
35{
36 state: State,
38
39 error: Option<frame::GoAway>,
44
45 go_away: GoAway,
47
48 ping_pong: PingPong,
50
51 settings: Settings,
53
54 streams: Streams<B, P>,
56
57 span: tracing::Span,
59
60 _phantom: PhantomData<P>,
62}
63
64struct DynConnection<'a, B: Buf = Bytes> {
65 state: &'a mut State,
66
67 go_away: &'a mut GoAway,
68
69 streams: DynStreams<'a, B>,
70
71 error: &'a mut Option<frame::GoAway>,
72
73 ping_pong: &'a mut PingPong,
74}
75
76#[derive(Debug, Clone)]
77pub(crate) struct Config {
78 pub next_stream_id: StreamId,
79 pub initial_max_send_streams: usize,
80 pub max_send_buffer_size: usize,
81 pub reset_stream_duration: Duration,
82 pub reset_stream_max: usize,
83 pub remote_reset_stream_max: usize,
84 pub local_error_reset_streams_max: Option<usize>,
85 pub settings: frame::Settings,
86}
87
88#[derive(Debug)]
89enum State {
90 Open,
92
93 Closing(Reason, Initiator),
95
96 Closed(Reason, Initiator),
98}
99
100impl<T, P, B> Connection<T, P, B>
101where
102 T: AsyncRead + AsyncWrite + Unpin,
103 P: Peer,
104 B: Buf,
105{
106 pub fn new(codec: Codec<T, Prioritized<B>>, config: Config) -> Connection<T, P, B> {
107 fn streams_config(config: &Config) -> streams::Config {
108 streams::Config {
109 initial_max_send_streams: config.initial_max_send_streams,
110 local_max_buffer_size: config.max_send_buffer_size,
111 local_next_stream_id: config.next_stream_id,
112 local_push_enabled: config.settings.is_push_enabled().unwrap_or(true),
113 extended_connect_protocol_enabled: config
114 .settings
115 .is_extended_connect_protocol_enabled()
116 .unwrap_or(false),
117 local_reset_duration: config.reset_stream_duration,
118 local_reset_max: config.reset_stream_max,
119 remote_reset_max: config.remote_reset_stream_max,
120 remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE,
121 remote_max_initiated: config
122 .settings
123 .max_concurrent_streams()
124 .map(|max| max as usize),
125 local_max_error_reset_streams: config.local_error_reset_streams_max,
126 }
127 }
128 let streams = Streams::new(streams_config(&config));
129 Connection {
130 codec,
131 inner: ConnectionInner {
132 state: State::Open,
133 error: None,
134 go_away: GoAway::new(),
135 ping_pong: PingPong::new(),
136 settings: Settings::new(config.settings),
137 streams,
138 span: tracing::debug_span!("Connection", peer = %P::NAME),
139 _phantom: PhantomData,
140 },
141 }
142 }
143
144 pub(crate) fn set_target_window_size(&mut self, size: WindowSize) {
146 let _res = self.inner.streams.set_target_connection_window_size(size);
147 debug_assert!(_res.is_ok());
149 }
150
151 pub(crate) fn set_initial_window_size(&mut self, size: WindowSize) -> Result<(), UserError> {
153 let mut settings = frame::Settings::default();
154 settings.set_initial_window_size(Some(size));
155 self.inner.settings.send_settings(settings)
156 }
157
158 pub(crate) fn set_enable_connect_protocol(&mut self) -> Result<(), UserError> {
160 let mut settings = frame::Settings::default();
161 settings.set_enable_connect_protocol(Some(1));
162 self.inner.settings.send_settings(settings)
163 }
164
165 pub(crate) fn max_send_streams(&self) -> usize {
168 self.inner.streams.max_send_streams()
169 }
170
171 pub(crate) fn max_recv_streams(&self) -> usize {
174 self.inner.streams.max_recv_streams()
175 }
176
177 #[cfg(feature = "unstable")]
178 pub fn num_wired_streams(&self) -> usize {
179 self.inner.streams.num_wired_streams()
180 }
181
182 fn poll_ready(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
187 let _e = self.inner.span.enter();
188 let span = tracing::trace_span!("poll_ready");
189 let _e = span.enter();
190 ready!(self.inner.ping_pong.send_pending_pong(cx, &mut self.codec))?;
192 ready!(self.inner.ping_pong.send_pending_ping(cx, &mut self.codec))?;
193 ready!(self
194 .inner
195 .settings
196 .poll_send(cx, &mut self.codec, &mut self.inner.streams))?;
197 ready!(self.inner.streams.send_pending_refusal(cx, &mut self.codec))?;
198
199 Poll::Ready(Ok(()))
200 }
201
202 fn poll_go_away(&mut self, cx: &mut Context) -> Poll<Option<io::Result<Reason>>> {
207 self.inner.go_away.send_pending_go_away(cx, &mut self.codec)
208 }
209
210 pub fn go_away_from_user(&mut self, e: Reason) {
211 self.inner.as_dyn().go_away_from_user(e)
212 }
213
214 fn take_error(&mut self, ours: Reason, initiator: Initiator) -> Result<(), Error> {
215 let (debug_data, theirs) = self
216 .inner
217 .error
218 .take()
219 .as_ref()
220 .map_or((Bytes::new(), Reason::NO_ERROR), |frame| {
221 (frame.debug_data().clone(), frame.reason())
222 });
223
224 match (ours, theirs) {
225 (Reason::NO_ERROR, Reason::NO_ERROR) => Ok(()),
226 (ours, Reason::NO_ERROR) => Err(Error::GoAway(Bytes::new(), ours, initiator)),
227 (_, theirs) => Err(Error::remote_go_away(debug_data, theirs)),
232 }
233 }
234
235 pub fn maybe_close_connection_if_no_streams(&mut self) {
238 if !self.inner.streams.has_streams_or_other_references() {
241 self.inner.as_dyn().go_away_now(Reason::NO_ERROR);
242 }
243 }
244
245 pub fn has_streams_or_other_references(&self) -> bool {
247 self.inner.streams.has_streams_or_other_references()
250 }
251
252 pub(crate) fn take_user_pings(&mut self) -> Option<UserPings> {
253 self.inner.ping_pong.take_user_pings()
254 }
255
256 pub fn poll(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
258 let span = self.inner.span.clone();
263 let _e = span.enter();
264 let span = tracing::trace_span!("poll");
265 let _e = span.enter();
266
267 loop {
268 tracing::trace!(connection.state = ?self.inner.state);
269 match self.inner.state {
271 State::Open => {
273 let result = match self.poll2(cx) {
274 Poll::Ready(result) => result,
275 Poll::Pending => {
277 ready!(self.inner.streams.poll_complete(cx, &mut self.codec))?;
281
282 if (self.inner.error.is_some()
283 || self.inner.go_away.should_close_on_idle())
284 && !self.inner.streams.has_streams()
285 {
286 self.inner.as_dyn().go_away_now(Reason::NO_ERROR);
287 continue;
288 }
289
290 return Poll::Pending;
291 }
292 };
293
294 self.inner.as_dyn().handle_poll2_result(result)?
295 }
296 State::Closing(reason, initiator) => {
297 tracing::trace!("connection closing after flush");
298 ready!(self.codec.shutdown(cx))?;
300
301 self.inner.state = State::Closed(reason, initiator);
303 }
304 State::Closed(reason, initiator) => {
305 return Poll::Ready(self.take_error(reason, initiator));
306 }
307 }
308 }
309 }
310
311 fn poll2(&mut self, cx: &mut Context) -> Poll<Result<(), Error>> {
312 self.clear_expired_reset_streams();
316
317 loop {
318 if let Some(reason) = ready!(self.poll_go_away(cx)?) {
324 if self.inner.go_away.should_close_now() {
325 if self.inner.go_away.is_user_initiated() {
326 return Poll::Ready(Ok(()));
329 } else {
330 return Poll::Ready(Err(Error::library_go_away(reason)));
331 }
332 }
333 debug_assert_eq!(
335 reason,
336 Reason::NO_ERROR,
337 "graceful GOAWAY should be NO_ERROR"
338 );
339 }
340 ready!(self.poll_ready(cx))?;
341
342 match self
343 .inner
344 .as_dyn()
345 .recv_frame(ready!(Pin::new(&mut self.codec).poll_next(cx)?))?
346 {
347 ReceivedFrame::Settings(frame) => {
348 self.inner.settings.recv_settings(
349 frame,
350 &mut self.codec,
351 &mut self.inner.streams,
352 )?;
353 }
354 ReceivedFrame::Continue => (),
355 ReceivedFrame::Done => {
356 return Poll::Ready(Ok(()));
357 }
358 }
359 }
360 }
361
362 fn clear_expired_reset_streams(&mut self) {
363 self.inner.streams.clear_expired_reset_streams();
364 }
365}
366
367impl<P, B> ConnectionInner<P, B>
368where
369 P: Peer,
370 B: Buf,
371{
372 fn as_dyn(&mut self) -> DynConnection<'_, B> {
373 let ConnectionInner {
374 state,
375 go_away,
376 streams,
377 error,
378 ping_pong,
379 ..
380 } = self;
381 let streams = streams.as_dyn();
382 DynConnection {
383 state,
384 go_away,
385 streams,
386 error,
387 ping_pong,
388 }
389 }
390}
391
392impl<B> DynConnection<'_, B>
393where
394 B: Buf,
395{
396 fn go_away(&mut self, id: StreamId, e: Reason) {
397 let frame = frame::GoAway::new(id, e);
398 self.streams.send_go_away(id);
399 self.go_away.go_away(frame);
400 }
401
402 fn go_away_now(&mut self, e: Reason) {
403 let last_processed_id = self.streams.last_processed_id();
404 let frame = frame::GoAway::new(last_processed_id, e);
405 self.go_away.go_away_now(frame);
406 }
407
408 fn go_away_now_data(&mut self, e: Reason, data: Bytes) {
409 let last_processed_id = self.streams.last_processed_id();
410 let frame = frame::GoAway::with_debug_data(last_processed_id, e, data);
411 self.go_away.go_away_now(frame);
412 }
413
414 fn go_away_from_user(&mut self, e: Reason) {
415 let last_processed_id = self.streams.last_processed_id();
416 let frame = frame::GoAway::new(last_processed_id, e);
417 self.go_away.go_away_from_user(frame);
418
419 self.streams.handle_error(Error::user_go_away(e));
421 }
422
423 fn handle_poll2_result(&mut self, result: Result<(), Error>) -> Result<(), Error> {
424 match result {
425 Ok(()) => {
427 *self.state = State::Closing(Reason::NO_ERROR, Initiator::Library);
428 Ok(())
429 }
430 Err(Error::GoAway(debug_data, reason, initiator)) => {
434 let e = Error::GoAway(debug_data.clone(), reason, initiator);
435 tracing::debug!(error = ?e, "Connection::poll; connection error");
436
437 if self
440 .go_away
441 .going_away()
442 .map_or(false, |frame| frame.reason() == reason)
443 {
444 tracing::trace!(" -> already going away");
445 *self.state = State::Closing(reason, initiator);
446 return Ok(());
447 }
448
449 self.streams.handle_error(e);
451 self.go_away_now_data(reason, debug_data);
452 Ok(())
453 }
454 Err(Error::Reset(id, reason, initiator)) => {
458 debug_assert_eq!(initiator, Initiator::Library);
459 tracing::trace!(?id, ?reason, "stream error");
460 self.streams.send_reset(id, reason);
461 Ok(())
462 }
463 Err(Error::Io(kind, inner)) => {
468 tracing::debug!(error = ?kind, "Connection::poll; IO error");
469 let e = Error::Io(kind, inner);
470
471 self.streams.handle_error(e.clone());
473
474 if self.streams.is_server()
481 && self.streams.is_buffer_empty()
482 && matches!(kind, io::ErrorKind::UnexpectedEof)
483 {
484 *self.state = State::Closed(Reason::NO_ERROR, Initiator::Library);
485 return Ok(());
486 }
487
488 Err(e)
490 }
491 }
492 }
493
494 fn recv_frame(&mut self, frame: Option<Frame>) -> Result<ReceivedFrame, Error> {
495 use crate::frame::Frame::*;
496 match frame {
497 Some(Headers(frame)) => {
498 tracing::trace!(?frame, "recv HEADERS");
499 self.streams.recv_headers(frame)?;
500 }
501 Some(Data(frame)) => {
502 tracing::trace!(?frame, "recv DATA");
503 self.streams.recv_data(frame)?;
504 }
505 Some(Reset(frame)) => {
506 tracing::trace!(?frame, "recv RST_STREAM");
507 self.streams.recv_reset(frame)?;
508 }
509 Some(PushPromise(frame)) => {
510 tracing::trace!(?frame, "recv PUSH_PROMISE");
511 self.streams.recv_push_promise(frame)?;
512 }
513 Some(Settings(frame)) => {
514 tracing::trace!(?frame, "recv SETTINGS");
515 return Ok(ReceivedFrame::Settings(frame));
516 }
517 Some(GoAway(frame)) => {
518 tracing::trace!(?frame, "recv GOAWAY");
519 self.streams.recv_go_away(&frame)?;
524 *self.error = Some(frame);
525 }
526 Some(Ping(frame)) => {
527 tracing::trace!(?frame, "recv PING");
528 let status = self.ping_pong.recv_ping(frame);
529 if status.is_shutdown() {
530 assert!(
531 self.go_away.is_going_away(),
532 "received unexpected shutdown ping"
533 );
534
535 let last_processed_id = self.streams.last_processed_id();
536 self.go_away(last_processed_id, Reason::NO_ERROR);
537 }
538 }
539 Some(WindowUpdate(frame)) => {
540 tracing::trace!(?frame, "recv WINDOW_UPDATE");
541 self.streams.recv_window_update(frame)?;
542 }
543 Some(Priority(frame)) => {
544 tracing::trace!(?frame, "recv PRIORITY");
545 }
547 None => {
548 tracing::trace!("codec closed");
549 self.streams.recv_eof(false).expect("mutex poisoned");
550 return Ok(ReceivedFrame::Done);
551 }
552 }
553 Ok(ReceivedFrame::Continue)
554 }
555}
556
557enum ReceivedFrame {
558 Settings(frame::Settings),
559 Continue,
560 Done,
561}
562
563impl<T, B> Connection<T, client::Peer, B>
564where
565 T: AsyncRead + AsyncWrite,
566 B: Buf,
567{
568 pub(crate) fn streams(&self) -> &Streams<B, client::Peer> {
569 &self.inner.streams
570 }
571}
572
573impl<T, B> Connection<T, server::Peer, B>
574where
575 T: AsyncRead + AsyncWrite + Unpin,
576 B: Buf,
577{
578 pub fn next_incoming(&mut self) -> Option<StreamRef<B>> {
579 self.inner.streams.next_incoming()
580 }
581
582 pub fn go_away_gracefully(&mut self) {
584 if self.inner.go_away.is_going_away() {
585 return;
587 }
588
589 self.inner.as_dyn().go_away(StreamId::MAX, Reason::NO_ERROR);
601
602 self.inner.ping_pong.ping_shutdown();
605 }
606}
607
608impl<T, P, B> Drop for Connection<T, P, B>
609where
610 P: Peer,
611 B: Buf,
612{
613 fn drop(&mut self) {
614 let _ = self.inner.streams.recv_eof(true);
616 }
617}