Skip to main content

mz_ssh_util/
keys.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Utilities for generating and managing SSH keys.
11
12use std::cmp::Ordering;
13use std::fmt;
14
15use aws_lc_rs::signature::{Ed25519KeyPair, KeyPair};
16use mz_ore::secure::Zeroizing;
17use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
18use serde::ser::{SerializeStruct, Serializer};
19use serde::{Deserialize, Serialize};
20use ssh_key::private::{Ed25519Keypair, Ed25519PrivateKey, KeypairData};
21use ssh_key::public::Ed25519PublicKey;
22use ssh_key::{HashAlg, LineEnding, PrivateKey};
23
24/// A SSH key pair consisting of a public and private key.
25#[derive(Debug, Clone)]
26pub struct SshKeyPair {
27    // Even though the type is called PrivateKey, it includes the full key pair,
28    // and zeroes memory on `Drop`.
29    key_pair: PrivateKey,
30}
31
32impl SshKeyPair {
33    /// Generates a new SSH key pair.
34    ///
35    /// Ed25519 keys are generated via aws-lc-rs, using [`ssh_key`] to convert
36    /// them into the OpenSSH format.
37    pub fn new() -> Result<SshKeyPair, anyhow::Error> {
38        // Generate a random 32-byte Ed25519 seed and wrap in Zeroizing so it
39        // is erased on drop, preventing the intermediate buffer from lingering
40        // in memory.
41        let mut seed = Zeroizing::new([0u8; 32]);
42        aws_lc_rs::rand::fill(&mut *seed)
43            .map_err(|_| anyhow::anyhow!("random generation failed"))?;
44
45        // Derive the public key from the seed using aws-lc-rs.
46        let aws_key = Ed25519KeyPair::from_seed_unchecked(&*seed)
47            .map_err(|_| anyhow::anyhow!("Ed25519 key generation failed"))?;
48
49        let key_pair_data = KeypairData::Ed25519(Ed25519Keypair {
50            public: Ed25519PublicKey::try_from(aws_key.public_key().as_ref())?,
51            private: Ed25519PrivateKey::try_from(seed.as_slice())?,
52        });
53
54        let key_pair = PrivateKey::new(key_pair_data, "materialize")?;
55
56        Ok(SshKeyPair { key_pair })
57    }
58
59    /// Deserializes a key pair from a key pair set that was serialized with
60    /// [`SshKeyPairSet::serialize`].
61    pub fn from_bytes(data: &[u8]) -> anyhow::Result<SshKeyPair> {
62        let set = SshKeyPairSet::from_bytes(data)?;
63        Ok(set.primary().clone())
64    }
65
66    /// Deserializes a key pair from an OpenSSH-formatted private key.
67    fn from_private_key(private_key: &[u8]) -> Result<SshKeyPair, anyhow::Error> {
68        let private_key = PrivateKey::from_openssh(private_key)?;
69
70        Ok(SshKeyPair {
71            key_pair: private_key,
72        })
73    }
74
75    /// Returns the public key encoded in the OpenSSH format.
76    pub fn ssh_public_key(&self) -> String {
77        self.key_pair.public_key().to_string()
78    }
79
80    /// Return the private key encoded in the OpenSSH format.
81    pub fn ssh_private_key(&self) -> Zeroizing<String> {
82        self.key_pair
83            .to_openssh(LineEnding::LF)
84            .expect("encoding as OpenSSH cannot fail")
85    }
86}
87
88impl PartialEq for SshKeyPair {
89    fn eq(&self, other: &Self) -> bool {
90        self.key_pair.fingerprint(HashAlg::default())
91            == other.key_pair.fingerprint(HashAlg::default())
92    }
93}
94
95impl Eq for SshKeyPair {}
96
97impl PartialOrd for SshKeyPair {
98    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
99        Some(self.cmp(other))
100    }
101}
102
103impl Ord for SshKeyPair {
104    fn cmp(&self, other: &Self) -> Ordering {
105        self.key_pair
106            .fingerprint(HashAlg::default())
107            .cmp(&other.key_pair.fingerprint(HashAlg::default()))
108    }
109}
110
111impl Serialize for SshKeyPair {
112    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
113    where
114        S: Serializer,
115    {
116        let mut state = serializer.serialize_struct("SshKeypair", 2)?;
117        // Public key is still encoded for backwards compatibility, but it is not used anymore
118        state.serialize_field("public_key", self.ssh_public_key().as_bytes())?;
119        state.serialize_field("private_key", self.ssh_private_key().as_bytes())?;
120        state.end()
121    }
122}
123
124impl<'de> Deserialize<'de> for SshKeyPair {
125    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
126    where
127        D: Deserializer<'de>,
128    {
129        #[derive(Deserialize)]
130        #[serde(field_identifier, rename_all = "snake_case")]
131        enum Field {
132            PublicKey,
133            PrivateKey,
134        }
135
136        struct SshKeyPairVisitor;
137
138        impl<'de> Visitor<'de> for SshKeyPairVisitor {
139            type Value = SshKeyPair;
140
141            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
142                formatter.write_str("struct SshKeypair")
143            }
144
145            fn visit_seq<V>(self, mut seq: V) -> Result<SshKeyPair, V::Error>
146            where
147                V: SeqAccess<'de>,
148            {
149                // Public key is still read for backwards compatibility, but it is not used anymore
150                let _public_key: Vec<u8> = seq
151                    .next_element()?
152                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
153                let private_key: Zeroizing<Vec<u8>> = seq
154                    .next_element()?
155                    .ok_or_else(|| de::Error::invalid_length(1, &self))?;
156                SshKeyPair::from_private_key(&private_key).map_err(de::Error::custom)
157            }
158
159            fn visit_map<V>(self, mut map: V) -> Result<SshKeyPair, V::Error>
160            where
161                V: MapAccess<'de>,
162            {
163                // Public key is still read for backwards compatibility, but it is not used anymore
164                let mut _public_key: Option<Vec<u8>> = None;
165                let mut private_key: Option<Zeroizing<Vec<u8>>> = None;
166                while let Some(key) = map.next_key()? {
167                    match key {
168                        Field::PublicKey => {
169                            if _public_key.is_some() {
170                                return Err(de::Error::duplicate_field("public_key"));
171                            }
172                            _public_key = Some(map.next_value()?);
173                        }
174                        Field::PrivateKey => {
175                            if private_key.is_some() {
176                                return Err(de::Error::duplicate_field("private_key"));
177                            }
178                            private_key = Some(map.next_value()?);
179                        }
180                    }
181                }
182                let private_key =
183                    private_key.ok_or_else(|| de::Error::missing_field("private_key"))?;
184                SshKeyPair::from_private_key(&private_key).map_err(de::Error::custom)
185            }
186        }
187
188        const FIELDS: &[&str] = &["public_key", "private_key"];
189        deserializer.deserialize_struct("SshKeypair", FIELDS, SshKeyPairVisitor)
190    }
191}
192
193/// A set of two SSH key pairs, used to support key rotation.
194///
195/// When a key pair set is rotated, the secondary key pair becomes the new
196/// primary key pair, and a new secondary key pair is generated.
197#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
198pub struct SshKeyPairSet {
199    primary: SshKeyPair,
200    secondary: SshKeyPair,
201}
202
203impl SshKeyPairSet {
204    // Generates a new key pair set with random key pairs.
205    pub fn new() -> Result<SshKeyPairSet, anyhow::Error> {
206        Ok(SshKeyPairSet::from_parts(
207            SshKeyPair::new()?,
208            SshKeyPair::new()?,
209        ))
210    }
211
212    /// Creates a new key pair set from an existing primary and secondary key
213    /// pair.
214    pub fn from_parts(primary: SshKeyPair, secondary: SshKeyPair) -> SshKeyPairSet {
215        SshKeyPairSet { primary, secondary }
216    }
217
218    /// Rotate the key pairs in the set.
219    ///
220    /// The rotation promotes the secondary key_pair to primary and generates a
221    /// new random secondary key pair.
222    pub fn rotate(&self) -> Result<SshKeyPairSet, anyhow::Error> {
223        Ok(SshKeyPairSet {
224            primary: self.secondary.clone(),
225            secondary: SshKeyPair::new()?,
226        })
227    }
228
229    /// Returns the primary and secondary public keys in the set.
230    pub fn public_keys(&self) -> (String, String) {
231        let primary = self.primary().ssh_public_key();
232        let secondary = self.secondary().ssh_public_key();
233        (primary, secondary)
234    }
235
236    /// Return the primary key pair.
237    pub fn primary(&self) -> &SshKeyPair {
238        &self.primary
239    }
240
241    /// Returns the secondary pair.
242    pub fn secondary(&self) -> &SshKeyPair {
243        &self.secondary
244    }
245
246    /// Serializes the key pair set to an unspecified encoding.
247    ///
248    /// You can deserialize a key pair set with [`SshKeyPairSet::deserialize`].
249    pub fn to_bytes(&self) -> Zeroizing<Vec<u8>> {
250        Zeroizing::new(serde_json::to_vec(self).expect("serialization of key_set cannot fail"))
251    }
252
253    /// Deserializes a key pair set that was serialized with
254    /// [`SshKeyPairSet::serialize`].
255    pub fn from_bytes(data: &[u8]) -> anyhow::Result<SshKeyPairSet> {
256        Ok(serde_json::from_slice(data)?)
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use aws_lc_rs::signature::{Ed25519KeyPair, KeyPair};
263    use mz_ore::secure::Zeroizing;
264    use serde::{Deserialize, Serialize};
265    use ssh_key::private::{Ed25519Keypair, Ed25519PrivateKey, KeypairData};
266    use ssh_key::public::Ed25519PublicKey;
267    use ssh_key::{LineEnding, PrivateKey};
268
269    use super::{SshKeyPair, SshKeyPairSet};
270
271    #[mz_ore::test]
272    fn test_key_pair_generation() -> anyhow::Result<()> {
273        for _ in 0..100 {
274            let key_pair = SshKeyPair::new()?;
275
276            // Public keys should be in ASCII
277            let public_key = key_pair.ssh_public_key();
278            // Public keys should be in the OpenSSH format
279            assert!(public_key.starts_with("ssh-ed25519 "));
280            // Private keys should be in ASCII
281            let private_key = key_pair.ssh_private_key();
282            // Private keys should also be in the OpenSSH format
283            assert!(private_key.starts_with("-----BEGIN OPENSSH PRIVATE KEY-----"));
284        }
285        Ok(())
286    }
287
288    #[mz_ore::test]
289    fn test_unique_keys() -> anyhow::Result<()> {
290        for _ in 0..100 {
291            let key_set = SshKeyPairSet::new()?;
292            assert_ne!(key_set.primary(), key_set.secondary());
293        }
294        Ok(())
295    }
296
297    #[mz_ore::test]
298    fn test_key_pair_serialization_roundtrip() -> anyhow::Result<()> {
299        for _ in 0..100 {
300            let key_pair = SshKeyPair::new()?;
301            let roundtripped_key_pair: SshKeyPair = serde_json::from_slice(
302                &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
303            )?;
304
305            assert_eq!(key_pair, roundtripped_key_pair);
306        }
307        Ok(())
308    }
309
310    #[mz_ore::test]
311    fn test_key_set_serialization_roundtrip() -> anyhow::Result<()> {
312        for _ in 0..100 {
313            let key_set = SshKeyPairSet::new()?;
314            let roundtripped_key_set = SshKeyPairSet::from_bytes(key_set.to_bytes().as_slice())?;
315
316            assert_eq!(key_set, roundtripped_key_set);
317        }
318        Ok(())
319    }
320
321    #[mz_ore::test]
322    fn test_key_rotation() -> anyhow::Result<()> {
323        for _ in 0..100 {
324            let key_set = SshKeyPairSet::new()?;
325            let rotated_key_set = key_set.rotate()?;
326
327            assert_eq!(key_set.secondary(), rotated_key_set.primary());
328            assert_ne!(key_set.primary(), rotated_key_set.secondary());
329            assert_ne!(rotated_key_set.primary(), rotated_key_set.secondary());
330        }
331        Ok(())
332    }
333
334    /// Ensure the new code can read legacy generated Keypairs
335    #[mz_ore::test]
336    fn test_deserializing_legacy_key_pairs() -> anyhow::Result<()> {
337        for _ in 0..100 {
338            let legacy_key_pair = LegacySshKeyPair::new()?;
339            let roundtripped_key_pair: SshKeyPair = serde_json::from_slice(
340                &serde_json::to_vec(&legacy_key_pair)
341                    .expect("serialization of key_set cannot fail"),
342            )?;
343
344            assert_eq!(
345                legacy_key_pair.private_key,
346                roundtripped_key_pair.ssh_private_key().as_bytes()
347            );
348        }
349        Ok(())
350    }
351
352    /// Ensure the legacy code can read newly generated Keypairs, e.g. if we have to rollback
353    #[mz_ore::test]
354    fn test_serializing_legacy_key_pairs() -> anyhow::Result<()> {
355        for _ in 0..100 {
356            let key_pair = SshKeyPair::new()?;
357            let roundtripped_legacy_key_pair: LegacySshKeyPair = serde_json::from_slice(
358                &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
359            )?;
360
361            assert_eq!(
362                roundtripped_legacy_key_pair.private_key,
363                key_pair.ssh_private_key().as_bytes()
364            );
365        }
366        Ok(())
367    }
368
369    /// Ensure the new code can read legacy generated Keysets
370    #[mz_ore::test]
371    fn test_deserializing_legacy_key_sets() -> anyhow::Result<()> {
372        for _ in 0..100 {
373            let legacy_key_pair = LegacySshKeyPairSet::new()?;
374            let roundtripped_key_pair: SshKeyPairSet = serde_json::from_slice(
375                &serde_json::to_vec(&legacy_key_pair)
376                    .expect("serialization of key_set cannot fail"),
377            )?;
378
379            // assert_eq!(legacy_key_pair.private_key, roundtripped_key_pair.ssh_private_key().as_bytes());
380            assert_eq!(
381                legacy_key_pair.primary.private_key,
382                roundtripped_key_pair.primary().ssh_private_key().as_bytes()
383            );
384            assert_eq!(
385                legacy_key_pair.secondary.private_key,
386                roundtripped_key_pair
387                    .secondary()
388                    .ssh_private_key()
389                    .as_bytes()
390            );
391        }
392        Ok(())
393    }
394
395    /// Ensure the legacy code can read newly generated Keysets, e.g. if we have to rollback
396    #[mz_ore::test]
397    fn test_serializing_legacy_key_sets() -> anyhow::Result<()> {
398        for _ in 0..100 {
399            let key_pair = SshKeyPairSet::new()?;
400            let roundtripped_legacy_key_pair: LegacySshKeyPairSet = serde_json::from_slice(
401                &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
402            )?;
403
404            assert_eq!(
405                roundtripped_legacy_key_pair.primary.private_key,
406                key_pair.primary().ssh_private_key().as_bytes()
407            );
408            assert_eq!(
409                roundtripped_legacy_key_pair.secondary.private_key,
410                key_pair.secondary().ssh_private_key().as_bytes()
411            );
412        }
413        Ok(())
414    }
415
416    /// The previously used Keypair struct, here to test serialization logic across versions
417    #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Clone, Serialize, Deserialize)]
418    struct LegacySshKeyPair {
419        public_key: Vec<u8>,
420        private_key: Vec<u8>,
421    }
422
423    impl LegacySshKeyPair {
424        fn new() -> Result<Self, anyhow::Error> {
425            let mut seed = Zeroizing::new([0u8; 32]);
426            aws_lc_rs::rand::fill(&mut *seed)
427                .map_err(|_| anyhow::anyhow!("random generation failed"))?;
428            let aws_key = Ed25519KeyPair::from_seed_unchecked(&*seed)
429                .map_err(|_| anyhow::anyhow!("Ed25519 key generation failed"))?;
430
431            let key_pair = KeypairData::Ed25519(Ed25519Keypair {
432                public: Ed25519PublicKey::try_from(aws_key.public_key().as_ref())?,
433                private: Ed25519PrivateKey::try_from(seed.as_slice())?,
434            });
435
436            let private_key_wrapper = PrivateKey::new(key_pair, "materialize")?;
437            let openssh_private_key = &*private_key_wrapper.to_openssh(LineEnding::LF)?;
438            let openssh_public_key = private_key_wrapper.public_key().to_openssh()?;
439
440            Ok(Self {
441                public_key: openssh_public_key.as_bytes().to_vec(),
442                private_key: openssh_private_key.as_bytes().to_vec(),
443            })
444        }
445    }
446
447    /// The previously used Keyset struct, here to test serialization logic across versions
448    #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Clone, Serialize, Deserialize)]
449    struct LegacySshKeyPairSet {
450        primary: LegacySshKeyPair,
451        secondary: LegacySshKeyPair,
452    }
453
454    impl LegacySshKeyPairSet {
455        /// Generate a new key_set with random keys
456        fn new() -> Result<Self, anyhow::Error> {
457            Ok(Self {
458                primary: LegacySshKeyPair::new()?,
459                secondary: LegacySshKeyPair::new()?,
460            })
461        }
462    }
463}