domain/base/
header.rs

1//! The header of a DNS message.
2//!
3//! Each DNS message starts with a twelve octet long header section
4//! containing some general information related to the message as well as
5//! the number of records in each of the four sections that follow the header.
6//! Its content and format are defined in section 4.1.1 of [RFC 1035].
7//!
8//! In order to reflect the fact that changing the section counts may
9//! invalidate the rest of the message whereas the other elements of the
10//! header section can safely be modified, the whole header has been split
11//! into two separate types: [`Header`] contains the safely modifyable part
12//! at the beginning and [`HeaderCounts`] contains the section counts. In
13//! addition, the [`HeaderSection`] type wraps both of them into a single
14//! type.
15//!
16//! [RFC 1035]: https://tools.ietf.org/html/rfc1035
17
18use super::iana::{Opcode, Rcode};
19use super::wire::ParseError;
20use core::{fmt, mem, str::FromStr};
21use octseq::builder::OctetsBuilder;
22use octseq::parse::Parser;
23
24//------------ Header --------------------------------------------------
25
26/// The first part of the header of a DNS message.
27///
28/// This type represents the information contained in the first four octets
29/// of the header: the message ID, opcode, rcode, and the various flags. It
30/// keeps those four octets in wire representation, i.e., in network byte
31/// order. The data is layed out like this:
32///
33/// ```text
34///                                 1  1  1  1  1  1
35///   0  1  2  3  4  5  6  7  8  9  0  1  2  3  4  5
36/// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
37/// |                      ID                       |
38/// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
39/// |QR|   Opcode  |AA|TC|RD|RA|Z |AD|CD|   RCODE   |
40/// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
41/// ```
42///
43/// Methods are available for accessing each of these fields. For more
44/// information on the fields, see these methods in the section
45/// [Field Access] below.
46///
47/// You can create owned values via the [`new`][Self::new] method or
48/// the [`Default`] trait.  However, more often the type will
49/// be used via a reference into the octets of an actual message. The
50/// functions [`for_message_slice`][Self::for_message_slice] and
51/// [`for_message_slice_mut`][Self::for_message_slice_mut] create such
52/// references from an octets slice.
53///
54/// The basic structure and most of the fields re defined in [RFC 1035],
55/// except for the AD and CD flags, which are defined in [RFC 4035].
56///
57/// [Field Access]: #field-access
58/// [RFC 1035]: https://tools.ietf.org/html/rfc1035
59/// [RFC 4035]: https://tools.ietf.org/html/rfc4035
60#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
61#[repr(transparent)]
62pub struct Header {
63    /// The actual header in its wire format representation.
64    ///
65    /// This means that the ID field is in big endian.
66    inner: [u8; 4],
67}
68
69/// # Creation and Conversion
70///
71impl Header {
72    /// Creates a new header.
73    ///
74    /// The new header has all fields as either zero or false. Thus, the
75    /// opcode will be [`Opcode::QUERY`] and the response code will be
76    /// [`Rcode::NOERROR`].
77    #[must_use]
78    pub fn new() -> Self {
79        Self::default()
80    }
81
82    /// Creates a header reference from an octets slice of a message.
83    ///
84    /// # Panics
85    ///
86    /// This function panics if the slice is less than four octets long.
87    #[must_use]
88    pub fn for_message_slice(s: &[u8]) -> &Header {
89        assert!(s.len() >= mem::size_of::<Header>());
90
91        // SAFETY: The pointer cast is sound because
92        //  - Header has repr(transparent) and
93        //  - the slice is large enough
94        unsafe { &*(s.as_ptr() as *const Header) }
95    }
96
97    /// Creates a mutable header reference from an octets slice of a message.
98    ///
99    /// # Panics
100    ///
101    /// This function panics if the slice is less than four octets long.
102    pub fn for_message_slice_mut(s: &mut [u8]) -> &mut Header {
103        assert!(s.len() >= mem::size_of::<Header>());
104
105        // SAFETY: The pointer cast is sound because
106        //  - Header has repr(transparent) and
107        //  - the slice is large enough
108        unsafe { &mut *(s.as_mut_ptr() as *mut Header) }
109    }
110
111    /// Returns a reference to the underlying octets slice.
112    #[must_use]
113    pub fn as_slice(&self) -> &[u8] {
114        &self.inner
115    }
116}
117
118/// # Field Access
119///
120impl Header {
121    /// Returns the value of the ID field.
122    ///
123    /// The ID field is an identifier chosen by whoever created a query
124    /// and is copied into a response by a server. It allows matching
125    /// incoming responses to their queries.
126    ///
127    /// When choosing an ID for an outgoing message, make sure it is random
128    /// to avoid spoofing by guessing the message ID. If `std` support
129    /// is enabled, the method
130    #[cfg_attr(
131        feature = "std",
132        doc = "[`set_random_id`][Self::set_random_id]"
133    )]
134    #[cfg_attr(not(feature = "std"), doc = "`set_random_id`")]
135    /// can be used for this purpose.
136    #[must_use]
137    pub fn id(self) -> u16 {
138        u16::from_be_bytes(self.inner[..2].try_into().unwrap())
139    }
140
141    /// Sets the value of the ID field.
142    pub fn set_id(&mut self, value: u16) {
143        self.inner[..2].copy_from_slice(&value.to_be_bytes())
144    }
145
146    /// Sets the value of the ID field to a randomly chosen number.
147    #[cfg(feature = "rand")]
148    pub fn set_random_id(&mut self) {
149        self.set_id(::rand::random())
150    }
151
152    /// Returns whether the [QR](Flags::qr) bit is set.
153    #[must_use]
154    pub fn qr(self) -> bool {
155        self.get_bit(2, 7)
156    }
157
158    /// Sets the value of the [QR](Flags::qr) bit.
159    pub fn set_qr(&mut self, set: bool) {
160        self.set_bit(2, 7, set)
161    }
162
163    /// Returns the value of the Opcode field.
164    ///
165    /// This field specifies the kind of query a message contains. See
166    /// the [`Opcode`] type for more information on the possible values and
167    /// their meaning. Normal queries have the variant [`Opcode::QUERY`]
168    /// which is also the default value when creating a new header.
169    #[must_use]
170    pub fn opcode(self) -> Opcode {
171        Opcode::from_int((self.inner[2] >> 3) & 0x0F)
172    }
173
174    /// Sets the value of the opcode field.
175    pub fn set_opcode(&mut self, opcode: Opcode) {
176        self.inner[2] = self.inner[2] & 0x87 | (opcode.to_int() << 3);
177    }
178
179    /// Returns all flags contained in the header.
180    ///
181    /// This is a virtual field composed of all the flag bits that are present
182    /// in the header. The returned [`Flags`] type can be useful when you're
183    /// working with all flags, rather than a single one, which can be easily
184    /// obtained from the header directly.
185    #[must_use]
186    pub fn flags(self) -> Flags {
187        Flags {
188            qr: self.qr(),
189            aa: self.aa(),
190            tc: self.tc(),
191            rd: self.rd(),
192            ra: self.ra(),
193            ad: self.ad(),
194            cd: self.cd(),
195        }
196    }
197
198    /// Sets all flag bits.
199    pub fn set_flags(&mut self, flags: Flags) {
200        self.set_qr(flags.qr);
201        self.set_aa(flags.aa);
202        self.set_tc(flags.tc);
203        self.set_rd(flags.rd);
204        self.set_ra(flags.ra);
205        self.set_ad(flags.ad);
206        self.set_cd(flags.cd);
207    }
208
209    /// Returns whether the [AA](Flags::aa) bit is set.
210    #[must_use]
211    pub fn aa(self) -> bool {
212        self.get_bit(2, 2)
213    }
214
215    /// Sets the value of the [AA](Flags::aa) bit.
216    pub fn set_aa(&mut self, set: bool) {
217        self.set_bit(2, 2, set)
218    }
219
220    /// Returns whether the [TC](Flags::tc) bit is set.
221    #[must_use]
222    pub fn tc(self) -> bool {
223        self.get_bit(2, 1)
224    }
225
226    /// Sets the value of the [TC](Flags::tc) bit.
227    pub fn set_tc(&mut self, set: bool) {
228        self.set_bit(2, 1, set)
229    }
230
231    /// Returns whether the [RD](Flags::rd) bit is set.
232    #[must_use]
233    pub fn rd(self) -> bool {
234        self.get_bit(2, 0)
235    }
236
237    /// Sets the value of the [RD](Flags::rd) bit.
238    pub fn set_rd(&mut self, set: bool) {
239        self.set_bit(2, 0, set)
240    }
241
242    /// Returns whether the [RA](Flags::ra) bit is set.
243    #[must_use]
244    pub fn ra(self) -> bool {
245        self.get_bit(3, 7)
246    }
247
248    /// Sets the value of the [RA](Flags::ra) bit.
249    pub fn set_ra(&mut self, set: bool) {
250        self.set_bit(3, 7, set)
251    }
252
253    /// Returns whether the reserved bit is set.
254    ///
255    /// This bit must be `false` in all queries and responses.
256    #[must_use]
257    pub fn z(self) -> bool {
258        self.get_bit(3, 6)
259    }
260
261    /// Sets the value of the reserved bit.
262    pub fn set_z(&mut self, set: bool) {
263        self.set_bit(3, 6, set)
264    }
265
266    /// Returns whether the [AD](Flags::ad) bit is set.
267    #[must_use]
268    pub fn ad(self) -> bool {
269        self.get_bit(3, 5)
270    }
271
272    /// Sets the value of the [AD](Flags::ad) bit.
273    pub fn set_ad(&mut self, set: bool) {
274        self.set_bit(3, 5, set)
275    }
276
277    /// Returns whether the [CD](Flags::cd) bit is set.
278    #[must_use]
279    pub fn cd(self) -> bool {
280        self.get_bit(3, 4)
281    }
282
283    /// Sets the value of the [CD](Flags::cd) bit.
284    pub fn set_cd(&mut self, set: bool) {
285        self.set_bit(3, 4, set)
286    }
287
288    /// Returns the value of the RCODE field.
289    ///
290    /// The *response code* is used in a response to indicate what happened
291    /// when processing the query. See the [`Rcode`] type for information on
292    /// possible values and their meaning.
293    ///
294    /// [`Rcode`]: ../../iana/rcode/enum.Rcode.html
295    #[must_use]
296    pub fn rcode(self) -> Rcode {
297        Rcode::masked_from_int(self.inner[3])
298    }
299
300    /// Sets the value of the RCODE field.
301    pub fn set_rcode(&mut self, rcode: Rcode) {
302        self.inner[3] = self.inner[3] & 0xF0 | (rcode.to_int() & 0x0F);
303    }
304
305    //--- Internal helpers
306
307    /// Returns the value of the bit at the given position.
308    ///
309    /// The argument `offset` gives the byte offset of the underlying bytes
310    /// slice and `bit` gives the number of the bit with the most significant
311    /// bit being 7.
312    fn get_bit(self, offset: usize, bit: usize) -> bool {
313        self.inner[offset] & (1 << bit) != 0
314    }
315
316    /// Sets or resets the given bit.
317    fn set_bit(&mut self, offset: usize, bit: usize, set: bool) {
318        if set {
319            self.inner[offset] |= 1 << bit
320        } else {
321            self.inner[offset] &= !(1 << bit)
322        }
323    }
324}
325
326//------------ Flags ---------------------------------------------------
327
328/// The flags contained in the DNS message header.
329///
330/// This is a utility type that makes it easier to work with flags. It contains
331/// only standard DNS message flags that are part of the [`Header`], i.e., EDNS
332/// flags are not included.
333///
334/// This type has a text notation and can be created from it as well. Each
335/// flags that is set is represented by a two-letter token, which is the
336/// uppercase version of the flag name.  If mutliple flags are set, the tokens
337/// are separated by space.
338///
339/// ```
340/// use core::str::FromStr;
341/// use domain::base::header::Flags;
342///
343/// let flags = Flags::from_str("QR AA").unwrap();
344/// assert!(flags.qr && flags.aa);
345/// assert_eq!(format!("{}", flags), "QR AA");
346/// ```
347#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Hash)]
348pub struct Flags {
349    /// The `QR` bit specifies whether a message is a query (`false`) or a
350    /// response (`true`). In other words, this bit is actually stating whether
351    /// the message is *not* a query. So, perhaps it might be good to read ‘QR’
352    /// as ‘query response.’
353    pub qr: bool,
354
355    /// Using the `AA` bit, a name server generating a response states whether
356    /// it is authoritative for the requested domain name, ie., whether this
357    /// response is an *authoritative answer.* The field has no meaning in a
358    /// query.
359    pub aa: bool,
360
361    /// The *truncation* (`TC`) bit is set if there was more data available then
362    /// fit into the message. This is typically used when employing datagram
363    /// transports such as UDP to signal that the answer didn’t fit into a
364    /// response and the query should be tried again using a stream transport
365    /// such as TCP.
366    pub tc: bool,
367
368    /// The *recursion desired* (`RD`) bit may be set in a query to ask the name
369    /// server to try and recursively gather a response if it doesn’t have the
370    /// data available locally. The bit’s value is copied into the response.
371    pub rd: bool,
372
373    /// In a response, the *recursion available* (`RA`) bit denotes whether the
374    /// responding name server supports recursion. It has no meaning in a query.
375    pub ra: bool,
376
377    /// The *authentic data* (`AD`) bit is used by security-aware recursive name
378    /// servers to indicate that it considers all RRsets in its response are
379    /// authentic, i.e., have successfully passed DNSSEC validation.
380    pub ad: bool,
381
382    /// The *checking disabled* (`CD`) bit is used by a security-aware resolver
383    /// to indicate that it does not want upstream name servers to perform
384    /// verification but rather would like to verify everything itself.
385    pub cd: bool,
386}
387
388/// # Creation and Conversion
389///
390impl Flags {
391    /// Creates new flags.
392    ///
393    /// All flags will be unset.
394    #[must_use]
395    pub fn new() -> Self {
396        Self::default()
397    }
398}
399
400//--- Display & FromStr
401
402impl fmt::Display for Flags {
403    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
404        let mut sep = "";
405        if self.qr {
406            write!(f, "QR")?;
407            sep = " ";
408        }
409        if self.aa {
410            write!(f, "{}AA", sep)?;
411            sep = " ";
412        }
413        if self.tc {
414            write!(f, "{}TC", sep)?;
415            sep = " ";
416        }
417        if self.rd {
418            write!(f, "{}RD", sep)?;
419            sep = " ";
420        }
421        if self.ra {
422            write!(f, "{}RA", sep)?;
423            sep = " ";
424        }
425        if self.ad {
426            write!(f, "{}AD", sep)?;
427            sep = " ";
428        }
429        if self.cd {
430            write!(f, "{}CD", sep)?;
431        }
432        Ok(())
433    }
434}
435
436impl FromStr for Flags {
437    type Err = FlagsFromStrError;
438
439    fn from_str(s: &str) -> Result<Self, Self::Err> {
440        let mut flags = Flags::new();
441        for token in s.split(' ') {
442            match token {
443                "QR" | "Qr" | "qR" | "qr" => flags.qr = true,
444                "AA" | "Aa" | "aA" | "aa" => flags.aa = true,
445                "TC" | "Tc" | "tC" | "tc" => flags.tc = true,
446                "RD" | "Rd" | "rD" | "rd" => flags.rd = true,
447                "RA" | "Ra" | "rA" | "ra" => flags.ra = true,
448                "AD" | "Ad" | "aD" | "ad" => flags.ad = true,
449                "CD" | "Cd" | "cD" | "cd" => flags.cd = true,
450                "" => {}
451                _ => return Err(FlagsFromStrError(())),
452            }
453        }
454        Ok(flags)
455    }
456}
457
458//------------ HeaderCounts -------------------------------------------------
459
460/// The section count part of the header section of a DNS message.
461///
462/// This part consists of four 16 bit counters for the number of entries in
463/// the four sections of a DNS message. The type contains the sequence of
464/// these for values in wire format, i.e., in network byte order.
465///
466/// The counters are arranged in the same order as the sections themselves:
467/// QDCOUNT for the question section, ANCOUNT for the answer section,
468/// NSCOUNT for the authority section, and ARCOUNT for the additional section.
469/// These are defined in [RFC 1035].
470///
471/// Like with the other header part, you can create an owned value via the
472/// [`new`][Self::new] method or the `Default` trait or can get a reference
473/// to the value atop a message slice via
474/// [`for_message_slice`][Self::for_message_slice] or
475/// [`for_message_slice_mut`][Self::for_message_slice_mut].
476///
477/// For each field there are three methods for getting, setting, and
478/// incrementing.
479///
480/// [RFC 2136] defines the UPDATE method and reuses the four section for
481/// different purposes. Here the counters are ZOCOUNT for the zone section,
482/// PRCOUNT for the prerequisite section, UPCOUNT for the update section,
483/// and ADCOUNT for the additional section. The type has convenience methods
484/// for these fields as well so you don’t have to remember which is which.
485///
486/// [RFC 1035]: https://tools.ietf.org/html/rfc1035
487/// [RFC 2136]: https://tools.ietf.org/html/rfc2136
488#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
489#[repr(transparent)]
490pub struct HeaderCounts {
491    /// The actual headers in their wire-format representation.
492    ///
493    /// Ie., all values are stored big endian.
494    inner: [u8; 8],
495}
496
497/// # Creation and Conversion
498///
499impl HeaderCounts {
500    /// Creates a new value with all counters set to zero.
501    #[must_use]
502    pub fn new() -> Self {
503        Self::default()
504    }
505
506    /// Creates a header counts reference from the octets slice of a message.
507    ///
508    /// The slice `message` mut be the whole message, i.e., start with the
509    /// bytes of the [`Header`](struct.Header.html).
510    ///
511    /// # Panics
512    ///
513    /// This function panics if the octets slice is shorter than 24 octets.
514    #[must_use]
515    pub fn for_message_slice(message: &[u8]) -> &Self {
516        assert!(message.len() >= mem::size_of::<HeaderSection>());
517
518        // SAFETY: The pointer cast is sound because
519        //  - HeaderCounts has repr(transparent) and
520        //  - the slice is large enough for a HeaderSection, which contains
521        //    both a Header (which we trim) and a HeaderCounts.
522        unsafe {
523            &*((message[mem::size_of::<Header>()..].as_ptr())
524                as *const HeaderCounts)
525        }
526    }
527
528    /// Creates a mutable counts reference from the octets slice of a message.
529    ///
530    /// The slice `message` mut be the whole message, i.e., start with the
531    /// bytes of the [`Header`].
532    ///
533    /// # Panics
534    ///
535    /// This function panics if the octets slice is shorter than 24 octets.
536    pub fn for_message_slice_mut(message: &mut [u8]) -> &mut Self {
537        assert!(message.len() >= mem::size_of::<HeaderSection>());
538
539        // SAFETY: The pointer cast is sound because
540        //  - HeaderCounts has repr(transparent) and
541        //  - the slice is large enough for a HeaderSection, which contains
542        //    both a Header (which we trim) and a HeaderCounts.
543        unsafe {
544            &mut *((message[mem::size_of::<Header>()..].as_mut_ptr())
545                as *mut HeaderCounts)
546        }
547    }
548
549    /// Returns a reference to the raw octets slice of the header counts.
550    #[must_use]
551    pub fn as_slice(&self) -> &[u8] {
552        &self.inner
553    }
554
555    /// Returns a mutable reference to the octets slice of the header counts.
556    pub fn as_slice_mut(&mut self) -> &mut [u8] {
557        &mut self.inner
558    }
559
560    /// Sets the counts to those from `counts`.
561    pub fn set(&mut self, counts: HeaderCounts) {
562        self.as_slice_mut().copy_from_slice(counts.as_slice())
563    }
564}
565
566/// # Field Access
567///
568impl HeaderCounts {
569    //--- Count fields in regular messages
570
571    /// Returns the value of the QDCOUNT field.
572    ///
573    /// This field contains the number of questions in the first
574    /// section of the message, normally the question section.
575    #[must_use]
576    pub fn qdcount(self) -> u16 {
577        self.get_u16(0)
578    }
579
580    /// Sets the value of the QDCOUNT field.
581    pub fn set_qdcount(&mut self, value: u16) {
582        self.set_u16(0, value)
583    }
584
585    /// Increases the value of the QDCOUNT field by one.
586    ///
587    /// If increasing the counter would result in an overflow, returns an
588    /// error.
589    pub fn inc_qdcount(&mut self) -> Result<(), CountOverflow> {
590        match self.qdcount().checked_add(1) {
591            Some(count) => {
592                self.set_qdcount(count);
593                Ok(())
594            }
595            None => Err(CountOverflow(())),
596        }
597    }
598
599    /// Decreases the value of the QDCOUNT field by one.
600    ///
601    /// # Panics
602    ///
603    /// This method panics if the count is already zero.
604    pub fn dec_qdcount(&mut self) {
605        let count = self.qdcount();
606        assert!(count > 0);
607        self.set_qdcount(count - 1);
608    }
609
610    /// Returns the value of the ANCOUNT field.
611    ///
612    /// This field contains the number of resource records in the second
613    /// section of the message, normally the answer section.
614    #[must_use]
615    pub fn ancount(self) -> u16 {
616        self.get_u16(2)
617    }
618
619    /// Sets the value of the ANCOUNT field.
620    pub fn set_ancount(&mut self, value: u16) {
621        self.set_u16(2, value)
622    }
623
624    /// Increases the value of the ANCOUNT field by one.
625    ///
626    /// If increasing the counter would result in an overflow, returns an
627    /// error.
628    pub fn inc_ancount(&mut self) -> Result<(), CountOverflow> {
629        match self.ancount().checked_add(1) {
630            Some(count) => {
631                self.set_ancount(count);
632                Ok(())
633            }
634            None => Err(CountOverflow(())),
635        }
636    }
637
638    /// Decreases the value of the ANCOUNT field by one.
639    ///
640    /// # Panics
641    ///
642    /// This method panics if the count is already zero.
643    pub fn dec_ancount(&mut self) {
644        let count = self.ancount();
645        assert!(count > 0);
646        self.set_ancount(count - 1);
647    }
648
649    /// Returns the value of the NSCOUNT field.
650    ///
651    /// This field contains the number of resource records in the third
652    /// section of the message, normally the authority section.
653    #[must_use]
654    pub fn nscount(self) -> u16 {
655        self.get_u16(4)
656    }
657
658    /// Sets the value of the NSCOUNT field.
659    pub fn set_nscount(&mut self, value: u16) {
660        self.set_u16(4, value)
661    }
662
663    /// Increases the value of the NSCOUNT field by one.
664    ///
665    /// If increasing the counter would result in an overflow, returns an
666    /// error.
667    pub fn inc_nscount(&mut self) -> Result<(), CountOverflow> {
668        match self.nscount().checked_add(1) {
669            Some(count) => {
670                self.set_nscount(count);
671                Ok(())
672            }
673            None => Err(CountOverflow(())),
674        }
675    }
676
677    /// Decreases the value of the NSCOUNT field by one.
678    ///
679    /// # Panics
680    ///
681    /// This method panics if the count is already zero.
682    pub fn dec_nscount(&mut self) {
683        let count = self.nscount();
684        assert!(count > 0);
685        self.set_nscount(count - 1);
686    }
687
688    /// Returns the value of the ARCOUNT field.
689    ///
690    /// This field contains the number of resource records in the fourth
691    /// section of the message, normally the additional section.
692    #[must_use]
693    pub fn arcount(self) -> u16 {
694        self.get_u16(6)
695    }
696
697    /// Sets the value of the ARCOUNT field.
698    pub fn set_arcount(&mut self, value: u16) {
699        self.set_u16(6, value)
700    }
701
702    /// Increases the value of the ARCOUNT field by one.
703    ///
704    /// If increasing the counter would result in an overflow, returns an
705    /// error.
706    pub fn inc_arcount(&mut self) -> Result<(), CountOverflow> {
707        match self.arcount().checked_add(1) {
708            Some(count) => {
709                self.set_arcount(count);
710                Ok(())
711            }
712            None => Err(CountOverflow(())),
713        }
714    }
715
716    /// Decreases the value of the ARCOUNT field by one.
717    ///
718    /// # Panics
719    ///
720    /// This method panics if the count is already zero.
721    pub fn dec_arcount(&mut self) {
722        let count = self.arcount();
723        assert!(count > 0);
724        self.set_arcount(count - 1);
725    }
726
727    //--- Count fields in UPDATE messages
728
729    /// Returns the value of the ZOCOUNT field.
730    ///
731    /// This is the same as the `qdcount()`. It is used in UPDATE queries
732    /// where the first section is the zone section.
733    #[must_use]
734    pub fn zocount(self) -> u16 {
735        self.qdcount()
736    }
737
738    /// Sets the value of the ZOCOUNT field.
739    pub fn set_zocount(&mut self, value: u16) {
740        self.set_qdcount(value)
741    }
742
743    /// Returns the value of the PRCOUNT field.
744    ///
745    /// This is the same as the `ancount()`. It is used in UPDATE queries
746    /// where the first section is the prerequisite section.
747    #[must_use]
748    pub fn prcount(self) -> u16 {
749        self.ancount()
750    }
751
752    /// Sete the value of the PRCOUNT field.
753    pub fn set_prcount(&mut self, value: u16) {
754        self.set_ancount(value)
755    }
756
757    /// Returns the value of the UPCOUNT field.
758    ///
759    /// This is the same as the `nscount()`. It is used in UPDATE queries
760    /// where the first section is the update section.
761    #[must_use]
762    pub fn upcount(self) -> u16 {
763        self.nscount()
764    }
765
766    /// Sets the value of the UPCOUNT field.
767    pub fn set_upcount(&mut self, value: u16) {
768        self.set_nscount(value)
769    }
770
771    /// Returns the value of the ADCOUNT field.
772    ///
773    /// This is the same as the `arcount()`. It is used in UPDATE queries
774    /// where the first section is the additional section.
775    #[must_use]
776    pub fn adcount(self) -> u16 {
777        self.arcount()
778    }
779
780    /// Sets the value of the ADCOUNT field.
781    pub fn set_adcount(&mut self, value: u16) {
782        self.set_arcount(value)
783    }
784
785    //--- Internal helpers
786
787    /// Returns the value of the 16 bit integer starting at a given offset.
788    fn get_u16(self, offset: usize) -> u16 {
789        u16::from_be_bytes(self.inner[offset..offset + 2].try_into().unwrap())
790    }
791
792    /// Sets the value of the 16 bit integer starting at a given offset.
793    fn set_u16(&mut self, offset: usize, value: u16) {
794        self.inner[offset..offset + 2].copy_from_slice(&value.to_be_bytes())
795    }
796}
797
798//------------ HeaderSection -------------------------------------------------
799
800/// The complete header section of a DNS message.
801///
802/// Consists of a [`Header`] directly followed by a [`HeaderCounts`].
803///
804/// You can create an owned value via the [`new`][Self::new] function or the
805/// `Default` trait and acquire a pointer referring the the header section of
806/// an existing DNS message via the
807/// [`for_message_slice`][Self::for_message_slice] or
808/// [`for_message_slice_mut`][Self::for_message_slice_mut]
809/// functions.
810#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
811#[repr(transparent)]
812pub struct HeaderSection {
813    inner: [u8; 12],
814}
815
816/// # Creation and Conversion
817///
818impl HeaderSection {
819    /// Creates a new header section.
820    ///
821    /// The value will have all header and header counts fields set to zero
822    /// or false.
823    #[must_use]
824    pub fn new() -> Self {
825        Self::default()
826    }
827
828    /// Creates a reference from the octets slice of a message.
829    ///
830    /// # Panics
831    ///
832    /// This function panics if the octets slice is shorter than 12 octets.
833    #[must_use]
834    pub fn for_message_slice(s: &[u8]) -> &HeaderSection {
835        assert!(s.len() >= mem::size_of::<HeaderSection>());
836        unsafe { &*(s.as_ptr() as *const HeaderSection) }
837    }
838
839    /// Creates a mutable reference from the octets slice of a message.
840    ///
841    /// # Panics
842    ///
843    /// This function panics if the octets slice is shorter than 12 octets.
844    pub fn for_message_slice_mut(s: &mut [u8]) -> &mut HeaderSection {
845        assert!(s.len() >= mem::size_of::<HeaderSection>());
846        unsafe { &mut *(s.as_mut_ptr() as *mut HeaderSection) }
847    }
848
849    /// Returns a reference to the underlying octets slice.
850    #[must_use]
851    pub fn as_slice(&self) -> &[u8] {
852        &self.inner
853    }
854}
855
856/// # Access to Header and Counts
857///
858impl HeaderSection {
859    /// Returns a reference to the header.
860    #[must_use]
861    pub fn header(&self) -> &Header {
862        Header::for_message_slice(&self.inner)
863    }
864
865    /// Returns a mutable reference to the header.
866    pub fn header_mut(&mut self) -> &mut Header {
867        Header::for_message_slice_mut(&mut self.inner)
868    }
869
870    /// Returns a reference to the header counts.
871    #[must_use]
872    pub fn counts(&self) -> &HeaderCounts {
873        HeaderCounts::for_message_slice(&self.inner)
874    }
875
876    /// Returns a mutable reference to the header counts.
877    pub fn counts_mut(&mut self) -> &mut HeaderCounts {
878        HeaderCounts::for_message_slice_mut(&mut self.inner)
879    }
880}
881
882/// # Parsing and Composing
883///
884impl HeaderSection {
885    pub fn parse<Octs: AsRef<[u8]>>(
886        parser: &mut Parser<Octs>,
887    ) -> Result<Self, ParseError> {
888        let mut res = Self::default();
889        parser.parse_buf(&mut res.inner)?;
890        Ok(res)
891    }
892
893    pub fn compose<Target: OctetsBuilder + ?Sized>(
894        &self,
895        target: &mut Target,
896    ) -> Result<(), Target::AppendError> {
897        target.append_slice(&self.inner)
898    }
899}
900
901//--- AsRef and AsMut
902
903impl AsRef<Header> for HeaderSection {
904    fn as_ref(&self) -> &Header {
905        self.header()
906    }
907}
908
909impl AsMut<Header> for HeaderSection {
910    fn as_mut(&mut self) -> &mut Header {
911        self.header_mut()
912    }
913}
914
915impl AsRef<HeaderCounts> for HeaderSection {
916    fn as_ref(&self) -> &HeaderCounts {
917        self.counts()
918    }
919}
920
921impl AsMut<HeaderCounts> for HeaderSection {
922    fn as_mut(&mut self) -> &mut HeaderCounts {
923        self.counts_mut()
924    }
925}
926
927//============ Error Types ===================================================
928
929//------------ FlagsFromStrError --------------------------------------------
930
931/// An error happened when converting string to flags.
932#[derive(Debug)]
933pub struct FlagsFromStrError(());
934
935impl fmt::Display for FlagsFromStrError {
936    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
937        write!(f, "illegal flags token")
938    }
939}
940
941#[cfg(feature = "std")]
942impl std::error::Error for FlagsFromStrError {}
943
944//------------ CountOverflow -------------------------------------------------
945
946/// An error happened while increasing a header count.
947#[derive(Debug)]
948pub struct CountOverflow(());
949
950impl fmt::Display for CountOverflow {
951    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
952        write!(f, "increasing a header count lead to an overflow")
953    }
954}
955
956#[cfg(feature = "std")]
957impl std::error::Error for CountOverflow {}
958
959//============ Testing ======================================================
960
961#[cfg(test)]
962mod test {
963    use super::*;
964
965    #[test]
966    #[cfg(feature = "std")]
967    fn for_slice() {
968        use std::vec::Vec;
969
970        let header = b"\x01\x02\x00\x00\x12\x34\x56\x78\x9a\xbc\xde\xf0";
971        let mut vec = Vec::from(&header[..]);
972        assert_eq!(
973            Header::for_message_slice(header).as_slice(),
974            b"\x01\x02\x00\x00"
975        );
976        assert_eq!(
977            Header::for_message_slice_mut(vec.as_mut()).as_slice(),
978            b"\x01\x02\x00\x00"
979        );
980        assert_eq!(
981            HeaderCounts::for_message_slice(header).as_slice(),
982            b"\x12\x34\x56\x78\x9a\xbc\xde\xf0"
983        );
984        assert_eq!(
985            HeaderCounts::for_message_slice_mut(vec.as_mut()).as_slice(),
986            b"\x12\x34\x56\x78\x9a\xbc\xde\xf0"
987        );
988        assert_eq!(
989            HeaderSection::for_message_slice(header).as_slice(),
990            header
991        );
992        assert_eq!(
993            HeaderSection::for_message_slice_mut(vec.as_mut()).as_slice(),
994            header
995        );
996    }
997
998    #[test]
999    #[should_panic]
1000    fn short_header() {
1001        let _ = Header::for_message_slice(b"134");
1002    }
1003
1004    #[test]
1005    #[should_panic]
1006    fn short_header_counts() {
1007        let _ = HeaderCounts::for_message_slice(b"12345678");
1008    }
1009
1010    #[test]
1011    #[should_panic]
1012    fn short_header_section() {
1013        let _ = HeaderSection::for_message_slice(b"1234");
1014    }
1015
1016    macro_rules! test_field {
1017        ($get:ident, $set:ident, $default:expr, $($value:expr),*) => {
1018            $({
1019                let mut h = Header::new();
1020                assert_eq!(h.$get(), $default);
1021                h.$set($value);
1022                assert_eq!(h.$get(), $value);
1023            })*
1024        }
1025    }
1026
1027    #[test]
1028    fn header() {
1029        test_field!(id, set_id, 0, 0x1234);
1030        test_field!(qr, set_qr, false, true, false);
1031        test_field!(opcode, set_opcode, Opcode::QUERY, Opcode::NOTIFY);
1032        test_field!(
1033            flags,
1034            set_flags,
1035            Flags::new(),
1036            Flags {
1037                qr: true,
1038                ..Default::default()
1039            }
1040        );
1041        test_field!(aa, set_aa, false, true, false);
1042        test_field!(tc, set_tc, false, true, false);
1043        test_field!(rd, set_rd, false, true, false);
1044        test_field!(ra, set_ra, false, true, false);
1045        test_field!(z, set_z, false, true, false);
1046        test_field!(ad, set_ad, false, true, false);
1047        test_field!(cd, set_cd, false, true, false);
1048        test_field!(rcode, set_rcode, Rcode::NOERROR, Rcode::REFUSED);
1049    }
1050
1051    #[test]
1052    fn counts() {
1053        let mut c = HeaderCounts {
1054            inner: [1, 2, 3, 4, 5, 6, 7, 8],
1055        };
1056        assert_eq!(c.qdcount(), 0x0102);
1057        assert_eq!(c.ancount(), 0x0304);
1058        assert_eq!(c.nscount(), 0x0506);
1059        assert_eq!(c.arcount(), 0x0708);
1060        c.inc_qdcount().unwrap();
1061        c.inc_ancount().unwrap();
1062        c.inc_nscount().unwrap();
1063        c.inc_arcount().unwrap();
1064        assert_eq!(c.inner, [1, 3, 3, 5, 5, 7, 7, 9]);
1065        c.set_qdcount(0x0807);
1066        c.set_ancount(0x0605);
1067        c.set_nscount(0x0403);
1068        c.set_arcount(0x0201);
1069        assert_eq!(c.inner, [8, 7, 6, 5, 4, 3, 2, 1]);
1070    }
1071
1072    #[test]
1073    fn update_counts() {
1074        let mut c = HeaderCounts {
1075            inner: [1, 2, 3, 4, 5, 6, 7, 8],
1076        };
1077        assert_eq!(c.zocount(), 0x0102);
1078        assert_eq!(c.prcount(), 0x0304);
1079        assert_eq!(c.upcount(), 0x0506);
1080        assert_eq!(c.adcount(), 0x0708);
1081        c.set_zocount(0x0807);
1082        c.set_prcount(0x0605);
1083        c.set_upcount(0x0403);
1084        c.set_adcount(0x0201);
1085        assert_eq!(c.inner, [8, 7, 6, 5, 4, 3, 2, 1]);
1086    }
1087
1088    #[test]
1089    fn inc_qdcount() {
1090        let mut c = HeaderCounts {
1091            inner: [0xff, 0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff],
1092        };
1093        assert!(c.inc_qdcount().is_ok());
1094        assert!(c.inc_qdcount().is_err());
1095    }
1096
1097    #[test]
1098    fn inc_ancount() {
1099        let mut c = HeaderCounts {
1100            inner: [0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xff, 0xff],
1101        };
1102        assert!(c.inc_ancount().is_ok());
1103        assert!(c.inc_ancount().is_err());
1104    }
1105
1106    #[test]
1107    fn inc_nscount() {
1108        let mut c = HeaderCounts {
1109            inner: [0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, 0xff, 0xff],
1110        };
1111        assert!(c.inc_nscount().is_ok());
1112        assert!(c.inc_nscount().is_err());
1113    }
1114
1115    #[test]
1116    fn inc_arcount() {
1117        let mut c = HeaderCounts {
1118            inner: [0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe],
1119        };
1120        assert!(c.inc_arcount().is_ok());
1121        assert!(c.inc_arcount().is_err());
1122    }
1123
1124    #[cfg(feature = "std")]
1125    #[test]
1126    fn flags_display() {
1127        let f = Flags::new();
1128        assert_eq!(format!("{}", f), "");
1129        let f = Flags {
1130            qr: true,
1131            aa: true,
1132            tc: true,
1133            rd: true,
1134            ra: true,
1135            ad: true,
1136            cd: true,
1137        };
1138        assert_eq!(format!("{}", f), "QR AA TC RD RA AD CD");
1139        let mut f = Flags::new();
1140        f.rd = true;
1141        f.cd = true;
1142        assert_eq!(format!("{}", f), "RD CD");
1143    }
1144
1145    #[cfg(feature = "std")]
1146    #[test]
1147    fn flags_from_str() {
1148        let f1 = Flags::from_str("").unwrap();
1149        let f2 = Flags::new();
1150        assert_eq!(f1, f2);
1151        let f1 = Flags::from_str("QR AA TC RD RA AD CD").unwrap();
1152        let f2 = Flags {
1153            qr: true,
1154            aa: true,
1155            tc: true,
1156            rd: true,
1157            ra: true,
1158            ad: true,
1159            cd: true,
1160        };
1161        assert_eq!(f1, f2);
1162        let f1 = Flags::from_str("tC Aa CD rd").unwrap();
1163        let f2 = Flags {
1164            aa: true,
1165            tc: true,
1166            rd: true,
1167            cd: true,
1168            ..Default::default()
1169        };
1170        assert_eq!(f1, f2);
1171        let f1 = Flags::from_str("XXXX");
1172        assert!(f1.is_err());
1173    }
1174}