1use crate::{
4 checked::CheckedSum, decode::Decode, encode::Encode, reader::Reader, writer::Writer, Error,
5 Result,
6};
7use alloc::vec::Vec;
8use core::fmt;
9use zeroize::Zeroize;
10
11#[cfg(feature = "rsa")]
12use zeroize::Zeroizing;
13
14#[cfg(feature = "subtle")]
15use subtle::{Choice, ConstantTimeEq};
16
17#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
47#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
48pub struct MPInt {
49 inner: Vec<u8>,
51}
52
53impl MPInt {
54 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
60 bytes.try_into()
61 }
62
63 pub fn from_positive_bytes(bytes: &[u8]) -> Result<Self> {
68 let mut inner = Vec::with_capacity(bytes.len());
69
70 match bytes.first().cloned() {
71 Some(0) => return Err(Error::FormatEncoding),
72 Some(n) if n >= 0x80 => inner.push(0),
73 _ => (),
74 }
75
76 inner.extend_from_slice(bytes);
77 inner.try_into()
78 }
79
80 pub fn as_bytes(&self) -> &[u8] {
86 &self.inner
87 }
88
89 pub fn as_positive_bytes(&self) -> Option<&[u8]> {
95 match self.as_bytes() {
96 [0x00, rest @ ..] => Some(rest),
97 [byte, ..] if *byte < 0x80 => Some(self.as_bytes()),
98 _ => None,
99 }
100 }
101}
102
103impl AsRef<[u8]> for MPInt {
104 fn as_ref(&self) -> &[u8] {
105 self.as_bytes()
106 }
107}
108
109impl Decode for MPInt {
110 fn decode(reader: &mut impl Reader) -> Result<Self> {
111 Vec::decode(reader)?.try_into()
112 }
113}
114
115impl Encode for MPInt {
116 fn encoded_len(&self) -> Result<usize> {
117 [4, self.as_bytes().len()].checked_sum()
118 }
119
120 fn encode(&self, writer: &mut impl Writer) -> Result<()> {
121 self.as_bytes().encode(writer)
122 }
123}
124
125impl TryFrom<&[u8]> for MPInt {
126 type Error = Error;
127
128 fn try_from(bytes: &[u8]) -> Result<Self> {
129 Vec::from(bytes).try_into()
130 }
131}
132
133impl TryFrom<Vec<u8>> for MPInt {
134 type Error = Error;
135
136 fn try_from(bytes: Vec<u8>) -> Result<Self> {
137 match bytes.as_slice() {
138 [0x00] => Err(Error::FormatEncoding),
140 [0x00, n, ..] if *n < 0x80 => Err(Error::FormatEncoding),
142 _ => Ok(Self { inner: bytes }),
143 }
144 }
145}
146
147impl Zeroize for MPInt {
148 fn zeroize(&mut self) {
149 self.inner.zeroize();
150 }
151}
152
153impl fmt::Debug for MPInt {
154 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155 write!(f, "MPInt({:X})", self)
156 }
157}
158
159impl fmt::Display for MPInt {
160 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161 write!(f, "{:X}", self)
162 }
163}
164
165impl fmt::LowerHex for MPInt {
166 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167 for byte in self.as_bytes() {
168 write!(f, "{:02x}", byte)?;
169 }
170 Ok(())
171 }
172}
173
174impl fmt::UpperHex for MPInt {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 for byte in self.as_bytes() {
177 write!(f, "{:02X}", byte)?;
178 }
179 Ok(())
180 }
181}
182
183#[cfg(feature = "rsa")]
184#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
185impl TryFrom<rsa::BigUint> for MPInt {
186 type Error = Error;
187
188 fn try_from(uint: rsa::BigUint) -> Result<MPInt> {
189 MPInt::try_from(&uint)
190 }
191}
192
193#[cfg(feature = "rsa")]
194#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
195impl TryFrom<&rsa::BigUint> for MPInt {
196 type Error = Error;
197
198 fn try_from(uint: &rsa::BigUint) -> Result<MPInt> {
199 let bytes = Zeroizing::new(uint.to_bytes_be());
200 MPInt::from_positive_bytes(bytes.as_slice())
201 }
202}
203
204#[cfg(feature = "rsa")]
205#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
206impl TryFrom<MPInt> for rsa::BigUint {
207 type Error = Error;
208
209 fn try_from(mpint: MPInt) -> Result<rsa::BigUint> {
210 rsa::BigUint::try_from(&mpint)
211 }
212}
213
214#[cfg(feature = "rsa")]
215#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
216impl TryFrom<&MPInt> for rsa::BigUint {
217 type Error = Error;
218
219 fn try_from(mpint: &MPInt) -> Result<rsa::BigUint> {
220 mpint
221 .as_positive_bytes()
222 .map(rsa::BigUint::from_bytes_be)
223 .ok_or(Error::Crypto)
224 }
225}
226
227#[cfg(feature = "subtle")]
228#[cfg_attr(docsrs, doc(cfg(feature = "subtle")))]
229impl ConstantTimeEq for MPInt {
230 fn ct_eq(&self, other: &Self) -> Choice {
231 self.as_ref().ct_eq(other.as_ref())
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::MPInt;
238 use hex_literal::hex;
239
240 #[test]
241 fn decode_0() {
242 let n = MPInt::from_bytes(b"").unwrap();
243 assert_eq!(b"", n.as_bytes())
244 }
245
246 #[test]
247 fn reject_extra_leading_zeroes() {
248 assert!(MPInt::from_bytes(&hex!("00")).is_err());
249 assert!(MPInt::from_bytes(&hex!("00 00")).is_err());
250 assert!(MPInt::from_bytes(&hex!("00 01")).is_err());
251 }
252
253 #[test]
254 fn decode_9a378f9b2e332a7() {
255 assert!(MPInt::from_bytes(&hex!("09 a3 78 f9 b2 e3 32 a7")).is_ok());
256 }
257
258 #[test]
259 fn decode_80() {
260 let n = MPInt::from_bytes(&hex!("00 80")).unwrap();
261
262 assert_eq!(&hex!("80"), n.as_positive_bytes().unwrap())
264 }
265
266 #[test]
268 fn decode_neg_1234() {
269 let n = MPInt::from_bytes(&hex!("ed cc")).unwrap();
270 assert!(n.as_positive_bytes().is_none());
271 }
272
273 #[test]
275 fn decode_neg_deadbeef() {
276 let n = MPInt::from_bytes(&hex!("ff 21 52 41 11")).unwrap();
277 assert!(n.as_positive_bytes().is_none());
278 }
279}