1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
// Copyright Materialize, Inc. and contributors. All rights reserved.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

use std::collections::BTreeMap;
use std::fmt;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;

use serde::{Deserialize, Serialize};
use url::Url;

use crate::client::Client;
use crate::tls::{Certificate, Identity};

#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct Auth {
    pub username: String,
    pub password: Option<String>,
}

/// Configuration for a `Client`.
#[derive(Clone)]
pub struct ClientConfig {
    url: Arc<dyn Fn() -> Url + Send + Sync + 'static>,
    root_certs: Vec<Certificate>,
    identity: Option<Identity>,
    auth: Option<Auth>,
    dns_overrides: BTreeMap<String, Vec<SocketAddr>>,
}

impl fmt::Debug for ClientConfig {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("ClientConfig")
            .field("url", &"...")
            .field("root_certs", &self.root_certs)
            .field("identity", &self.identity)
            .field("auth", &self.auth)
            .field("dns_overrides", &self.dns_overrides)
            .finish()
    }
}

impl ClientConfig {
    /// Constructs a new `ClientConfig` that will target the schema registry at
    /// the specified URL.
    pub fn new(url: Url) -> ClientConfig {
        ClientConfig {
            url: Arc::new(move || url.clone()),
            root_certs: Vec::new(),
            identity: None,
            auth: None,
            dns_overrides: BTreeMap::new(),
        }
    }

    /// Adds a trusted root TLS certificate.
    ///
    /// Certificates in the system's certificate store are trusted by default.
    pub fn add_root_certificate(mut self, cert: Certificate) -> ClientConfig {
        self.root_certs.push(cert);
        self
    }

    /// Enables TLS client authentication with the provided identity.
    pub fn identity(mut self, identity: Identity) -> ClientConfig {
        self.identity = Some(identity);
        self
    }

    /// Enables HTTP basic authentication with the specified username and
    /// optional password.
    pub fn auth(mut self, username: String, password: Option<String>) -> ClientConfig {
        self.auth = Some(Auth { username, password });
        self
    }

    /// Overrides DNS resolution for specific domains to the provided IP
    /// addresses.
    ///
    /// See [`reqwest::ClientBuilder::resolve_to_addrs`].
    pub fn resolve_to_addrs(mut self, domain: &str, addrs: &[SocketAddr]) -> ClientConfig {
        self.dns_overrides.insert(domain.into(), addrs.into());
        self
    }

    /// Sets a callback that will be used to dynamically override the url
    /// the client uses.
    // Note this this doesn't use native `reqwest` `Proxy`s because not all schema
    // registry implementations support them.
    pub fn dynamic_url<F: Fn() -> Url + Send + Sync + 'static>(
        mut self,
        callback: F,
    ) -> ClientConfig {
        self.url = Arc::new(callback);
        self
    }

    /// Builds the [`Client`].
    pub fn build(self) -> Result<Client, anyhow::Error> {
        let mut builder = reqwest::ClientBuilder::new();

        for root_cert in self.root_certs {
            builder = builder.add_root_certificate(root_cert.into());
        }

        if let Some(ident) = self.identity {
            builder = builder.identity(ident.into());
        }

        for (domain, addrs) in self.dns_overrides {
            builder = builder.resolve_to_addrs(&domain, &addrs);
        }

        // TODO(guswynn): make this configurable.
        let timeout = Duration::from_secs(60);

        let inner = builder
            .redirect(reqwest::redirect::Policy::none())
            .timeout(timeout)
            .build()
            .unwrap();

        Client::new(inner, self.url, self.auth, timeout)
    }
}