reqsign/aws/
credential.rs

1use std::fmt::Debug;
2use std::fmt::Write;
3use std::fs;
4use std::sync::Arc;
5
6use anyhow::anyhow;
7use anyhow::Result;
8use async_trait::async_trait;
9use http::header::CONTENT_LENGTH;
10use log::debug;
11use quick_xml::de;
12use reqwest::Client;
13use serde::Deserialize;
14
15use super::config::Config;
16use super::constants::X_AMZ_CONTENT_SHA_256;
17use super::v4::Signer;
18use crate::time::now;
19use crate::time::parse_rfc3339;
20use crate::time::DateTime;
21use tokio::sync::Mutex;
22
23pub const EMPTY_STRING_SHA256: &str =
24    "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
25
26/// Credential that holds the `access_key` and `secret_key`.
27#[derive(Default, Clone)]
28#[cfg_attr(test, derive(Debug))]
29pub struct Credential {
30    /// Access key id for aws services.
31    pub access_key_id: String,
32    /// Secret access key for aws services.
33    pub secret_access_key: String,
34    /// Session token for aws services.
35    pub session_token: Option<String>,
36    /// Expiration time for this credential.
37    pub expires_in: Option<DateTime>,
38}
39
40impl Credential {
41    /// is current cred is valid?
42    ///
43    /// # Panics
44    ///
45    /// Panics if the time delta calculation overflows (which should not happen in practice).
46    #[must_use]
47    pub fn is_valid(&self) -> bool {
48        if (self.access_key_id.is_empty() || self.secret_access_key.is_empty())
49            && self.session_token.is_none()
50        {
51            return false;
52        }
53        // Take 120s as buffer to avoid edge cases.
54        if let Some(valid) = self
55            .expires_in
56            .map(|v| v > now() + chrono::TimeDelta::try_minutes(2).expect("in bounds"))
57        {
58            return valid;
59        }
60
61        true
62    }
63}
64
65/// Loader trait will try to load credential from different sources.
66#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
67#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
68pub trait CredentialLoad: 'static + Send + Sync {
69    /// Load credential from sources.
70    ///
71    /// - If succeed, return `Ok(Some(cred))`
72    /// - If not found, return `Ok(None)`
73    /// - If unexpected errors happened, return `Err(err)`
74    async fn load_credential(&self, client: Client) -> Result<Option<Credential>>;
75}
76
77/// `CredentialLoader` will load credential from different methods.
78pub struct DefaultLoader {
79    client: Client,
80    config: Config,
81    credential: Arc<Mutex<Option<Credential>>>,
82    imds_v2_loader: Option<IMDSv2Loader>,
83}
84
85impl DefaultLoader {
86    /// Create a new `CredentialLoader`
87    #[must_use]
88    pub fn new(client: Client, config: Config) -> Self {
89        let imds_v2_loader = if config.ec2_metadata_disabled {
90            None
91        } else {
92            Some(IMDSv2Loader::new(client.clone()))
93        };
94        Self {
95            client,
96            config,
97            credential: Arc::default(),
98            imds_v2_loader,
99        }
100    }
101
102    /// Disable load from ec2 metadata.
103    #[must_use]
104    pub fn with_disable_ec2_metadata(mut self) -> Self {
105        self.imds_v2_loader = None;
106        self
107    }
108
109    /// Load credential.
110    ///
111    /// Resolution order:
112    /// 1. Environment variables
113    /// 2. Shared config (`~/.aws/config`, `~/.aws/credentials`)
114    /// 3. Web Identity Tokens
115    /// 4. ECS (IAM Roles for Tasks) & General HTTP credentials:
116    /// 5. EC2 `IMDSv2`
117    ///
118    /// # Errors
119    ///
120    /// Will return an error if credential loading fails from any source.
121    pub async fn load(&self) -> Result<Option<Credential>> {
122        let mut lock = self.credential.lock().await;
123
124        // Return cached credential if it has been loaded and is still valid
125        if let Some(ref cred) = *lock {
126            if cred.is_valid() {
127                return Ok(Some(cred.clone()));
128            }
129        }
130
131        // Load new credential while holding the lock
132        // This ensures only one thread refreshes at a time
133        let new_cred = self.load_inner().await?;
134        lock.clone_from(&new_cred);
135
136        Ok(new_cred)
137    }
138
139    async fn load_inner(&self) -> Result<Option<Credential>> {
140        if let Some(cred) = self.load_via_config() {
141            return Ok(Some(cred));
142        }
143
144        if let Some(cred) = self
145            .load_via_assume_role_with_web_identity()
146            .await
147            .map_err(|err| {
148                debug!("load credential via assume_role_with_web_identity failed: {err:?}");
149                err
150            })?
151        {
152            return Ok(Some(cred));
153        }
154
155        if let Some(cred) = self.load_via_imds_v2().await.map_err(|err| {
156            debug!("load credential via imds_v2 failed: {err:?}");
157            err
158        })? {
159            return Ok(Some(cred));
160        }
161
162        Ok(None)
163    }
164
165    fn load_via_config(&self) -> Option<Credential> {
166        if let (Some(ak), Some(sk)) = (&self.config.access_key_id, &self.config.secret_access_key) {
167            Some(Credential {
168                access_key_id: ak.clone(),
169                secret_access_key: sk.clone(),
170                session_token: self.config.session_token.clone(),
171                // Set expires_in to 10 minutes to enforce re-read
172                // from file.
173                expires_in: Some(now() + chrono::TimeDelta::try_minutes(10).expect("in bounds")),
174            })
175        } else {
176            None
177        }
178    }
179
180    async fn load_via_imds_v2(&self) -> Result<Option<Credential>> {
181        let Some(loader) = &self.imds_v2_loader else {
182            return Ok(None);
183        };
184
185        loader.load().await
186    }
187
188    async fn load_via_assume_role_with_web_identity(&self) -> Result<Option<Credential>> {
189        let (Some(token_file), Some(role_arn)) =
190            (&self.config.web_identity_token_file, &self.config.role_arn)
191        else {
192            return Ok(None);
193        };
194
195        let token = fs::read_to_string(token_file)?;
196        let role_session_name = &self.config.role_session_name;
197
198        let endpoint = self.sts_endpoint()?;
199
200        // Construct request to AWS STS Service.
201        let url = format!("https://{endpoint}/?Action=AssumeRoleWithWebIdentity&RoleArn={role_arn}&WebIdentityToken={token}&Version=2011-06-15&RoleSessionName={role_session_name}");
202        let req = self.client.get(&url).header(
203            http::header::CONTENT_TYPE.as_str(),
204            "application/x-www-form-urlencoded",
205        );
206
207        let resp = req.send().await?;
208        if resp.status() != http::StatusCode::OK {
209            let content = resp.text().await?;
210            return Err(anyhow!("request to AWS STS Services failed: {content}"));
211        }
212
213        let resp: AssumeRoleWithWebIdentityResponse = de::from_str(&resp.text().await?)?;
214        let resp_cred = resp.result.credentials;
215
216        let cred = Credential {
217            access_key_id: resp_cred.access_key_id,
218            secret_access_key: resp_cred.secret_access_key,
219            session_token: Some(resp_cred.session_token),
220            expires_in: Some(parse_rfc3339(&resp_cred.expiration)?),
221        };
222
223        Ok(Some(cred))
224    }
225
226    /// Get the sts endpoint.
227    ///
228    /// The returning format may look like `sts.{region}.amazonaws.com`
229    ///
230    /// # Notes
231    ///
232    /// AWS could have different sts endpoint based on it's region.
233    /// We can check them by region name.
234    ///
235    /// ref: <https://github.com/awslabs/aws-sdk-rust/blob/31cfae2cf23be0c68a47357070dea1aee9227e3a/sdk/sts/src/aws_endpoint.rs>
236    fn sts_endpoint(&self) -> Result<String> {
237        // use regional sts if sts_regional_endpoints has been set.
238        if self.config.sts_regional_endpoints == "regional" {
239            let region = self.config.region.clone().ok_or_else(|| {
240                anyhow!("sts_regional_endpoints set to reginal, but region is not set")
241            })?;
242            if region.starts_with("cn-") {
243                Ok(format!("sts.{region}.amazonaws.com.cn"))
244            } else {
245                Ok(format!("sts.{region}.amazonaws.com"))
246            }
247        } else {
248            let region = self.config.region.clone().unwrap_or_default();
249            if region.starts_with("cn") {
250                // TODO: seems aws china doesn't support global sts?
251                Ok("sts.amazonaws.com.cn".to_string())
252            } else {
253                Ok("sts.amazonaws.com".to_string())
254            }
255        }
256    }
257}
258
259#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
260#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
261impl CredentialLoad for DefaultLoader {
262    async fn load_credential(&self, _: Client) -> Result<Option<Credential>> {
263        self.load().await
264    }
265}
266
267pub struct IMDSv2Loader {
268    client: Client,
269
270    token: Arc<Mutex<(String, DateTime)>>,
271}
272
273impl IMDSv2Loader {
274    /// Create a new `IMDSv2Loader`.
275    #[must_use]
276    pub fn new(client: Client) -> Self {
277        Self {
278            client,
279            token: Arc::new(Mutex::new((String::new().to_string(), DateTime::MIN_UTC))),
280        }
281    }
282
283    pub async fn load(&self) -> Result<Option<Credential>> {
284        let token = self.load_ec2_metadata_token().await?;
285
286        // List all credentials that node has.
287        let url = "http://169.254.169.254/latest/meta-data/iam/security-credentials/";
288        let req = self
289            .client
290            .get(url)
291            .header("x-aws-ec2-metadata-token", &token);
292        let resp = req.send().await?;
293        if resp.status() != http::StatusCode::OK {
294            let content = resp.text().await?;
295            return Err(anyhow!(
296                "request to AWS EC2 Metadata Services failed: {content}"
297            ));
298        }
299        let profile_name = resp.text().await?;
300
301        // Get the credentials via role_name.
302        let url = format!(
303            "http://169.254.169.254/latest/meta-data/iam/security-credentials/{profile_name}"
304        );
305        let req = self
306            .client
307            .get(&url)
308            .header("x-aws-ec2-metadata-token", &token);
309        let resp = req.send().await?;
310        if resp.status() != http::StatusCode::OK {
311            let content = resp.text().await?;
312            return Err(anyhow!(
313                "request to AWS EC2 Metadata Services failed: {content}"
314            ));
315        }
316
317        let content = resp.text().await?;
318        let resp: Ec2MetadataIamSecurityCredentials = serde_json::from_str(&content)?;
319        if resp.code != "Success" {
320            return Err(anyhow!(
321                "request to AWS EC2 Metadata Services failed: {content}"
322            ));
323        }
324
325        let cred = Credential {
326            access_key_id: resp.access_key_id,
327            secret_access_key: resp.secret_access_key,
328            session_token: Some(resp.token),
329            expires_in: Some(parse_rfc3339(&resp.expiration)?),
330        };
331
332        Ok(Some(cred))
333    }
334
335    /// `load_ec2_metadata_token` will load ec2 metadata token from IMDS.
336    ///
337    /// Return value is (token, `expires_in`).
338    async fn load_ec2_metadata_token(&self) -> Result<String> {
339        let mut lock = self.token.lock().await;
340        let (ref token, expires_in) = &*lock;
341
342        // Return cached token if still valid
343        if expires_in > &now() {
344            return Ok(token.clone());
345        }
346
347        // Refresh token while holding the lock
348        // This ensures only one thread refreshes at a time
349        let url = "http://169.254.169.254/latest/api/token";
350        #[allow(unused_mut)]
351        let mut req = self
352            .client
353            .put(url)
354            .header(CONTENT_LENGTH, "0")
355            // 21600s (6h) is recommended by AWS.
356            .header("x-aws-ec2-metadata-token-ttl-seconds", "21600");
357
358        // Set timeout to 1s to avoid hanging on non-s3 env.
359        #[cfg(not(target_arch = "wasm32"))]
360        {
361            req = req.timeout(std::time::Duration::from_secs(1));
362        }
363
364        let resp = req.send().await?;
365        if resp.status() != http::StatusCode::OK {
366            let content = resp.text().await?;
367            return Err(anyhow!(
368                "request to AWS EC2 Metadata Services failed: {content}"
369            ));
370        }
371        let ec2_token = resp.text().await?;
372        // Set expires_in to 10 minutes to enforce re-read.
373        let expires_in = now() + chrono::TimeDelta::try_seconds(21600).expect("in bounds")
374            - chrono::TimeDelta::try_seconds(600).expect("in bounds");
375
376        *lock = (ec2_token.clone(), expires_in);
377
378        Ok(ec2_token)
379    }
380}
381
382#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
383#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
384impl CredentialLoad for IMDSv2Loader {
385    async fn load_credential(&self, _: Client) -> Result<Option<Credential>> {
386        self.load().await
387    }
388}
389
390/// `AssumeRoleLoader` will load credential via assume role.
391pub struct AssumeRoleLoader {
392    client: Client,
393    config: Config,
394
395    source_credential: Box<dyn CredentialLoad>,
396    sts_signer: Signer,
397    credential: Arc<Mutex<Option<Credential>>>,
398}
399
400impl AssumeRoleLoader {
401    /// Create a new assume role loader.
402    ///
403    /// # Errors
404    ///
405    /// Returns an error if the region is not configured.
406    pub fn new(
407        client: Client,
408        config: Config,
409        source_credential: Box<dyn CredentialLoad>,
410    ) -> Result<Self> {
411        let region = config.region.clone().ok_or_else(|| {
412            anyhow!("assume role loader requires region, but not found, please check your configuration")
413        })?;
414
415        Ok(Self {
416            client,
417            config,
418            source_credential,
419
420            sts_signer: Signer::new("sts", &region),
421            credential: Arc::default(),
422        })
423    }
424
425    /// Load credential via assume role.
426    ///
427    /// # Errors
428    ///
429    /// Returns an error if `role_arn` is not configured or if the STS request fails.
430    pub async fn load(&self) -> Result<Option<Credential>> {
431        let mut lock = self.credential.lock().await;
432
433        // Return cached credential if it has been loaded and is still valid
434        if let Some(ref cred) = *lock {
435            if cred.is_valid() {
436                return Ok(Some(cred.clone()));
437            }
438        }
439
440        // Load new credential while holding the lock
441        // This ensures only one thread refreshes at a time
442        let new_cred = self.load_inner().await?;
443        lock.clone_from(&new_cred);
444
445        Ok(new_cred)
446    }
447
448    async fn load_inner(&self) -> Result<Option<Credential>> {
449        let role_arn =self.config.role_arn.clone().ok_or_else(|| {
450            anyhow!("assume role loader requires role_arn, but not found, please check your configuration")
451        })?;
452
453        let role_session_name = &self.config.role_session_name;
454
455        let endpoint = self.sts_endpoint()?;
456
457        // Construct request to AWS STS Service.
458        let mut url = format!("https://{endpoint}/?Action=AssumeRole&RoleArn={role_arn}&Version=2011-06-15&RoleSessionName={role_session_name}");
459        if let Some(external_id) = &self.config.external_id {
460            write!(url, "&ExternalId={external_id}")?;
461        }
462        if let Some(duration_seconds) = &self.config.duration_seconds {
463            write!(url, "&DurationSeconds={duration_seconds}")?;
464        }
465        if let Some(tags) = &self.config.tags {
466            for (idx, (key, value)) in tags.iter().enumerate() {
467                let tag_index = idx + 1;
468                write!(
469                    url,
470                    "&Tags.member.{tag_index}.Key={key}&Tags.member.{tag_index}.Value={value}"
471                )?;
472            }
473        }
474
475        let mut req = self
476            .client
477            .get(&url)
478            .header(
479                http::header::CONTENT_TYPE.as_str(),
480                "application/x-www-form-urlencoded",
481            )
482            // Set content sha to empty string.
483            .header(X_AMZ_CONTENT_SHA_256, EMPTY_STRING_SHA256)
484            .build()?;
485
486        let source_cred = self
487            .source_credential
488            .load_credential(self.client.clone())
489            .await?
490            .ok_or_else(|| {
491                anyhow!("source credential is required for AssumeRole, but not found, please check your configuration")
492            })?;
493
494        self.sts_signer.sign(&mut req, &source_cred)?;
495
496        let resp = self.client.execute(req).await?;
497        if resp.status() != http::StatusCode::OK {
498            let content = resp.text().await?;
499            return Err(anyhow!("request to AWS STS Services failed: {content}"));
500        }
501
502        let resp: AssumeRoleResponse = de::from_str(&resp.text().await?)?;
503        let resp_cred = resp.result.credentials;
504
505        let cred = Credential {
506            access_key_id: resp_cred.access_key_id,
507            secret_access_key: resp_cred.secret_access_key,
508            session_token: Some(resp_cred.session_token),
509            expires_in: Some(parse_rfc3339(&resp_cred.expiration)?),
510        };
511
512        Ok(Some(cred))
513    }
514
515    /// Get the sts endpoint.
516    ///
517    /// The returning format may look like `sts.{region}.amazonaws.com`
518    ///
519    /// # Notes
520    ///
521    /// AWS could have different sts endpoint based on it's region.
522    /// We can check them by region name.
523    ///
524    /// ref: <https://github.com/awslabs/aws-sdk-rust/blob/31cfae2cf23be0c68a47357070dea1aee9227e3a/sdk/sts/src/aws_endpoint.rs>
525    fn sts_endpoint(&self) -> Result<String> {
526        // use regional sts if sts_regional_endpoints has been set.
527        if self.config.sts_regional_endpoints == "regional" {
528            let region = self.config.region.clone().ok_or_else(|| {
529                anyhow!("sts_regional_endpoints set to reginal, but region is not set")
530            })?;
531            if region.starts_with("cn-") {
532                Ok(format!("sts.{region}.amazonaws.com.cn"))
533            } else {
534                Ok(format!("sts.{region}.amazonaws.com"))
535            }
536        } else {
537            let region = self.config.region.clone().unwrap_or_default();
538            if region.starts_with("cn") {
539                // TODO: seems aws china doesn't support global sts?
540                Ok("sts.amazonaws.com.cn".to_string())
541            } else {
542                Ok("sts.amazonaws.com".to_string())
543            }
544        }
545    }
546}
547
548#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
549#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
550impl CredentialLoad for AssumeRoleLoader {
551    async fn load_credential(&self, _: Client) -> Result<Option<Credential>> {
552        self.load().await
553    }
554}
555
556#[derive(Default, Debug, Deserialize)]
557#[serde(default, rename_all = "PascalCase")]
558struct AssumeRoleWithWebIdentityResponse {
559    #[serde(rename = "AssumeRoleWithWebIdentityResult")]
560    result: AssumeRoleWithWebIdentityResult,
561}
562
563#[derive(Default, Debug, Deserialize)]
564#[serde(default, rename_all = "PascalCase")]
565struct AssumeRoleWithWebIdentityResult {
566    credentials: AssumeRoleWithWebIdentityCredentials,
567}
568
569#[derive(Default, Debug, Deserialize)]
570#[serde(default, rename_all = "PascalCase")]
571struct AssumeRoleWithWebIdentityCredentials {
572    access_key_id: String,
573    secret_access_key: String,
574    session_token: String,
575    expiration: String,
576}
577
578#[derive(Default, Debug, Deserialize)]
579#[serde(default, rename_all = "PascalCase")]
580struct AssumeRoleResponse {
581    #[serde(rename = "AssumeRoleResult")]
582    result: AssumeRoleResult,
583}
584
585#[derive(Default, Debug, Deserialize)]
586#[serde(default, rename_all = "PascalCase")]
587struct AssumeRoleResult {
588    credentials: AssumeRoleCredentials,
589}
590
591#[derive(Default, Debug, Deserialize)]
592#[serde(default, rename_all = "PascalCase")]
593struct AssumeRoleCredentials {
594    access_key_id: String,
595    secret_access_key: String,
596    session_token: String,
597    expiration: String,
598}
599
600#[derive(Default, Debug, Deserialize)]
601#[serde(default, rename_all = "PascalCase")]
602struct Ec2MetadataIamSecurityCredentials {
603    access_key_id: String,
604    secret_access_key: String,
605    token: String,
606    expiration: String,
607
608    code: String,
609}
610
611#[cfg(test)]
612mod tests {
613    use std::env;
614    use std::str::FromStr;
615    use std::vec;
616
617    use anyhow::Result;
618    use http::Request;
619    use http::StatusCode;
620    use once_cell::sync::Lazy;
621    use quick_xml::de;
622    use reqwest::Client;
623    use tokio::runtime::Runtime;
624
625    use super::*;
626    use crate::aws::constants::*;
627    use crate::aws::v4::Signer;
628
629    static RUNTIME: Lazy<Runtime> = Lazy::new(|| {
630        tokio::runtime::Builder::new_multi_thread()
631            .enable_all()
632            .build()
633            .expect("Should create a tokio runtime")
634    });
635
636    #[test]
637    fn test_credential_env_loader_without_env() {
638        let _ = env_logger::builder().is_test(true).try_init();
639
640        temp_env::with_vars_unset(vec![AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY], || {
641            RUNTIME.block_on(async {
642                let l = DefaultLoader::new(reqwest::Client::new(), Config::default())
643                    .with_disable_ec2_metadata();
644                let x = l.load().await.expect("load must succeed");
645                assert!(x.is_none());
646            })
647        });
648    }
649
650    #[test]
651    fn test_credential_env_loader_with_env() {
652        let _ = env_logger::builder().is_test(true).try_init();
653
654        temp_env::with_vars(
655            vec![
656                (AWS_ACCESS_KEY_ID, Some("access_key_id")),
657                (AWS_SECRET_ACCESS_KEY, Some("secret_access_key")),
658            ],
659            || {
660                RUNTIME.block_on(async {
661                    let l = DefaultLoader::new(Client::new(), Config::default().from_env());
662                    let x = l.load().await.expect("load must succeed");
663
664                    let x = x.expect("must load succeed");
665                    assert_eq!("access_key_id", x.access_key_id);
666                    assert_eq!("secret_access_key", x.secret_access_key);
667                })
668            },
669        );
670    }
671
672    #[test]
673    fn test_credential_profile_loader_from_config() {
674        let _ = env_logger::builder().is_test(true).try_init();
675
676        temp_env::with_vars(
677            vec![
678                (AWS_ACCESS_KEY_ID, None),
679                (AWS_SECRET_ACCESS_KEY, None),
680                (
681                    AWS_CONFIG_FILE,
682                    Some(format!(
683                        "{}/testdata/services/aws/default_config",
684                        env::current_dir()
685                            .expect("current_dir must exist")
686                            .to_string_lossy()
687                    )),
688                ),
689                (
690                    AWS_SHARED_CREDENTIALS_FILE,
691                    Some(format!(
692                        "{}/testdata/services/aws/not_exist",
693                        env::current_dir()
694                            .expect("current_dir must exist")
695                            .to_string_lossy()
696                    )),
697                ),
698            ],
699            || {
700                RUNTIME.block_on(async {
701                    let l = DefaultLoader::new(
702                        Client::new(),
703                        Config::default().from_env().from_profile(),
704                    );
705                    let x = l.load().await.unwrap().unwrap();
706                    assert_eq!("config_access_key_id", x.access_key_id);
707                    assert_eq!("config_secret_access_key", x.secret_access_key);
708                })
709            },
710        );
711    }
712
713    #[test]
714    fn test_credential_profile_loader_from_shared() {
715        let _ = env_logger::builder().is_test(true).try_init();
716
717        temp_env::with_vars(
718            vec![
719                (AWS_ACCESS_KEY_ID, None),
720                (AWS_SECRET_ACCESS_KEY, None),
721                (
722                    AWS_CONFIG_FILE,
723                    Some(format!(
724                        "{}/testdata/services/aws/not_exist",
725                        env::current_dir()
726                            .expect("load must exist")
727                            .to_string_lossy()
728                    )),
729                ),
730                (
731                    AWS_SHARED_CREDENTIALS_FILE,
732                    Some(format!(
733                        "{}/testdata/services/aws/default_credential",
734                        env::current_dir()
735                            .expect("load must exist")
736                            .to_string_lossy()
737                    )),
738                ),
739            ],
740            || {
741                RUNTIME.block_on(async {
742                    let l = DefaultLoader::new(
743                        Client::new(),
744                        Config::default().from_env().from_profile(),
745                    );
746                    let x = l.load().await.unwrap().unwrap();
747                    assert_eq!("shared_access_key_id", x.access_key_id);
748                    assert_eq!("shared_secret_access_key", x.secret_access_key);
749                })
750            },
751        );
752    }
753
754    /// AWS_SHARED_CREDENTIALS_FILE should be taken first.
755    #[test]
756    fn test_credential_profile_loader_from_both() {
757        let _ = env_logger::builder().is_test(true).try_init();
758
759        temp_env::with_vars(
760            vec![
761                (AWS_ACCESS_KEY_ID, None),
762                (AWS_SECRET_ACCESS_KEY, None),
763                (
764                    AWS_CONFIG_FILE,
765                    Some(format!(
766                        "{}/testdata/services/aws/default_config",
767                        env::current_dir()
768                            .expect("current_dir must exist")
769                            .to_string_lossy()
770                    )),
771                ),
772                (
773                    AWS_SHARED_CREDENTIALS_FILE,
774                    Some(format!(
775                        "{}/testdata/services/aws/default_credential",
776                        env::current_dir()
777                            .expect("current_dir must exist")
778                            .to_string_lossy()
779                    )),
780                ),
781            ],
782            || {
783                RUNTIME.block_on(async {
784                    let l = DefaultLoader::new(
785                        Client::new(),
786                        Config::default().from_env().from_profile(),
787                    );
788                    let x = l.load().await.expect("load must success").unwrap();
789                    assert_eq!("shared_access_key_id", x.access_key_id);
790                    assert_eq!("shared_secret_access_key", x.secret_access_key);
791                })
792            },
793        );
794    }
795
796    #[test]
797    fn test_signer_with_web_loader() -> Result<()> {
798        let _ = env_logger::builder().is_test(true).try_init();
799
800        dotenv::from_filename(".env").ok();
801
802        if env::var("REQSIGN_AWS_S3_TEST").is_err()
803            || env::var("REQSIGN_AWS_S3_TEST").unwrap() != "on"
804        {
805            return Ok(());
806        }
807
808        // Ignore test if role_arn not set
809        let role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ASSUME_ROLE_ARN") {
810            v
811        } else {
812            return Ok(());
813        };
814
815        // let provider_arn = env::var("REQSIGN_AWS_PROVIDER_ARN").expect("REQSIGN_AWS_PROVIDER_ARN not exist");
816        let region = env::var("REQSIGN_AWS_S3_REGION").expect("REQSIGN_AWS_S3_REGION not exist");
817
818        let github_token = env::var("GITHUB_ID_TOKEN").expect("GITHUB_ID_TOKEN not exist");
819        let file_path = format!(
820            "{}/testdata/services/aws/web_identity_token_file",
821            env::current_dir()
822                .expect("current_dir must exist")
823                .to_string_lossy()
824        );
825        fs::write(&file_path, github_token)?;
826
827        temp_env::with_vars(
828            vec![
829                (AWS_REGION, Some(&region)),
830                (AWS_ROLE_ARN, Some(&role_arn)),
831                (AWS_WEB_IDENTITY_TOKEN_FILE, Some(&file_path)),
832            ],
833            || {
834                RUNTIME.block_on(async {
835                    let config = Config::default().from_env();
836                    let loader = DefaultLoader::new(reqwest::Client::new(), config);
837
838                    let signer = Signer::new("s3", &region);
839
840                    let endpoint = format!("https://s3.{region}.amazonaws.com/opendal-testing");
841                    let mut req = Request::new("");
842                    *req.method_mut() = http::Method::GET;
843                    *req.uri_mut() =
844                        http::Uri::from_str(&format!("{}/{}", endpoint, "not_exist_file")).unwrap();
845
846                    let cred = loader
847                        .load()
848                        .await
849                        .expect("credential must be valid")
850                        .unwrap();
851
852                    signer.sign(&mut req, &cred).expect("sign must success");
853
854                    debug!("signed request url: {:?}", req.uri().to_string());
855                    debug!("signed request: {req:?}");
856
857                    let client = Client::new();
858                    let resp = client.execute(req.try_into().unwrap()).await.unwrap();
859
860                    let status = resp.status();
861                    debug!("got response: {resp:?}");
862                    debug!("got response content: {:?}", resp.text().await.unwrap());
863                    assert_eq!(status, StatusCode::NOT_FOUND);
864                })
865            },
866        );
867
868        Ok(())
869    }
870
871    #[test]
872    fn test_signer_with_web_loader_assume_role() -> Result<()> {
873        let _ = env_logger::builder().is_test(true).try_init();
874
875        dotenv::from_filename(".env").ok();
876
877        if env::var("REQSIGN_AWS_S3_TEST").is_err()
878            || env::var("REQSIGN_AWS_S3_TEST").unwrap() != "on"
879        {
880            return Ok(());
881        }
882
883        // Ignore test if role_arn not set
884        let role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ROLE_ARN") {
885            v
886        } else {
887            return Ok(());
888        };
889        // Ignore test if assume_role_arn not set
890        let assume_role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ASSUME_ROLE_ARN") {
891            v
892        } else {
893            return Ok(());
894        };
895
896        let region = env::var("REQSIGN_AWS_S3_REGION").expect("REQSIGN_AWS_S3_REGION not exist");
897
898        let github_token = env::var("GITHUB_ID_TOKEN").expect("GITHUB_ID_TOKEN not exist");
899        let file_path = format!(
900            "{}/testdata/services/aws/web_identity_token_file",
901            env::current_dir()
902                .expect("current_dir must exist")
903                .to_string_lossy()
904        );
905        fs::write(&file_path, github_token)?;
906
907        temp_env::with_vars(
908            vec![
909                (AWS_REGION, Some(&region)),
910                (AWS_ROLE_ARN, Some(&role_arn)),
911                (AWS_WEB_IDENTITY_TOKEN_FILE, Some(&file_path)),
912            ],
913            || {
914                RUNTIME.block_on(async {
915                    let client = reqwest::Client::new();
916                    let default_loader =
917                        DefaultLoader::new(client.clone(), Config::default().from_env())
918                            .with_disable_ec2_metadata();
919
920                    let cfg = Config {
921                        role_arn: Some(assume_role_arn.clone()),
922                        region: Some(region.clone()),
923                        sts_regional_endpoints: "regional".to_string(),
924                        ..Default::default()
925                    };
926                    let loader =
927                        AssumeRoleLoader::new(client.clone(), cfg, Box::new(default_loader))
928                            .expect("AssumeRoleLoader must be valid");
929
930                    let signer = Signer::new("s3", &region);
931                    let endpoint = format!("https://s3.{region}.amazonaws.com/opendal-testing");
932                    let mut req = Request::new("");
933                    *req.method_mut() = http::Method::GET;
934                    *req.uri_mut() =
935                        http::Uri::from_str(&format!("{}/{}", endpoint, "not_exist_file")).unwrap();
936                    let cred = loader
937                        .load()
938                        .await
939                        .expect("credential must be valid")
940                        .unwrap();
941                    signer.sign(&mut req, &cred).expect("sign must success");
942                    debug!("signed request url: {:?}", req.uri().to_string());
943                    debug!("signed request: {req:?}");
944                    let client = Client::new();
945                    let resp = client.execute(req.try_into().unwrap()).await.unwrap();
946                    let status = resp.status();
947                    debug!("got response: {resp:?}");
948                    debug!("got response content: {:?}", resp.text().await.unwrap());
949                    assert_eq!(status, StatusCode::NOT_FOUND);
950                })
951            },
952        );
953        Ok(())
954    }
955
956    #[test]
957    fn test_parse_assume_role_with_web_identity_response() -> Result<()> {
958        let _ = env_logger::builder().is_test(true).try_init();
959
960        let content = r#"<AssumeRoleWithWebIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
961  <AssumeRoleWithWebIdentityResult>
962    <Audience>test_audience</Audience>
963    <AssumedRoleUser>
964      <AssumedRoleId>role_id:reqsign</AssumedRoleId>
965      <Arn>arn:aws:sts::123:assumed-role/reqsign/reqsign</Arn>
966    </AssumedRoleUser>
967    <Provider>arn:aws:iam::123:oidc-provider/example.com/</Provider>
968    <Credentials>
969      <AccessKeyId>access_key_id</AccessKeyId>
970      <SecretAccessKey>secret_access_key</SecretAccessKey>
971      <SessionToken>session_token</SessionToken>
972      <Expiration>2022-05-25T11:45:17Z</Expiration>
973    </Credentials>
974    <SubjectFromWebIdentityToken>subject</SubjectFromWebIdentityToken>
975  </AssumeRoleWithWebIdentityResult>
976  <ResponseMetadata>
977    <RequestId>b1663ad1-23ab-45e9-b465-9af30b202eba</RequestId>
978  </ResponseMetadata>
979</AssumeRoleWithWebIdentityResponse>"#;
980
981        let resp: AssumeRoleWithWebIdentityResponse =
982            de::from_str(content).expect("xml deserialize must success");
983
984        assert_eq!(&resp.result.credentials.access_key_id, "access_key_id");
985        assert_eq!(
986            &resp.result.credentials.secret_access_key,
987            "secret_access_key"
988        );
989        assert_eq!(&resp.result.credentials.session_token, "session_token");
990        assert_eq!(&resp.result.credentials.expiration, "2022-05-25T11:45:17Z");
991
992        Ok(())
993    }
994
995    #[test]
996    fn test_parse_assume_role_response() -> Result<()> {
997        let _ = env_logger::builder().is_test(true).try_init();
998
999        let content = r#"<AssumeRoleResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
1000  <AssumeRoleResult>
1001  <SourceIdentity>Alice</SourceIdentity>
1002    <AssumedRoleUser>
1003      <Arn>arn:aws:sts::123456789012:assumed-role/demo/TestAR</Arn>
1004      <AssumedRoleId>ARO123EXAMPLE123:TestAR</AssumedRoleId>
1005    </AssumedRoleUser>
1006    <Credentials>
1007      <AccessKeyId>ASIAIOSFODNN7EXAMPLE</AccessKeyId>
1008      <SecretAccessKey>wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY</SecretAccessKey>
1009      <SessionToken>
1010       AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQW
1011       LWsKWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGd
1012       QrmGdeehM4IC1NtBmUpp2wUE8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU
1013       9HFvlRd8Tx6q6fE8YQcHNVXAkiY9q6d+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz
1014       +scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA==
1015      </SessionToken>
1016      <Expiration>2019-11-09T13:34:41Z</Expiration>
1017    </Credentials>
1018    <PackedPolicySize>6</PackedPolicySize>
1019  </AssumeRoleResult>
1020  <ResponseMetadata>
1021    <RequestId>c6104cbe-af31-11e0-8154-cbc7ccf896c7</RequestId>
1022  </ResponseMetadata>
1023</AssumeRoleResponse>"#;
1024
1025        let resp: AssumeRoleResponse = de::from_str(content).expect("xml deserialize must success");
1026
1027        assert_eq!(
1028            &resp.result.credentials.access_key_id,
1029            "ASIAIOSFODNN7EXAMPLE"
1030        );
1031        assert_eq!(
1032            &resp.result.credentials.secret_access_key,
1033            "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY"
1034        );
1035        assert_eq!(
1036            &resp.result.credentials.session_token,
1037            "AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQW
1038       LWsKWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGd
1039       QrmGdeehM4IC1NtBmUpp2wUE8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU
1040       9HFvlRd8Tx6q6fE8YQcHNVXAkiY9q6d+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz
1041       +scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA=="
1042        );
1043        assert_eq!(&resp.result.credentials.expiration, "2019-11-09T13:34:41Z");
1044
1045        Ok(())
1046    }
1047}