1use crate::frame::{self, Frame, Kind, Reason};
2use crate::frame::{
3 DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE,
4};
5use crate::proto::Error;
6
7use crate::hpack;
8
9use futures_core::Stream;
10
11use bytes::BytesMut;
12
13use std::io;
14
15use std::pin::Pin;
16use std::task::{Context, Poll};
17use tokio::io::AsyncRead;
18use tokio_util::codec::FramedRead as InnerFramedRead;
19use tokio_util::codec::{LengthDelimitedCodec, LengthDelimitedCodecError};
20
21const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20;
23
24#[derive(Debug)]
25pub struct FramedRead<T> {
26 inner: InnerFramedRead<T, LengthDelimitedCodec>,
27
28 hpack: hpack::Decoder,
30
31 max_header_list_size: usize,
32
33 max_continuation_frames: usize,
34
35 partial: Option<Partial>,
36}
37
38#[derive(Debug)]
40struct Partial {
41 frame: Continuable,
43
44 buf: BytesMut,
46
47 continuation_frames_count: usize,
48}
49
50#[derive(Debug)]
51enum Continuable {
52 Headers(frame::Headers),
53 PushPromise(frame::PushPromise),
54}
55
56impl<T> FramedRead<T> {
57 pub fn new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T> {
58 let max_header_list_size = DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE;
59 let max_continuation_frames =
60 calc_max_continuation_frames(max_header_list_size, inner.decoder().max_frame_length());
61 FramedRead {
62 inner,
63 hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE),
64 max_header_list_size,
65 max_continuation_frames,
66 partial: None,
67 }
68 }
69
70 pub fn get_ref(&self) -> &T {
71 self.inner.get_ref()
72 }
73
74 pub fn get_mut(&mut self) -> &mut T {
75 self.inner.get_mut()
76 }
77
78 #[inline]
80 pub fn max_frame_size(&self) -> usize {
81 self.inner.decoder().max_frame_length()
82 }
83
84 #[inline]
88 pub fn set_max_frame_size(&mut self, val: usize) {
89 assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize);
90 self.inner.decoder_mut().set_max_frame_length(val);
91 self.max_continuation_frames = calc_max_continuation_frames(self.max_header_list_size, val);
93 }
94
95 #[inline]
97 pub fn set_max_header_list_size(&mut self, val: usize) {
98 self.max_header_list_size = val;
99 self.max_continuation_frames = calc_max_continuation_frames(val, self.max_frame_size());
101 }
102
103 #[inline]
105 pub fn set_header_table_size(&mut self, val: usize) {
106 self.hpack.queue_size_update(val);
107 }
108}
109
110fn calc_max_continuation_frames(header_max: usize, frame_max: usize) -> usize {
111 let min_frames_for_list = (header_max / frame_max).max(1);
113 let padding = min_frames_for_list >> 2;
116 min_frames_for_list.saturating_add(padding).max(5)
117}
118
119fn decode_frame(
123 hpack: &mut hpack::Decoder,
124 max_header_list_size: usize,
125 max_continuation_frames: usize,
126 partial_inout: &mut Option<Partial>,
127 mut bytes: BytesMut,
128) -> Result<Option<Frame>, Error> {
129 let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len());
130 let _e = span.enter();
131
132 tracing::trace!("decoding frame from {}B", bytes.len());
133
134 let head = frame::Head::parse(&bytes);
136
137 if partial_inout.is_some() && head.kind() != Kind::Continuation {
138 proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind());
139 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
140 }
141
142 let kind = head.kind();
143
144 tracing::trace!(frame.kind = ?kind);
145
146 macro_rules! header_block {
147 ($frame:ident, $head:ident, $bytes:ident) => ({
148 let _ = $bytes.split_to(frame::HEADER_LEN);
151
152 let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) {
154 Ok(res) => res,
155 Err(frame::Error::InvalidDependencyId) => {
156 proto_err!(stream: "invalid HEADERS dependency ID");
157 return Err(Error::library_reset($head.stream_id(), Reason::PROTOCOL_ERROR));
161 },
162 Err(e) => {
163 proto_err!(conn: "failed to load frame; err={:?}", e);
164 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
165 }
166 };
167
168 let is_end_headers = frame.is_end_headers();
169
170 match frame.load_hpack(&mut payload, max_header_list_size, hpack) {
172 Ok(_) => {},
173 Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {},
174 Err(frame::Error::MalformedMessage) => {
175 let id = $head.stream_id();
176 proto_err!(stream: "malformed header block; stream={:?}", id);
177 return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
178 },
179 Err(e) => {
180 proto_err!(conn: "failed HPACK decoding; err={:?}", e);
181 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
182 }
183 }
184
185 if is_end_headers {
186 frame.into()
187 } else {
188 tracing::trace!("loaded partial header block");
189 *partial_inout = Some(Partial {
191 frame: Continuable::$frame(frame),
192 buf: payload,
193 continuation_frames_count: 0,
194 });
195
196 return Ok(None);
197 }
198 });
199 }
200
201 let frame = match kind {
202 Kind::Settings => {
203 let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]);
204
205 res.map_err(|e| {
206 proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e);
207 Error::library_go_away(Reason::PROTOCOL_ERROR)
208 })?
209 .into()
210 }
211 Kind::Ping => {
212 let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]);
213
214 res.map_err(|e| {
215 proto_err!(conn: "failed to load PING frame; err={:?}", e);
216 Error::library_go_away(Reason::PROTOCOL_ERROR)
217 })?
218 .into()
219 }
220 Kind::WindowUpdate => {
221 let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]);
222
223 res.map_err(|e| {
224 proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e);
225 Error::library_go_away(Reason::PROTOCOL_ERROR)
226 })?
227 .into()
228 }
229 Kind::Data => {
230 let _ = bytes.split_to(frame::HEADER_LEN);
231 let res = frame::Data::load(head, bytes.freeze());
232
233 res.map_err(|e| {
235 proto_err!(conn: "failed to load DATA frame; err={:?}", e);
236 Error::library_go_away(Reason::PROTOCOL_ERROR)
237 })?
238 .into()
239 }
240 Kind::Headers => header_block!(Headers, head, bytes),
241 Kind::Reset => {
242 let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]);
243 res.map_err(|e| {
244 proto_err!(conn: "failed to load RESET frame; err={:?}", e);
245 Error::library_go_away(Reason::PROTOCOL_ERROR)
246 })?
247 .into()
248 }
249 Kind::GoAway => {
250 let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]);
251 res.map_err(|e| {
252 proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e);
253 Error::library_go_away(Reason::PROTOCOL_ERROR)
254 })?
255 .into()
256 }
257 Kind::PushPromise => header_block!(PushPromise, head, bytes),
258 Kind::Priority => {
259 if head.stream_id() == 0 {
260 proto_err!(conn: "invalid stream ID 0");
262 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
263 }
264
265 match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) {
266 Ok(frame) => frame.into(),
267 Err(frame::Error::InvalidDependencyId) => {
268 let id = head.stream_id();
272 proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id);
273 return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
274 }
275 Err(e) => {
276 proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e);
277 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
278 }
279 }
280 }
281 Kind::Continuation => {
282 let is_end_headers = (head.flag() & 0x4) == 0x4;
283
284 let mut partial = match partial_inout.take() {
285 Some(partial) => partial,
286 None => {
287 proto_err!(conn: "received unexpected CONTINUATION frame");
288 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
289 }
290 };
291
292 if partial.frame.stream_id() != head.stream_id() {
294 proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID");
295 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
296 }
297
298 if is_end_headers {
300 partial.continuation_frames_count = 0;
301 } else {
302 let cnt = partial.continuation_frames_count + 1;
303 if cnt > max_continuation_frames {
304 tracing::debug!("too_many_continuations, max = {}", max_continuation_frames);
305 return Err(Error::library_go_away_data(
306 Reason::ENHANCE_YOUR_CALM,
307 "too_many_continuations",
308 ));
309 } else {
310 partial.continuation_frames_count = cnt;
311 }
312 }
313
314 if partial.buf.is_empty() {
316 partial.buf = bytes.split_off(frame::HEADER_LEN);
317 } else {
318 if partial.frame.is_over_size() {
319 if partial.buf.len() + bytes.len() > max_header_list_size {
333 proto_err!(conn: "CONTINUATION frame header block size over ignorable limit");
334 return Err(Error::library_go_away(Reason::COMPRESSION_ERROR));
335 }
336 }
337 partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
338 }
339
340 match partial
341 .frame
342 .load_hpack(&mut partial.buf, max_header_list_size, hpack)
343 {
344 Ok(_) => {}
345 Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}
346 Err(frame::Error::MalformedMessage) => {
347 let id = head.stream_id();
348 proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id);
349 return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR));
350 }
351 Err(e) => {
352 proto_err!(conn: "failed HPACK decoding; err={:?}", e);
353 return Err(Error::library_go_away(Reason::PROTOCOL_ERROR));
354 }
355 }
356
357 if is_end_headers {
358 partial.frame.into()
359 } else {
360 *partial_inout = Some(partial);
361 return Ok(None);
362 }
363 }
364 Kind::Unknown => {
365 return Ok(None);
367 }
368 };
369
370 Ok(Some(frame))
371}
372
373impl<T> Stream for FramedRead<T>
374where
375 T: AsyncRead + Unpin,
376{
377 type Item = Result<Frame, Error>;
378
379 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
380 let span = tracing::trace_span!("FramedRead::poll_next");
381 let _e = span.enter();
382 loop {
383 tracing::trace!("poll");
384 let bytes = match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
385 Some(Ok(bytes)) => bytes,
386 Some(Err(e)) => return Poll::Ready(Some(Err(map_err(e)))),
387 None => return Poll::Ready(None),
388 };
389
390 tracing::trace!(read.bytes = bytes.len());
391 let Self {
392 ref mut hpack,
393 max_header_list_size,
394 ref mut partial,
395 max_continuation_frames,
396 ..
397 } = *self;
398 if let Some(frame) = decode_frame(
399 hpack,
400 max_header_list_size,
401 max_continuation_frames,
402 partial,
403 bytes,
404 )? {
405 tracing::debug!(?frame, "received");
406 return Poll::Ready(Some(Ok(frame)));
407 }
408 }
409 }
410}
411
412fn map_err(err: io::Error) -> Error {
413 if let io::ErrorKind::InvalidData = err.kind() {
414 if let Some(custom) = err.get_ref() {
415 if custom.is::<LengthDelimitedCodecError>() {
416 return Error::library_go_away(Reason::FRAME_SIZE_ERROR);
417 }
418 }
419 }
420 err.into()
421}
422
423impl Continuable {
426 fn stream_id(&self) -> frame::StreamId {
427 match *self {
428 Continuable::Headers(ref h) => h.stream_id(),
429 Continuable::PushPromise(ref p) => p.stream_id(),
430 }
431 }
432
433 fn is_over_size(&self) -> bool {
434 match *self {
435 Continuable::Headers(ref h) => h.is_over_size(),
436 Continuable::PushPromise(ref p) => p.is_over_size(),
437 }
438 }
439
440 fn load_hpack(
441 &mut self,
442 src: &mut BytesMut,
443 max_header_list_size: usize,
444 decoder: &mut hpack::Decoder,
445 ) -> Result<(), frame::Error> {
446 match *self {
447 Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder),
448 Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder),
449 }
450 }
451}
452
453impl<T> From<Continuable> for Frame<T> {
454 fn from(cont: Continuable) -> Self {
455 match cont {
456 Continuable::Headers(mut headers) => {
457 headers.set_end_headers();
458 headers.into()
459 }
460 Continuable::PushPromise(mut push) => {
461 push.set_end_headers();
462 push.into()
463 }
464 }
465 }
466}