1#![allow(missing_docs)]
2use crate::{
8 errors::{self, Error, ErrorKind},
9 Algorithm,
10};
11use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
12use std::{fmt, str::FromStr};
13
14#[derive(Clone, Debug, Eq, PartialEq, Hash)]
16pub enum PublicKeyUse {
17 Signature,
19 Encryption,
21 Other(String),
23}
24
25impl Serialize for PublicKeyUse {
26 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
27 where
28 S: Serializer,
29 {
30 let string = match self {
31 PublicKeyUse::Signature => "sig",
32 PublicKeyUse::Encryption => "enc",
33 PublicKeyUse::Other(other) => other,
34 };
35
36 serializer.serialize_str(string)
37 }
38}
39
40impl<'de> Deserialize<'de> for PublicKeyUse {
41 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
42 where
43 D: Deserializer<'de>,
44 {
45 struct PublicKeyUseVisitor;
46 impl<'de> de::Visitor<'de> for PublicKeyUseVisitor {
47 type Value = PublicKeyUse;
48
49 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
50 write!(formatter, "a string")
51 }
52
53 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
54 where
55 E: de::Error,
56 {
57 Ok(match v {
58 "sig" => PublicKeyUse::Signature,
59 "enc" => PublicKeyUse::Encryption,
60 other => PublicKeyUse::Other(other.to_string()),
61 })
62 }
63 }
64
65 deserializer.deserialize_string(PublicKeyUseVisitor)
66 }
67}
68
69#[derive(Clone, Debug, Eq, PartialEq, Hash)]
71pub enum KeyOperations {
72 Sign,
74 Verify,
76 Encrypt,
78 Decrypt,
80 WrapKey,
82 UnwrapKey,
84 DeriveKey,
86 DeriveBits,
88 Other(String),
90}
91
92impl Serialize for KeyOperations {
93 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
94 where
95 S: Serializer,
96 {
97 let string = match self {
98 KeyOperations::Sign => "sign",
99 KeyOperations::Verify => "verify",
100 KeyOperations::Encrypt => "encrypt",
101 KeyOperations::Decrypt => "decrypt",
102 KeyOperations::WrapKey => "wrapKey",
103 KeyOperations::UnwrapKey => "unwrapKey",
104 KeyOperations::DeriveKey => "deriveKey",
105 KeyOperations::DeriveBits => "deriveBits",
106 KeyOperations::Other(other) => other,
107 };
108
109 serializer.serialize_str(string)
110 }
111}
112
113impl<'de> Deserialize<'de> for KeyOperations {
114 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
115 where
116 D: Deserializer<'de>,
117 {
118 struct KeyOperationsVisitor;
119 impl<'de> de::Visitor<'de> for KeyOperationsVisitor {
120 type Value = KeyOperations;
121
122 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
123 write!(formatter, "a string")
124 }
125
126 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
127 where
128 E: de::Error,
129 {
130 Ok(match v {
131 "sign" => KeyOperations::Sign,
132 "verify" => KeyOperations::Verify,
133 "encrypt" => KeyOperations::Encrypt,
134 "decrypt" => KeyOperations::Decrypt,
135 "wrapKey" => KeyOperations::WrapKey,
136 "unwrapKey" => KeyOperations::UnwrapKey,
137 "deriveKey" => KeyOperations::DeriveKey,
138 "deriveBits" => KeyOperations::DeriveBits,
139 other => KeyOperations::Other(other.to_string()),
140 })
141 }
142 }
143
144 deserializer.deserialize_string(KeyOperationsVisitor)
145 }
146}
147
148#[allow(non_camel_case_types, clippy::upper_case_acronyms)]
150#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone, Serialize, Deserialize)]
151pub enum KeyAlgorithm {
152 HS256,
154 HS384,
156 HS512,
158
159 ES256,
161 ES384,
163
164 RS256,
166 RS384,
168 RS512,
170
171 PS256,
173 PS384,
175 PS512,
177
178 EdDSA,
180
181 RSA1_5,
183
184 #[serde(rename = "RSA-OAEP")]
186 RSA_OAEP,
187
188 #[serde(rename = "RSA-OAEP-256")]
190 RSA_OAEP_256,
191}
192
193impl FromStr for KeyAlgorithm {
194 type Err = Error;
195 fn from_str(s: &str) -> errors::Result<Self> {
196 match s {
197 "HS256" => Ok(KeyAlgorithm::HS256),
198 "HS384" => Ok(KeyAlgorithm::HS384),
199 "HS512" => Ok(KeyAlgorithm::HS512),
200 "ES256" => Ok(KeyAlgorithm::ES256),
201 "ES384" => Ok(KeyAlgorithm::ES384),
202 "RS256" => Ok(KeyAlgorithm::RS256),
203 "RS384" => Ok(KeyAlgorithm::RS384),
204 "PS256" => Ok(KeyAlgorithm::PS256),
205 "PS384" => Ok(KeyAlgorithm::PS384),
206 "PS512" => Ok(KeyAlgorithm::PS512),
207 "RS512" => Ok(KeyAlgorithm::RS512),
208 "EdDSA" => Ok(KeyAlgorithm::EdDSA),
209 "RSA1_5" => Ok(KeyAlgorithm::RSA1_5),
210 "RSA-OAEP" => Ok(KeyAlgorithm::RSA_OAEP),
211 "RSA-OAEP-256" => Ok(KeyAlgorithm::RSA_OAEP_256),
212 _ => Err(ErrorKind::InvalidAlgorithmName.into()),
213 }
214 }
215}
216
217impl fmt::Display for KeyAlgorithm {
218 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
219 write!(f, "{:?}", self)
220 }
221}
222
223impl KeyAlgorithm {
224 fn to_algorithm(self) -> errors::Result<Algorithm> {
225 Algorithm::from_str(self.to_string().as_str())
226 }
227}
228
229#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize, Default, Hash)]
231pub struct CommonParameters {
232 #[serde(rename = "use", skip_serializing_if = "Option::is_none", default)]
235 pub public_key_use: Option<PublicKeyUse>,
236
237 #[serde(rename = "key_ops", skip_serializing_if = "Option::is_none", default)]
244 pub key_operations: Option<Vec<KeyOperations>>,
245
246 #[serde(rename = "alg", skip_serializing_if = "Option::is_none", default)]
248 pub key_algorithm: Option<KeyAlgorithm>,
249
250 #[serde(rename = "kid", skip_serializing_if = "Option::is_none", default)]
252 pub key_id: Option<String>,
253
254 #[serde(rename = "x5u", skip_serializing_if = "Option::is_none")]
258 pub x509_url: Option<String>,
259
260 #[serde(rename = "x5c", skip_serializing_if = "Option::is_none")]
264 pub x509_chain: Option<Vec<String>>,
265
266 #[serde(rename = "x5t", skip_serializing_if = "Option::is_none")]
270 pub x509_sha1_fingerprint: Option<String>,
271
272 #[serde(rename = "x5t#S256", skip_serializing_if = "Option::is_none")]
276 pub x509_sha256_fingerprint: Option<String>,
277}
278
279#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize, Hash)]
282pub enum EllipticCurveKeyType {
283 #[default]
285 EC,
286}
287
288#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, Hash)]
291pub enum EllipticCurve {
292 #[serde(rename = "P-256")]
294 #[default]
295 P256,
296 #[serde(rename = "P-384")]
298 P384,
299 #[serde(rename = "P-521")]
301 P521,
302 #[serde(rename = "Ed25519")]
304 Ed25519,
305}
306
307#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Default, Hash)]
309pub struct EllipticCurveKeyParameters {
310 #[serde(rename = "kty")]
312 pub key_type: EllipticCurveKeyType,
313 #[serde(rename = "crv")]
316 pub curve: EllipticCurve,
317 pub x: String,
320 pub y: String,
323}
324
325#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize, Hash)]
328pub enum RSAKeyType {
329 #[default]
331 RSA,
332}
333
334#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Default, Hash)]
336pub struct RSAKeyParameters {
337 #[serde(rename = "kty")]
339 pub key_type: RSAKeyType,
340
341 pub n: String,
344
345 pub e: String,
348}
349
350#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize, Hash)]
353pub enum OctetKeyType {
354 #[serde(rename = "oct")]
356 #[default]
357 Octet,
358}
359
360#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Default, Hash)]
362pub struct OctetKeyParameters {
363 #[serde(rename = "kty")]
365 pub key_type: OctetKeyType,
366 #[serde(rename = "k")]
368 pub value: String,
369}
370
371#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize, Hash)]
374pub enum OctetKeyPairType {
375 #[serde(rename = "OKP")]
377 #[default]
378 OctetKeyPair,
379}
380
381#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Default, Hash)]
383pub struct OctetKeyPairParameters {
384 #[serde(rename = "kty")]
386 pub key_type: OctetKeyPairType,
387 #[serde(rename = "crv")]
390 pub curve: EllipticCurve,
391 pub x: String,
393}
394
395#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
397#[serde(untagged)]
398pub enum AlgorithmParameters {
399 EllipticCurve(EllipticCurveKeyParameters),
400 RSA(RSAKeyParameters),
401 OctetKey(OctetKeyParameters),
402 OctetKeyPair(OctetKeyPairParameters),
403}
404
405#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
406pub struct Jwk {
407 #[serde(flatten)]
408 pub common: CommonParameters,
409 #[serde(flatten)]
411 pub algorithm: AlgorithmParameters,
412}
413
414impl Jwk {
415 pub fn is_supported(&self) -> bool {
417 self.common.key_algorithm.unwrap().to_algorithm().is_ok()
418 }
419}
420
421#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
423pub struct JwkSet {
424 pub keys: Vec<Jwk>,
425}
426
427impl JwkSet {
428 pub fn find(&self, kid: &str) -> Option<&Jwk> {
430 self.keys
431 .iter()
432 .find(|jwk| jwk.common.key_id.is_some() && jwk.common.key_id.as_ref().unwrap() == kid)
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use crate::jwk::{AlgorithmParameters, JwkSet, OctetKeyType};
439 use crate::serialization::b64_encode;
440 use crate::Algorithm;
441 use serde_json::json;
442 use wasm_bindgen_test::wasm_bindgen_test;
443
444 #[test]
445 #[wasm_bindgen_test]
446 fn check_hs256() {
447 let key = b64_encode("abcdefghijklmnopqrstuvwxyz012345");
448 let jwks_json = json!({
449 "keys": [
450 {
451 "kty": "oct",
452 "alg": "HS256",
453 "kid": "abc123",
454 "k": key
455 }
456 ]
457 });
458
459 let set: JwkSet = serde_json::from_value(jwks_json).expect("Failed HS256 check");
460 assert_eq!(set.keys.len(), 1);
461 let key = &set.keys[0];
462 assert_eq!(key.common.key_id, Some("abc123".to_string()));
463 let algorithm = key.common.key_algorithm.unwrap().to_algorithm().unwrap();
464 assert_eq!(algorithm, Algorithm::HS256);
465
466 match &key.algorithm {
467 AlgorithmParameters::OctetKey(key) => {
468 assert_eq!(key.key_type, OctetKeyType::Octet);
469 assert_eq!(key.value, key.value)
470 }
471 _ => panic!("Unexpected key algorithm"),
472 }
473 }
474}