const_oid/
encoder.rs

1//! OID encoder with `const` support.
2
3use crate::{
4    arcs::{ARC_MAX_FIRST, ARC_MAX_SECOND},
5    Arc, Error, ObjectIdentifier, Result,
6};
7
8/// BER/DER encoder
9pub(crate) struct Encoder {
10    /// Current state
11    state: State,
12
13    /// Bytes of the OID being encoded in-progress
14    bytes: [u8; ObjectIdentifier::MAX_SIZE],
15
16    /// Current position within the byte buffer
17    cursor: usize,
18}
19
20/// Current state of the encoder
21enum State {
22    /// Initial state - no arcs yet encoded
23    Initial,
24
25    /// First arc parsed
26    FirstArc(Arc),
27
28    /// Encoding base 128 body of the OID
29    Body,
30}
31
32impl Encoder {
33    /// Create a new encoder initialized to an empty default state
34    pub(crate) const fn new() -> Self {
35        Self {
36            state: State::Initial,
37            bytes: [0u8; ObjectIdentifier::MAX_SIZE],
38            cursor: 0,
39        }
40    }
41
42    /// Encode an [`Arc`] as base 128 into the internal buffer
43    pub(crate) const fn encode(mut self, arc: Arc) -> Self {
44        match self.state {
45            State::Initial => {
46                const_assert!(arc <= ARC_MAX_FIRST, "invalid first arc (must be 0-2)");
47                self.state = State::FirstArc(arc);
48                self
49            }
50            State::FirstArc(first_arc) => {
51                const_assert!(arc <= ARC_MAX_SECOND, "invalid second arc (must be 0-39)");
52                self.state = State::Body;
53                self.bytes[0] = (first_arc * (ARC_MAX_SECOND + 1)) as u8 + arc as u8;
54                self.cursor = 1;
55                self
56            }
57            State::Body => {
58                // Total number of bytes in encoded arc - 1
59                let nbytes = base128_len(arc);
60
61                const_assert!(
62                    self.cursor + nbytes + 1 < ObjectIdentifier::MAX_SIZE,
63                    "OID too long (exceeded max DER bytes)"
64                );
65
66                let new_cursor = self.cursor + nbytes + 1;
67                let mut result = self.encode_base128_byte(arc, nbytes, false);
68                result.cursor = new_cursor;
69                result
70            }
71        }
72    }
73
74    /// Finish encoding an OID
75    pub(crate) const fn finish(self) -> ObjectIdentifier {
76        const_assert!(self.cursor >= 2, "OID too short (minimum 3 arcs)");
77        ObjectIdentifier {
78            bytes: self.bytes,
79            length: self.cursor as u8,
80        }
81    }
82
83    /// Encode a single byte of a base128 value
84    const fn encode_base128_byte(mut self, mut n: u32, i: usize, continued: bool) -> Self {
85        let mask = if continued { 0b10000000 } else { 0 };
86
87        if n > 0x80 {
88            self.bytes[self.cursor + i] = (n & 0b1111111) as u8 | mask;
89            n >>= 7;
90
91            const_assert!(i > 0, "Base 128 offset miscalculation");
92            self.encode_base128_byte(n, i.saturating_sub(1), true)
93        } else {
94            self.bytes[self.cursor] = n as u8 | mask;
95            self
96        }
97    }
98}
99
100/// Compute the length - 1 of an arc when encoded in base 128
101const fn base128_len(arc: Arc) -> usize {
102    match arc {
103        0..=0x7f => 0,
104        0x80..=0x3fff => 1,
105        0x4000..=0x1fffff => 2,
106        0x200000..=0x1fffffff => 3,
107        _ => 4,
108    }
109}
110
111/// Write the given unsigned integer in base 128
112// TODO(tarcieri): consolidate encoding logic with `encode_base128_byte`
113pub(crate) fn write_base128(bytes: &mut [u8], mut n: Arc) -> Result<usize> {
114    let nbytes = base128_len(n);
115    let mut i = nbytes;
116    let mut mask = 0;
117
118    while n > 0x80 {
119        let byte = bytes.get_mut(i).ok_or(Error)?;
120        *byte = (n & 0b1111111 | mask) as u8;
121        n >>= 7;
122        i = i.checked_sub(1).expect("overflow");
123        mask = 0b10000000;
124    }
125
126    bytes[0] = (n | mask) as u8;
127
128    Ok(nbytes + 1)
129}
130
131#[cfg(test)]
132mod tests {
133    use super::Encoder;
134    use hex_literal::hex;
135
136    /// OID `1.2.840.10045.2.1` encoded as ASN.1 BER/DER
137    const EXAMPLE_OID_BER: &[u8] = &hex!("2A8648CE3D0201");
138
139    #[test]
140    fn encode() {
141        let encoder = Encoder::new();
142        let encoder = encoder.encode(1);
143        let encoder = encoder.encode(2);
144        let encoder = encoder.encode(840);
145        let encoder = encoder.encode(10045);
146        let encoder = encoder.encode(2);
147        let encoder = encoder.encode(1);
148        assert_eq!(&encoder.bytes[..encoder.cursor], EXAMPLE_OID_BER);
149    }
150}