aws_config/sso/
cache.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_runtime::fs_util::{home_dir, Os};
7use aws_smithy_json::deserialize::token::skip_value;
8use aws_smithy_json::deserialize::Token;
9use aws_smithy_json::deserialize::{json_token_iter, EscapeError};
10use aws_smithy_json::serialize::JsonObjectWriter;
11use aws_smithy_types::date_time::{DateTimeFormatError, Format};
12use aws_smithy_types::DateTime;
13use aws_types::os_shim_internal::{Env, Fs};
14use ring::digest;
15use std::borrow::Cow;
16use std::error::Error as StdError;
17use std::fmt;
18use std::path::PathBuf;
19use std::time::SystemTime;
20use zeroize::Zeroizing;
21
22#[cfg_attr(test, derive(Eq, PartialEq))]
23#[derive(Clone)]
24pub(super) struct CachedSsoToken {
25    pub(super) access_token: Zeroizing<String>,
26    pub(super) client_id: Option<String>,
27    pub(super) client_secret: Option<Zeroizing<String>>,
28    pub(super) expires_at: SystemTime,
29    pub(super) refresh_token: Option<Zeroizing<String>>,
30    pub(super) region: Option<String>,
31    pub(super) registration_expires_at: Option<SystemTime>,
32    pub(super) start_url: Option<String>,
33}
34
35impl CachedSsoToken {
36    /// True if the information required to refresh this token is present.
37    ///
38    /// The expiration times are not considered by this function.
39    pub(super) fn refreshable(&self) -> bool {
40        self.client_id.is_some()
41            && self.client_secret.is_some()
42            && self.refresh_token.is_some()
43            && self.registration_expires_at.is_some()
44    }
45}
46
47impl fmt::Debug for CachedSsoToken {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        f.debug_struct("CachedSsoToken")
50            .field("access_token", &"** redacted **")
51            .field("client_id", &self.client_id)
52            .field("client_secret", &"** redacted **")
53            .field("expires_at", &self.expires_at)
54            .field("refresh_token", &"** redacted **")
55            .field("region", &self.region)
56            .field("registration_expires_at", &self.registration_expires_at)
57            .field("start_url", &self.start_url)
58            .finish()
59    }
60}
61
62#[derive(Debug)]
63pub(super) enum CachedSsoTokenError {
64    FailedToFormatDateTime {
65        source: Box<dyn StdError + Send + Sync>,
66    },
67    InvalidField {
68        field: &'static str,
69        source: Box<dyn StdError + Send + Sync>,
70    },
71    IoError {
72        what: &'static str,
73        path: PathBuf,
74        source: std::io::Error,
75    },
76    JsonError(Box<dyn StdError + Send + Sync>),
77    MissingField(&'static str),
78    NoHomeDirectory,
79    Other(Cow<'static, str>),
80}
81
82impl fmt::Display for CachedSsoTokenError {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        match self {
85            Self::FailedToFormatDateTime { .. } => write!(f, "failed to format date time"),
86            Self::InvalidField { field, .. } => write!(
87                f,
88                "invalid value for the `{field}` field in the cached SSO token file"
89            ),
90            Self::IoError { what, path, .. } => write!(f, "failed to {what} `{}`", path.display()),
91            Self::JsonError(_) => write!(f, "invalid JSON in cached SSO token file"),
92            Self::MissingField(field) => {
93                write!(f, "missing field `{field}` in cached SSO token file")
94            }
95            Self::NoHomeDirectory => write!(f, "couldn't resolve a home directory"),
96            Self::Other(message) => f.write_str(message),
97        }
98    }
99}
100
101impl StdError for CachedSsoTokenError {
102    fn source(&self) -> Option<&(dyn StdError + 'static)> {
103        match self {
104            Self::FailedToFormatDateTime { source } => Some(source.as_ref()),
105            Self::InvalidField { source, .. } => Some(source.as_ref()),
106            Self::IoError { source, .. } => Some(source),
107            Self::JsonError(source) => Some(source.as_ref()),
108            Self::MissingField(_) => None,
109            Self::NoHomeDirectory => None,
110            Self::Other(_) => None,
111        }
112    }
113}
114
115impl From<EscapeError> for CachedSsoTokenError {
116    fn from(err: EscapeError) -> Self {
117        Self::JsonError(err.into())
118    }
119}
120
121impl From<aws_smithy_json::deserialize::error::DeserializeError> for CachedSsoTokenError {
122    fn from(err: aws_smithy_json::deserialize::error::DeserializeError) -> Self {
123        Self::JsonError(err.into())
124    }
125}
126
127impl From<DateTimeFormatError> for CachedSsoTokenError {
128    fn from(value: DateTimeFormatError) -> Self {
129        Self::FailedToFormatDateTime {
130            source: value.into(),
131        }
132    }
133}
134
135/// Determine the SSO cached token path for a given identifier.
136///
137/// The `identifier` is the `sso_start_url` for credentials providers, and `sso_session_name` for token providers.
138fn cached_token_path(identifier: &str, home: &str) -> PathBuf {
139    // hex::encode returns a lowercase string
140    let mut out = PathBuf::with_capacity(home.len() + "/.aws/sso/cache".len() + ".json".len() + 40);
141    out.push(home);
142    out.push(".aws/sso/cache");
143    out.push(&hex::encode(digest::digest(
144        &digest::SHA1_FOR_LEGACY_USE_ONLY,
145        identifier.as_bytes(),
146    )));
147    out.set_extension("json");
148    out
149}
150
151/// Load the token for `identifier` from `~/.aws/sso/cache/<hashofidentifier>.json`
152///
153/// The `identifier` is the `sso_start_url` for credentials providers, and `sso_session_name` for token providers.
154pub(super) async fn load_cached_token(
155    env: &Env,
156    fs: &Fs,
157    identifier: &str,
158) -> Result<CachedSsoToken, CachedSsoTokenError> {
159    let home = home_dir(env, Os::real()).ok_or(CachedSsoTokenError::NoHomeDirectory)?;
160    let path = cached_token_path(identifier, &home);
161    let data = Zeroizing::new(fs.read_to_end(&path).await.map_err(|source| {
162        CachedSsoTokenError::IoError {
163            what: "read",
164            path,
165            source,
166        }
167    })?);
168    parse_cached_token(&data)
169}
170
171/// Parse SSO token JSON from input
172fn parse_cached_token(
173    cached_token_file_contents: &[u8],
174) -> Result<CachedSsoToken, CachedSsoTokenError> {
175    use CachedSsoTokenError as Error;
176
177    let mut access_token = None;
178    let mut expires_at = None;
179    let mut client_id = None;
180    let mut client_secret = None;
181    let mut refresh_token = None;
182    let mut region = None;
183    let mut registration_expires_at = None;
184    let mut start_url = None;
185    json_parse_loop(cached_token_file_contents, |key, value| {
186        match (key, value) {
187            /*
188            // Required fields:
189            "accessToken": "string",
190            "expiresAt": "2019-11-14T04:05:45Z",
191
192            // Optional fields:
193            "refreshToken": "string",
194            "clientId": "ABCDEFG323242423121312312312312312",
195            "clientSecret": "ABCDE123",
196            "registrationExpiresAt": "2022-03-06T19:53:17Z",
197            "region": "us-west-2",
198            "startUrl": "https://d-abc123.awsapps.com/start"
199            */
200            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("accessToken") => {
201                access_token = Some(Zeroizing::new(value.to_unescaped()?.into_owned()));
202            }
203            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("expiresAt") => {
204                expires_at = Some(value.to_unescaped()?);
205            }
206            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("clientId") => {
207                client_id = Some(value.to_unescaped()?);
208            }
209            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("clientSecret") => {
210                client_secret = Some(Zeroizing::new(value.to_unescaped()?.into_owned()));
211            }
212            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("refreshToken") => {
213                refresh_token = Some(Zeroizing::new(value.to_unescaped()?.into_owned()));
214            }
215            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("region") => {
216                region = Some(value.to_unescaped()?.into_owned());
217            }
218            (key, Token::ValueString { value, .. })
219                if key.eq_ignore_ascii_case("registrationExpiresAt") =>
220            {
221                registration_expires_at = Some(value.to_unescaped()?);
222            }
223            (key, Token::ValueString { value, .. }) if key.eq_ignore_ascii_case("startUrl") => {
224                start_url = Some(value.to_unescaped()?.into_owned());
225            }
226            _ => {}
227        };
228        Ok(())
229    })?;
230
231    Ok(CachedSsoToken {
232        access_token: access_token.ok_or(Error::MissingField("accessToken"))?,
233        expires_at: expires_at
234            .ok_or(Error::MissingField("expiresAt"))
235            .and_then(|expires_at| {
236                DateTime::from_str(expires_at.as_ref(), Format::DateTime)
237                    .map_err(|err| Error::InvalidField { field: "expiresAt", source: err.into() })
238                    .and_then(|date_time| {
239                        SystemTime::try_from(date_time).map_err(|_| {
240                            Error::Other(
241                                "SSO token expiration time cannot be represented by a SystemTime"
242                                    .into(),
243                            )
244                        })
245                    })
246            })?,
247        client_id: client_id.map(Cow::into_owned),
248        client_secret,
249        refresh_token,
250        region,
251        registration_expires_at: Ok(registration_expires_at).and_then(|maybe_expires_at| {
252            if let Some(expires_at) = maybe_expires_at {
253                Some(
254                    DateTime::from_str(expires_at.as_ref(), Format::DateTime)
255                        .map_err(|err| Error::InvalidField { field: "registrationExpiresAt", source: err.into()})
256                        .and_then(|date_time| {
257                            SystemTime::try_from(date_time).map_err(|_| {
258                                Error::Other(
259                                    "SSO registration expiration time cannot be represented by a SystemTime"
260                                        .into(),
261                                )
262                            })
263                        }),
264                )
265                .transpose()
266            } else {
267                Ok(None)
268            }
269        })?,
270        start_url,
271    })
272}
273
274fn json_parse_loop<'a>(
275    input: &'a [u8],
276    mut f: impl FnMut(Cow<'a, str>, &Token<'a>) -> Result<(), CachedSsoTokenError>,
277) -> Result<(), CachedSsoTokenError> {
278    use CachedSsoTokenError as Error;
279    let mut tokens = json_token_iter(input).peekable();
280    if !matches!(tokens.next().transpose()?, Some(Token::StartObject { .. })) {
281        return Err(Error::Other(
282            "expected a JSON document starting with `{`".into(),
283        ));
284    }
285    loop {
286        match tokens.next().transpose()? {
287            Some(Token::EndObject { .. }) => break,
288            Some(Token::ObjectKey { key, .. }) => {
289                if let Some(Ok(token)) = tokens.peek() {
290                    let key = key.to_unescaped()?;
291                    f(key, token)?
292                }
293                skip_value(&mut tokens)?;
294            }
295            other => {
296                return Err(Error::Other(
297                    format!("expected object key, found: {:?}", other).into(),
298                ));
299            }
300        }
301    }
302    if tokens.next().is_some() {
303        return Err(Error::Other(
304            "found more JSON tokens after completing parsing".into(),
305        ));
306    }
307    Ok(())
308}
309
310pub(super) async fn save_cached_token(
311    env: &Env,
312    fs: &Fs,
313    identifier: &str,
314    token: &CachedSsoToken,
315) -> Result<(), CachedSsoTokenError> {
316    let expires_at = DateTime::from(token.expires_at).fmt(Format::DateTime)?;
317    let registration_expires_at = token
318        .registration_expires_at
319        .map(|time| DateTime::from(time).fmt(Format::DateTime))
320        .transpose()?;
321
322    let mut out = Zeroizing::new(String::new());
323    let mut writer = JsonObjectWriter::new(&mut out);
324    writer.key("accessToken").string(&token.access_token);
325    writer.key("expiresAt").string(&expires_at);
326    if let Some(refresh_token) = &token.refresh_token {
327        writer.key("refreshToken").string(refresh_token);
328    }
329    if let Some(client_id) = &token.client_id {
330        writer.key("clientId").string(client_id);
331    }
332    if let Some(client_secret) = &token.client_secret {
333        writer.key("clientSecret").string(client_secret);
334    }
335    if let Some(registration_expires_at) = registration_expires_at {
336        writer
337            .key("registrationExpiresAt")
338            .string(&registration_expires_at);
339    }
340    if let Some(region) = &token.region {
341        writer.key("region").string(region);
342    }
343    if let Some(start_url) = &token.start_url {
344        writer.key("startUrl").string(start_url);
345    }
346    writer.finish();
347
348    let home = home_dir(env, Os::real()).ok_or(CachedSsoTokenError::NoHomeDirectory)?;
349    let path = cached_token_path(identifier, &home);
350    fs.write(&path, out.as_bytes())
351        .await
352        .map_err(|err| CachedSsoTokenError::IoError {
353            what: "write",
354            path,
355            source: err,
356        })?;
357    Ok(())
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use std::collections::HashMap;
364    use std::time::Duration;
365
366    #[test]
367    fn redact_fields_in_token_debug() {
368        let token = CachedSsoToken {
369            access_token: Zeroizing::new("!!SENSITIVE!!".into()),
370            client_id: Some("clientid".into()),
371            client_secret: Some(Zeroizing::new("!!SENSITIVE!!".into())),
372            expires_at: SystemTime::now(),
373            refresh_token: Some(Zeroizing::new("!!SENSITIVE!!".into())),
374            region: Some("region".into()),
375            registration_expires_at: Some(SystemTime::now()),
376            start_url: Some("starturl".into()),
377        };
378        let debug_str = format!("{:?}", token);
379        assert!(!debug_str.contains("!!SENSITIVE!!"), "The `Debug` impl for `CachedSsoToken` isn't properly redacting sensitive fields: {debug_str}");
380    }
381
382    // Valid token with all fields
383    #[test]
384    fn parse_valid_token() {
385        let file_contents = r#"
386        {
387            "startUrl": "https://d-123.awsapps.com/start",
388            "region": "us-west-2",
389            "accessToken": "cachedtoken",
390            "expiresAt": "2021-12-25T21:30:00Z",
391            "clientId": "clientid",
392            "clientSecret": "YSBzZWNyZXQ=",
393            "registrationExpiresAt": "2022-12-25T13:30:00Z",
394            "refreshToken": "cachedrefreshtoken"
395        }
396        "#;
397        let cached = parse_cached_token(file_contents.as_bytes()).expect("success");
398        assert_eq!("cachedtoken", cached.access_token.as_str());
399        assert_eq!(
400            SystemTime::UNIX_EPOCH + Duration::from_secs(1640467800),
401            cached.expires_at
402        );
403        assert_eq!("clientid", cached.client_id.expect("client id is present"));
404        assert_eq!(
405            "YSBzZWNyZXQ=",
406            cached
407                .client_secret
408                .expect("client secret is present")
409                .as_str()
410        );
411        assert_eq!(
412            "cachedrefreshtoken",
413            cached
414                .refresh_token
415                .expect("refresh token is present")
416                .as_str()
417        );
418        assert_eq!(
419            SystemTime::UNIX_EPOCH + Duration::from_secs(1671975000),
420            cached
421                .registration_expires_at
422                .expect("registration expiration is present")
423        );
424        assert_eq!("us-west-2", cached.region.expect("region is present"));
425        assert_eq!(
426            "https://d-123.awsapps.com/start",
427            cached.start_url.expect("startUrl is present")
428        );
429    }
430
431    // Minimal valid cached token
432    #[test]
433    fn parse_valid_token_with_optional_fields_absent() {
434        let file_contents = r#"
435        {
436            "accessToken": "cachedtoken",
437            "expiresAt": "2021-12-25T21:30:00Z"
438        }
439        "#;
440        let cached = parse_cached_token(file_contents.as_bytes()).expect("success");
441        assert_eq!("cachedtoken", cached.access_token.as_str());
442        assert_eq!(
443            SystemTime::UNIX_EPOCH + Duration::from_secs(1640467800),
444            cached.expires_at
445        );
446        assert!(cached.client_id.is_none());
447        assert!(cached.client_secret.is_none());
448        assert!(cached.refresh_token.is_none());
449        assert!(cached.registration_expires_at.is_none());
450    }
451
452    #[test]
453    fn parse_invalid_timestamp() {
454        let token = br#"
455        {
456            "accessToken": "base64string",
457            "expiresAt": "notatimestamp",
458            "region": "us-west-2",
459            "startUrl": "https://d-abc123.awsapps.com/start"
460        }"#;
461        let err = parse_cached_token(token).expect_err("invalid timestamp");
462        let expected = "invalid value for the `expiresAt` field in the cached SSO token file";
463        let actual = format!("{err}");
464        assert!(
465            actual.contains(expected),
466            "expected error to contain `{expected}`, but was `{actual}`",
467        );
468    }
469
470    #[test]
471    fn parse_missing_fields() {
472        // Token missing accessToken field
473        let token = br#"
474        {
475            "expiresAt": "notatimestamp",
476            "region": "us-west-2",
477            "startUrl": "https://d-abc123.awsapps.com/start"
478        }"#;
479        let err = parse_cached_token(token).expect_err("missing akid");
480        assert!(
481            matches!(err, CachedSsoTokenError::MissingField("accessToken")),
482            "incorrect error: {:?}",
483            err
484        );
485
486        // Token missing expiresAt field
487        let token = br#"
488        {
489            "accessToken": "akid",
490            "region": "us-west-2",
491            "startUrl": "https://d-abc123.awsapps.com/start"
492        }"#;
493        let err = parse_cached_token(token).expect_err("missing expiry");
494        assert!(
495            matches!(err, CachedSsoTokenError::MissingField("expiresAt")),
496            "incorrect error: {:?}",
497            err
498        );
499    }
500
501    #[tokio::test]
502    async fn gracefully_handle_missing_files() {
503        let err = load_cached_token(
504            &Env::from_slice(&[("HOME", "/home")]),
505            &Fs::from_slice(&[]),
506            "asdf",
507        )
508        .await
509        .expect_err("should fail, file is missing");
510        assert!(
511            matches!(err, CachedSsoTokenError::IoError { .. }),
512            "should be io error, got {}",
513            err
514        );
515    }
516
517    // TODO(https://github.com/awslabs/aws-sdk-rust/issues/1117) This test is ignored on Windows because it uses Unix-style paths
518    #[cfg_attr(windows, ignore)]
519    #[test]
520    fn determine_correct_cache_filenames() {
521        assert_eq!(
522            "/home/someuser/.aws/sso/cache/d033e22ae348aeb5660fc2140aec35850c4da997.json",
523            cached_token_path("admin", "/home/someuser").as_os_str()
524        );
525        assert_eq!(
526            "/home/someuser/.aws/sso/cache/75e4d41276d8bd17f85986fc6cccef29fd725ce3.json",
527            cached_token_path("dev-scopes", "/home/someuser").as_os_str()
528        );
529        assert_eq!(
530            "/home/me/.aws/sso/cache/13f9d35043871d073ab260e020f0ffde092cb14b.json",
531            cached_token_path("https://d-92671207e4.awsapps.com/start", "/home/me").as_os_str(),
532        );
533        assert_eq!(
534            "/home/me/.aws/sso/cache/13f9d35043871d073ab260e020f0ffde092cb14b.json",
535            cached_token_path("https://d-92671207e4.awsapps.com/start", "/home/me/").as_os_str(),
536        );
537    }
538
539    // TODO(https://github.com/awslabs/aws-sdk-rust/issues/1117) This test is ignored on Windows because it uses Unix-style paths
540    #[cfg_attr(windows, ignore)]
541    #[tokio::test]
542    async fn save_cached_token() {
543        let expires_at = SystemTime::UNIX_EPOCH + Duration::from_secs(50_000_000);
544        let reg_expires_at = SystemTime::UNIX_EPOCH + Duration::from_secs(100_000_000);
545        let token = CachedSsoToken {
546            access_token: Zeroizing::new("access-token".into()),
547            client_id: Some("client-id".into()),
548            client_secret: Some(Zeroizing::new("client-secret".into())),
549            expires_at,
550            refresh_token: Some(Zeroizing::new("refresh-token".into())),
551            region: Some("region".into()),
552            registration_expires_at: Some(reg_expires_at),
553            start_url: Some("start-url".into()),
554        };
555
556        let env = Env::from_slice(&[("HOME", "/home/user")]);
557        let fs = Fs::from_map(HashMap::<_, Vec<u8>>::new());
558        super::save_cached_token(&env, &fs, "test", &token)
559            .await
560            .expect("success");
561
562        let contents = fs
563            .read_to_end("/home/user/.aws/sso/cache/a94a8fe5ccb19ba61c4c0873d391e987982fbbd3.json")
564            .await
565            .expect("correct file written");
566        let contents_str = String::from_utf8(contents).expect("valid utf8");
567        assert_eq!(
568            r#"{"accessToken":"access-token","expiresAt":"1971-08-02T16:53:20Z","refreshToken":"refresh-token","clientId":"client-id","clientSecret":"client-secret","registrationExpiresAt":"1973-03-03T09:46:40Z","region":"region","startUrl":"start-url"}"#,
569            contents_str,
570        );
571    }
572
573    #[tokio::test]
574    async fn round_trip_token() {
575        let expires_at = SystemTime::UNIX_EPOCH + Duration::from_secs(50_000_000);
576        let reg_expires_at = SystemTime::UNIX_EPOCH + Duration::from_secs(100_000_000);
577        let original = CachedSsoToken {
578            access_token: Zeroizing::new("access-token".into()),
579            client_id: Some("client-id".into()),
580            client_secret: Some(Zeroizing::new("client-secret".into())),
581            expires_at,
582            refresh_token: Some(Zeroizing::new("refresh-token".into())),
583            region: Some("region".into()),
584            registration_expires_at: Some(reg_expires_at),
585            start_url: Some("start-url".into()),
586        };
587
588        let env = Env::from_slice(&[("HOME", "/home/user")]);
589        let fs = Fs::from_map(HashMap::<_, Vec<u8>>::new());
590
591        super::save_cached_token(&env, &fs, "test", &original)
592            .await
593            .unwrap();
594
595        let roundtripped = load_cached_token(&env, &fs, "test").await.unwrap();
596        assert_eq!(original, roundtripped)
597    }
598}