h2/frame/
headers.rs

1use super::{util, StreamDependency, StreamId};
2use crate::ext::Protocol;
3use crate::frame::{Error, Frame, Head, Kind};
4use crate::hpack::{self, BytesStr};
5
6use http::header::{self, HeaderName, HeaderValue};
7use http::{uri, HeaderMap, Method, Request, StatusCode, Uri};
8
9use bytes::{BufMut, Bytes, BytesMut};
10
11use std::fmt;
12use std::io::Cursor;
13
14type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>;
15
16/// Header frame
17///
18/// This could be either a request or a response.
19#[derive(Eq, PartialEq)]
20pub struct Headers {
21    /// The ID of the stream with which this frame is associated.
22    stream_id: StreamId,
23
24    /// The stream dependency information, if any.
25    stream_dep: Option<StreamDependency>,
26
27    /// The header block fragment
28    header_block: HeaderBlock,
29
30    /// The associated flags
31    flags: HeadersFlag,
32}
33
34#[derive(Copy, Clone, Eq, PartialEq)]
35pub struct HeadersFlag(u8);
36
37#[derive(Eq, PartialEq)]
38pub struct PushPromise {
39    /// The ID of the stream with which this frame is associated.
40    stream_id: StreamId,
41
42    /// The ID of the stream being reserved by this PushPromise.
43    promised_id: StreamId,
44
45    /// The header block fragment
46    header_block: HeaderBlock,
47
48    /// The associated flags
49    flags: PushPromiseFlag,
50}
51
52#[derive(Copy, Clone, Eq, PartialEq)]
53pub struct PushPromiseFlag(u8);
54
55#[derive(Debug)]
56pub struct Continuation {
57    /// Stream ID of continuation frame
58    stream_id: StreamId,
59
60    header_block: EncodingHeaderBlock,
61}
62
63// TODO: These fields shouldn't be `pub`
64#[derive(Debug, Default, Eq, PartialEq)]
65pub struct Pseudo {
66    // Request
67    pub method: Option<Method>,
68    pub scheme: Option<BytesStr>,
69    pub authority: Option<BytesStr>,
70    pub path: Option<BytesStr>,
71    pub protocol: Option<Protocol>,
72
73    // Response
74    pub status: Option<StatusCode>,
75}
76
77#[derive(Debug)]
78pub struct Iter {
79    /// Pseudo headers
80    pseudo: Option<Pseudo>,
81
82    /// Header fields
83    fields: header::IntoIter<HeaderValue>,
84}
85
86#[derive(Debug, PartialEq, Eq)]
87struct HeaderBlock {
88    /// The decoded header fields
89    fields: HeaderMap,
90
91    /// Precomputed size of all of our header fields, for perf reasons
92    field_size: usize,
93
94    /// Set to true if decoding went over the max header list size.
95    is_over_size: bool,
96
97    /// Pseudo headers, these are broken out as they must be sent as part of the
98    /// headers frame.
99    pseudo: Pseudo,
100}
101
102#[derive(Debug)]
103struct EncodingHeaderBlock {
104    hpack: Bytes,
105}
106
107const END_STREAM: u8 = 0x1;
108const END_HEADERS: u8 = 0x4;
109const PADDED: u8 = 0x8;
110const PRIORITY: u8 = 0x20;
111const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
112
113// ===== impl Headers =====
114
115impl Headers {
116    /// Create a new HEADERS frame
117    pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
118        Headers {
119            stream_id,
120            stream_dep: None,
121            header_block: HeaderBlock {
122                field_size: calculate_headermap_size(&fields),
123                fields,
124                is_over_size: false,
125                pseudo,
126            },
127            flags: HeadersFlag::default(),
128        }
129    }
130
131    pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
132        let mut flags = HeadersFlag::default();
133        flags.set_end_stream();
134
135        Headers {
136            stream_id,
137            stream_dep: None,
138            header_block: HeaderBlock {
139                field_size: calculate_headermap_size(&fields),
140                fields,
141                is_over_size: false,
142                pseudo: Pseudo::default(),
143            },
144            flags,
145        }
146    }
147
148    /// Loads the header frame but doesn't actually do HPACK decoding.
149    ///
150    /// HPACK decoding is done in the `load_hpack` step.
151    pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
152        let flags = HeadersFlag(head.flag());
153        let mut pad = 0;
154
155        tracing::trace!("loading headers; flags={:?}", flags);
156
157        if head.stream_id().is_zero() {
158            return Err(Error::InvalidStreamId);
159        }
160
161        // Read the padding length
162        if flags.is_padded() {
163            if src.is_empty() {
164                return Err(Error::MalformedMessage);
165            }
166            pad = src[0] as usize;
167
168            // Drop the padding
169            let _ = src.split_to(1);
170        }
171
172        // Read the stream dependency
173        let stream_dep = if flags.is_priority() {
174            if src.len() < 5 {
175                return Err(Error::MalformedMessage);
176            }
177            let stream_dep = StreamDependency::load(&src[..5])?;
178
179            if stream_dep.dependency_id() == head.stream_id() {
180                return Err(Error::InvalidDependencyId);
181            }
182
183            // Drop the next 5 bytes
184            let _ = src.split_to(5);
185
186            Some(stream_dep)
187        } else {
188            None
189        };
190
191        if pad > 0 {
192            if pad > src.len() {
193                return Err(Error::TooMuchPadding);
194            }
195
196            let len = src.len() - pad;
197            src.truncate(len);
198        }
199
200        let headers = Headers {
201            stream_id: head.stream_id(),
202            stream_dep,
203            header_block: HeaderBlock {
204                fields: HeaderMap::new(),
205                field_size: 0,
206                is_over_size: false,
207                pseudo: Pseudo::default(),
208            },
209            flags,
210        };
211
212        Ok((headers, src))
213    }
214
215    pub fn load_hpack(
216        &mut self,
217        src: &mut BytesMut,
218        max_header_list_size: usize,
219        decoder: &mut hpack::Decoder,
220    ) -> Result<(), Error> {
221        self.header_block.load(src, max_header_list_size, decoder)
222    }
223
224    pub fn stream_id(&self) -> StreamId {
225        self.stream_id
226    }
227
228    pub fn is_end_headers(&self) -> bool {
229        self.flags.is_end_headers()
230    }
231
232    pub fn set_end_headers(&mut self) {
233        self.flags.set_end_headers();
234    }
235
236    pub fn is_end_stream(&self) -> bool {
237        self.flags.is_end_stream()
238    }
239
240    pub fn set_end_stream(&mut self) {
241        self.flags.set_end_stream()
242    }
243
244    pub fn is_over_size(&self) -> bool {
245        self.header_block.is_over_size
246    }
247
248    pub fn into_parts(self) -> (Pseudo, HeaderMap) {
249        (self.header_block.pseudo, self.header_block.fields)
250    }
251
252    #[cfg(feature = "unstable")]
253    pub fn pseudo_mut(&mut self) -> &mut Pseudo {
254        &mut self.header_block.pseudo
255    }
256
257    /// Whether it has status 1xx
258    pub(crate) fn is_informational(&self) -> bool {
259        self.header_block.pseudo.is_informational()
260    }
261
262    pub fn fields(&self) -> &HeaderMap {
263        &self.header_block.fields
264    }
265
266    pub fn into_fields(self) -> HeaderMap {
267        self.header_block.fields
268    }
269
270    pub fn encode(
271        self,
272        encoder: &mut hpack::Encoder,
273        dst: &mut EncodeBuf<'_>,
274    ) -> Option<Continuation> {
275        // At this point, the `is_end_headers` flag should always be set
276        debug_assert!(self.flags.is_end_headers());
277
278        // Get the HEADERS frame head
279        let head = self.head();
280
281        self.header_block
282            .into_encoding(encoder)
283            .encode(&head, dst, |_| {})
284    }
285
286    fn head(&self) -> Head {
287        Head::new(Kind::Headers, self.flags.into(), self.stream_id)
288    }
289}
290
291impl<T> From<Headers> for Frame<T> {
292    fn from(src: Headers) -> Self {
293        Frame::Headers(src)
294    }
295}
296
297impl fmt::Debug for Headers {
298    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
299        let mut builder = f.debug_struct("Headers");
300        builder
301            .field("stream_id", &self.stream_id)
302            .field("flags", &self.flags);
303
304        if let Some(ref protocol) = self.header_block.pseudo.protocol {
305            builder.field("protocol", protocol);
306        }
307
308        if let Some(ref dep) = self.stream_dep {
309            builder.field("stream_dep", dep);
310        }
311
312        // `fields` and `pseudo` purposefully not included
313        builder.finish()
314    }
315}
316
317// ===== util =====
318
319#[derive(Debug, PartialEq, Eq)]
320pub struct ParseU64Error;
321
322pub fn parse_u64(src: &[u8]) -> Result<u64, ParseU64Error> {
323    if src.len() > 19 {
324        // At danger for overflow...
325        return Err(ParseU64Error);
326    }
327
328    let mut ret = 0;
329
330    for &d in src {
331        if d < b'0' || d > b'9' {
332            return Err(ParseU64Error);
333        }
334
335        ret *= 10;
336        ret += (d - b'0') as u64;
337    }
338
339    Ok(ret)
340}
341
342// ===== impl PushPromise =====
343
344#[derive(Debug)]
345pub enum PushPromiseHeaderError {
346    InvalidContentLength(Result<u64, ParseU64Error>),
347    NotSafeAndCacheable,
348}
349
350impl PushPromise {
351    pub fn new(
352        stream_id: StreamId,
353        promised_id: StreamId,
354        pseudo: Pseudo,
355        fields: HeaderMap,
356    ) -> Self {
357        PushPromise {
358            flags: PushPromiseFlag::default(),
359            header_block: HeaderBlock {
360                field_size: calculate_headermap_size(&fields),
361                fields,
362                is_over_size: false,
363                pseudo,
364            },
365            promised_id,
366            stream_id,
367        }
368    }
369
370    pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> {
371        use PushPromiseHeaderError::*;
372        // The spec has some requirements for promised request headers
373        // [https://httpwg.org/specs/rfc7540.html#PushRequests]
374
375        // A promised request "that indicates the presence of a request body
376        // MUST reset the promised stream with a stream error"
377        if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
378            let parsed_length = parse_u64(content_length.as_bytes());
379            if parsed_length != Ok(0) {
380                return Err(InvalidContentLength(parsed_length));
381            }
382        }
383        // "The server MUST include a method in the :method pseudo-header field
384        // that is safe and cacheable"
385        if !Self::safe_and_cacheable(req.method()) {
386            return Err(NotSafeAndCacheable);
387        }
388
389        Ok(())
390    }
391
392    fn safe_and_cacheable(method: &Method) -> bool {
393        // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods
394        // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods
395        method == Method::GET || method == Method::HEAD
396    }
397
398    pub fn fields(&self) -> &HeaderMap {
399        &self.header_block.fields
400    }
401
402    #[cfg(feature = "unstable")]
403    pub fn into_fields(self) -> HeaderMap {
404        self.header_block.fields
405    }
406
407    /// Loads the push promise frame but doesn't actually do HPACK decoding.
408    ///
409    /// HPACK decoding is done in the `load_hpack` step.
410    pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
411        let flags = PushPromiseFlag(head.flag());
412        let mut pad = 0;
413
414        if head.stream_id().is_zero() {
415            return Err(Error::InvalidStreamId);
416        }
417
418        // Read the padding length
419        if flags.is_padded() {
420            if src.is_empty() {
421                return Err(Error::MalformedMessage);
422            }
423
424            // TODO: Ensure payload is sized correctly
425            pad = src[0] as usize;
426
427            // Drop the padding
428            let _ = src.split_to(1);
429        }
430
431        if src.len() < 5 {
432            return Err(Error::MalformedMessage);
433        }
434
435        let (promised_id, _) = StreamId::parse(&src[..4]);
436        // Drop promised_id bytes
437        let _ = src.split_to(4);
438
439        if pad > 0 {
440            if pad > src.len() {
441                return Err(Error::TooMuchPadding);
442            }
443
444            let len = src.len() - pad;
445            src.truncate(len);
446        }
447
448        let frame = PushPromise {
449            flags,
450            header_block: HeaderBlock {
451                fields: HeaderMap::new(),
452                field_size: 0,
453                is_over_size: false,
454                pseudo: Pseudo::default(),
455            },
456            promised_id,
457            stream_id: head.stream_id(),
458        };
459        Ok((frame, src))
460    }
461
462    pub fn load_hpack(
463        &mut self,
464        src: &mut BytesMut,
465        max_header_list_size: usize,
466        decoder: &mut hpack::Decoder,
467    ) -> Result<(), Error> {
468        self.header_block.load(src, max_header_list_size, decoder)
469    }
470
471    pub fn stream_id(&self) -> StreamId {
472        self.stream_id
473    }
474
475    pub fn promised_id(&self) -> StreamId {
476        self.promised_id
477    }
478
479    pub fn is_end_headers(&self) -> bool {
480        self.flags.is_end_headers()
481    }
482
483    pub fn set_end_headers(&mut self) {
484        self.flags.set_end_headers();
485    }
486
487    pub fn is_over_size(&self) -> bool {
488        self.header_block.is_over_size
489    }
490
491    pub fn encode(
492        self,
493        encoder: &mut hpack::Encoder,
494        dst: &mut EncodeBuf<'_>,
495    ) -> Option<Continuation> {
496        // At this point, the `is_end_headers` flag should always be set
497        debug_assert!(self.flags.is_end_headers());
498
499        let head = self.head();
500        let promised_id = self.promised_id;
501
502        self.header_block
503            .into_encoding(encoder)
504            .encode(&head, dst, |dst| {
505                dst.put_u32(promised_id.into());
506            })
507    }
508
509    fn head(&self) -> Head {
510        Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
511    }
512
513    /// Consume `self`, returning the parts of the frame
514    pub fn into_parts(self) -> (Pseudo, HeaderMap) {
515        (self.header_block.pseudo, self.header_block.fields)
516    }
517}
518
519impl<T> From<PushPromise> for Frame<T> {
520    fn from(src: PushPromise) -> Self {
521        Frame::PushPromise(src)
522    }
523}
524
525impl fmt::Debug for PushPromise {
526    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
527        f.debug_struct("PushPromise")
528            .field("stream_id", &self.stream_id)
529            .field("promised_id", &self.promised_id)
530            .field("flags", &self.flags)
531            // `fields` and `pseudo` purposefully not included
532            .finish()
533    }
534}
535
536// ===== impl Continuation =====
537
538impl Continuation {
539    fn head(&self) -> Head {
540        Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
541    }
542
543    pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> {
544        // Get the CONTINUATION frame head
545        let head = self.head();
546
547        self.header_block.encode(&head, dst, |_| {})
548    }
549}
550
551// ===== impl Pseudo =====
552
553impl Pseudo {
554    pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
555        let parts = uri::Parts::from(uri);
556
557        let (scheme, path) = if method == Method::CONNECT && protocol.is_none() {
558            (None, None)
559        } else {
560            let path = parts
561                .path_and_query
562                .map(|v| BytesStr::from(v.as_str()))
563                .unwrap_or(BytesStr::from_static(""));
564
565            let path = if !path.is_empty() {
566                path
567            } else {
568                if method == Method::OPTIONS {
569                    BytesStr::from_static("*")
570                } else {
571                    BytesStr::from_static("/")
572                }
573            };
574
575            (parts.scheme, Some(path))
576        };
577
578        let mut pseudo = Pseudo {
579            method: Some(method),
580            scheme: None,
581            authority: None,
582            path,
583            protocol,
584            status: None,
585        };
586
587        // If the URI includes a scheme component, add it to the pseudo headers
588        if let Some(scheme) = scheme {
589            pseudo.set_scheme(scheme);
590        }
591
592        // If the URI includes an authority component, add it to the pseudo
593        // headers
594        if let Some(authority) = parts.authority {
595            pseudo.set_authority(BytesStr::from(authority.as_str()));
596        }
597
598        pseudo
599    }
600
601    pub fn response(status: StatusCode) -> Self {
602        Pseudo {
603            method: None,
604            scheme: None,
605            authority: None,
606            path: None,
607            protocol: None,
608            status: Some(status),
609        }
610    }
611
612    #[cfg(feature = "unstable")]
613    pub fn set_status(&mut self, value: StatusCode) {
614        self.status = Some(value);
615    }
616
617    pub fn set_scheme(&mut self, scheme: uri::Scheme) {
618        let bytes_str = match scheme.as_str() {
619            "http" => BytesStr::from_static("http"),
620            "https" => BytesStr::from_static("https"),
621            s => BytesStr::from(s),
622        };
623        self.scheme = Some(bytes_str);
624    }
625
626    #[cfg(feature = "unstable")]
627    pub fn set_protocol(&mut self, protocol: Protocol) {
628        self.protocol = Some(protocol);
629    }
630
631    pub fn set_authority(&mut self, authority: BytesStr) {
632        self.authority = Some(authority);
633    }
634
635    /// Whether it has status 1xx
636    pub(crate) fn is_informational(&self) -> bool {
637        self.status
638            .map_or(false, |status| status.is_informational())
639    }
640}
641
642// ===== impl EncodingHeaderBlock =====
643
644impl EncodingHeaderBlock {
645    fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation>
646    where
647        F: FnOnce(&mut EncodeBuf<'_>),
648    {
649        let head_pos = dst.get_ref().len();
650
651        // At this point, we don't know how big the h2 frame will be.
652        // So, we write the head with length 0, then write the body, and
653        // finally write the length once we know the size.
654        head.encode(0, dst);
655
656        let payload_pos = dst.get_ref().len();
657
658        f(dst);
659
660        // Now, encode the header payload
661        let continuation = if self.hpack.len() > dst.remaining_mut() {
662            dst.put_slice(&self.hpack.split_to(dst.remaining_mut()));
663
664            Some(Continuation {
665                stream_id: head.stream_id(),
666                header_block: self,
667            })
668        } else {
669            dst.put_slice(&self.hpack);
670
671            None
672        };
673
674        // Compute the header block length
675        let payload_len = (dst.get_ref().len() - payload_pos) as u64;
676
677        // Write the frame length
678        let payload_len_be = payload_len.to_be_bytes();
679        assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
680        (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
681
682        if continuation.is_some() {
683            // There will be continuation frames, so the `is_end_headers` flag
684            // must be unset
685            debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
686
687            dst.get_mut()[head_pos + 4] -= END_HEADERS;
688        }
689
690        continuation
691    }
692}
693
694// ===== impl Iter =====
695
696impl Iterator for Iter {
697    type Item = hpack::Header<Option<HeaderName>>;
698
699    fn next(&mut self) -> Option<Self::Item> {
700        use crate::hpack::Header::*;
701
702        if let Some(ref mut pseudo) = self.pseudo {
703            if let Some(method) = pseudo.method.take() {
704                return Some(Method(method));
705            }
706
707            if let Some(scheme) = pseudo.scheme.take() {
708                return Some(Scheme(scheme));
709            }
710
711            if let Some(authority) = pseudo.authority.take() {
712                return Some(Authority(authority));
713            }
714
715            if let Some(path) = pseudo.path.take() {
716                return Some(Path(path));
717            }
718
719            if let Some(protocol) = pseudo.protocol.take() {
720                return Some(Protocol(protocol));
721            }
722
723            if let Some(status) = pseudo.status.take() {
724                return Some(Status(status));
725            }
726        }
727
728        self.pseudo = None;
729
730        self.fields
731            .next()
732            .map(|(name, value)| Field { name, value })
733    }
734}
735
736// ===== impl HeadersFlag =====
737
738impl HeadersFlag {
739    pub fn empty() -> HeadersFlag {
740        HeadersFlag(0)
741    }
742
743    pub fn load(bits: u8) -> HeadersFlag {
744        HeadersFlag(bits & ALL)
745    }
746
747    pub fn is_end_stream(&self) -> bool {
748        self.0 & END_STREAM == END_STREAM
749    }
750
751    pub fn set_end_stream(&mut self) {
752        self.0 |= END_STREAM;
753    }
754
755    pub fn is_end_headers(&self) -> bool {
756        self.0 & END_HEADERS == END_HEADERS
757    }
758
759    pub fn set_end_headers(&mut self) {
760        self.0 |= END_HEADERS;
761    }
762
763    pub fn is_padded(&self) -> bool {
764        self.0 & PADDED == PADDED
765    }
766
767    pub fn is_priority(&self) -> bool {
768        self.0 & PRIORITY == PRIORITY
769    }
770}
771
772impl Default for HeadersFlag {
773    /// Returns a `HeadersFlag` value with `END_HEADERS` set.
774    fn default() -> Self {
775        HeadersFlag(END_HEADERS)
776    }
777}
778
779impl From<HeadersFlag> for u8 {
780    fn from(src: HeadersFlag) -> u8 {
781        src.0
782    }
783}
784
785impl fmt::Debug for HeadersFlag {
786    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
787        util::debug_flags(fmt, self.0)
788            .flag_if(self.is_end_headers(), "END_HEADERS")
789            .flag_if(self.is_end_stream(), "END_STREAM")
790            .flag_if(self.is_padded(), "PADDED")
791            .flag_if(self.is_priority(), "PRIORITY")
792            .finish()
793    }
794}
795
796// ===== impl PushPromiseFlag =====
797
798impl PushPromiseFlag {
799    pub fn empty() -> PushPromiseFlag {
800        PushPromiseFlag(0)
801    }
802
803    pub fn load(bits: u8) -> PushPromiseFlag {
804        PushPromiseFlag(bits & ALL)
805    }
806
807    pub fn is_end_headers(&self) -> bool {
808        self.0 & END_HEADERS == END_HEADERS
809    }
810
811    pub fn set_end_headers(&mut self) {
812        self.0 |= END_HEADERS;
813    }
814
815    pub fn is_padded(&self) -> bool {
816        self.0 & PADDED == PADDED
817    }
818}
819
820impl Default for PushPromiseFlag {
821    /// Returns a `PushPromiseFlag` value with `END_HEADERS` set.
822    fn default() -> Self {
823        PushPromiseFlag(END_HEADERS)
824    }
825}
826
827impl From<PushPromiseFlag> for u8 {
828    fn from(src: PushPromiseFlag) -> u8 {
829        src.0
830    }
831}
832
833impl fmt::Debug for PushPromiseFlag {
834    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
835        util::debug_flags(fmt, self.0)
836            .flag_if(self.is_end_headers(), "END_HEADERS")
837            .flag_if(self.is_padded(), "PADDED")
838            .finish()
839    }
840}
841
842// ===== HeaderBlock =====
843
844impl HeaderBlock {
845    fn load(
846        &mut self,
847        src: &mut BytesMut,
848        max_header_list_size: usize,
849        decoder: &mut hpack::Decoder,
850    ) -> Result<(), Error> {
851        let mut reg = !self.fields.is_empty();
852        let mut malformed = false;
853        let mut headers_size = self.calculate_header_list_size();
854
855        macro_rules! set_pseudo {
856            ($field:ident, $val:expr) => {{
857                if reg {
858                    tracing::trace!("load_hpack; header malformed -- pseudo not at head of block");
859                    malformed = true;
860                } else if self.pseudo.$field.is_some() {
861                    tracing::trace!("load_hpack; header malformed -- repeated pseudo");
862                    malformed = true;
863                } else {
864                    let __val = $val;
865                    headers_size +=
866                        decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
867                    if headers_size < max_header_list_size {
868                        self.pseudo.$field = Some(__val);
869                    } else if !self.is_over_size {
870                        tracing::trace!("load_hpack; header list size over max");
871                        self.is_over_size = true;
872                    }
873                }
874            }};
875        }
876
877        let mut cursor = Cursor::new(src);
878
879        // If the header frame is malformed, we still have to continue decoding
880        // the headers. A malformed header frame is a stream level error, but
881        // the hpack state is connection level. In order to maintain correct
882        // state for other streams, the hpack decoding process must complete.
883        let res = decoder.decode(&mut cursor, |header| {
884            use crate::hpack::Header::*;
885
886            match header {
887                Field { name, value } => {
888                    // Connection level header fields are not supported and must
889                    // result in a protocol error.
890
891                    if name == header::CONNECTION
892                        || name == header::TRANSFER_ENCODING
893                        || name == header::UPGRADE
894                        || name == "keep-alive"
895                        || name == "proxy-connection"
896                    {
897                        tracing::trace!("load_hpack; connection level header");
898                        malformed = true;
899                    } else if name == header::TE && value != "trailers" {
900                        tracing::trace!(
901                            "load_hpack; TE header not set to trailers; val={:?}",
902                            value
903                        );
904                        malformed = true;
905                    } else {
906                        reg = true;
907
908                        headers_size += decoded_header_size(name.as_str().len(), value.len());
909                        if headers_size < max_header_list_size {
910                            self.field_size +=
911                                decoded_header_size(name.as_str().len(), value.len());
912                            self.fields.append(name, value);
913                        } else if !self.is_over_size {
914                            tracing::trace!("load_hpack; header list size over max");
915                            self.is_over_size = true;
916                        }
917                    }
918                }
919                Authority(v) => set_pseudo!(authority, v),
920                Method(v) => set_pseudo!(method, v),
921                Scheme(v) => set_pseudo!(scheme, v),
922                Path(v) => set_pseudo!(path, v),
923                Protocol(v) => set_pseudo!(protocol, v),
924                Status(v) => set_pseudo!(status, v),
925            }
926        });
927
928        if let Err(e) = res {
929            tracing::trace!("hpack decoding error; err={:?}", e);
930            return Err(e.into());
931        }
932
933        if malformed {
934            tracing::trace!("malformed message");
935            return Err(Error::MalformedMessage);
936        }
937
938        Ok(())
939    }
940
941    fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock {
942        let mut hpack = BytesMut::new();
943        let headers = Iter {
944            pseudo: Some(self.pseudo),
945            fields: self.fields.into_iter(),
946        };
947
948        encoder.encode(headers, &mut hpack);
949
950        EncodingHeaderBlock {
951            hpack: hpack.freeze(),
952        }
953    }
954
955    /// Calculates the size of the currently decoded header list.
956    ///
957    /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
958    ///
959    /// > The value is based on the uncompressed size of header fields,
960    /// > including the length of the name and value in octets plus an
961    /// > overhead of 32 octets for each header field.
962    fn calculate_header_list_size(&self) -> usize {
963        macro_rules! pseudo_size {
964            ($name:ident) => {{
965                self.pseudo
966                    .$name
967                    .as_ref()
968                    .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
969                    .unwrap_or(0)
970            }};
971        }
972
973        pseudo_size!(method)
974            + pseudo_size!(scheme)
975            + pseudo_size!(status)
976            + pseudo_size!(authority)
977            + pseudo_size!(path)
978            + self.field_size
979    }
980}
981
982fn calculate_headermap_size(map: &HeaderMap) -> usize {
983    map.iter()
984        .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
985        .sum::<usize>()
986}
987
988fn decoded_header_size(name: usize, value: usize) -> usize {
989    name + value + 32
990}
991
992#[cfg(test)]
993mod test {
994    use super::*;
995    use crate::frame;
996    use crate::hpack::{huffman, Encoder};
997
998    #[test]
999    fn test_nameless_header_at_resume() {
1000        let mut encoder = Encoder::default();
1001        let mut dst = BytesMut::new();
1002
1003        let headers = Headers::new(
1004            StreamId::ZERO,
1005            Default::default(),
1006            HeaderMap::from_iter(vec![
1007                (
1008                    HeaderName::from_static("hello"),
1009                    HeaderValue::from_static("world"),
1010                ),
1011                (
1012                    HeaderName::from_static("hello"),
1013                    HeaderValue::from_static("zomg"),
1014                ),
1015                (
1016                    HeaderName::from_static("hello"),
1017                    HeaderValue::from_static("sup"),
1018                ),
1019            ]),
1020        );
1021
1022        let continuation = headers
1023            .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8))
1024            .unwrap();
1025
1026        assert_eq!(17, dst.len());
1027        assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
1028        assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
1029        assert_eq!("hello", huff_decode(&dst[11..15]));
1030        assert_eq!(0x80 | 4, dst[15]);
1031
1032        let mut world = dst[16..17].to_owned();
1033
1034        dst.clear();
1035
1036        assert!(continuation
1037            .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16))
1038            .is_none());
1039
1040        world.extend_from_slice(&dst[9..12]);
1041        assert_eq!("world", huff_decode(&world));
1042
1043        assert_eq!(24, dst.len());
1044        assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]);
1045
1046        // // Next is not indexed
1047        assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
1048        assert_eq!("zomg", huff_decode(&dst[15..18]));
1049        assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
1050        assert_eq!("sup", huff_decode(&dst[21..]));
1051    }
1052
1053    fn huff_decode(src: &[u8]) -> BytesMut {
1054        let mut buf = BytesMut::new();
1055        huffman::decode(src, &mut buf).unwrap()
1056    }
1057
1058    #[test]
1059    fn test_connect_request_pseudo_headers_omits_path_and_scheme() {
1060        // CONNECT requests MUST NOT include :scheme & :path pseudo-header fields
1061        // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.5
1062
1063        assert_eq!(
1064            Pseudo::request(
1065                Method::CONNECT,
1066                Uri::from_static("https://example.com:8443"),
1067                None
1068            ),
1069            Pseudo {
1070                method: Method::CONNECT.into(),
1071                authority: BytesStr::from_static("example.com:8443").into(),
1072                ..Default::default()
1073            }
1074        );
1075
1076        assert_eq!(
1077            Pseudo::request(
1078                Method::CONNECT,
1079                Uri::from_static("https://example.com/test"),
1080                None
1081            ),
1082            Pseudo {
1083                method: Method::CONNECT.into(),
1084                authority: BytesStr::from_static("example.com").into(),
1085                ..Default::default()
1086            }
1087        );
1088
1089        assert_eq!(
1090            Pseudo::request(Method::CONNECT, Uri::from_static("example.com:8443"), None),
1091            Pseudo {
1092                method: Method::CONNECT.into(),
1093                authority: BytesStr::from_static("example.com:8443").into(),
1094                ..Default::default()
1095            }
1096        );
1097    }
1098
1099    #[test]
1100    fn test_extended_connect_request_pseudo_headers_includes_path_and_scheme() {
1101        // On requests that contain the :protocol pseudo-header field, the
1102        // :scheme and :path pseudo-header fields of the target URI (see
1103        // Section 5) MUST also be included.
1104        // See: https://datatracker.ietf.org/doc/html/rfc8441#section-4
1105
1106        assert_eq!(
1107            Pseudo::request(
1108                Method::CONNECT,
1109                Uri::from_static("https://example.com:8443"),
1110                Protocol::from_static("the-bread-protocol").into()
1111            ),
1112            Pseudo {
1113                method: Method::CONNECT.into(),
1114                authority: BytesStr::from_static("example.com:8443").into(),
1115                scheme: BytesStr::from_static("https").into(),
1116                path: BytesStr::from_static("/").into(),
1117                protocol: Protocol::from_static("the-bread-protocol").into(),
1118                ..Default::default()
1119            }
1120        );
1121
1122        assert_eq!(
1123            Pseudo::request(
1124                Method::CONNECT,
1125                Uri::from_static("https://example.com:8443/test"),
1126                Protocol::from_static("the-bread-protocol").into()
1127            ),
1128            Pseudo {
1129                method: Method::CONNECT.into(),
1130                authority: BytesStr::from_static("example.com:8443").into(),
1131                scheme: BytesStr::from_static("https").into(),
1132                path: BytesStr::from_static("/test").into(),
1133                protocol: Protocol::from_static("the-bread-protocol").into(),
1134                ..Default::default()
1135            }
1136        );
1137
1138        assert_eq!(
1139            Pseudo::request(
1140                Method::CONNECT,
1141                Uri::from_static("http://example.com/a/b/c"),
1142                Protocol::from_static("the-bread-protocol").into()
1143            ),
1144            Pseudo {
1145                method: Method::CONNECT.into(),
1146                authority: BytesStr::from_static("example.com").into(),
1147                scheme: BytesStr::from_static("http").into(),
1148                path: BytesStr::from_static("/a/b/c").into(),
1149                protocol: Protocol::from_static("the-bread-protocol").into(),
1150                ..Default::default()
1151            }
1152        );
1153    }
1154
1155    #[test]
1156    fn test_options_request_with_empty_path_has_asterisk_as_pseudo_path() {
1157        // an OPTIONS request for an "http" or "https" URI that does not include a path component;
1158        // these MUST include a ":path" pseudo-header field with a value of '*' (see Section 7.1 of [HTTP]).
1159        // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.3.1
1160        assert_eq!(
1161            Pseudo::request(Method::OPTIONS, Uri::from_static("example.com:8080"), None,),
1162            Pseudo {
1163                method: Method::OPTIONS.into(),
1164                authority: BytesStr::from_static("example.com:8080").into(),
1165                path: BytesStr::from_static("*").into(),
1166                ..Default::default()
1167            }
1168        );
1169    }
1170
1171    #[test]
1172    fn test_non_option_and_non_connect_requests_include_path_and_scheme() {
1173        let methods = [
1174            Method::GET,
1175            Method::POST,
1176            Method::PUT,
1177            Method::DELETE,
1178            Method::HEAD,
1179            Method::PATCH,
1180            Method::TRACE,
1181        ];
1182
1183        for method in methods {
1184            assert_eq!(
1185                Pseudo::request(
1186                    method.clone(),
1187                    Uri::from_static("http://example.com:8080"),
1188                    None,
1189                ),
1190                Pseudo {
1191                    method: method.clone().into(),
1192                    authority: BytesStr::from_static("example.com:8080").into(),
1193                    scheme: BytesStr::from_static("http").into(),
1194                    path: BytesStr::from_static("/").into(),
1195                    ..Default::default()
1196                }
1197            );
1198            assert_eq!(
1199                Pseudo::request(
1200                    method.clone(),
1201                    Uri::from_static("https://example.com/a/b/c"),
1202                    None,
1203                ),
1204                Pseudo {
1205                    method: method.into(),
1206                    authority: BytesStr::from_static("example.com").into(),
1207                    scheme: BytesStr::from_static("https").into(),
1208                    path: BytesStr::from_static("/a/b/c").into(),
1209                    ..Default::default()
1210                }
1211            );
1212        }
1213    }
1214}