1use crate::{
4 asn1::*, ByteSlice, Choice, Decodable, DecodeValue, Error, ErrorKind, FixedTag, Header, Length,
5 Result, Tag, TagMode, TagNumber,
6};
7
8#[derive(Clone, Debug)]
10pub struct Decoder<'a> {
11 bytes: Option<ByteSlice<'a>>,
16
17 position: Length,
19}
20
21impl<'a> Decoder<'a> {
22 pub fn new(bytes: &'a [u8]) -> Result<Self> {
24 Ok(Self {
25 bytes: Some(ByteSlice::new(bytes)?),
26 position: Length::ZERO,
27 })
28 }
29
30 pub fn decode<T: Decodable<'a>>(&mut self) -> Result<T> {
32 if self.is_failed() {
33 return Err(self.error(ErrorKind::Failed));
34 }
35
36 T::decode(self).map_err(|e| {
37 self.bytes.take();
38 e.nested(self.position)
39 })
40 }
41
42 pub fn error(&mut self, kind: ErrorKind) -> Error {
45 self.bytes.take();
46 kind.at(self.position)
47 }
48
49 pub fn value_error(&mut self, tag: Tag) -> Error {
51 self.error(tag.value_error().kind())
52 }
53
54 pub fn is_failed(&self) -> bool {
56 self.bytes.is_none()
57 }
58
59 pub fn position(&self) -> Length {
61 self.position
62 }
63
64 pub fn peek_byte(&self) -> Option<u8> {
66 self.remaining()
67 .ok()
68 .and_then(|bytes| bytes.get(0).cloned())
69 }
70
71 pub fn peek_tag(&self) -> Result<Tag> {
76 match self.peek_byte() {
77 Some(byte) => byte.try_into(),
78 None => {
79 let actual_len = self.input_len()?;
80 let expected_len = (actual_len + Length::ONE)?;
81 Err(ErrorKind::Incomplete {
82 expected_len,
83 actual_len,
84 }
85 .into())
86 }
87 }
88 }
89
90 pub fn peek_header(&self) -> Result<Header> {
95 Header::decode(&mut self.clone())
96 }
97
98 pub fn finish<T>(self, value: T) -> Result<T> {
101 if self.is_failed() {
102 Err(ErrorKind::Failed.at(self.position))
103 } else if !self.is_finished() {
104 Err(ErrorKind::TrailingData {
105 decoded: self.position,
106 remaining: self.remaining_len()?,
107 }
108 .at(self.position))
109 } else {
110 Ok(value)
111 }
112 }
113
114 pub fn is_finished(&self) -> bool {
119 self.remaining().map(|rem| rem.is_empty()).unwrap_or(false)
120 }
121
122 pub fn any(&mut self) -> Result<Any<'a>> {
124 self.decode()
125 }
126
127 pub fn any_optional(&mut self) -> Result<Option<Any<'a>>> {
129 self.decode()
130 }
131
132 pub fn int8(&mut self) -> Result<i8> {
134 self.decode()
135 }
136
137 pub fn int16(&mut self) -> Result<i16> {
139 self.decode()
140 }
141
142 pub fn uint8(&mut self) -> Result<u8> {
144 self.decode()
145 }
146
147 pub fn uint16(&mut self) -> Result<u16> {
149 self.decode()
150 }
151
152 #[cfg(feature = "bigint")]
154 #[cfg_attr(docsrs, doc(cfg(feature = "bigint")))]
155 pub fn uint_bytes(&mut self) -> Result<UIntBytes<'a>> {
156 self.decode()
157 }
158
159 pub fn bit_string(&mut self) -> Result<BitString<'a>> {
161 self.decode()
162 }
163
164 pub fn context_specific<T>(
167 &mut self,
168 tag_number: TagNumber,
169 tag_mode: TagMode,
170 ) -> Result<Option<T>>
171 where
172 T: DecodeValue<'a> + FixedTag,
173 {
174 Ok(match tag_mode {
175 TagMode::Explicit => ContextSpecific::<T>::decode_explicit(self, tag_number)?,
176 TagMode::Implicit => ContextSpecific::<T>::decode_implicit(self, tag_number)?,
177 }
178 .map(|field| field.value))
179 }
180
181 pub fn generalized_time(&mut self) -> Result<GeneralizedTime> {
183 self.decode()
184 }
185
186 pub fn ia5_string(&mut self) -> Result<Ia5String<'a>> {
188 self.decode()
189 }
190
191 pub fn null(&mut self) -> Result<Null> {
193 self.decode()
194 }
195
196 pub fn octet_string(&mut self) -> Result<OctetString<'a>> {
198 self.decode()
199 }
200
201 #[cfg(feature = "oid")]
203 #[cfg_attr(docsrs, doc(cfg(feature = "oid")))]
204 pub fn oid(&mut self) -> Result<ObjectIdentifier> {
205 self.decode()
206 }
207
208 pub fn optional<T: Choice<'a>>(&mut self) -> Result<Option<T>> {
210 self.decode()
211 }
212
213 pub fn printable_string(&mut self) -> Result<PrintableString<'a>> {
215 self.decode()
216 }
217
218 pub fn utc_time(&mut self) -> Result<UtcTime> {
220 self.decode()
221 }
222
223 pub fn utf8_string(&mut self) -> Result<Utf8String<'a>> {
225 self.decode()
226 }
227
228 pub fn sequence<F, T>(&mut self, f: F) -> Result<T>
231 where
232 F: FnOnce(&mut Decoder<'a>) -> Result<T>,
233 {
234 Tag::try_from(self.byte()?)?.assert_eq(Tag::Sequence)?;
235 let len = Length::decode(self)?;
236 self.decode_nested(len, f)
237 }
238
239 pub(crate) fn byte(&mut self) -> Result<u8> {
241 match self.bytes(1u8)? {
242 [byte] => Ok(*byte),
243 _ => {
244 let actual_len = self.input_len()?;
245 let expected_len = (actual_len + Length::ONE)?;
246 Err(self.error(ErrorKind::Incomplete {
247 expected_len,
248 actual_len,
249 }))
250 }
251 }
252 }
253
254 pub(crate) fn bytes(&mut self, len: impl TryInto<Length>) -> Result<&'a [u8]> {
257 if self.is_failed() {
258 return Err(self.error(ErrorKind::Failed));
259 }
260
261 let len = len
262 .try_into()
263 .map_err(|_| self.error(ErrorKind::Overflow))?;
264
265 match self.remaining()?.get(..len.try_into()?) {
266 Some(result) => {
267 self.position = (self.position + len)?;
268 Ok(result)
269 }
270 None => {
271 let actual_len = self.input_len()?;
272 let expected_len = (actual_len + len)?;
273 Err(self.error(ErrorKind::Incomplete {
274 expected_len,
275 actual_len,
276 }))
277 }
278 }
279 }
280
281 pub(crate) fn input_len(&self) -> Result<Length> {
283 Ok(self.bytes.ok_or(ErrorKind::Failed)?.len())
284 }
285
286 pub(crate) fn remaining_len(&self) -> Result<Length> {
288 self.remaining()?.len().try_into()
289 }
290
291 fn decode_nested<F, T>(&mut self, length: Length, f: F) -> Result<T>
296 where
297 F: FnOnce(&mut Self) -> Result<T>,
298 {
299 let start_pos = self.position();
300 let end_pos = (start_pos + length)?;
301 let bytes = match self.bytes {
302 Some(slice) => {
303 slice
304 .as_bytes()
305 .get(..end_pos.try_into()?)
306 .ok_or(ErrorKind::Incomplete {
307 expected_len: end_pos,
308 actual_len: self.input_len()?,
309 })?
310 }
311 None => return Err(self.error(ErrorKind::Failed)),
312 };
313
314 let mut nested_decoder = Self {
315 bytes: Some(ByteSlice::new(bytes)?),
316 position: start_pos,
317 };
318
319 self.position = end_pos;
320 let result = f(&mut nested_decoder)?;
321 nested_decoder.finish(result)
322 }
323
324 fn remaining(&self) -> Result<&'a [u8]> {
327 let pos = usize::try_from(self.position)?;
328
329 match self.bytes.and_then(|slice| slice.as_bytes().get(pos..)) {
330 Some(result) => Ok(result),
331 None => {
332 let actual_len = self.input_len()?;
333 let expected_len = (actual_len + Length::ONE)?;
334 Err(ErrorKind::Incomplete {
335 expected_len,
336 actual_len,
337 }
338 .at(self.position))
339 }
340 }
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::Decoder;
347 use crate::{Decodable, ErrorKind, Length, Tag};
348 use hex_literal::hex;
349
350 const EXAMPLE_MSG: &[u8] = &hex!("02012A00");
352
353 #[test]
354 fn empty_message() {
355 let mut decoder = Decoder::new(&[]).unwrap();
356 let err = bool::decode(&mut decoder).err().unwrap();
357 assert_eq!(Some(Length::ZERO), err.position());
358
359 match err.kind() {
360 ErrorKind::Incomplete {
361 expected_len,
362 actual_len,
363 } => {
364 assert_eq!(expected_len, 1u8.into());
365 assert_eq!(actual_len, 0u8.into());
366 }
367 other => panic!("unexpected error kind: {:?}", other),
368 }
369 }
370
371 #[test]
372 fn invalid_field_length() {
373 let mut decoder = Decoder::new(&EXAMPLE_MSG[..2]).unwrap();
374 let err = i8::decode(&mut decoder).err().unwrap();
375 assert_eq!(Some(Length::from(2u8)), err.position());
376
377 match err.kind() {
378 ErrorKind::Incomplete {
379 expected_len,
380 actual_len,
381 } => {
382 assert_eq!(expected_len, 3u8.into());
383 assert_eq!(actual_len, 2u8.into());
384 }
385 other => panic!("unexpected error kind: {:?}", other),
386 }
387 }
388
389 #[test]
390 fn trailing_data() {
391 let mut decoder = Decoder::new(EXAMPLE_MSG).unwrap();
392 let x = decoder.decode().unwrap();
393 assert_eq!(42i8, x);
394
395 let err = decoder.finish(x).err().unwrap();
396 assert_eq!(Some(Length::from(3u8)), err.position());
397
398 assert_eq!(
399 ErrorKind::TrailingData {
400 decoded: 3u8.into(),
401 remaining: 1u8.into()
402 },
403 err.kind()
404 );
405 }
406
407 #[test]
408 fn peek_tag() {
409 let decoder = Decoder::new(EXAMPLE_MSG).unwrap();
410 assert_eq!(decoder.position(), Length::ZERO);
411 assert_eq!(decoder.peek_tag().unwrap(), Tag::Integer);
412 assert_eq!(decoder.position(), Length::ZERO); }
414
415 #[test]
416 fn peek_header() {
417 let decoder = Decoder::new(EXAMPLE_MSG).unwrap();
418 assert_eq!(decoder.position(), Length::ZERO);
419
420 let header = decoder.peek_header().unwrap();
421 assert_eq!(header.tag, Tag::Integer);
422 assert_eq!(header.length, Length::ONE);
423 assert_eq!(decoder.position(), Length::ZERO); }
425}