1use std::fmt::{Debug, Formatter};
2
3use base64::{Engine, engine::general_purpose::STANDARD};
4use serde::de::DeserializeOwned;
5
6use crate::algorithms::AlgorithmFamily;
7use crate::crypto::{CryptoProvider, JwtVerifier};
8use crate::errors::{ErrorKind, Result, new_error};
9use crate::header::Header;
10use crate::jwk::{AlgorithmParameters, Jwk};
11#[cfg(feature = "use_pem")]
12use crate::pem::decoder::PemEncodedKey;
13use crate::serialization::{DecodedJwtPartClaims, b64_decode};
14use crate::validation::{Validation, validate};
15
16#[derive(Debug)]
18pub struct TokenData<T> {
19 pub header: Header,
21 pub claims: T,
23}
24
25impl<T> Clone for TokenData<T>
26where
27 T: Clone,
28{
29 fn clone(&self) -> Self {
30 Self { header: self.header.clone(), claims: self.claims.clone() }
31 }
32}
33
34macro_rules! expect_two {
37 ($iter:expr) => {{
38 let mut i = $iter;
39 match (i.next(), i.next(), i.next()) {
40 (Some(first), Some(second), None) => (first, second),
41 _ => return Err(new_error(ErrorKind::InvalidToken)),
42 }
43 }};
44}
45
46#[derive(Clone)]
47pub enum DecodingKeyKind {
49 SecretOrDer(Vec<u8>),
51 RsaModulusExponent {
53 n: Vec<u8>,
55 e: Vec<u8>,
57 },
58}
59
60impl Debug for DecodingKeyKind {
61 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
62 match self {
63 Self::SecretOrDer(_) => f.debug_tuple("SecretOrDer").field(&"[redacted]").finish(),
64 Self::RsaModulusExponent { .. } => f
65 .debug_struct("RsaModulusExponent")
66 .field("n", &"[redacted]")
67 .field("e", &"[redacted]")
68 .finish(),
69 }
70 }
71}
72
73#[derive(Clone, Debug)]
76pub struct DecodingKey {
77 family: AlgorithmFamily,
78 kind: DecodingKeyKind,
79}
80
81impl DecodingKey {
82 pub fn family(&self) -> AlgorithmFamily {
84 self.family
85 }
86
87 pub fn kind(&self) -> &DecodingKeyKind {
89 &self.kind
90 }
91
92 pub fn from_secret(secret: &[u8]) -> Self {
94 DecodingKey {
95 family: AlgorithmFamily::Hmac,
96 kind: DecodingKeyKind::SecretOrDer(secret.to_vec()),
97 }
98 }
99
100 pub fn from_base64_secret(secret: &str) -> Result<Self> {
102 let out = STANDARD.decode(secret)?;
103 Ok(DecodingKey { family: AlgorithmFamily::Hmac, kind: DecodingKeyKind::SecretOrDer(out) })
104 }
105
106 #[cfg(feature = "use_pem")]
109 pub fn from_rsa_pem(key: &[u8]) -> Result<Self> {
110 let pem_key = PemEncodedKey::new(key)?;
111 let content = pem_key.as_rsa_key()?;
112 Ok(DecodingKey {
113 family: AlgorithmFamily::Rsa,
114 kind: DecodingKeyKind::SecretOrDer(content.to_vec()),
115 })
116 }
117
118 pub fn from_rsa_components(modulus: &str, exponent: &str) -> Result<Self> {
120 let n = b64_decode(modulus)?;
121 let e = b64_decode(exponent)?;
122 Ok(DecodingKey {
123 family: AlgorithmFamily::Rsa,
124 kind: DecodingKeyKind::RsaModulusExponent { n, e },
125 })
126 }
127
128 pub fn from_rsa_raw_components(modulus: &[u8], exponent: &[u8]) -> Self {
130 DecodingKey {
131 family: AlgorithmFamily::Rsa,
132 kind: DecodingKeyKind::RsaModulusExponent { n: modulus.to_vec(), e: exponent.to_vec() },
133 }
134 }
135
136 #[cfg(feature = "use_pem")]
139 pub fn from_ec_pem(key: &[u8]) -> Result<Self> {
140 let pem_key = PemEncodedKey::new(key)?;
141 let content = pem_key.as_ec_public_key()?;
142 Ok(DecodingKey {
143 family: AlgorithmFamily::Ec,
144 kind: DecodingKeyKind::SecretOrDer(content.to_vec()),
145 })
146 }
147
148 pub fn from_ec_components(x: &str, y: &str) -> Result<Self> {
150 let x_cmp = b64_decode(x)?;
151 let y_cmp = b64_decode(y)?;
152
153 let mut public_key = Vec::with_capacity(1 + x.len() + y.len());
154 public_key.push(0x04);
155 public_key.extend_from_slice(&x_cmp);
156 public_key.extend_from_slice(&y_cmp);
157
158 Ok(DecodingKey {
159 family: AlgorithmFamily::Ec,
160 kind: DecodingKeyKind::SecretOrDer(public_key),
161 })
162 }
163
164 #[cfg(feature = "use_pem")]
167 pub fn from_ed_pem(key: &[u8]) -> Result<Self> {
168 let pem_key = PemEncodedKey::new(key)?;
169 let content = pem_key.as_ed_public_key()?;
170 Ok(DecodingKey {
171 family: AlgorithmFamily::Ed,
172 kind: DecodingKeyKind::SecretOrDer(content.to_vec()),
173 })
174 }
175
176 pub fn from_rsa_der(der: &[u8]) -> Self {
178 DecodingKey {
179 family: AlgorithmFamily::Rsa,
180 kind: DecodingKeyKind::SecretOrDer(der.to_vec()),
181 }
182 }
183
184 pub fn from_ec_der(der: &[u8]) -> Self {
186 DecodingKey {
187 family: AlgorithmFamily::Ec,
188 kind: DecodingKeyKind::SecretOrDer(der.to_vec()),
189 }
190 }
191
192 pub fn from_ed_der(der: &[u8]) -> Self {
194 DecodingKey {
195 family: AlgorithmFamily::Ed,
196 kind: DecodingKeyKind::SecretOrDer(der.to_vec()),
197 }
198 }
199
200 pub fn from_ed_components(x: &str) -> Result<Self> {
202 let x_decoded = b64_decode(x)?;
203 Ok(DecodingKey {
204 family: AlgorithmFamily::Ed,
205 kind: DecodingKeyKind::SecretOrDer(x_decoded),
206 })
207 }
208
209 pub fn from_jwk(jwk: &Jwk) -> Result<Self> {
211 match &jwk.algorithm {
212 AlgorithmParameters::RSA(params) => {
213 DecodingKey::from_rsa_components(¶ms.n, ¶ms.e)
214 }
215 AlgorithmParameters::EllipticCurve(params) => {
216 DecodingKey::from_ec_components(¶ms.x, ¶ms.y)
217 }
218 AlgorithmParameters::OctetKeyPair(params) => DecodingKey::from_ed_components(¶ms.x),
219 AlgorithmParameters::OctetKey(params) => {
220 let out = b64_decode(¶ms.value)?;
221 Ok(DecodingKey {
222 family: AlgorithmFamily::Hmac,
223 kind: DecodingKeyKind::SecretOrDer(out),
224 })
225 }
226 }
227 }
228
229 pub fn as_bytes(&self) -> &[u8] {
231 match &self.kind {
232 DecodingKeyKind::SecretOrDer(b) => b,
233 DecodingKeyKind::RsaModulusExponent { .. } => unreachable!(),
234 }
235 }
236
237 pub fn try_get_hmac_secret(&self) -> Result<&[u8]> {
239 if self.family == AlgorithmFamily::Hmac {
240 Ok(self.as_bytes())
241 } else {
242 Err(new_error(ErrorKind::InvalidKeyFormat))
243 }
244 }
245}
246
247impl TryFrom<&Jwk> for DecodingKey {
248 type Error = crate::errors::Error;
249
250 fn try_from(jwk: &Jwk) -> Result<Self> {
251 Self::from_jwk(jwk)
252 }
253}
254
255pub fn decode<T: DeserializeOwned>(
274 token: impl AsRef<[u8]>,
275 key: &DecodingKey,
276 validation: &Validation,
277) -> Result<TokenData<T>> {
278 let token = token.as_ref();
279 let header = decode_header(token)?;
280
281 if validation.validate_signature && !validation.algorithms.contains(&header.alg) {
282 return Err(new_error(ErrorKind::InvalidAlgorithm));
283 }
284
285 let verifying_provider = (CryptoProvider::get_default().verifier_factory)(&header.alg, key)?;
286
287 let (header, claims) = verify_signature(token, validation, verifying_provider)?;
288
289 let decoded_claims = DecodedJwtPartClaims::from_jwt_part_claims(claims)?;
290 let claims = decoded_claims.deserialize()?;
291 validate(decoded_claims.deserialize()?, validation)?;
292
293 Ok(TokenData { header, claims })
294}
295
296pub fn insecure_decode<T: DeserializeOwned>(token: impl AsRef<[u8]>) -> Result<TokenData<T>> {
300 let token = token.as_ref();
301
302 let (_, message) = expect_two!(token.rsplitn(2, |b| *b == b'.'));
303 let (payload, header) = expect_two!(message.rsplitn(2, |b| *b == b'.'));
304
305 let header = Header::from_encoded(header)?;
306 let claims = DecodedJwtPartClaims::from_jwt_part_claims(payload)?.deserialize()?;
307
308 Ok(TokenData { header, claims })
309}
310
311pub fn decode_header(token: impl AsRef<[u8]>) -> Result<Header> {
322 let token = token.as_ref();
323 let (_, message) = expect_two!(token.rsplitn(2, |b| *b == b'.'));
324 let (_, header) = expect_two!(message.rsplitn(2, |b| *b == b'.'));
325 Header::from_encoded(header)
326}
327
328pub(crate) fn verify_signature_body(
329 message: &[u8],
330 signature: &[u8],
331 header: &Header,
332 validation: &Validation,
333 verifying_provider: Box<dyn JwtVerifier>,
334) -> Result<()> {
335 if validation.validate_signature && validation.algorithms.is_empty() {
336 return Err(new_error(ErrorKind::MissingAlgorithm));
337 }
338
339 if validation.validate_signature {
340 for alg in &validation.algorithms {
341 if verifying_provider.algorithm().family() != alg.family() {
342 return Err(new_error(ErrorKind::InvalidAlgorithm));
343 }
344 }
345 }
346
347 if validation.validate_signature && !validation.algorithms.contains(&header.alg) {
348 return Err(new_error(ErrorKind::InvalidAlgorithm));
349 }
350
351 if validation.validate_signature
352 && verifying_provider.verify(message, &b64_decode(signature)?).is_err()
353 {
354 return Err(new_error(ErrorKind::InvalidSignature));
355 }
356
357 Ok(())
358}
359
360fn verify_signature<'a>(
364 token: &'a [u8],
365 validation: &Validation,
366 verifying_provider: Box<dyn JwtVerifier>,
367) -> Result<(Header, &'a [u8])> {
368 let (signature, message) = expect_two!(token.rsplitn(2, |b| *b == b'.'));
369 let (payload, header) = expect_two!(message.rsplitn(2, |b| *b == b'.'));
370 let header = Header::from_encoded(header)?;
371 verify_signature_body(message, signature, &header, validation, verifying_provider)?;
372
373 Ok((header, payload))
374}