1use std::cmp::Ordering;
13use std::fmt;
14
15use aws_lc_rs::signature::{Ed25519KeyPair, KeyPair};
16use mz_ore::secure::Zeroizing;
17use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
18use serde::ser::{SerializeStruct, Serializer};
19use serde::{Deserialize, Serialize};
20use ssh_key::private::{Ed25519Keypair, Ed25519PrivateKey, KeypairData};
21use ssh_key::public::Ed25519PublicKey;
22use ssh_key::{HashAlg, LineEnding, PrivateKey};
23
24#[derive(Debug, Clone)]
26pub struct SshKeyPair {
27 key_pair: PrivateKey,
30}
31
32impl SshKeyPair {
33 pub fn new() -> Result<SshKeyPair, anyhow::Error> {
38 let mut seed = Zeroizing::new([0u8; 32]);
42 aws_lc_rs::rand::fill(&mut *seed)
43 .map_err(|_| anyhow::anyhow!("random generation failed"))?;
44
45 let aws_key = Ed25519KeyPair::from_seed_unchecked(&*seed)
47 .map_err(|_| anyhow::anyhow!("Ed25519 key generation failed"))?;
48
49 let key_pair_data = KeypairData::Ed25519(Ed25519Keypair {
50 public: Ed25519PublicKey::try_from(aws_key.public_key().as_ref())?,
51 private: Ed25519PrivateKey::try_from(seed.as_slice())?,
52 });
53
54 let key_pair = PrivateKey::new(key_pair_data, "materialize")?;
55
56 Ok(SshKeyPair { key_pair })
57 }
58
59 pub fn from_bytes(data: &[u8]) -> anyhow::Result<SshKeyPair> {
62 let set = SshKeyPairSet::from_bytes(data)?;
63 Ok(set.primary().clone())
64 }
65
66 fn from_private_key(private_key: &[u8]) -> Result<SshKeyPair, anyhow::Error> {
68 let private_key = PrivateKey::from_openssh(private_key)?;
69
70 Ok(SshKeyPair {
71 key_pair: private_key,
72 })
73 }
74
75 pub fn ssh_public_key(&self) -> String {
77 self.key_pair.public_key().to_string()
78 }
79
80 pub fn ssh_private_key(&self) -> Zeroizing<String> {
82 self.key_pair
83 .to_openssh(LineEnding::LF)
84 .expect("encoding as OpenSSH cannot fail")
85 }
86}
87
88impl PartialEq for SshKeyPair {
89 fn eq(&self, other: &Self) -> bool {
90 self.key_pair.fingerprint(HashAlg::default())
91 == other.key_pair.fingerprint(HashAlg::default())
92 }
93}
94
95impl Eq for SshKeyPair {}
96
97impl PartialOrd for SshKeyPair {
98 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
99 Some(self.cmp(other))
100 }
101}
102
103impl Ord for SshKeyPair {
104 fn cmp(&self, other: &Self) -> Ordering {
105 self.key_pair
106 .fingerprint(HashAlg::default())
107 .cmp(&other.key_pair.fingerprint(HashAlg::default()))
108 }
109}
110
111impl Serialize for SshKeyPair {
112 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
113 where
114 S: Serializer,
115 {
116 let mut state = serializer.serialize_struct("SshKeypair", 2)?;
117 state.serialize_field("public_key", self.ssh_public_key().as_bytes())?;
119 state.serialize_field("private_key", self.ssh_private_key().as_bytes())?;
120 state.end()
121 }
122}
123
124impl<'de> Deserialize<'de> for SshKeyPair {
125 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
126 where
127 D: Deserializer<'de>,
128 {
129 #[derive(Deserialize)]
130 #[serde(field_identifier, rename_all = "snake_case")]
131 enum Field {
132 PublicKey,
133 PrivateKey,
134 }
135
136 struct SshKeyPairVisitor;
137
138 impl<'de> Visitor<'de> for SshKeyPairVisitor {
139 type Value = SshKeyPair;
140
141 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
142 formatter.write_str("struct SshKeypair")
143 }
144
145 fn visit_seq<V>(self, mut seq: V) -> Result<SshKeyPair, V::Error>
146 where
147 V: SeqAccess<'de>,
148 {
149 let _public_key: Vec<u8> = seq
151 .next_element()?
152 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
153 let private_key: Zeroizing<Vec<u8>> = seq
154 .next_element()?
155 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
156 SshKeyPair::from_private_key(&private_key).map_err(de::Error::custom)
157 }
158
159 fn visit_map<V>(self, mut map: V) -> Result<SshKeyPair, V::Error>
160 where
161 V: MapAccess<'de>,
162 {
163 let mut _public_key: Option<Vec<u8>> = None;
165 let mut private_key: Option<Zeroizing<Vec<u8>>> = None;
166 while let Some(key) = map.next_key()? {
167 match key {
168 Field::PublicKey => {
169 if _public_key.is_some() {
170 return Err(de::Error::duplicate_field("public_key"));
171 }
172 _public_key = Some(map.next_value()?);
173 }
174 Field::PrivateKey => {
175 if private_key.is_some() {
176 return Err(de::Error::duplicate_field("private_key"));
177 }
178 private_key = Some(map.next_value()?);
179 }
180 }
181 }
182 let private_key =
183 private_key.ok_or_else(|| de::Error::missing_field("private_key"))?;
184 SshKeyPair::from_private_key(&private_key).map_err(de::Error::custom)
185 }
186 }
187
188 const FIELDS: &[&str] = &["public_key", "private_key"];
189 deserializer.deserialize_struct("SshKeypair", FIELDS, SshKeyPairVisitor)
190 }
191}
192
193#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
198pub struct SshKeyPairSet {
199 primary: SshKeyPair,
200 secondary: SshKeyPair,
201}
202
203impl SshKeyPairSet {
204 pub fn new() -> Result<SshKeyPairSet, anyhow::Error> {
206 Ok(SshKeyPairSet::from_parts(
207 SshKeyPair::new()?,
208 SshKeyPair::new()?,
209 ))
210 }
211
212 pub fn from_parts(primary: SshKeyPair, secondary: SshKeyPair) -> SshKeyPairSet {
215 SshKeyPairSet { primary, secondary }
216 }
217
218 pub fn rotate(&self) -> Result<SshKeyPairSet, anyhow::Error> {
223 Ok(SshKeyPairSet {
224 primary: self.secondary.clone(),
225 secondary: SshKeyPair::new()?,
226 })
227 }
228
229 pub fn public_keys(&self) -> (String, String) {
231 let primary = self.primary().ssh_public_key();
232 let secondary = self.secondary().ssh_public_key();
233 (primary, secondary)
234 }
235
236 pub fn primary(&self) -> &SshKeyPair {
238 &self.primary
239 }
240
241 pub fn secondary(&self) -> &SshKeyPair {
243 &self.secondary
244 }
245
246 pub fn to_bytes(&self) -> Zeroizing<Vec<u8>> {
250 Zeroizing::new(serde_json::to_vec(self).expect("serialization of key_set cannot fail"))
251 }
252
253 pub fn from_bytes(data: &[u8]) -> anyhow::Result<SshKeyPairSet> {
256 Ok(serde_json::from_slice(data)?)
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use aws_lc_rs::signature::{Ed25519KeyPair, KeyPair};
263 use mz_ore::secure::Zeroizing;
264 use serde::{Deserialize, Serialize};
265 use ssh_key::private::{Ed25519Keypair, Ed25519PrivateKey, KeypairData};
266 use ssh_key::public::Ed25519PublicKey;
267 use ssh_key::{LineEnding, PrivateKey};
268
269 use super::{SshKeyPair, SshKeyPairSet};
270
271 #[mz_ore::test]
272 fn test_key_pair_generation() -> anyhow::Result<()> {
273 for _ in 0..100 {
274 let key_pair = SshKeyPair::new()?;
275
276 let public_key = key_pair.ssh_public_key();
278 assert!(public_key.starts_with("ssh-ed25519 "));
280 let private_key = key_pair.ssh_private_key();
282 assert!(private_key.starts_with("-----BEGIN OPENSSH PRIVATE KEY-----"));
284 }
285 Ok(())
286 }
287
288 #[mz_ore::test]
289 fn test_unique_keys() -> anyhow::Result<()> {
290 for _ in 0..100 {
291 let key_set = SshKeyPairSet::new()?;
292 assert_ne!(key_set.primary(), key_set.secondary());
293 }
294 Ok(())
295 }
296
297 #[mz_ore::test]
298 fn test_key_pair_serialization_roundtrip() -> anyhow::Result<()> {
299 for _ in 0..100 {
300 let key_pair = SshKeyPair::new()?;
301 let roundtripped_key_pair: SshKeyPair = serde_json::from_slice(
302 &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
303 )?;
304
305 assert_eq!(key_pair, roundtripped_key_pair);
306 }
307 Ok(())
308 }
309
310 #[mz_ore::test]
311 fn test_key_set_serialization_roundtrip() -> anyhow::Result<()> {
312 for _ in 0..100 {
313 let key_set = SshKeyPairSet::new()?;
314 let roundtripped_key_set = SshKeyPairSet::from_bytes(key_set.to_bytes().as_slice())?;
315
316 assert_eq!(key_set, roundtripped_key_set);
317 }
318 Ok(())
319 }
320
321 #[mz_ore::test]
322 fn test_key_rotation() -> anyhow::Result<()> {
323 for _ in 0..100 {
324 let key_set = SshKeyPairSet::new()?;
325 let rotated_key_set = key_set.rotate()?;
326
327 assert_eq!(key_set.secondary(), rotated_key_set.primary());
328 assert_ne!(key_set.primary(), rotated_key_set.secondary());
329 assert_ne!(rotated_key_set.primary(), rotated_key_set.secondary());
330 }
331 Ok(())
332 }
333
334 #[mz_ore::test]
336 fn test_deserializing_legacy_key_pairs() -> anyhow::Result<()> {
337 for _ in 0..100 {
338 let legacy_key_pair = LegacySshKeyPair::new()?;
339 let roundtripped_key_pair: SshKeyPair = serde_json::from_slice(
340 &serde_json::to_vec(&legacy_key_pair)
341 .expect("serialization of key_set cannot fail"),
342 )?;
343
344 assert_eq!(
345 legacy_key_pair.private_key,
346 roundtripped_key_pair.ssh_private_key().as_bytes()
347 );
348 }
349 Ok(())
350 }
351
352 #[mz_ore::test]
354 fn test_serializing_legacy_key_pairs() -> anyhow::Result<()> {
355 for _ in 0..100 {
356 let key_pair = SshKeyPair::new()?;
357 let roundtripped_legacy_key_pair: LegacySshKeyPair = serde_json::from_slice(
358 &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
359 )?;
360
361 assert_eq!(
362 roundtripped_legacy_key_pair.private_key,
363 key_pair.ssh_private_key().as_bytes()
364 );
365 }
366 Ok(())
367 }
368
369 #[mz_ore::test]
371 fn test_deserializing_legacy_key_sets() -> anyhow::Result<()> {
372 for _ in 0..100 {
373 let legacy_key_pair = LegacySshKeyPairSet::new()?;
374 let roundtripped_key_pair: SshKeyPairSet = serde_json::from_slice(
375 &serde_json::to_vec(&legacy_key_pair)
376 .expect("serialization of key_set cannot fail"),
377 )?;
378
379 assert_eq!(
381 legacy_key_pair.primary.private_key,
382 roundtripped_key_pair.primary().ssh_private_key().as_bytes()
383 );
384 assert_eq!(
385 legacy_key_pair.secondary.private_key,
386 roundtripped_key_pair
387 .secondary()
388 .ssh_private_key()
389 .as_bytes()
390 );
391 }
392 Ok(())
393 }
394
395 #[mz_ore::test]
397 fn test_serializing_legacy_key_sets() -> anyhow::Result<()> {
398 for _ in 0..100 {
399 let key_pair = SshKeyPairSet::new()?;
400 let roundtripped_legacy_key_pair: LegacySshKeyPairSet = serde_json::from_slice(
401 &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
402 )?;
403
404 assert_eq!(
405 roundtripped_legacy_key_pair.primary.private_key,
406 key_pair.primary().ssh_private_key().as_bytes()
407 );
408 assert_eq!(
409 roundtripped_legacy_key_pair.secondary.private_key,
410 key_pair.secondary().ssh_private_key().as_bytes()
411 );
412 }
413 Ok(())
414 }
415
416 #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Clone, Serialize, Deserialize)]
418 struct LegacySshKeyPair {
419 public_key: Vec<u8>,
420 private_key: Vec<u8>,
421 }
422
423 impl LegacySshKeyPair {
424 fn new() -> Result<Self, anyhow::Error> {
425 let mut seed = Zeroizing::new([0u8; 32]);
426 aws_lc_rs::rand::fill(&mut *seed)
427 .map_err(|_| anyhow::anyhow!("random generation failed"))?;
428 let aws_key = Ed25519KeyPair::from_seed_unchecked(&*seed)
429 .map_err(|_| anyhow::anyhow!("Ed25519 key generation failed"))?;
430
431 let key_pair = KeypairData::Ed25519(Ed25519Keypair {
432 public: Ed25519PublicKey::try_from(aws_key.public_key().as_ref())?,
433 private: Ed25519PrivateKey::try_from(seed.as_slice())?,
434 });
435
436 let private_key_wrapper = PrivateKey::new(key_pair, "materialize")?;
437 let openssh_private_key = &*private_key_wrapper.to_openssh(LineEnding::LF)?;
438 let openssh_public_key = private_key_wrapper.public_key().to_openssh()?;
439
440 Ok(Self {
441 public_key: openssh_public_key.as_bytes().to_vec(),
442 private_key: openssh_private_key.as_bytes().to_vec(),
443 })
444 }
445 }
446
447 #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Clone, Serialize, Deserialize)]
449 struct LegacySshKeyPairSet {
450 primary: LegacySshKeyPair,
451 secondary: LegacySshKeyPair,
452 }
453
454 impl LegacySshKeyPairSet {
455 fn new() -> Result<Self, anyhow::Error> {
457 Ok(Self {
458 primary: LegacySshKeyPair::new()?,
459 secondary: LegacySshKeyPair::new()?,
460 })
461 }
462 }
463}