1use std::cmp::Ordering;
13use std::fmt;
14
15use mz_ore::secure::Zeroizing;
16use openssl::pkey::{PKey, Private};
17use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
18use serde::ser::{SerializeStruct, Serializer};
19use serde::{Deserialize, Serialize};
20use ssh_key::private::{Ed25519Keypair, Ed25519PrivateKey, KeypairData};
21use ssh_key::public::Ed25519PublicKey;
22use ssh_key::{HashAlg, LineEnding, PrivateKey};
23
24#[derive(Debug, Clone)]
26pub struct SshKeyPair {
27 key_pair: PrivateKey,
30}
31
32impl SshKeyPair {
33 pub fn new() -> Result<SshKeyPair, anyhow::Error> {
38 let openssl_key = PKey::<Private>::generate_ed25519()?;
39
40 let raw_private = Zeroizing::new(openssl_key.raw_private_key()?);
43 let key_pair_data = KeypairData::Ed25519(Ed25519Keypair {
44 public: Ed25519PublicKey::try_from(openssl_key.raw_public_key()?.as_slice())?,
45 private: Ed25519PrivateKey::try_from(raw_private.as_slice())?,
46 });
47
48 let key_pair = PrivateKey::new(key_pair_data, "materialize")?;
49
50 Ok(SshKeyPair { key_pair })
51 }
52
53 pub fn from_bytes(data: &[u8]) -> anyhow::Result<SshKeyPair> {
56 let set = SshKeyPairSet::from_bytes(data)?;
57 Ok(set.primary().clone())
58 }
59
60 fn from_private_key(private_key: &[u8]) -> Result<SshKeyPair, anyhow::Error> {
62 let private_key = PrivateKey::from_openssh(private_key)?;
63
64 Ok(SshKeyPair {
65 key_pair: private_key,
66 })
67 }
68
69 pub fn ssh_public_key(&self) -> String {
71 self.key_pair.public_key().to_string()
72 }
73
74 pub fn ssh_private_key(&self) -> Zeroizing<String> {
76 self.key_pair
77 .to_openssh(LineEnding::LF)
78 .expect("encoding as OpenSSH cannot fail")
79 }
80}
81
82impl PartialEq for SshKeyPair {
83 fn eq(&self, other: &Self) -> bool {
84 self.key_pair.fingerprint(HashAlg::default())
85 == other.key_pair.fingerprint(HashAlg::default())
86 }
87}
88
89impl Eq for SshKeyPair {}
90
91impl PartialOrd for SshKeyPair {
92 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
93 Some(self.cmp(other))
94 }
95}
96
97impl Ord for SshKeyPair {
98 fn cmp(&self, other: &Self) -> Ordering {
99 self.key_pair
100 .fingerprint(HashAlg::default())
101 .cmp(&other.key_pair.fingerprint(HashAlg::default()))
102 }
103}
104
105impl Serialize for SshKeyPair {
106 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
107 where
108 S: Serializer,
109 {
110 let mut state = serializer.serialize_struct("SshKeypair", 2)?;
111 state.serialize_field("public_key", self.ssh_public_key().as_bytes())?;
113 state.serialize_field("private_key", self.ssh_private_key().as_bytes())?;
114 state.end()
115 }
116}
117
118impl<'de> Deserialize<'de> for SshKeyPair {
119 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
120 where
121 D: Deserializer<'de>,
122 {
123 #[derive(Deserialize)]
124 #[serde(field_identifier, rename_all = "snake_case")]
125 enum Field {
126 PublicKey,
127 PrivateKey,
128 }
129
130 struct SshKeyPairVisitor;
131
132 impl<'de> Visitor<'de> for SshKeyPairVisitor {
133 type Value = SshKeyPair;
134
135 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
136 formatter.write_str("struct SshKeypair")
137 }
138
139 fn visit_seq<V>(self, mut seq: V) -> Result<SshKeyPair, V::Error>
140 where
141 V: SeqAccess<'de>,
142 {
143 let _public_key: Vec<u8> = seq
145 .next_element()?
146 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
147 let private_key: Zeroizing<Vec<u8>> = seq
148 .next_element()?
149 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
150 SshKeyPair::from_private_key(&private_key).map_err(de::Error::custom)
151 }
152
153 fn visit_map<V>(self, mut map: V) -> Result<SshKeyPair, V::Error>
154 where
155 V: MapAccess<'de>,
156 {
157 let mut _public_key: Option<Vec<u8>> = None;
159 let mut private_key: Option<Zeroizing<Vec<u8>>> = None;
160 while let Some(key) = map.next_key()? {
161 match key {
162 Field::PublicKey => {
163 if _public_key.is_some() {
164 return Err(de::Error::duplicate_field("public_key"));
165 }
166 _public_key = Some(map.next_value()?);
167 }
168 Field::PrivateKey => {
169 if private_key.is_some() {
170 return Err(de::Error::duplicate_field("private_key"));
171 }
172 private_key = Some(map.next_value()?);
173 }
174 }
175 }
176 let private_key =
177 private_key.ok_or_else(|| de::Error::missing_field("private_key"))?;
178 SshKeyPair::from_private_key(&private_key).map_err(de::Error::custom)
179 }
180 }
181
182 const FIELDS: &[&str] = &["public_key", "private_key"];
183 deserializer.deserialize_struct("SshKeypair", FIELDS, SshKeyPairVisitor)
184 }
185}
186
187#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
192pub struct SshKeyPairSet {
193 primary: SshKeyPair,
194 secondary: SshKeyPair,
195}
196
197impl SshKeyPairSet {
198 pub fn new() -> Result<SshKeyPairSet, anyhow::Error> {
200 Ok(SshKeyPairSet::from_parts(
201 SshKeyPair::new()?,
202 SshKeyPair::new()?,
203 ))
204 }
205
206 pub fn from_parts(primary: SshKeyPair, secondary: SshKeyPair) -> SshKeyPairSet {
209 SshKeyPairSet { primary, secondary }
210 }
211
212 pub fn rotate(&self) -> Result<SshKeyPairSet, anyhow::Error> {
217 Ok(SshKeyPairSet {
218 primary: self.secondary.clone(),
219 secondary: SshKeyPair::new()?,
220 })
221 }
222
223 pub fn public_keys(&self) -> (String, String) {
225 let primary = self.primary().ssh_public_key();
226 let secondary = self.secondary().ssh_public_key();
227 (primary, secondary)
228 }
229
230 pub fn primary(&self) -> &SshKeyPair {
232 &self.primary
233 }
234
235 pub fn secondary(&self) -> &SshKeyPair {
237 &self.secondary
238 }
239
240 pub fn to_bytes(&self) -> Zeroizing<Vec<u8>> {
244 Zeroizing::new(serde_json::to_vec(self).expect("serialization of key_set cannot fail"))
245 }
246
247 pub fn from_bytes(data: &[u8]) -> anyhow::Result<SshKeyPairSet> {
250 Ok(serde_json::from_slice(data)?)
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use openssl::pkey::{PKey, Private};
257 use serde::{Deserialize, Serialize};
258 use ssh_key::private::{Ed25519Keypair, Ed25519PrivateKey, KeypairData};
259 use ssh_key::public::Ed25519PublicKey;
260 use ssh_key::{LineEnding, PrivateKey};
261
262 use super::{SshKeyPair, SshKeyPairSet};
263
264 #[mz_ore::test]
265 fn test_key_pair_generation() -> anyhow::Result<()> {
266 for _ in 0..100 {
267 let key_pair = SshKeyPair::new()?;
268
269 let public_key = key_pair.ssh_public_key();
271 assert!(public_key.starts_with("ssh-ed25519 "));
273 let private_key = key_pair.ssh_private_key();
275 assert!(private_key.starts_with("-----BEGIN OPENSSH PRIVATE KEY-----"));
277 }
278 Ok(())
279 }
280
281 #[mz_ore::test]
282 fn test_unique_keys() -> anyhow::Result<()> {
283 for _ in 0..100 {
284 let key_set = SshKeyPairSet::new()?;
285 assert_ne!(key_set.primary(), key_set.secondary());
286 }
287 Ok(())
288 }
289
290 #[mz_ore::test]
291 fn test_key_pair_serialization_roundtrip() -> anyhow::Result<()> {
292 for _ in 0..100 {
293 let key_pair = SshKeyPair::new()?;
294 let roundtripped_key_pair: SshKeyPair = serde_json::from_slice(
295 &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
296 )?;
297
298 assert_eq!(key_pair, roundtripped_key_pair);
299 }
300 Ok(())
301 }
302
303 #[mz_ore::test]
304 fn test_key_set_serialization_roundtrip() -> anyhow::Result<()> {
305 for _ in 0..100 {
306 let key_set = SshKeyPairSet::new()?;
307 let roundtripped_key_set = SshKeyPairSet::from_bytes(key_set.to_bytes().as_slice())?;
308
309 assert_eq!(key_set, roundtripped_key_set);
310 }
311 Ok(())
312 }
313
314 #[mz_ore::test]
315 fn test_key_rotation() -> anyhow::Result<()> {
316 for _ in 0..100 {
317 let key_set = SshKeyPairSet::new()?;
318 let rotated_key_set = key_set.rotate()?;
319
320 assert_eq!(key_set.secondary(), rotated_key_set.primary());
321 assert_ne!(key_set.primary(), rotated_key_set.secondary());
322 assert_ne!(rotated_key_set.primary(), rotated_key_set.secondary());
323 }
324 Ok(())
325 }
326
327 #[mz_ore::test]
329 fn test_deserializing_legacy_key_pairs() -> anyhow::Result<()> {
330 for _ in 0..100 {
331 let legacy_key_pair = LegacySshKeyPair::new()?;
332 let roundtripped_key_pair: SshKeyPair = serde_json::from_slice(
333 &serde_json::to_vec(&legacy_key_pair)
334 .expect("serialization of key_set cannot fail"),
335 )?;
336
337 assert_eq!(
338 legacy_key_pair.private_key,
339 roundtripped_key_pair.ssh_private_key().as_bytes()
340 );
341 }
342 Ok(())
343 }
344
345 #[mz_ore::test]
347 fn test_serializing_legacy_key_pairs() -> anyhow::Result<()> {
348 for _ in 0..100 {
349 let key_pair = SshKeyPair::new()?;
350 let roundtripped_legacy_key_pair: LegacySshKeyPair = serde_json::from_slice(
351 &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
352 )?;
353
354 assert_eq!(
355 roundtripped_legacy_key_pair.private_key,
356 key_pair.ssh_private_key().as_bytes()
357 );
358 }
359 Ok(())
360 }
361
362 #[mz_ore::test]
364 fn test_deserializing_legacy_key_sets() -> anyhow::Result<()> {
365 for _ in 0..100 {
366 let legacy_key_pair = LegacySshKeyPairSet::new()?;
367 let roundtripped_key_pair: SshKeyPairSet = serde_json::from_slice(
368 &serde_json::to_vec(&legacy_key_pair)
369 .expect("serialization of key_set cannot fail"),
370 )?;
371
372 assert_eq!(
374 legacy_key_pair.primary.private_key,
375 roundtripped_key_pair.primary().ssh_private_key().as_bytes()
376 );
377 assert_eq!(
378 legacy_key_pair.secondary.private_key,
379 roundtripped_key_pair
380 .secondary()
381 .ssh_private_key()
382 .as_bytes()
383 );
384 }
385 Ok(())
386 }
387
388 #[mz_ore::test]
390 fn test_serializing_legacy_key_sets() -> anyhow::Result<()> {
391 for _ in 0..100 {
392 let key_pair = SshKeyPairSet::new()?;
393 let roundtripped_legacy_key_pair: LegacySshKeyPairSet = serde_json::from_slice(
394 &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
395 )?;
396
397 assert_eq!(
398 roundtripped_legacy_key_pair.primary.private_key,
399 key_pair.primary().ssh_private_key().as_bytes()
400 );
401 assert_eq!(
402 roundtripped_legacy_key_pair.secondary.private_key,
403 key_pair.secondary().ssh_private_key().as_bytes()
404 );
405 }
406 Ok(())
407 }
408
409 #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Clone, Serialize, Deserialize)]
411 struct LegacySshKeyPair {
412 public_key: Vec<u8>,
413 private_key: Vec<u8>,
414 }
415
416 impl LegacySshKeyPair {
417 fn new() -> Result<Self, anyhow::Error> {
418 let openssl_key = PKey::<Private>::generate_ed25519()?;
419
420 let key_pair = KeypairData::Ed25519(Ed25519Keypair {
421 public: Ed25519PublicKey::try_from(openssl_key.raw_public_key()?.as_slice())?,
422 private: Ed25519PrivateKey::try_from(openssl_key.raw_private_key()?.as_slice())?,
423 });
424
425 let private_key_wrapper = PrivateKey::new(key_pair, "materialize")?;
426 let openssh_private_key = &*private_key_wrapper.to_openssh(LineEnding::LF)?;
427 let openssh_public_key = private_key_wrapper.public_key().to_openssh()?;
428
429 Ok(Self {
430 public_key: openssh_public_key.as_bytes().to_vec(),
431 private_key: openssh_private_key.as_bytes().to_vec(),
432 })
433 }
434 }
435
436 #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Clone, Serialize, Deserialize)]
438 struct LegacySshKeyPairSet {
439 primary: LegacySshKeyPair,
440 secondary: LegacySshKeyPair,
441 }
442
443 impl LegacySshKeyPairSet {
444 fn new() -> Result<Self, anyhow::Error> {
446 Ok(Self {
447 primary: LegacySshKeyPair::new()?,
448 secondary: LegacySshKeyPair::new()?,
449 })
450 }
451 }
452}