1use std::cmp::Ordering;
13use std::fmt;
14
15use openssl::pkey::{PKey, Private};
16use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
17use serde::ser::{SerializeStruct, Serializer};
18use serde::{Deserialize, Serialize};
19use ssh_key::private::{Ed25519Keypair, Ed25519PrivateKey, KeypairData};
20use ssh_key::public::Ed25519PublicKey;
21use ssh_key::{HashAlg, LineEnding, PrivateKey};
22use zeroize::Zeroizing;
23
24#[derive(Debug, Clone)]
26pub struct SshKeyPair {
27 key_pair: PrivateKey,
30}
31
32impl SshKeyPair {
33 pub fn new() -> Result<SshKeyPair, anyhow::Error> {
38 let openssl_key = PKey::<Private>::generate_ed25519()?;
39
40 let key_pair_data = KeypairData::Ed25519(Ed25519Keypair {
41 public: Ed25519PublicKey::try_from(openssl_key.raw_public_key()?.as_slice())?,
42 private: Ed25519PrivateKey::try_from(openssl_key.raw_private_key()?.as_slice())?,
43 });
44
45 let key_pair = PrivateKey::new(key_pair_data, "materialize")?;
46
47 Ok(SshKeyPair { key_pair })
48 }
49
50 pub fn from_bytes(data: &[u8]) -> anyhow::Result<SshKeyPair> {
53 let set = SshKeyPairSet::from_bytes(data)?;
54 Ok(set.primary().clone())
55 }
56
57 fn from_private_key(private_key: &[u8]) -> Result<SshKeyPair, anyhow::Error> {
59 let private_key = PrivateKey::from_openssh(private_key)?;
60
61 Ok(SshKeyPair {
62 key_pair: private_key,
63 })
64 }
65
66 pub fn ssh_public_key(&self) -> String {
68 self.key_pair.public_key().to_string()
69 }
70
71 pub fn ssh_private_key(&self) -> Zeroizing<String> {
73 self.key_pair
74 .to_openssh(LineEnding::LF)
75 .expect("encoding as OpenSSH cannot fail")
76 }
77}
78
79impl PartialEq for SshKeyPair {
80 fn eq(&self, other: &Self) -> bool {
81 self.key_pair.fingerprint(HashAlg::default())
82 == other.key_pair.fingerprint(HashAlg::default())
83 }
84}
85
86impl Eq for SshKeyPair {}
87
88impl PartialOrd for SshKeyPair {
89 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
90 Some(self.cmp(other))
91 }
92}
93
94impl Ord for SshKeyPair {
95 fn cmp(&self, other: &Self) -> Ordering {
96 self.key_pair
97 .fingerprint(HashAlg::default())
98 .cmp(&other.key_pair.fingerprint(HashAlg::default()))
99 }
100}
101
102impl Serialize for SshKeyPair {
103 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
104 where
105 S: Serializer,
106 {
107 let mut state = serializer.serialize_struct("SshKeypair", 2)?;
108 state.serialize_field("public_key", self.ssh_public_key().as_bytes())?;
110 state.serialize_field("private_key", self.ssh_private_key().as_bytes())?;
111 state.end()
112 }
113}
114
115impl<'de> Deserialize<'de> for SshKeyPair {
116 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
117 where
118 D: Deserializer<'de>,
119 {
120 #[derive(Deserialize)]
121 #[serde(field_identifier, rename_all = "snake_case")]
122 enum Field {
123 PublicKey,
124 PrivateKey,
125 }
126
127 struct SshKeyPairVisitor;
128
129 impl<'de> Visitor<'de> for SshKeyPairVisitor {
130 type Value = SshKeyPair;
131
132 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
133 formatter.write_str("struct SshKeypair")
134 }
135
136 fn visit_seq<V>(self, mut seq: V) -> Result<SshKeyPair, V::Error>
137 where
138 V: SeqAccess<'de>,
139 {
140 let _public_key: Vec<u8> = seq
142 .next_element()?
143 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
144 let private_key: Zeroizing<Vec<u8>> = seq
145 .next_element()?
146 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
147 SshKeyPair::from_private_key(&private_key).map_err(de::Error::custom)
148 }
149
150 fn visit_map<V>(self, mut map: V) -> Result<SshKeyPair, V::Error>
151 where
152 V: MapAccess<'de>,
153 {
154 let mut _public_key: Option<Vec<u8>> = None;
156 let mut private_key: Option<Zeroizing<Vec<u8>>> = None;
157 while let Some(key) = map.next_key()? {
158 match key {
159 Field::PublicKey => {
160 if _public_key.is_some() {
161 return Err(de::Error::duplicate_field("public_key"));
162 }
163 _public_key = Some(map.next_value()?);
164 }
165 Field::PrivateKey => {
166 if private_key.is_some() {
167 return Err(de::Error::duplicate_field("private_key"));
168 }
169 private_key = Some(map.next_value()?);
170 }
171 }
172 }
173 let private_key =
174 private_key.ok_or_else(|| de::Error::missing_field("private_key"))?;
175 SshKeyPair::from_private_key(&private_key).map_err(de::Error::custom)
176 }
177 }
178
179 const FIELDS: &[&str] = &["public_key", "private_key"];
180 deserializer.deserialize_struct("SshKeypair", FIELDS, SshKeyPairVisitor)
181 }
182}
183
184#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
189pub struct SshKeyPairSet {
190 primary: SshKeyPair,
191 secondary: SshKeyPair,
192}
193
194impl SshKeyPairSet {
195 pub fn new() -> Result<SshKeyPairSet, anyhow::Error> {
197 Ok(SshKeyPairSet::from_parts(
198 SshKeyPair::new()?,
199 SshKeyPair::new()?,
200 ))
201 }
202
203 pub fn from_parts(primary: SshKeyPair, secondary: SshKeyPair) -> SshKeyPairSet {
206 SshKeyPairSet { primary, secondary }
207 }
208
209 pub fn rotate(&self) -> Result<SshKeyPairSet, anyhow::Error> {
214 Ok(SshKeyPairSet {
215 primary: self.secondary.clone(),
216 secondary: SshKeyPair::new()?,
217 })
218 }
219
220 pub fn public_keys(&self) -> (String, String) {
222 let primary = self.primary().ssh_public_key();
223 let secondary = self.secondary().ssh_public_key();
224 (primary, secondary)
225 }
226
227 pub fn primary(&self) -> &SshKeyPair {
229 &self.primary
230 }
231
232 pub fn secondary(&self) -> &SshKeyPair {
234 &self.secondary
235 }
236
237 pub fn to_bytes(&self) -> Zeroizing<Vec<u8>> {
241 Zeroizing::new(serde_json::to_vec(self).expect("serialization of key_set cannot fail"))
242 }
243
244 pub fn from_bytes(data: &[u8]) -> anyhow::Result<SshKeyPairSet> {
247 Ok(serde_json::from_slice(data)?)
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use openssl::pkey::{PKey, Private};
254 use serde::{Deserialize, Serialize};
255 use ssh_key::private::{Ed25519Keypair, Ed25519PrivateKey, KeypairData};
256 use ssh_key::public::Ed25519PublicKey;
257 use ssh_key::{LineEnding, PrivateKey};
258
259 use super::{SshKeyPair, SshKeyPairSet};
260
261 #[mz_ore::test]
262 fn test_key_pair_generation() -> anyhow::Result<()> {
263 for _ in 0..100 {
264 let key_pair = SshKeyPair::new()?;
265
266 let public_key = key_pair.ssh_public_key();
268 assert!(public_key.starts_with("ssh-ed25519 "));
270 let private_key = key_pair.ssh_private_key();
272 assert!(private_key.starts_with("-----BEGIN OPENSSH PRIVATE KEY-----"));
274 }
275 Ok(())
276 }
277
278 #[mz_ore::test]
279 fn test_unique_keys() -> anyhow::Result<()> {
280 for _ in 0..100 {
281 let key_set = SshKeyPairSet::new()?;
282 assert_ne!(key_set.primary(), key_set.secondary());
283 }
284 Ok(())
285 }
286
287 #[mz_ore::test]
288 fn test_key_pair_serialization_roundtrip() -> anyhow::Result<()> {
289 for _ in 0..100 {
290 let key_pair = SshKeyPair::new()?;
291 let roundtripped_key_pair: SshKeyPair = serde_json::from_slice(
292 &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
293 )?;
294
295 assert_eq!(key_pair, roundtripped_key_pair);
296 }
297 Ok(())
298 }
299
300 #[mz_ore::test]
301 fn test_key_set_serialization_roundtrip() -> anyhow::Result<()> {
302 for _ in 0..100 {
303 let key_set = SshKeyPairSet::new()?;
304 let roundtripped_key_set = SshKeyPairSet::from_bytes(key_set.to_bytes().as_slice())?;
305
306 assert_eq!(key_set, roundtripped_key_set);
307 }
308 Ok(())
309 }
310
311 #[mz_ore::test]
312 fn test_key_rotation() -> anyhow::Result<()> {
313 for _ in 0..100 {
314 let key_set = SshKeyPairSet::new()?;
315 let rotated_key_set = key_set.rotate()?;
316
317 assert_eq!(key_set.secondary(), rotated_key_set.primary());
318 assert_ne!(key_set.primary(), rotated_key_set.secondary());
319 assert_ne!(rotated_key_set.primary(), rotated_key_set.secondary());
320 }
321 Ok(())
322 }
323
324 #[mz_ore::test]
326 fn test_deserializing_legacy_key_pairs() -> anyhow::Result<()> {
327 for _ in 0..100 {
328 let legacy_key_pair = LegacySshKeyPair::new()?;
329 let roundtripped_key_pair: SshKeyPair = serde_json::from_slice(
330 &serde_json::to_vec(&legacy_key_pair)
331 .expect("serialization of key_set cannot fail"),
332 )?;
333
334 assert_eq!(
335 legacy_key_pair.private_key,
336 roundtripped_key_pair.ssh_private_key().as_bytes()
337 );
338 }
339 Ok(())
340 }
341
342 #[mz_ore::test]
344 fn test_serializing_legacy_key_pairs() -> anyhow::Result<()> {
345 for _ in 0..100 {
346 let key_pair = SshKeyPair::new()?;
347 let roundtripped_legacy_key_pair: LegacySshKeyPair = serde_json::from_slice(
348 &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
349 )?;
350
351 assert_eq!(
352 roundtripped_legacy_key_pair.private_key,
353 key_pair.ssh_private_key().as_bytes()
354 );
355 }
356 Ok(())
357 }
358
359 #[mz_ore::test]
361 fn test_deserializing_legacy_key_sets() -> anyhow::Result<()> {
362 for _ in 0..100 {
363 let legacy_key_pair = LegacySshKeyPairSet::new()?;
364 let roundtripped_key_pair: SshKeyPairSet = serde_json::from_slice(
365 &serde_json::to_vec(&legacy_key_pair)
366 .expect("serialization of key_set cannot fail"),
367 )?;
368
369 assert_eq!(
371 legacy_key_pair.primary.private_key,
372 roundtripped_key_pair.primary().ssh_private_key().as_bytes()
373 );
374 assert_eq!(
375 legacy_key_pair.secondary.private_key,
376 roundtripped_key_pair
377 .secondary()
378 .ssh_private_key()
379 .as_bytes()
380 );
381 }
382 Ok(())
383 }
384
385 #[mz_ore::test]
387 fn test_serializing_legacy_key_sets() -> anyhow::Result<()> {
388 for _ in 0..100 {
389 let key_pair = SshKeyPairSet::new()?;
390 let roundtripped_legacy_key_pair: LegacySshKeyPairSet = serde_json::from_slice(
391 &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
392 )?;
393
394 assert_eq!(
395 roundtripped_legacy_key_pair.primary.private_key,
396 key_pair.primary().ssh_private_key().as_bytes()
397 );
398 assert_eq!(
399 roundtripped_legacy_key_pair.secondary.private_key,
400 key_pair.secondary().ssh_private_key().as_bytes()
401 );
402 }
403 Ok(())
404 }
405
406 #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Clone, Serialize, Deserialize)]
408 struct LegacySshKeyPair {
409 public_key: Vec<u8>,
410 private_key: Vec<u8>,
411 }
412
413 impl LegacySshKeyPair {
414 fn new() -> Result<Self, anyhow::Error> {
415 let openssl_key = PKey::<Private>::generate_ed25519()?;
416
417 let key_pair = KeypairData::Ed25519(Ed25519Keypair {
418 public: Ed25519PublicKey::try_from(openssl_key.raw_public_key()?.as_slice())?,
419 private: Ed25519PrivateKey::try_from(openssl_key.raw_private_key()?.as_slice())?,
420 });
421
422 let private_key_wrapper = PrivateKey::new(key_pair, "materialize")?;
423 let openssh_private_key = &*private_key_wrapper.to_openssh(LineEnding::LF)?;
424 let openssh_public_key = private_key_wrapper.public_key().to_openssh()?;
425
426 Ok(Self {
427 public_key: openssh_public_key.as_bytes().to_vec(),
428 private_key: openssh_private_key.as_bytes().to_vec(),
429 })
430 }
431 }
432
433 #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Clone, Serialize, Deserialize)]
435 struct LegacySshKeyPairSet {
436 primary: LegacySshKeyPair,
437 secondary: LegacySshKeyPair,
438 }
439
440 impl LegacySshKeyPairSet {
441 fn new() -> Result<Self, anyhow::Error> {
443 Ok(Self {
444 primary: LegacySshKeyPair::new()?,
445 secondary: LegacySshKeyPair::new()?,
446 })
447 }
448 }
449}