1use crate::expiring_cache::ExpiringCache;
7use aws_smithy_async::future::timeout::Timeout;
8use aws_smithy_async::rt::sleep::{AsyncSleep, SharedAsyncSleep};
9use aws_smithy_async::time::{SharedTimeSource, TimeSource};
10use aws_smithy_runtime_api::box_error::BoxError;
11use aws_smithy_runtime_api::client::identity::{
12    Identity, IdentityCachePartition, IdentityFuture, ResolveCachedIdentity, ResolveIdentity,
13    SharedIdentityCache, SharedIdentityResolver,
14};
15use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
16use aws_smithy_runtime_api::shared::IntoShared;
17use aws_smithy_types::config_bag::ConfigBag;
18use aws_smithy_types::DateTime;
19use std::collections::HashMap;
20use std::fmt;
21use std::sync::RwLock;
22use std::time::Duration;
23use tracing::Instrument;
24
25const DEFAULT_LOAD_TIMEOUT: Duration = Duration::from_secs(5);
26const DEFAULT_EXPIRATION: Duration = Duration::from_secs(15 * 60);
27const DEFAULT_BUFFER_TIME: Duration = Duration::from_secs(10);
28const DEFAULT_BUFFER_TIME_JITTER_FRACTION: fn() -> f64 = || fastrand::f64() * 0.5;
29
30#[derive(Default, Debug)]
32pub struct LazyCacheBuilder {
33    time_source: Option<SharedTimeSource>,
34    sleep_impl: Option<SharedAsyncSleep>,
35    load_timeout: Option<Duration>,
36    buffer_time: Option<Duration>,
37    buffer_time_jitter_fraction: Option<fn() -> f64>,
38    default_expiration: Option<Duration>,
39}
40
41impl LazyCacheBuilder {
42    pub fn new() -> Self {
44        Default::default()
45    }
46
47    pub fn time_source(mut self, time_source: impl TimeSource + 'static) -> Self {
49        self.set_time_source(time_source.into_shared());
50        self
51    }
52    pub fn set_time_source(&mut self, time_source: SharedTimeSource) -> &mut Self {
54        self.time_source = Some(time_source.into_shared());
55        self
56    }
57
58    pub fn sleep_impl(mut self, sleep_impl: impl AsyncSleep + 'static) -> Self {
60        self.set_sleep_impl(sleep_impl.into_shared());
61        self
62    }
63    pub fn set_sleep_impl(&mut self, sleep_impl: SharedAsyncSleep) -> &mut Self {
65        self.sleep_impl = Some(sleep_impl);
66        self
67    }
68
69    pub fn load_timeout(mut self, timeout: Duration) -> Self {
73        self.set_load_timeout(Some(timeout));
74        self
75    }
76
77    pub fn set_load_timeout(&mut self, timeout: Option<Duration>) -> &mut Self {
81        self.load_timeout = timeout;
82        self
83    }
84
85    pub fn buffer_time(mut self, buffer_time: Duration) -> Self {
94        self.set_buffer_time(Some(buffer_time));
95        self
96    }
97
98    pub fn set_buffer_time(&mut self, buffer_time: Option<Duration>) -> &mut Self {
107        self.buffer_time = buffer_time;
108        self
109    }
110
111    #[allow(unused)]
119    #[cfg(test)]
120    fn buffer_time_jitter_fraction(mut self, buffer_time_jitter_fraction: fn() -> f64) -> Self {
121        self.set_buffer_time_jitter_fraction(Some(buffer_time_jitter_fraction));
122        self
123    }
124
125    #[allow(unused)]
133    #[cfg(test)]
134    fn set_buffer_time_jitter_fraction(
135        &mut self,
136        buffer_time_jitter_fraction: Option<fn() -> f64>,
137    ) -> &mut Self {
138        self.buffer_time_jitter_fraction = buffer_time_jitter_fraction;
139        self
140    }
141
142    pub fn default_expiration(mut self, duration: Duration) -> Self {
149        self.set_default_expiration(Some(duration));
150        self
151    }
152
153    pub fn set_default_expiration(&mut self, duration: Option<Duration>) -> &mut Self {
160        self.default_expiration = duration;
161        self
162    }
163
164    pub fn build(self) -> SharedIdentityCache {
170        let default_expiration = self.default_expiration.unwrap_or(DEFAULT_EXPIRATION);
171        assert!(
172            default_expiration >= DEFAULT_EXPIRATION,
173            "default_expiration must be at least 15 minutes"
174        );
175        LazyCache::new(
176            self.load_timeout.unwrap_or(DEFAULT_LOAD_TIMEOUT),
177            self.buffer_time.unwrap_or(DEFAULT_BUFFER_TIME),
178            self.buffer_time_jitter_fraction
179                .unwrap_or(DEFAULT_BUFFER_TIME_JITTER_FRACTION),
180            default_expiration,
181        )
182        .into_shared()
183    }
184}
185
186#[derive(Debug)]
187struct CachePartitions {
188    partitions: RwLock<HashMap<IdentityCachePartition, ExpiringCache<Identity, BoxError>>>,
189    buffer_time: Duration,
190}
191
192impl CachePartitions {
193    fn new(buffer_time: Duration) -> Self {
194        Self {
195            partitions: RwLock::new(HashMap::new()),
196            buffer_time,
197        }
198    }
199
200    fn partition(&self, key: IdentityCachePartition) -> ExpiringCache<Identity, BoxError> {
201        let mut partition = self.partitions.read().unwrap().get(&key).cloned();
202        if partition.is_none() {
205            let mut partitions = self.partitions.write().unwrap();
206            partitions
209                .entry(key)
210                .or_insert_with(|| ExpiringCache::new(self.buffer_time));
211            drop(partitions);
212
213            partition = self.partitions.read().unwrap().get(&key).cloned();
214        }
215        partition.expect("inserted above if not present")
216    }
217}
218
219#[derive(Debug)]
220struct LazyCache {
221    partitions: CachePartitions,
222    load_timeout: Duration,
223    buffer_time: Duration,
224    buffer_time_jitter_fraction: fn() -> f64,
225    default_expiration: Duration,
226}
227
228impl LazyCache {
229    fn new(
230        load_timeout: Duration,
231        buffer_time: Duration,
232        buffer_time_jitter_fraction: fn() -> f64,
233        default_expiration: Duration,
234    ) -> Self {
235        Self {
236            partitions: CachePartitions::new(buffer_time),
237            load_timeout,
238            buffer_time,
239            buffer_time_jitter_fraction,
240            default_expiration,
241        }
242    }
243}
244
245macro_rules! required_err {
246    ($thing:literal, $how:literal) => {
247        BoxError::from(concat!(
248            "Lazy identity caching requires ",
249            $thing,
250            " to be configured. ",
251            $how,
252            " If this isn't possible, then disable identity caching by calling ",
253            "the `identity_cache` method on config with `IdentityCache::no_cache()`",
254        ))
255    };
256}
257macro_rules! validate_components {
258    ($components:ident) => {
259        let _ = $components.time_source().ok_or_else(|| {
260            required_err!(
261                "a time source",
262                "Set a time source using the `time_source` method on config."
263            )
264        })?;
265        let _ = $components.sleep_impl().ok_or_else(|| {
266            required_err!(
267                "an async sleep implementation",
268                "Set a sleep impl using the `sleep_impl` method on config."
269            )
270        })?;
271    };
272}
273
274impl ResolveCachedIdentity for LazyCache {
275    fn validate_base_client_config(
276        &self,
277        runtime_components: &aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder,
278        _cfg: &ConfigBag,
279    ) -> Result<(), BoxError> {
280        validate_components!(runtime_components);
281        Ok(())
282    }
283
284    fn validate_final_config(
285        &self,
286        runtime_components: &RuntimeComponents,
287        _cfg: &ConfigBag,
288    ) -> Result<(), BoxError> {
289        validate_components!(runtime_components);
290        Ok(())
291    }
292
293    fn resolve_cached_identity<'a>(
294        &'a self,
295        resolver: SharedIdentityResolver,
296        runtime_components: &'a RuntimeComponents,
297        config_bag: &'a ConfigBag,
298    ) -> IdentityFuture<'a> {
299        let (time_source, sleep_impl) = (
300            runtime_components.time_source().expect("validated"),
301            runtime_components.sleep_impl().expect("validated"),
302        );
303
304        let now = time_source.now();
305        let timeout_future = sleep_impl.sleep(self.load_timeout);
306        let load_timeout = self.load_timeout;
307        let partition = resolver.cache_partition();
308        let cache = self.partitions.partition(partition);
309        let default_expiration = self.default_expiration;
310
311        IdentityFuture::new(async move {
312            if let Some(identity) = cache.yield_or_clear_if_expired(now).await {
314                tracing::debug!(
315                    buffer_time=?self.buffer_time,
316                    cached_expiration=?identity.expiration(),
317                    now=?now,
318                    "loaded identity from cache"
319                );
320                Ok(identity)
321            } else {
322                let start_time = time_source.now();
327                let result = cache
328                    .get_or_load(|| {
329                        let span = tracing::debug_span!("lazy_load_identity");
330                        async move {
331                            let fut = Timeout::new(
332                                resolver.resolve_identity(runtime_components, config_bag),
333                                timeout_future,
334                            );
335                            let identity = match fut.await {
336                                Ok(result) => result?,
337                                Err(_err) => match resolver.fallback_on_interrupt() {
338                                    Some(identity) => identity,
339                                    None => {
340                                        return Err(BoxError::from(TimedOutError(load_timeout)))
341                                    }
342                                },
343                            };
344                            let expiration =
346                                identity.expiration().unwrap_or(now + default_expiration);
347
348                            let jitter = self
349                                .buffer_time
350                                .mul_f64((self.buffer_time_jitter_fraction)());
351
352                            let printable = DateTime::from(expiration);
357                            tracing::debug!(
358                                new_expiration=%printable,
359                                valid_for=?expiration.duration_since(time_source.now()).unwrap_or_default(),
360                                partition=?partition,
361                                "identity cache miss occurred; added new identity (took {:?})",
362                                time_source.now().duration_since(start_time).unwrap_or_default()
363                            );
364
365                            Ok((identity, expiration + jitter))
366                        }
367                        .instrument(span)
370                    })
371                    .await;
372                tracing::debug!("loaded identity");
373                result
374            }
375        })
376    }
377}
378
379#[derive(Debug)]
380struct TimedOutError(Duration);
381
382impl std::error::Error for TimedOutError {}
383
384impl fmt::Display for TimedOutError {
385    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
386        write!(f, "identity resolver timed out after {:?}", self.0)
387    }
388}
389
390#[cfg(all(test, feature = "client", feature = "http-auth"))]
391mod tests {
392    use super::*;
393    use aws_smithy_async::rt::sleep::TokioSleep;
394    use aws_smithy_async::test_util::{instant_time_and_sleep, ManualTimeSource};
395    use aws_smithy_async::time::TimeSource;
396    use aws_smithy_runtime_api::client::identity::http::Token;
397    use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
398    use std::sync::atomic::{AtomicUsize, Ordering};
399    use std::sync::{Arc, Mutex};
400    use std::time::{Duration, SystemTime, UNIX_EPOCH};
401    use tracing::info;
402
403    const BUFFER_TIME_NO_JITTER: fn() -> f64 = || 0_f64;
404
405    struct ResolverFn<F>(F);
406    impl<F> fmt::Debug for ResolverFn<F> {
407        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
408            f.write_str("ResolverFn")
409        }
410    }
411    impl<F> ResolveIdentity for ResolverFn<F>
412    where
413        F: Fn() -> IdentityFuture<'static> + Send + Sync,
414    {
415        fn resolve_identity<'a>(
416            &'a self,
417            _: &'a RuntimeComponents,
418            _config_bag: &'a ConfigBag,
419        ) -> IdentityFuture<'a> {
420            (self.0)()
421        }
422    }
423
424    fn resolver_fn<F>(f: F) -> SharedIdentityResolver
425    where
426        F: Fn() -> IdentityFuture<'static> + Send + Sync + 'static,
427    {
428        SharedIdentityResolver::new(ResolverFn(f))
429    }
430
431    fn test_cache(
432        buffer_time_jitter_fraction: fn() -> f64,
433        load_list: Vec<Result<Identity, BoxError>>,
434    ) -> (LazyCache, SharedIdentityResolver) {
435        #[derive(Debug)]
436        struct Resolver(Mutex<Vec<Result<Identity, BoxError>>>);
437        impl ResolveIdentity for Resolver {
438            fn resolve_identity<'a>(
439                &'a self,
440                _: &'a RuntimeComponents,
441                _config_bag: &'a ConfigBag,
442            ) -> IdentityFuture<'a> {
443                let mut list = self.0.lock().unwrap();
444                if list.len() > 0 {
445                    let next = list.remove(0);
446                    info!("refreshing the identity to {:?}", next);
447                    IdentityFuture::ready(next)
448                } else {
449                    drop(list);
450                    panic!("no more identities")
451                }
452            }
453        }
454
455        let identity_resolver = SharedIdentityResolver::new(Resolver(Mutex::new(load_list)));
456        let cache = LazyCache::new(
457            DEFAULT_LOAD_TIMEOUT,
458            DEFAULT_BUFFER_TIME,
459            buffer_time_jitter_fraction,
460            DEFAULT_EXPIRATION,
461        );
462        (cache, identity_resolver)
463    }
464
465    fn epoch_secs(secs: u64) -> SystemTime {
466        SystemTime::UNIX_EPOCH + Duration::from_secs(secs)
467    }
468
469    fn test_identity(expired_secs: u64) -> Identity {
470        let expiration = Some(epoch_secs(expired_secs));
471        Identity::new(Token::new("test", expiration), expiration)
472    }
473
474    async fn expect_identity(
475        expired_secs: u64,
476        cache: &LazyCache,
477        components: &RuntimeComponents,
478        resolver: SharedIdentityResolver,
479    ) {
480        let config_bag = ConfigBag::base();
481        let identity = cache
482            .resolve_cached_identity(resolver, components, &config_bag)
483            .await
484            .expect("expected identity");
485        assert_eq!(Some(epoch_secs(expired_secs)), identity.expiration());
486    }
487
488    #[tokio::test]
489    async fn initial_populate_test_identity() {
490        let time = ManualTimeSource::new(UNIX_EPOCH);
491        let components = RuntimeComponentsBuilder::for_tests()
492            .with_time_source(Some(time.clone()))
493            .with_sleep_impl(Some(TokioSleep::new()))
494            .build()
495            .unwrap();
496        let config_bag = ConfigBag::base();
497        let resolver = SharedIdentityResolver::new(resolver_fn(|| {
498            info!("refreshing the test_identity");
499            IdentityFuture::ready(Ok(test_identity(1000)))
500        }));
501        let cache = LazyCache::new(
502            DEFAULT_LOAD_TIMEOUT,
503            DEFAULT_BUFFER_TIME,
504            BUFFER_TIME_NO_JITTER,
505            DEFAULT_EXPIRATION,
506        );
507        assert_eq!(
508            epoch_secs(1000),
509            cache
510                .resolve_cached_identity(resolver, &components, &config_bag)
511                .await
512                .unwrap()
513                .expiration()
514                .unwrap()
515        );
516    }
517
518    #[tokio::test]
519    async fn reload_expired_test_identity() {
520        let time = ManualTimeSource::new(epoch_secs(100));
521        let components = RuntimeComponentsBuilder::for_tests()
522            .with_time_source(Some(time.clone()))
523            .with_sleep_impl(Some(TokioSleep::new()))
524            .build()
525            .unwrap();
526        let (cache, resolver) = test_cache(
527            BUFFER_TIME_NO_JITTER,
528            vec![
529                Ok(test_identity(1000)),
530                Ok(test_identity(2000)),
531                Ok(test_identity(3000)),
532            ],
533        );
534
535        expect_identity(1000, &cache, &components, resolver.clone()).await;
536        expect_identity(1000, &cache, &components, resolver.clone()).await;
537        time.set_time(epoch_secs(1500));
538        expect_identity(2000, &cache, &components, resolver.clone()).await;
539        expect_identity(2000, &cache, &components, resolver.clone()).await;
540        time.set_time(epoch_secs(2500));
541        expect_identity(3000, &cache, &components, resolver.clone()).await;
542        expect_identity(3000, &cache, &components, resolver.clone()).await;
543    }
544
545    #[tokio::test]
546    async fn load_failed_error() {
547        let config_bag = ConfigBag::base();
548        let time = ManualTimeSource::new(epoch_secs(100));
549        let components = RuntimeComponentsBuilder::for_tests()
550            .with_time_source(Some(time.clone()))
551            .with_sleep_impl(Some(TokioSleep::new()))
552            .build()
553            .unwrap();
554        let (cache, resolver) = test_cache(
555            BUFFER_TIME_NO_JITTER,
556            vec![Ok(test_identity(1000)), Err("failed".into())],
557        );
558
559        expect_identity(1000, &cache, &components, resolver.clone()).await;
560        time.set_time(epoch_secs(1500));
561        assert!(cache
562            .resolve_cached_identity(resolver.clone(), &components, &config_bag)
563            .await
564            .is_err());
565    }
566
567    #[test]
568    fn load_contention() {
569        let rt = tokio::runtime::Builder::new_multi_thread()
570            .enable_time()
571            .worker_threads(16)
572            .build()
573            .unwrap();
574
575        let time = ManualTimeSource::new(epoch_secs(0));
576        let components = RuntimeComponentsBuilder::for_tests()
577            .with_time_source(Some(time.clone()))
578            .with_sleep_impl(Some(TokioSleep::new()))
579            .build()
580            .unwrap();
581        let (cache, resolver) = test_cache(
582            BUFFER_TIME_NO_JITTER,
583            vec![
584                Ok(test_identity(500)),
585                Ok(test_identity(1500)),
586                Ok(test_identity(2500)),
587                Ok(test_identity(3500)),
588                Ok(test_identity(4500)),
589            ],
590        );
591        let cache: SharedIdentityCache = cache.into_shared();
592
593        for _ in 0..4 {
596            let mut tasks = Vec::new();
597            for _ in 0..50 {
598                let resolver = resolver.clone();
599                let cache = cache.clone();
600                let time = time.clone();
601                let components = components.clone();
602                tasks.push(rt.spawn(async move {
603                    let now = time.advance(Duration::from_secs(22));
604
605                    let config_bag = ConfigBag::base();
606                    let identity = cache
607                        .resolve_cached_identity(resolver, &components, &config_bag)
608                        .await
609                        .unwrap();
610                    assert!(
611                        identity.expiration().unwrap() >= now,
612                        "{:?} >= {:?}",
613                        identity.expiration(),
614                        now
615                    );
616                }));
617            }
618            for task in tasks {
619                rt.block_on(task).unwrap();
620            }
621        }
622    }
623
624    #[tokio::test]
625    async fn load_timeout() {
626        let config_bag = ConfigBag::base();
627        let (time, sleep) = instant_time_and_sleep(epoch_secs(100));
628        let components = RuntimeComponentsBuilder::for_tests()
629            .with_time_source(Some(time.clone()))
630            .with_sleep_impl(Some(sleep))
631            .build()
632            .unwrap();
633        let resolver = SharedIdentityResolver::new(resolver_fn(|| {
634            IdentityFuture::new(async {
635                aws_smithy_async::future::never::Never::new().await;
636                Ok(test_identity(1000))
637            })
638        }));
639        let cache = LazyCache::new(
640            Duration::from_secs(5),
641            DEFAULT_BUFFER_TIME,
642            BUFFER_TIME_NO_JITTER,
643            DEFAULT_EXPIRATION,
644        );
645
646        let err: BoxError = cache
647            .resolve_cached_identity(resolver, &components, &config_bag)
648            .await
649            .expect_err("it should return an error");
650        let downcasted = err.downcast_ref::<TimedOutError>();
651        assert!(
652            downcasted.is_some(),
653            "expected a BoxError of TimedOutError, but was {err:?}"
654        );
655        assert_eq!(time.now(), epoch_secs(105));
656    }
657
658    #[tokio::test]
659    async fn buffer_time_jitter() {
660        let time = ManualTimeSource::new(epoch_secs(100));
661        let components = RuntimeComponentsBuilder::for_tests()
662            .with_time_source(Some(time.clone()))
663            .with_sleep_impl(Some(TokioSleep::new()))
664            .build()
665            .unwrap();
666        let buffer_time_jitter_fraction = || 0.5_f64;
667        let (cache, resolver) = test_cache(
668            buffer_time_jitter_fraction,
669            vec![Ok(test_identity(1000)), Ok(test_identity(2000))],
670        );
671
672        expect_identity(1000, &cache, &components, resolver.clone()).await;
673        let buffer_time_with_jitter =
674            (DEFAULT_BUFFER_TIME.as_secs_f64() * buffer_time_jitter_fraction()) as u64;
675        assert_eq!(buffer_time_with_jitter, 5);
676        let almost_expired_secs = 1000 - buffer_time_with_jitter - 1;
678        time.set_time(epoch_secs(almost_expired_secs));
679        expect_identity(1000, &cache, &components, resolver.clone()).await;
681        let expired_secs = almost_expired_secs + 1;
683        time.set_time(epoch_secs(expired_secs));
684        expect_identity(2000, &cache, &components, resolver.clone()).await;
686    }
687
688    #[tokio::test]
689    async fn cache_partitioning() {
690        let time = ManualTimeSource::new(epoch_secs(0));
691        let components = RuntimeComponentsBuilder::for_tests()
692            .with_time_source(Some(time.clone()))
693            .with_sleep_impl(Some(TokioSleep::new()))
694            .build()
695            .unwrap();
696        let (cache, _) = test_cache(BUFFER_TIME_NO_JITTER, Vec::new());
697
698        #[allow(clippy::disallowed_methods)]
699        let far_future = SystemTime::now() + Duration::from_secs(10_000);
700
701        let resolver_a_calls = Arc::new(AtomicUsize::new(0));
704        let resolver_b_calls = Arc::new(AtomicUsize::new(0));
705        let resolver_a = resolver_fn({
706            let calls = resolver_a_calls.clone();
707            move || {
708                calls.fetch_add(1, Ordering::Relaxed);
709                IdentityFuture::ready(Ok(Identity::new(
710                    Token::new("A", Some(far_future)),
711                    Some(far_future),
712                )))
713            }
714        });
715        let resolver_b = resolver_fn({
716            let calls = resolver_b_calls.clone();
717            move || {
718                calls.fetch_add(1, Ordering::Relaxed);
719                IdentityFuture::ready(Ok(Identity::new(
720                    Token::new("B", Some(far_future)),
721                    Some(far_future),
722                )))
723            }
724        });
725        assert_ne!(
726            resolver_a.cache_partition(),
727            resolver_b.cache_partition(),
728            "pre-condition: they should have different partition keys"
729        );
730
731        let config_bag = ConfigBag::base();
732
733        let identity = cache
736            .resolve_cached_identity(resolver_a.clone(), &components, &config_bag)
737            .await
738            .unwrap();
739        assert_eq!("A", identity.data::<Token>().unwrap().token());
740        let identity = cache
741            .resolve_cached_identity(resolver_a.clone(), &components, &config_bag)
742            .await
743            .unwrap();
744        assert_eq!("A", identity.data::<Token>().unwrap().token());
745        assert_eq!(1, resolver_a_calls.load(Ordering::Relaxed));
746
747        let identity = cache
750            .resolve_cached_identity(resolver_b.clone(), &components, &config_bag)
751            .await
752            .unwrap();
753        assert_eq!("B", identity.data::<Token>().unwrap().token());
754        let identity = cache
755            .resolve_cached_identity(resolver_b.clone(), &components, &config_bag)
756            .await
757            .unwrap();
758        assert_eq!("B", identity.data::<Token>().unwrap().token());
759        assert_eq!(1, resolver_a_calls.load(Ordering::Relaxed));
760        assert_eq!(1, resolver_b_calls.load(Ordering::Relaxed));
761
762        let identity = cache
764            .resolve_cached_identity(resolver_a.clone(), &components, &config_bag)
765            .await
766            .unwrap();
767        assert_eq!("A", identity.data::<Token>().unwrap().token());
768        assert_eq!(1, resolver_a_calls.load(Ordering::Relaxed));
769        assert_eq!(1, resolver_b_calls.load(Ordering::Relaxed));
770    }
771}