mz_ssh_util/
keys.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Utilities for generating and managing SSH keys.
11
12use std::cmp::Ordering;
13use std::fmt;
14
15use openssl::pkey::{PKey, Private};
16use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
17use serde::ser::{SerializeStruct, Serializer};
18use serde::{Deserialize, Serialize};
19use ssh_key::private::{Ed25519Keypair, Ed25519PrivateKey, KeypairData};
20use ssh_key::public::Ed25519PublicKey;
21use ssh_key::{HashAlg, LineEnding, PrivateKey};
22use zeroize::Zeroizing;
23
24/// A SSH key pair consisting of a public and private key.
25#[derive(Debug, Clone)]
26pub struct SshKeyPair {
27    // Even though the type is called PrivateKey, it includes the full key pair,
28    // and zeroes memory on `Drop`.
29    key_pair: PrivateKey,
30}
31
32impl SshKeyPair {
33    /// Generates a new SSH key pair.
34    ///
35    /// Ed25519 keys are generated via OpenSSL, using [`ssh_key`] to convert
36    /// them into the OpenSSH format.
37    pub fn new() -> Result<SshKeyPair, anyhow::Error> {
38        let openssl_key = PKey::<Private>::generate_ed25519()?;
39
40        let key_pair_data = KeypairData::Ed25519(Ed25519Keypair {
41            public: Ed25519PublicKey::try_from(openssl_key.raw_public_key()?.as_slice())?,
42            private: Ed25519PrivateKey::try_from(openssl_key.raw_private_key()?.as_slice())?,
43        });
44
45        let key_pair = PrivateKey::new(key_pair_data, "materialize")?;
46
47        Ok(SshKeyPair { key_pair })
48    }
49
50    /// Deserializes a key pair from a key pair set that was serialized with
51    /// [`SshKeyPairSet::serialize`].
52    pub fn from_bytes(data: &[u8]) -> anyhow::Result<SshKeyPair> {
53        let set = SshKeyPairSet::from_bytes(data)?;
54        Ok(set.primary().clone())
55    }
56
57    /// Deserializes a key pair from an OpenSSH-formatted private key.
58    fn from_private_key(private_key: &[u8]) -> Result<SshKeyPair, anyhow::Error> {
59        let private_key = PrivateKey::from_openssh(private_key)?;
60
61        Ok(SshKeyPair {
62            key_pair: private_key,
63        })
64    }
65
66    /// Returns the public key encoded in the OpenSSH format.
67    pub fn ssh_public_key(&self) -> String {
68        self.key_pair.public_key().to_string()
69    }
70
71    /// Return the private key encoded in the OpenSSH format.
72    pub fn ssh_private_key(&self) -> Zeroizing<String> {
73        self.key_pair
74            .to_openssh(LineEnding::LF)
75            .expect("encoding as OpenSSH cannot fail")
76    }
77}
78
79impl PartialEq for SshKeyPair {
80    fn eq(&self, other: &Self) -> bool {
81        self.key_pair.fingerprint(HashAlg::default())
82            == other.key_pair.fingerprint(HashAlg::default())
83    }
84}
85
86impl Eq for SshKeyPair {}
87
88impl PartialOrd for SshKeyPair {
89    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
90        Some(self.cmp(other))
91    }
92}
93
94impl Ord for SshKeyPair {
95    fn cmp(&self, other: &Self) -> Ordering {
96        self.key_pair
97            .fingerprint(HashAlg::default())
98            .cmp(&other.key_pair.fingerprint(HashAlg::default()))
99    }
100}
101
102impl Serialize for SshKeyPair {
103    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
104    where
105        S: Serializer,
106    {
107        let mut state = serializer.serialize_struct("SshKeypair", 2)?;
108        // Public key is still encoded for backwards compatibility, but it is not used anymore
109        state.serialize_field("public_key", self.ssh_public_key().as_bytes())?;
110        state.serialize_field("private_key", self.ssh_private_key().as_bytes())?;
111        state.end()
112    }
113}
114
115impl<'de> Deserialize<'de> for SshKeyPair {
116    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
117    where
118        D: Deserializer<'de>,
119    {
120        #[derive(Deserialize)]
121        #[serde(field_identifier, rename_all = "snake_case")]
122        enum Field {
123            PublicKey,
124            PrivateKey,
125        }
126
127        struct SshKeyPairVisitor;
128
129        impl<'de> Visitor<'de> for SshKeyPairVisitor {
130            type Value = SshKeyPair;
131
132            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
133                formatter.write_str("struct SshKeypair")
134            }
135
136            fn visit_seq<V>(self, mut seq: V) -> Result<SshKeyPair, V::Error>
137            where
138                V: SeqAccess<'de>,
139            {
140                // Public key is still read for backwards compatibility, but it is not used anymore
141                let _public_key: Vec<u8> = seq
142                    .next_element()?
143                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
144                let private_key: Zeroizing<Vec<u8>> = seq
145                    .next_element()?
146                    .ok_or_else(|| de::Error::invalid_length(1, &self))?;
147                SshKeyPair::from_private_key(&private_key).map_err(de::Error::custom)
148            }
149
150            fn visit_map<V>(self, mut map: V) -> Result<SshKeyPair, V::Error>
151            where
152                V: MapAccess<'de>,
153            {
154                // Public key is still read for backwards compatibility, but it is not used anymore
155                let mut _public_key: Option<Vec<u8>> = None;
156                let mut private_key: Option<Zeroizing<Vec<u8>>> = None;
157                while let Some(key) = map.next_key()? {
158                    match key {
159                        Field::PublicKey => {
160                            if _public_key.is_some() {
161                                return Err(de::Error::duplicate_field("public_key"));
162                            }
163                            _public_key = Some(map.next_value()?);
164                        }
165                        Field::PrivateKey => {
166                            if private_key.is_some() {
167                                return Err(de::Error::duplicate_field("private_key"));
168                            }
169                            private_key = Some(map.next_value()?);
170                        }
171                    }
172                }
173                let private_key =
174                    private_key.ok_or_else(|| de::Error::missing_field("private_key"))?;
175                SshKeyPair::from_private_key(&private_key).map_err(de::Error::custom)
176            }
177        }
178
179        const FIELDS: &[&str] = &["public_key", "private_key"];
180        deserializer.deserialize_struct("SshKeypair", FIELDS, SshKeyPairVisitor)
181    }
182}
183
184/// A set of two SSH key pairs, used to support key rotation.
185///
186/// When a key pair set is rotated, the secondary key pair becomes the new
187/// primary key pair, and a new secondary key pair is generated.
188#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
189pub struct SshKeyPairSet {
190    primary: SshKeyPair,
191    secondary: SshKeyPair,
192}
193
194impl SshKeyPairSet {
195    // Generates a new key pair set with random key pairs.
196    pub fn new() -> Result<SshKeyPairSet, anyhow::Error> {
197        Ok(SshKeyPairSet::from_parts(
198            SshKeyPair::new()?,
199            SshKeyPair::new()?,
200        ))
201    }
202
203    /// Creates a new key pair set from an existing primary and secondary key
204    /// pair.
205    pub fn from_parts(primary: SshKeyPair, secondary: SshKeyPair) -> SshKeyPairSet {
206        SshKeyPairSet { primary, secondary }
207    }
208
209    /// Rotate the key pairs in the set.
210    ///
211    /// The rotation promotes the secondary key_pair to primary and generates a
212    /// new random secondary key pair.
213    pub fn rotate(&self) -> Result<SshKeyPairSet, anyhow::Error> {
214        Ok(SshKeyPairSet {
215            primary: self.secondary.clone(),
216            secondary: SshKeyPair::new()?,
217        })
218    }
219
220    /// Returns the primary and secondary public keys in the set.
221    pub fn public_keys(&self) -> (String, String) {
222        let primary = self.primary().ssh_public_key();
223        let secondary = self.secondary().ssh_public_key();
224        (primary, secondary)
225    }
226
227    /// Return the primary key pair.
228    pub fn primary(&self) -> &SshKeyPair {
229        &self.primary
230    }
231
232    /// Returns the secondary pair.
233    pub fn secondary(&self) -> &SshKeyPair {
234        &self.secondary
235    }
236
237    /// Serializes the key pair set to an unspecified encoding.
238    ///
239    /// You can deserialize a key pair set with [`SshKeyPairSet::deserialize`].
240    pub fn to_bytes(&self) -> Zeroizing<Vec<u8>> {
241        Zeroizing::new(serde_json::to_vec(self).expect("serialization of key_set cannot fail"))
242    }
243
244    /// Deserializes a key pair set that was serialized with
245    /// [`SshKeyPairSet::serialize`].
246    pub fn from_bytes(data: &[u8]) -> anyhow::Result<SshKeyPairSet> {
247        Ok(serde_json::from_slice(data)?)
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use openssl::pkey::{PKey, Private};
254    use serde::{Deserialize, Serialize};
255    use ssh_key::private::{Ed25519Keypair, Ed25519PrivateKey, KeypairData};
256    use ssh_key::public::Ed25519PublicKey;
257    use ssh_key::{LineEnding, PrivateKey};
258
259    use super::{SshKeyPair, SshKeyPairSet};
260
261    #[mz_ore::test]
262    fn test_key_pair_generation() -> anyhow::Result<()> {
263        for _ in 0..100 {
264            let key_pair = SshKeyPair::new()?;
265
266            // Public keys should be in ASCII
267            let public_key = key_pair.ssh_public_key();
268            // Public keys should be in the OpenSSH format
269            assert!(public_key.starts_with("ssh-ed25519 "));
270            // Private keys should be in ASCII
271            let private_key = key_pair.ssh_private_key();
272            // Private keys should also be in the OpenSSH format
273            assert!(private_key.starts_with("-----BEGIN OPENSSH PRIVATE KEY-----"));
274        }
275        Ok(())
276    }
277
278    #[mz_ore::test]
279    fn test_unique_keys() -> anyhow::Result<()> {
280        for _ in 0..100 {
281            let key_set = SshKeyPairSet::new()?;
282            assert_ne!(key_set.primary(), key_set.secondary());
283        }
284        Ok(())
285    }
286
287    #[mz_ore::test]
288    fn test_key_pair_serialization_roundtrip() -> anyhow::Result<()> {
289        for _ in 0..100 {
290            let key_pair = SshKeyPair::new()?;
291            let roundtripped_key_pair: SshKeyPair = serde_json::from_slice(
292                &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
293            )?;
294
295            assert_eq!(key_pair, roundtripped_key_pair);
296        }
297        Ok(())
298    }
299
300    #[mz_ore::test]
301    fn test_key_set_serialization_roundtrip() -> anyhow::Result<()> {
302        for _ in 0..100 {
303            let key_set = SshKeyPairSet::new()?;
304            let roundtripped_key_set = SshKeyPairSet::from_bytes(key_set.to_bytes().as_slice())?;
305
306            assert_eq!(key_set, roundtripped_key_set);
307        }
308        Ok(())
309    }
310
311    #[mz_ore::test]
312    fn test_key_rotation() -> anyhow::Result<()> {
313        for _ in 0..100 {
314            let key_set = SshKeyPairSet::new()?;
315            let rotated_key_set = key_set.rotate()?;
316
317            assert_eq!(key_set.secondary(), rotated_key_set.primary());
318            assert_ne!(key_set.primary(), rotated_key_set.secondary());
319            assert_ne!(rotated_key_set.primary(), rotated_key_set.secondary());
320        }
321        Ok(())
322    }
323
324    /// Ensure the new code can read legacy generated Keypairs
325    #[mz_ore::test]
326    fn test_deserializing_legacy_key_pairs() -> anyhow::Result<()> {
327        for _ in 0..100 {
328            let legacy_key_pair = LegacySshKeyPair::new()?;
329            let roundtripped_key_pair: SshKeyPair = serde_json::from_slice(
330                &serde_json::to_vec(&legacy_key_pair)
331                    .expect("serialization of key_set cannot fail"),
332            )?;
333
334            assert_eq!(
335                legacy_key_pair.private_key,
336                roundtripped_key_pair.ssh_private_key().as_bytes()
337            );
338        }
339        Ok(())
340    }
341
342    /// Ensure the legacy code can read newly generated Keypairs, e.g. if we have to rollback
343    #[mz_ore::test]
344    fn test_serializing_legacy_key_pairs() -> anyhow::Result<()> {
345        for _ in 0..100 {
346            let key_pair = SshKeyPair::new()?;
347            let roundtripped_legacy_key_pair: LegacySshKeyPair = serde_json::from_slice(
348                &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
349            )?;
350
351            assert_eq!(
352                roundtripped_legacy_key_pair.private_key,
353                key_pair.ssh_private_key().as_bytes()
354            );
355        }
356        Ok(())
357    }
358
359    /// Ensure the new code can read legacy generated Keysets
360    #[mz_ore::test]
361    fn test_deserializing_legacy_key_sets() -> anyhow::Result<()> {
362        for _ in 0..100 {
363            let legacy_key_pair = LegacySshKeyPairSet::new()?;
364            let roundtripped_key_pair: SshKeyPairSet = serde_json::from_slice(
365                &serde_json::to_vec(&legacy_key_pair)
366                    .expect("serialization of key_set cannot fail"),
367            )?;
368
369            // assert_eq!(legacy_key_pair.private_key, roundtripped_key_pair.ssh_private_key().as_bytes());
370            assert_eq!(
371                legacy_key_pair.primary.private_key,
372                roundtripped_key_pair.primary().ssh_private_key().as_bytes()
373            );
374            assert_eq!(
375                legacy_key_pair.secondary.private_key,
376                roundtripped_key_pair
377                    .secondary()
378                    .ssh_private_key()
379                    .as_bytes()
380            );
381        }
382        Ok(())
383    }
384
385    /// Ensure the legacy code can read newly generated Keysets, e.g. if we have to rollback
386    #[mz_ore::test]
387    fn test_serializing_legacy_key_sets() -> anyhow::Result<()> {
388        for _ in 0..100 {
389            let key_pair = SshKeyPairSet::new()?;
390            let roundtripped_legacy_key_pair: LegacySshKeyPairSet = serde_json::from_slice(
391                &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
392            )?;
393
394            assert_eq!(
395                roundtripped_legacy_key_pair.primary.private_key,
396                key_pair.primary().ssh_private_key().as_bytes()
397            );
398            assert_eq!(
399                roundtripped_legacy_key_pair.secondary.private_key,
400                key_pair.secondary().ssh_private_key().as_bytes()
401            );
402        }
403        Ok(())
404    }
405
406    /// The previously used Keypair struct, here to test serialization logic across versions
407    #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Clone, Serialize, Deserialize)]
408    struct LegacySshKeyPair {
409        public_key: Vec<u8>,
410        private_key: Vec<u8>,
411    }
412
413    impl LegacySshKeyPair {
414        fn new() -> Result<Self, anyhow::Error> {
415            let openssl_key = PKey::<Private>::generate_ed25519()?;
416
417            let key_pair = KeypairData::Ed25519(Ed25519Keypair {
418                public: Ed25519PublicKey::try_from(openssl_key.raw_public_key()?.as_slice())?,
419                private: Ed25519PrivateKey::try_from(openssl_key.raw_private_key()?.as_slice())?,
420            });
421
422            let private_key_wrapper = PrivateKey::new(key_pair, "materialize")?;
423            let openssh_private_key = &*private_key_wrapper.to_openssh(LineEnding::LF)?;
424            let openssh_public_key = private_key_wrapper.public_key().to_openssh()?;
425
426            Ok(Self {
427                public_key: openssh_public_key.as_bytes().to_vec(),
428                private_key: openssh_private_key.as_bytes().to_vec(),
429            })
430        }
431    }
432
433    /// The previously used Keyset struct, here to test serialization logic across versions
434    #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Clone, Serialize, Deserialize)]
435    struct LegacySshKeyPairSet {
436        primary: LegacySshKeyPair,
437        secondary: LegacySshKeyPair,
438    }
439
440    impl LegacySshKeyPairSet {
441        /// Generate a new key_set with random keys
442        fn new() -> Result<Self, anyhow::Error> {
443            Ok(Self {
444                primary: LegacySshKeyPair::new()?,
445                secondary: LegacySshKeyPair::new()?,
446            })
447        }
448    }
449}