der/
encoder.rs

1//! DER encoder.
2
3use crate::{
4    asn1::*, Encodable, EncodeValue, Error, ErrorKind, Header, Length, Result, Tag, TagMode,
5    TagNumber, Tagged,
6};
7
8/// DER encoder.
9#[derive(Debug)]
10pub struct Encoder<'a> {
11    /// Buffer into which DER-encoded message is written
12    bytes: Option<&'a mut [u8]>,
13
14    /// Total number of bytes written to buffer so far
15    position: Length,
16}
17
18impl<'a> Encoder<'a> {
19    /// Create a new encoder with the given byte slice as a backing buffer.
20    pub fn new(bytes: &'a mut [u8]) -> Self {
21        Self {
22            bytes: Some(bytes),
23            position: Length::ZERO,
24        }
25    }
26
27    /// Encode a value which impls the [`Encodable`] trait.
28    pub fn encode<T: Encodable>(&mut self, encodable: &T) -> Result<()> {
29        if self.is_failed() {
30            self.error(ErrorKind::Failed)?;
31        }
32
33        encodable.encode(self).map_err(|e| {
34            self.bytes.take();
35            e.nested(self.position)
36        })
37    }
38
39    /// Return an error with the given [`ErrorKind`], annotating it with
40    /// context about where the error occurred.
41    // TODO(tarcieri): change return type to `Error`
42    pub fn error<T>(&mut self, kind: ErrorKind) -> Result<T> {
43        self.bytes.take();
44        Err(kind.at(self.position))
45    }
46
47    /// Return an error for an invalid value with the given tag.
48    // TODO(tarcieri): compose this with `Encoder::error` after changing its return type
49    pub fn value_error(&mut self, tag: Tag) -> Error {
50        self.bytes.take();
51        tag.value_error().kind().at(self.position)
52    }
53
54    /// Did the decoding operation fail due to an error?
55    pub fn is_failed(&self) -> bool {
56        self.bytes.is_none()
57    }
58
59    /// Finish encoding to the buffer, returning a slice containing the data
60    /// written to the buffer.
61    pub fn finish(self) -> Result<&'a [u8]> {
62        let pos = self.position;
63        let range = ..usize::try_from(self.position)?;
64
65        match self.bytes {
66            Some(bytes) => bytes
67                .get(range)
68                .ok_or_else(|| ErrorKind::Overlength.at(pos)),
69            None => Err(ErrorKind::Failed.at(pos)),
70        }
71    }
72
73    /// Encode the provided value as an ASN.1 `BIT STRING`.
74    pub fn bit_string(&mut self, value: impl TryInto<BitString<'a>>) -> Result<()> {
75        value
76            .try_into()
77            .map_err(|_| self.value_error(Tag::BitString))
78            .and_then(|value| self.encode(&value))
79    }
80
81    /// Encode a `CONTEXT-SPECIFIC` field with `EXPLICIT` tagging.
82    pub fn context_specific<T>(
83        &mut self,
84        tag_number: TagNumber,
85        tag_mode: TagMode,
86        value: &T,
87    ) -> Result<()>
88    where
89        T: EncodeValue + Tagged,
90    {
91        ContextSpecificRef {
92            tag_number,
93            tag_mode,
94            value,
95        }
96        .encode(self)
97    }
98
99    /// Encode the provided value as an ASN.1 `GeneralizedTime`
100    pub fn generalized_time(&mut self, value: impl TryInto<GeneralizedTime>) -> Result<()> {
101        value
102            .try_into()
103            .map_err(|_| self.value_error(Tag::GeneralizedTime))
104            .and_then(|value| self.encode(&value))
105    }
106
107    /// Encode the provided value as an ASN.1 `IA5String`.
108    pub fn ia5_string(&mut self, value: impl TryInto<Ia5String<'a>>) -> Result<()> {
109        value
110            .try_into()
111            .map_err(|_| self.value_error(Tag::Ia5String))
112            .and_then(|value| self.encode(&value))
113    }
114
115    /// Encode an ASN.1 `NULL` value.
116    pub fn null(&mut self) -> Result<()> {
117        self.encode(&Null)
118    }
119
120    /// Encode the provided value as an ASN.1 `OCTET STRING`
121    pub fn octet_string(&mut self, value: impl TryInto<OctetString<'a>>) -> Result<()> {
122        value
123            .try_into()
124            .map_err(|_| self.value_error(Tag::OctetString))
125            .and_then(|value| self.encode(&value))
126    }
127
128    /// Encode an ASN.1 [`ObjectIdentifier`]
129    #[cfg(feature = "oid")]
130    #[cfg_attr(docsrs, doc(cfg(feature = "oid")))]
131    pub fn oid(&mut self, value: impl TryInto<ObjectIdentifier>) -> Result<()> {
132        value
133            .try_into()
134            .map_err(|_| self.value_error(Tag::ObjectIdentifier))
135            .and_then(|value| self.encode(&value))
136    }
137
138    /// Encode the provided value as an ASN.1 `PrintableString`
139    pub fn printable_string(&mut self, value: impl TryInto<PrintableString<'a>>) -> Result<()> {
140        value
141            .try_into()
142            .map_err(|_| self.value_error(Tag::PrintableString))
143            .and_then(|value| self.encode(&value))
144    }
145
146    /// Encode an ASN.1 `SEQUENCE` of the given length.
147    ///
148    /// Spawns a nested [`Encoder`] which is expected to be exactly the
149    /// specified length upon completion.
150    pub fn sequence<F>(&mut self, length: Length, f: F) -> Result<()>
151    where
152        F: FnOnce(&mut Encoder<'_>) -> Result<()>,
153    {
154        Header::new(Tag::Sequence, length).and_then(|header| header.encode(self))?;
155
156        let mut nested_encoder = Encoder::new(self.reserve(length)?);
157        f(&mut nested_encoder)?;
158
159        if nested_encoder.finish()?.len() == usize::try_from(length)? {
160            Ok(())
161        } else {
162            self.error(ErrorKind::Length { tag: Tag::Sequence })
163        }
164    }
165
166    /// Encode the provided value as an ASN.1 `UTCTime`
167    pub fn utc_time(&mut self, value: impl TryInto<UtcTime>) -> Result<()> {
168        value
169            .try_into()
170            .map_err(|_| self.value_error(Tag::UtcTime))
171            .and_then(|value| self.encode(&value))
172    }
173
174    /// Encode the provided value as an ASN.1 `Utf8String`
175    pub fn utf8_string(&mut self, value: impl TryInto<Utf8String<'a>>) -> Result<()> {
176        value
177            .try_into()
178            .map_err(|_| self.value_error(Tag::Utf8String))
179            .and_then(|value| self.encode(&value))
180    }
181
182    /// Reserve a portion of the internal buffer, updating the internal cursor
183    /// position and returning a mutable slice.
184    fn reserve(&mut self, len: impl TryInto<Length>) -> Result<&mut [u8]> {
185        let len = len
186            .try_into()
187            .or_else(|_| self.error(ErrorKind::Overflow))?;
188
189        if len > self.remaining_len()? {
190            self.error(ErrorKind::Overlength)?;
191        }
192
193        let end = (self.position + len).or_else(|e| self.error(e.kind()))?;
194        let range = self.position.try_into()?..end.try_into()?;
195        let position = &mut self.position;
196
197        // TODO(tarcieri): non-panicking version of this code
198        // We ensure above that the buffer is untainted and there is sufficient
199        // space to perform this slicing operation, however it would be nice to
200        // have fully panic-free code.
201        //
202        // Unfortunately tainting the buffer on error is tricky to do when
203        // potentially holding a reference to the buffer, and failure to taint
204        // it would not uphold the invariant that any errors should taint it.
205        let slice = &mut self.bytes.as_mut().expect("DER encoder tainted")[range];
206        *position = end;
207
208        Ok(slice)
209    }
210
211    /// Encode a single byte into the backing buffer.
212    pub(crate) fn byte(&mut self, byte: u8) -> Result<()> {
213        match self.reserve(1u8)?.first_mut() {
214            Some(b) => {
215                *b = byte;
216                Ok(())
217            }
218            None => self.error(ErrorKind::Overlength),
219        }
220    }
221
222    /// Encode the provided byte slice into the backing buffer.
223    pub(crate) fn bytes(&mut self, slice: &[u8]) -> Result<()> {
224        self.reserve(slice.len())?.copy_from_slice(slice);
225        Ok(())
226    }
227
228    /// Get the size of the buffer in bytes.
229    fn buffer_len(&self) -> Result<Length> {
230        self.bytes
231            .as_ref()
232            .map(|bytes| bytes.len())
233            .ok_or_else(|| ErrorKind::Failed.at(self.position))
234            .and_then(TryInto::try_into)
235    }
236
237    /// Get the number of bytes still remaining in the buffer.
238    fn remaining_len(&self) -> Result<Length> {
239        let buffer_len = usize::try_from(self.buffer_len()?)?;
240
241        buffer_len
242            .checked_sub(self.position.try_into()?)
243            .ok_or_else(|| ErrorKind::Overlength.at(self.position))
244            .and_then(TryInto::try_into)
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use hex_literal::hex;
251
252    use crate::{asn1::BitString, Encodable, ErrorKind, Length, TagMode, TagNumber};
253
254    use super::Encoder;
255
256    #[test]
257    fn overlength_message() {
258        let mut buffer = [];
259        let mut encoder = Encoder::new(&mut buffer);
260        let err = false.encode(&mut encoder).err().unwrap();
261        assert_eq!(err.kind(), ErrorKind::Overlength);
262        assert_eq!(err.position(), Some(Length::ZERO));
263    }
264
265    #[test]
266    fn context_specific_with_implicit_field() {
267        // From RFC8410 Section 10.3:
268        // <https://datatracker.ietf.org/doc/html/rfc8410#section-10.3>
269        //
270        //    81  33:   [1] 00 19 BF 44 09 69 84 CD FE 85 41 BA C1 67 DC 3B
271        //                  96 C8 50 86 AA 30 B6 B6 CB 0C 5C 38 AD 70 31 66
272        //                  E1
273        const EXPECTED_BYTES: &[u8] =
274            &hex!("81210019BF44096984CDFE8541BAC167DC3B96C85086AA30B6B6CB0C5C38AD703166E1");
275
276        let tag_number = TagNumber::new(1);
277        let bit_string = BitString::from_bytes(&EXPECTED_BYTES[3..]).unwrap();
278
279        let mut buf = [0u8; EXPECTED_BYTES.len()];
280        let mut encoder = Encoder::new(&mut buf);
281        encoder
282            .context_specific(tag_number, TagMode::Implicit, &bit_string)
283            .unwrap();
284
285        assert_eq!(EXPECTED_BYTES, encoder.finish().unwrap());
286    }
287}