der/
decoder.rs

1//! DER decoder.
2
3use crate::{
4    asn1::*, ByteSlice, Choice, Decodable, DecodeValue, Error, ErrorKind, FixedTag, Header, Length,
5    Result, Tag, TagMode, TagNumber,
6};
7
8/// DER decoder.
9#[derive(Clone, Debug)]
10pub struct Decoder<'a> {
11    /// Byte slice being decoded.
12    ///
13    /// In the event an error was previously encountered this will be set to
14    /// `None` to prevent further decoding while in a bad state.
15    bytes: Option<ByteSlice<'a>>,
16
17    /// Position within the decoded slice.
18    position: Length,
19}
20
21impl<'a> Decoder<'a> {
22    /// Create a new decoder for the given byte slice.
23    pub fn new(bytes: &'a [u8]) -> Result<Self> {
24        Ok(Self {
25            bytes: Some(ByteSlice::new(bytes)?),
26            position: Length::ZERO,
27        })
28    }
29
30    /// Decode a value which impls the [`Decodable`] trait.
31    pub fn decode<T: Decodable<'a>>(&mut self) -> Result<T> {
32        if self.is_failed() {
33            return Err(self.error(ErrorKind::Failed));
34        }
35
36        T::decode(self).map_err(|e| {
37            self.bytes.take();
38            e.nested(self.position)
39        })
40    }
41
42    /// Return an error with the given [`ErrorKind`], annotating it with
43    /// context about where the error occurred.
44    pub fn error(&mut self, kind: ErrorKind) -> Error {
45        self.bytes.take();
46        kind.at(self.position)
47    }
48
49    /// Return an error for an invalid value with the given tag.
50    pub fn value_error(&mut self, tag: Tag) -> Error {
51        self.error(tag.value_error().kind())
52    }
53
54    /// Did the decoding operation fail due to an error?
55    pub fn is_failed(&self) -> bool {
56        self.bytes.is_none()
57    }
58
59    /// Get the position within the buffer.
60    pub fn position(&self) -> Length {
61        self.position
62    }
63
64    /// Peek at the next byte in the decoder without modifying the cursor.
65    pub fn peek_byte(&self) -> Option<u8> {
66        self.remaining()
67            .ok()
68            .and_then(|bytes| bytes.get(0).cloned())
69    }
70
71    /// Peek at the next byte in the decoder and attempt to decode it as a
72    /// [`Tag`] value.
73    ///
74    /// Does not modify the decoder's state.
75    pub fn peek_tag(&self) -> Result<Tag> {
76        match self.peek_byte() {
77            Some(byte) => byte.try_into(),
78            None => {
79                let actual_len = self.input_len()?;
80                let expected_len = (actual_len + Length::ONE)?;
81                Err(ErrorKind::Incomplete {
82                    expected_len,
83                    actual_len,
84                }
85                .into())
86            }
87        }
88    }
89
90    /// Peek forward in the decoder, attempting to decode a [`Header`] from
91    /// the data at the current position in the decoder.
92    ///
93    /// Does not modify the decoder's state.
94    pub fn peek_header(&self) -> Result<Header> {
95        Header::decode(&mut self.clone())
96    }
97
98    /// Finish decoding, returning the given value if there is no
99    /// remaining data, or an error otherwise
100    pub fn finish<T>(self, value: T) -> Result<T> {
101        if self.is_failed() {
102            Err(ErrorKind::Failed.at(self.position))
103        } else if !self.is_finished() {
104            Err(ErrorKind::TrailingData {
105                decoded: self.position,
106                remaining: self.remaining_len()?,
107            }
108            .at(self.position))
109        } else {
110            Ok(value)
111        }
112    }
113
114    /// Have we decoded all of the bytes in this [`Decoder`]?
115    ///
116    /// Returns `false` if we're not finished decoding or if a fatal error
117    /// has occurred.
118    pub fn is_finished(&self) -> bool {
119        self.remaining().map(|rem| rem.is_empty()).unwrap_or(false)
120    }
121
122    /// Attempt to decode an ASN.1 `ANY` value.
123    pub fn any(&mut self) -> Result<Any<'a>> {
124        self.decode()
125    }
126
127    /// Attempt to decode an `OPTIONAL` ASN.1 `ANY` value.
128    pub fn any_optional(&mut self) -> Result<Option<Any<'a>>> {
129        self.decode()
130    }
131
132    /// Attempt to decode ASN.1 `INTEGER` as `i8`
133    pub fn int8(&mut self) -> Result<i8> {
134        self.decode()
135    }
136
137    /// Attempt to decode ASN.1 `INTEGER` as `i16`
138    pub fn int16(&mut self) -> Result<i16> {
139        self.decode()
140    }
141
142    /// Attempt to decode unsigned ASN.1 `INTEGER` as `u8`
143    pub fn uint8(&mut self) -> Result<u8> {
144        self.decode()
145    }
146
147    /// Attempt to decode unsigned ASN.1 `INTEGER` as `u16`
148    pub fn uint16(&mut self) -> Result<u16> {
149        self.decode()
150    }
151
152    /// Attempt to decode an ASN.1 `INTEGER` as a [`UIntBytes`].
153    #[cfg(feature = "bigint")]
154    #[cfg_attr(docsrs, doc(cfg(feature = "bigint")))]
155    pub fn uint_bytes(&mut self) -> Result<UIntBytes<'a>> {
156        self.decode()
157    }
158
159    /// Attempt to decode an ASN.1 `BIT STRING`.
160    pub fn bit_string(&mut self) -> Result<BitString<'a>> {
161        self.decode()
162    }
163
164    /// Attempt to decode an ASN.1 `CONTEXT-SPECIFIC` field with the
165    /// provided [`TagNumber`].
166    pub fn context_specific<T>(
167        &mut self,
168        tag_number: TagNumber,
169        tag_mode: TagMode,
170    ) -> Result<Option<T>>
171    where
172        T: DecodeValue<'a> + FixedTag,
173    {
174        Ok(match tag_mode {
175            TagMode::Explicit => ContextSpecific::<T>::decode_explicit(self, tag_number)?,
176            TagMode::Implicit => ContextSpecific::<T>::decode_implicit(self, tag_number)?,
177        }
178        .map(|field| field.value))
179    }
180
181    /// Attempt to decode an ASN.1 `GeneralizedTime`.
182    pub fn generalized_time(&mut self) -> Result<GeneralizedTime> {
183        self.decode()
184    }
185
186    /// Attempt to decode an ASN.1 `IA5String`.
187    pub fn ia5_string(&mut self) -> Result<Ia5String<'a>> {
188        self.decode()
189    }
190
191    /// Attempt to decode an ASN.1 `NULL` value.
192    pub fn null(&mut self) -> Result<Null> {
193        self.decode()
194    }
195
196    /// Attempt to decode an ASN.1 `OCTET STRING`.
197    pub fn octet_string(&mut self) -> Result<OctetString<'a>> {
198        self.decode()
199    }
200
201    /// Attempt to decode an ASN.1 `OBJECT IDENTIFIER`.
202    #[cfg(feature = "oid")]
203    #[cfg_attr(docsrs, doc(cfg(feature = "oid")))]
204    pub fn oid(&mut self) -> Result<ObjectIdentifier> {
205        self.decode()
206    }
207
208    /// Attempt to decode an ASN.1 `OPTIONAL` value.
209    pub fn optional<T: Choice<'a>>(&mut self) -> Result<Option<T>> {
210        self.decode()
211    }
212
213    /// Attempt to decode an ASN.1 `PrintableString`.
214    pub fn printable_string(&mut self) -> Result<PrintableString<'a>> {
215        self.decode()
216    }
217
218    /// Attempt to decode an ASN.1 `UTCTime`.
219    pub fn utc_time(&mut self) -> Result<UtcTime> {
220        self.decode()
221    }
222
223    /// Attempt to decode an ASN.1 `UTF8String`.
224    pub fn utf8_string(&mut self) -> Result<Utf8String<'a>> {
225        self.decode()
226    }
227
228    /// Attempt to decode an ASN.1 `SEQUENCE`, creating a new nested
229    /// [`Decoder`] and calling the provided argument with it.
230    pub fn sequence<F, T>(&mut self, f: F) -> Result<T>
231    where
232        F: FnOnce(&mut Decoder<'a>) -> Result<T>,
233    {
234        Tag::try_from(self.byte()?)?.assert_eq(Tag::Sequence)?;
235        let len = Length::decode(self)?;
236        self.decode_nested(len, f)
237    }
238
239    /// Decode a single byte, updating the internal cursor.
240    pub(crate) fn byte(&mut self) -> Result<u8> {
241        match self.bytes(1u8)? {
242            [byte] => Ok(*byte),
243            _ => {
244                let actual_len = self.input_len()?;
245                let expected_len = (actual_len + Length::ONE)?;
246                Err(self.error(ErrorKind::Incomplete {
247                    expected_len,
248                    actual_len,
249                }))
250            }
251        }
252    }
253
254    /// Obtain a slice of bytes of the given length from the current cursor
255    /// position, or return an error if we have insufficient data.
256    pub(crate) fn bytes(&mut self, len: impl TryInto<Length>) -> Result<&'a [u8]> {
257        if self.is_failed() {
258            return Err(self.error(ErrorKind::Failed));
259        }
260
261        let len = len
262            .try_into()
263            .map_err(|_| self.error(ErrorKind::Overflow))?;
264
265        match self.remaining()?.get(..len.try_into()?) {
266            Some(result) => {
267                self.position = (self.position + len)?;
268                Ok(result)
269            }
270            None => {
271                let actual_len = self.input_len()?;
272                let expected_len = (actual_len + len)?;
273                Err(self.error(ErrorKind::Incomplete {
274                    expected_len,
275                    actual_len,
276                }))
277            }
278        }
279    }
280
281    /// Get the length of the input, if decoding hasn't failed.
282    pub(crate) fn input_len(&self) -> Result<Length> {
283        Ok(self.bytes.ok_or(ErrorKind::Failed)?.len())
284    }
285
286    /// Get the number of bytes still remaining in the buffer.
287    pub(crate) fn remaining_len(&self) -> Result<Length> {
288        self.remaining()?.len().try_into()
289    }
290
291    /// Create a nested decoder which operates over the provided [`Length`].
292    ///
293    /// The nested decoder is passed to the provided callback function which is
294    /// expected to decode a value of type `T` with it.
295    fn decode_nested<F, T>(&mut self, length: Length, f: F) -> Result<T>
296    where
297        F: FnOnce(&mut Self) -> Result<T>,
298    {
299        let start_pos = self.position();
300        let end_pos = (start_pos + length)?;
301        let bytes = match self.bytes {
302            Some(slice) => {
303                slice
304                    .as_bytes()
305                    .get(..end_pos.try_into()?)
306                    .ok_or(ErrorKind::Incomplete {
307                        expected_len: end_pos,
308                        actual_len: self.input_len()?,
309                    })?
310            }
311            None => return Err(self.error(ErrorKind::Failed)),
312        };
313
314        let mut nested_decoder = Self {
315            bytes: Some(ByteSlice::new(bytes)?),
316            position: start_pos,
317        };
318
319        self.position = end_pos;
320        let result = f(&mut nested_decoder)?;
321        nested_decoder.finish(result)
322    }
323
324    /// Obtain the remaining bytes in this decoder from the current cursor
325    /// position.
326    fn remaining(&self) -> Result<&'a [u8]> {
327        let pos = usize::try_from(self.position)?;
328
329        match self.bytes.and_then(|slice| slice.as_bytes().get(pos..)) {
330            Some(result) => Ok(result),
331            None => {
332                let actual_len = self.input_len()?;
333                let expected_len = (actual_len + Length::ONE)?;
334                Err(ErrorKind::Incomplete {
335                    expected_len,
336                    actual_len,
337                }
338                .at(self.position))
339            }
340        }
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::Decoder;
347    use crate::{Decodable, ErrorKind, Length, Tag};
348    use hex_literal::hex;
349
350    // INTEGER: 42
351    const EXAMPLE_MSG: &[u8] = &hex!("02012A00");
352
353    #[test]
354    fn empty_message() {
355        let mut decoder = Decoder::new(&[]).unwrap();
356        let err = bool::decode(&mut decoder).err().unwrap();
357        assert_eq!(Some(Length::ZERO), err.position());
358
359        match err.kind() {
360            ErrorKind::Incomplete {
361                expected_len,
362                actual_len,
363            } => {
364                assert_eq!(expected_len, 1u8.into());
365                assert_eq!(actual_len, 0u8.into());
366            }
367            other => panic!("unexpected error kind: {:?}", other),
368        }
369    }
370
371    #[test]
372    fn invalid_field_length() {
373        let mut decoder = Decoder::new(&EXAMPLE_MSG[..2]).unwrap();
374        let err = i8::decode(&mut decoder).err().unwrap();
375        assert_eq!(Some(Length::from(2u8)), err.position());
376
377        match err.kind() {
378            ErrorKind::Incomplete {
379                expected_len,
380                actual_len,
381            } => {
382                assert_eq!(expected_len, 3u8.into());
383                assert_eq!(actual_len, 2u8.into());
384            }
385            other => panic!("unexpected error kind: {:?}", other),
386        }
387    }
388
389    #[test]
390    fn trailing_data() {
391        let mut decoder = Decoder::new(EXAMPLE_MSG).unwrap();
392        let x = decoder.decode().unwrap();
393        assert_eq!(42i8, x);
394
395        let err = decoder.finish(x).err().unwrap();
396        assert_eq!(Some(Length::from(3u8)), err.position());
397
398        assert_eq!(
399            ErrorKind::TrailingData {
400                decoded: 3u8.into(),
401                remaining: 1u8.into()
402            },
403            err.kind()
404        );
405    }
406
407    #[test]
408    fn peek_tag() {
409        let decoder = Decoder::new(EXAMPLE_MSG).unwrap();
410        assert_eq!(decoder.position(), Length::ZERO);
411        assert_eq!(decoder.peek_tag().unwrap(), Tag::Integer);
412        assert_eq!(decoder.position(), Length::ZERO); // Position unchanged
413    }
414
415    #[test]
416    fn peek_header() {
417        let decoder = Decoder::new(EXAMPLE_MSG).unwrap();
418        assert_eq!(decoder.position(), Length::ZERO);
419
420        let header = decoder.peek_header().unwrap();
421        assert_eq!(header.tag, Tag::Integer);
422        assert_eq!(header.length, Length::ONE);
423        assert_eq!(decoder.position(), Length::ZERO); // Position unchanged
424    }
425}