Skip to main content

mz_ssh_util/
keys.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Utilities for generating and managing SSH keys.
11
12use std::cmp::Ordering;
13use std::fmt;
14
15use mz_ore::secure::Zeroizing;
16use openssl::pkey::{PKey, Private};
17use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor};
18use serde::ser::{SerializeStruct, Serializer};
19use serde::{Deserialize, Serialize};
20use ssh_key::private::{Ed25519Keypair, Ed25519PrivateKey, KeypairData};
21use ssh_key::public::Ed25519PublicKey;
22use ssh_key::{HashAlg, LineEnding, PrivateKey};
23
24/// A SSH key pair consisting of a public and private key.
25#[derive(Debug, Clone)]
26pub struct SshKeyPair {
27    // Even though the type is called PrivateKey, it includes the full key pair,
28    // and zeroes memory on `Drop`.
29    key_pair: PrivateKey,
30}
31
32impl SshKeyPair {
33    /// Generates a new SSH key pair.
34    ///
35    /// Ed25519 keys are generated via OpenSSL, using [`ssh_key`] to convert
36    /// them into the OpenSSH format.
37    pub fn new() -> Result<SshKeyPair, anyhow::Error> {
38        let openssl_key = PKey::<Private>::generate_ed25519()?;
39
40        // Wrap the raw private key bytes in Zeroizing so they are erased on drop,
41        // preventing the intermediate buffer from lingering in memory.
42        let raw_private = Zeroizing::new(openssl_key.raw_private_key()?);
43        let key_pair_data = KeypairData::Ed25519(Ed25519Keypair {
44            public: Ed25519PublicKey::try_from(openssl_key.raw_public_key()?.as_slice())?,
45            private: Ed25519PrivateKey::try_from(raw_private.as_slice())?,
46        });
47
48        let key_pair = PrivateKey::new(key_pair_data, "materialize")?;
49
50        Ok(SshKeyPair { key_pair })
51    }
52
53    /// Deserializes a key pair from a key pair set that was serialized with
54    /// [`SshKeyPairSet::serialize`].
55    pub fn from_bytes(data: &[u8]) -> anyhow::Result<SshKeyPair> {
56        let set = SshKeyPairSet::from_bytes(data)?;
57        Ok(set.primary().clone())
58    }
59
60    /// Deserializes a key pair from an OpenSSH-formatted private key.
61    fn from_private_key(private_key: &[u8]) -> Result<SshKeyPair, anyhow::Error> {
62        let private_key = PrivateKey::from_openssh(private_key)?;
63
64        Ok(SshKeyPair {
65            key_pair: private_key,
66        })
67    }
68
69    /// Returns the public key encoded in the OpenSSH format.
70    pub fn ssh_public_key(&self) -> String {
71        self.key_pair.public_key().to_string()
72    }
73
74    /// Return the private key encoded in the OpenSSH format.
75    pub fn ssh_private_key(&self) -> Zeroizing<String> {
76        self.key_pair
77            .to_openssh(LineEnding::LF)
78            .expect("encoding as OpenSSH cannot fail")
79    }
80}
81
82impl PartialEq for SshKeyPair {
83    fn eq(&self, other: &Self) -> bool {
84        self.key_pair.fingerprint(HashAlg::default())
85            == other.key_pair.fingerprint(HashAlg::default())
86    }
87}
88
89impl Eq for SshKeyPair {}
90
91impl PartialOrd for SshKeyPair {
92    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
93        Some(self.cmp(other))
94    }
95}
96
97impl Ord for SshKeyPair {
98    fn cmp(&self, other: &Self) -> Ordering {
99        self.key_pair
100            .fingerprint(HashAlg::default())
101            .cmp(&other.key_pair.fingerprint(HashAlg::default()))
102    }
103}
104
105impl Serialize for SshKeyPair {
106    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
107    where
108        S: Serializer,
109    {
110        let mut state = serializer.serialize_struct("SshKeypair", 2)?;
111        // Public key is still encoded for backwards compatibility, but it is not used anymore
112        state.serialize_field("public_key", self.ssh_public_key().as_bytes())?;
113        state.serialize_field("private_key", self.ssh_private_key().as_bytes())?;
114        state.end()
115    }
116}
117
118impl<'de> Deserialize<'de> for SshKeyPair {
119    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
120    where
121        D: Deserializer<'de>,
122    {
123        #[derive(Deserialize)]
124        #[serde(field_identifier, rename_all = "snake_case")]
125        enum Field {
126            PublicKey,
127            PrivateKey,
128        }
129
130        struct SshKeyPairVisitor;
131
132        impl<'de> Visitor<'de> for SshKeyPairVisitor {
133            type Value = SshKeyPair;
134
135            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
136                formatter.write_str("struct SshKeypair")
137            }
138
139            fn visit_seq<V>(self, mut seq: V) -> Result<SshKeyPair, V::Error>
140            where
141                V: SeqAccess<'de>,
142            {
143                // Public key is still read for backwards compatibility, but it is not used anymore
144                let _public_key: Vec<u8> = seq
145                    .next_element()?
146                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
147                let private_key: Zeroizing<Vec<u8>> = seq
148                    .next_element()?
149                    .ok_or_else(|| de::Error::invalid_length(1, &self))?;
150                SshKeyPair::from_private_key(&private_key).map_err(de::Error::custom)
151            }
152
153            fn visit_map<V>(self, mut map: V) -> Result<SshKeyPair, V::Error>
154            where
155                V: MapAccess<'de>,
156            {
157                // Public key is still read for backwards compatibility, but it is not used anymore
158                let mut _public_key: Option<Vec<u8>> = None;
159                let mut private_key: Option<Zeroizing<Vec<u8>>> = None;
160                while let Some(key) = map.next_key()? {
161                    match key {
162                        Field::PublicKey => {
163                            if _public_key.is_some() {
164                                return Err(de::Error::duplicate_field("public_key"));
165                            }
166                            _public_key = Some(map.next_value()?);
167                        }
168                        Field::PrivateKey => {
169                            if private_key.is_some() {
170                                return Err(de::Error::duplicate_field("private_key"));
171                            }
172                            private_key = Some(map.next_value()?);
173                        }
174                    }
175                }
176                let private_key =
177                    private_key.ok_or_else(|| de::Error::missing_field("private_key"))?;
178                SshKeyPair::from_private_key(&private_key).map_err(de::Error::custom)
179            }
180        }
181
182        const FIELDS: &[&str] = &["public_key", "private_key"];
183        deserializer.deserialize_struct("SshKeypair", FIELDS, SshKeyPairVisitor)
184    }
185}
186
187/// A set of two SSH key pairs, used to support key rotation.
188///
189/// When a key pair set is rotated, the secondary key pair becomes the new
190/// primary key pair, and a new secondary key pair is generated.
191#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
192pub struct SshKeyPairSet {
193    primary: SshKeyPair,
194    secondary: SshKeyPair,
195}
196
197impl SshKeyPairSet {
198    // Generates a new key pair set with random key pairs.
199    pub fn new() -> Result<SshKeyPairSet, anyhow::Error> {
200        Ok(SshKeyPairSet::from_parts(
201            SshKeyPair::new()?,
202            SshKeyPair::new()?,
203        ))
204    }
205
206    /// Creates a new key pair set from an existing primary and secondary key
207    /// pair.
208    pub fn from_parts(primary: SshKeyPair, secondary: SshKeyPair) -> SshKeyPairSet {
209        SshKeyPairSet { primary, secondary }
210    }
211
212    /// Rotate the key pairs in the set.
213    ///
214    /// The rotation promotes the secondary key_pair to primary and generates a
215    /// new random secondary key pair.
216    pub fn rotate(&self) -> Result<SshKeyPairSet, anyhow::Error> {
217        Ok(SshKeyPairSet {
218            primary: self.secondary.clone(),
219            secondary: SshKeyPair::new()?,
220        })
221    }
222
223    /// Returns the primary and secondary public keys in the set.
224    pub fn public_keys(&self) -> (String, String) {
225        let primary = self.primary().ssh_public_key();
226        let secondary = self.secondary().ssh_public_key();
227        (primary, secondary)
228    }
229
230    /// Return the primary key pair.
231    pub fn primary(&self) -> &SshKeyPair {
232        &self.primary
233    }
234
235    /// Returns the secondary pair.
236    pub fn secondary(&self) -> &SshKeyPair {
237        &self.secondary
238    }
239
240    /// Serializes the key pair set to an unspecified encoding.
241    ///
242    /// You can deserialize a key pair set with [`SshKeyPairSet::deserialize`].
243    pub fn to_bytes(&self) -> Zeroizing<Vec<u8>> {
244        Zeroizing::new(serde_json::to_vec(self).expect("serialization of key_set cannot fail"))
245    }
246
247    /// Deserializes a key pair set that was serialized with
248    /// [`SshKeyPairSet::serialize`].
249    pub fn from_bytes(data: &[u8]) -> anyhow::Result<SshKeyPairSet> {
250        Ok(serde_json::from_slice(data)?)
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use openssl::pkey::{PKey, Private};
257    use serde::{Deserialize, Serialize};
258    use ssh_key::private::{Ed25519Keypair, Ed25519PrivateKey, KeypairData};
259    use ssh_key::public::Ed25519PublicKey;
260    use ssh_key::{LineEnding, PrivateKey};
261
262    use super::{SshKeyPair, SshKeyPairSet};
263
264    #[mz_ore::test]
265    fn test_key_pair_generation() -> anyhow::Result<()> {
266        for _ in 0..100 {
267            let key_pair = SshKeyPair::new()?;
268
269            // Public keys should be in ASCII
270            let public_key = key_pair.ssh_public_key();
271            // Public keys should be in the OpenSSH format
272            assert!(public_key.starts_with("ssh-ed25519 "));
273            // Private keys should be in ASCII
274            let private_key = key_pair.ssh_private_key();
275            // Private keys should also be in the OpenSSH format
276            assert!(private_key.starts_with("-----BEGIN OPENSSH PRIVATE KEY-----"));
277        }
278        Ok(())
279    }
280
281    #[mz_ore::test]
282    fn test_unique_keys() -> anyhow::Result<()> {
283        for _ in 0..100 {
284            let key_set = SshKeyPairSet::new()?;
285            assert_ne!(key_set.primary(), key_set.secondary());
286        }
287        Ok(())
288    }
289
290    #[mz_ore::test]
291    fn test_key_pair_serialization_roundtrip() -> anyhow::Result<()> {
292        for _ in 0..100 {
293            let key_pair = SshKeyPair::new()?;
294            let roundtripped_key_pair: SshKeyPair = serde_json::from_slice(
295                &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
296            )?;
297
298            assert_eq!(key_pair, roundtripped_key_pair);
299        }
300        Ok(())
301    }
302
303    #[mz_ore::test]
304    fn test_key_set_serialization_roundtrip() -> anyhow::Result<()> {
305        for _ in 0..100 {
306            let key_set = SshKeyPairSet::new()?;
307            let roundtripped_key_set = SshKeyPairSet::from_bytes(key_set.to_bytes().as_slice())?;
308
309            assert_eq!(key_set, roundtripped_key_set);
310        }
311        Ok(())
312    }
313
314    #[mz_ore::test]
315    fn test_key_rotation() -> anyhow::Result<()> {
316        for _ in 0..100 {
317            let key_set = SshKeyPairSet::new()?;
318            let rotated_key_set = key_set.rotate()?;
319
320            assert_eq!(key_set.secondary(), rotated_key_set.primary());
321            assert_ne!(key_set.primary(), rotated_key_set.secondary());
322            assert_ne!(rotated_key_set.primary(), rotated_key_set.secondary());
323        }
324        Ok(())
325    }
326
327    /// Ensure the new code can read legacy generated Keypairs
328    #[mz_ore::test]
329    fn test_deserializing_legacy_key_pairs() -> anyhow::Result<()> {
330        for _ in 0..100 {
331            let legacy_key_pair = LegacySshKeyPair::new()?;
332            let roundtripped_key_pair: SshKeyPair = serde_json::from_slice(
333                &serde_json::to_vec(&legacy_key_pair)
334                    .expect("serialization of key_set cannot fail"),
335            )?;
336
337            assert_eq!(
338                legacy_key_pair.private_key,
339                roundtripped_key_pair.ssh_private_key().as_bytes()
340            );
341        }
342        Ok(())
343    }
344
345    /// Ensure the legacy code can read newly generated Keypairs, e.g. if we have to rollback
346    #[mz_ore::test]
347    fn test_serializing_legacy_key_pairs() -> anyhow::Result<()> {
348        for _ in 0..100 {
349            let key_pair = SshKeyPair::new()?;
350            let roundtripped_legacy_key_pair: LegacySshKeyPair = serde_json::from_slice(
351                &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
352            )?;
353
354            assert_eq!(
355                roundtripped_legacy_key_pair.private_key,
356                key_pair.ssh_private_key().as_bytes()
357            );
358        }
359        Ok(())
360    }
361
362    /// Ensure the new code can read legacy generated Keysets
363    #[mz_ore::test]
364    fn test_deserializing_legacy_key_sets() -> anyhow::Result<()> {
365        for _ in 0..100 {
366            let legacy_key_pair = LegacySshKeyPairSet::new()?;
367            let roundtripped_key_pair: SshKeyPairSet = serde_json::from_slice(
368                &serde_json::to_vec(&legacy_key_pair)
369                    .expect("serialization of key_set cannot fail"),
370            )?;
371
372            // assert_eq!(legacy_key_pair.private_key, roundtripped_key_pair.ssh_private_key().as_bytes());
373            assert_eq!(
374                legacy_key_pair.primary.private_key,
375                roundtripped_key_pair.primary().ssh_private_key().as_bytes()
376            );
377            assert_eq!(
378                legacy_key_pair.secondary.private_key,
379                roundtripped_key_pair
380                    .secondary()
381                    .ssh_private_key()
382                    .as_bytes()
383            );
384        }
385        Ok(())
386    }
387
388    /// Ensure the legacy code can read newly generated Keysets, e.g. if we have to rollback
389    #[mz_ore::test]
390    fn test_serializing_legacy_key_sets() -> anyhow::Result<()> {
391        for _ in 0..100 {
392            let key_pair = SshKeyPairSet::new()?;
393            let roundtripped_legacy_key_pair: LegacySshKeyPairSet = serde_json::from_slice(
394                &serde_json::to_vec(&key_pair).expect("serialization of key_set cannot fail"),
395            )?;
396
397            assert_eq!(
398                roundtripped_legacy_key_pair.primary.private_key,
399                key_pair.primary().ssh_private_key().as_bytes()
400            );
401            assert_eq!(
402                roundtripped_legacy_key_pair.secondary.private_key,
403                key_pair.secondary().ssh_private_key().as_bytes()
404            );
405        }
406        Ok(())
407    }
408
409    /// The previously used Keypair struct, here to test serialization logic across versions
410    #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Clone, Serialize, Deserialize)]
411    struct LegacySshKeyPair {
412        public_key: Vec<u8>,
413        private_key: Vec<u8>,
414    }
415
416    impl LegacySshKeyPair {
417        fn new() -> Result<Self, anyhow::Error> {
418            let openssl_key = PKey::<Private>::generate_ed25519()?;
419
420            let key_pair = KeypairData::Ed25519(Ed25519Keypair {
421                public: Ed25519PublicKey::try_from(openssl_key.raw_public_key()?.as_slice())?,
422                private: Ed25519PrivateKey::try_from(openssl_key.raw_private_key()?.as_slice())?,
423            });
424
425            let private_key_wrapper = PrivateKey::new(key_pair, "materialize")?;
426            let openssh_private_key = &*private_key_wrapper.to_openssh(LineEnding::LF)?;
427            let openssh_public_key = private_key_wrapper.public_key().to_openssh()?;
428
429            Ok(Self {
430                public_key: openssh_public_key.as_bytes().to_vec(),
431                private_key: openssh_private_key.as_bytes().to_vec(),
432            })
433        }
434    }
435
436    /// The previously used Keyset struct, here to test serialization logic across versions
437    #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Clone, Serialize, Deserialize)]
438    struct LegacySshKeyPairSet {
439        primary: LegacySshKeyPair,
440        secondary: LegacySshKeyPair,
441    }
442
443    impl LegacySshKeyPairSet {
444        /// Generate a new key_set with random keys
445        fn new() -> Result<Self, anyhow::Error> {
446            Ok(Self {
447                primary: LegacySshKeyPair::new()?,
448                secondary: LegacySshKeyPair::new()?,
449            })
450        }
451    }
452}