ssh_key/
mpint.rs

1//! Multiple precision integer
2
3use crate::{
4    checked::CheckedSum, decode::Decode, encode::Encode, reader::Reader, writer::Writer, Error,
5    Result,
6};
7use alloc::vec::Vec;
8use core::fmt;
9use zeroize::Zeroize;
10
11#[cfg(feature = "rsa")]
12use zeroize::Zeroizing;
13
14#[cfg(feature = "subtle")]
15use subtle::{Choice, ConstantTimeEq};
16
17/// Multiple precision integer, a.k.a. "mpint".
18///
19/// This type is used for representing the big integer components of
20/// DSA and RSA keys.
21///
22/// Described in [RFC4251 § 5](https://datatracker.ietf.org/doc/html/rfc4251#section-5):
23///
24/// > Represents multiple precision integers in two's complement format,
25/// > stored as a string, 8 bits per byte, MSB first.  Negative numbers
26/// > have the value 1 as the most significant bit of the first byte of
27/// > the data partition.  If the most significant bit would be set for
28/// > a positive number, the number MUST be preceded by a zero byte.
29/// > Unnecessary leading bytes with the value 0 or 255 MUST NOT be
30/// > included.  The value zero MUST be stored as a string with zero
31/// > bytes of data.
32/// >
33/// > By convention, a number that is used in modular computations in
34/// > Z_n SHOULD be represented in the range 0 <= x < n.
35///
36/// ## Examples
37///
38/// | value (hex)     | representation (hex) |
39/// |-----------------|----------------------|
40/// | 0               | `00 00 00 00`
41/// | 9a378f9b2e332a7 | `00 00 00 08 09 a3 78 f9 b2 e3 32 a7`
42/// | 80              | `00 00 00 02 00 80`
43/// |-1234            | `00 00 00 02 ed cc`
44/// | -deadbeef       | `00 00 00 05 ff 21 52 41 11`
45// TODO(tarcieri): support for heapless platforms
46#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
47#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
48pub struct MPInt {
49    /// Inner big endian-serialized integer value
50    inner: Vec<u8>,
51}
52
53impl MPInt {
54    /// Create a new multiple precision integer from the given
55    /// big endian-encoded byte slice.
56    ///
57    /// Note that this method expects a leading zero on positive integers whose
58    /// MSB is set, but does *NOT* expect a 4-byte length prefix.
59    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
60        bytes.try_into()
61    }
62
63    /// Create a new multiple precision integer from the given big endian
64    /// encoded byte slice representing a positive integer.
65    ///
66    /// The integer should not start with any leading zeroes.
67    pub fn from_positive_bytes(bytes: &[u8]) -> Result<Self> {
68        let mut inner = Vec::with_capacity(bytes.len());
69
70        match bytes.first().cloned() {
71            Some(0) => return Err(Error::FormatEncoding),
72            Some(n) if n >= 0x80 => inner.push(0),
73            _ => (),
74        }
75
76        inner.extend_from_slice(bytes);
77        inner.try_into()
78    }
79
80    /// Get the big integer data encoded as big endian bytes.
81    ///
82    /// This slice will contain a leading zero if the value is positive but the
83    /// MSB is also set. Use [`MPInt::as_positive_bytes`] to ensure the number
84    /// is positive and strip the leading zero byte if it exists.
85    pub fn as_bytes(&self) -> &[u8] {
86        &self.inner
87    }
88
89    /// Get the bytes of a positive integer.
90    ///
91    /// # Returns
92    /// - `Some(bytes)` if the number is positive. The leading zero byte will be stripped.
93    /// - `None` if the value is negative
94    pub fn as_positive_bytes(&self) -> Option<&[u8]> {
95        match self.as_bytes() {
96            [0x00, rest @ ..] => Some(rest),
97            [byte, ..] if *byte < 0x80 => Some(self.as_bytes()),
98            _ => None,
99        }
100    }
101}
102
103impl AsRef<[u8]> for MPInt {
104    fn as_ref(&self) -> &[u8] {
105        self.as_bytes()
106    }
107}
108
109impl Decode for MPInt {
110    fn decode(reader: &mut impl Reader) -> Result<Self> {
111        Vec::decode(reader)?.try_into()
112    }
113}
114
115impl Encode for MPInt {
116    fn encoded_len(&self) -> Result<usize> {
117        [4, self.as_bytes().len()].checked_sum()
118    }
119
120    fn encode(&self, writer: &mut impl Writer) -> Result<()> {
121        self.as_bytes().encode(writer)
122    }
123}
124
125impl TryFrom<&[u8]> for MPInt {
126    type Error = Error;
127
128    fn try_from(bytes: &[u8]) -> Result<Self> {
129        Vec::from(bytes).try_into()
130    }
131}
132
133impl TryFrom<Vec<u8>> for MPInt {
134    type Error = Error;
135
136    fn try_from(bytes: Vec<u8>) -> Result<Self> {
137        match bytes.as_slice() {
138            // Unnecessary leading 0
139            [0x00] => Err(Error::FormatEncoding),
140            // Unnecessary leading 0
141            [0x00, n, ..] if *n < 0x80 => Err(Error::FormatEncoding),
142            _ => Ok(Self { inner: bytes }),
143        }
144    }
145}
146
147impl Zeroize for MPInt {
148    fn zeroize(&mut self) {
149        self.inner.zeroize();
150    }
151}
152
153impl fmt::Debug for MPInt {
154    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155        write!(f, "MPInt({:X})", self)
156    }
157}
158
159impl fmt::Display for MPInt {
160    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161        write!(f, "{:X}", self)
162    }
163}
164
165impl fmt::LowerHex for MPInt {
166    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167        for byte in self.as_bytes() {
168            write!(f, "{:02x}", byte)?;
169        }
170        Ok(())
171    }
172}
173
174impl fmt::UpperHex for MPInt {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        for byte in self.as_bytes() {
177            write!(f, "{:02X}", byte)?;
178        }
179        Ok(())
180    }
181}
182
183#[cfg(feature = "rsa")]
184#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
185impl TryFrom<rsa::BigUint> for MPInt {
186    type Error = Error;
187
188    fn try_from(uint: rsa::BigUint) -> Result<MPInt> {
189        MPInt::try_from(&uint)
190    }
191}
192
193#[cfg(feature = "rsa")]
194#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
195impl TryFrom<&rsa::BigUint> for MPInt {
196    type Error = Error;
197
198    fn try_from(uint: &rsa::BigUint) -> Result<MPInt> {
199        let bytes = Zeroizing::new(uint.to_bytes_be());
200        MPInt::from_positive_bytes(bytes.as_slice())
201    }
202}
203
204#[cfg(feature = "rsa")]
205#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
206impl TryFrom<MPInt> for rsa::BigUint {
207    type Error = Error;
208
209    fn try_from(mpint: MPInt) -> Result<rsa::BigUint> {
210        rsa::BigUint::try_from(&mpint)
211    }
212}
213
214#[cfg(feature = "rsa")]
215#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
216impl TryFrom<&MPInt> for rsa::BigUint {
217    type Error = Error;
218
219    fn try_from(mpint: &MPInt) -> Result<rsa::BigUint> {
220        mpint
221            .as_positive_bytes()
222            .map(rsa::BigUint::from_bytes_be)
223            .ok_or(Error::Crypto)
224    }
225}
226
227#[cfg(feature = "subtle")]
228#[cfg_attr(docsrs, doc(cfg(feature = "subtle")))]
229impl ConstantTimeEq for MPInt {
230    fn ct_eq(&self, other: &Self) -> Choice {
231        self.as_ref().ct_eq(other.as_ref())
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::MPInt;
238    use hex_literal::hex;
239
240    #[test]
241    fn decode_0() {
242        let n = MPInt::from_bytes(b"").unwrap();
243        assert_eq!(b"", n.as_bytes())
244    }
245
246    #[test]
247    fn reject_extra_leading_zeroes() {
248        assert!(MPInt::from_bytes(&hex!("00")).is_err());
249        assert!(MPInt::from_bytes(&hex!("00 00")).is_err());
250        assert!(MPInt::from_bytes(&hex!("00 01")).is_err());
251    }
252
253    #[test]
254    fn decode_9a378f9b2e332a7() {
255        assert!(MPInt::from_bytes(&hex!("09 a3 78 f9 b2 e3 32 a7")).is_ok());
256    }
257
258    #[test]
259    fn decode_80() {
260        let n = MPInt::from_bytes(&hex!("00 80")).unwrap();
261
262        // Leading zero stripped
263        assert_eq!(&hex!("80"), n.as_positive_bytes().unwrap())
264    }
265
266    // TODO(tarcieri): drop support for negative numbers?
267    #[test]
268    fn decode_neg_1234() {
269        let n = MPInt::from_bytes(&hex!("ed cc")).unwrap();
270        assert!(n.as_positive_bytes().is_none());
271    }
272
273    // TODO(tarcieri): drop support for negative numbers?
274    #[test]
275    fn decode_neg_deadbeef() {
276        let n = MPInt::from_bytes(&hex!("ff 21 52 41 11")).unwrap();
277        assert!(n.as_positive_bytes().is_none());
278    }
279}