mz/
context.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License in the LICENSE file at the
6// root of this repository, or online at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Context types for command implementations.
17//!
18//! The implementation of each command in the [crate::command] module takes exactly
19//! one of these context types, depending on whether it requires access to a
20//! valid authentication profile and active region.
21
22use std::path::PathBuf;
23use std::sync::Arc;
24
25use crate::config_file::ConfigFile;
26use crate::error::Error;
27use crate::sql_client::{Client as SqlClient, ClientConfig as SqlClientConfig};
28use crate::ui::{OutputFormat, OutputFormatter};
29use mz_cloud_api::client::Client as CloudClient;
30use mz_cloud_api::client::cloud_provider::CloudProvider;
31use mz_cloud_api::client::region::{Region, RegionInfo};
32use mz_cloud_api::config::{
33    ClientBuilder as CloudClientBuilder, ClientConfig as CloudClientConfig,
34};
35use mz_frontegg_client::client::{Authentication, Client as AdminClient};
36use mz_frontegg_client::config::{
37    ClientBuilder as AdminClientBuilder, ClientConfig as AdminClientConfig,
38};
39use url::{ParseError, Url};
40
41/// Arguments for [`Context::load`].
42pub struct ContextLoadArgs {
43    /// An override for the configuration file path to laod.
44    ///
45    /// If unspecified, the default configuration file path is used.
46    pub config_file_path: Option<PathBuf>,
47    /// The output format to use.
48    pub output_format: OutputFormat,
49    /// Whether to suppress color output.
50    pub no_color: bool,
51    /// Global optional region.
52    pub region: Option<String>,
53    /// Global optional profile.
54    pub profile: Option<String>,
55}
56
57/// Context for a basic command.
58#[derive(Clone)]
59pub struct Context {
60    config_file: ConfigFile,
61    output_formatter: OutputFormatter,
62    region: Option<String>,
63    profile: Option<String>,
64}
65
66impl Context {
67    /// Loads the context from the provided arguments.
68    pub async fn load(
69        ContextLoadArgs {
70            config_file_path,
71            output_format,
72            no_color,
73            region,
74            profile,
75        }: ContextLoadArgs,
76    ) -> Result<Context, Error> {
77        let config_file_path = match config_file_path {
78            None => ConfigFile::default_path()?,
79            Some(path) => path,
80        };
81        let config_file = ConfigFile::load(config_file_path).await?;
82        Ok(Context {
83            config_file,
84            output_formatter: OutputFormatter::new(output_format, no_color),
85            region,
86            profile,
87        })
88    }
89
90    /// Retrieves the admin endpoint from the configuration file.
91    ///
92    /// - If an admin-endpoint is provided, it uses its value.
93    /// - If only a cloud-endpoint is provided, it constructs the admin endpoint based on it.
94    /// - If neither an admin-endpoint nor a cloud-endpoint is provided, default values are used.
95    pub fn get_admin_endpoint(
96        &self,
97        cloud_endpoint: Option<Url>,
98        admin_endpoint: Option<&str>,
99    ) -> Result<Option<Url>, ParseError> {
100        if let Some(admin_endpoint) = admin_endpoint {
101            return Ok(Some(admin_endpoint.parse()?));
102        } else if let Some(cloud_endpoint) = cloud_endpoint {
103            let mut admin_endpoint_url = cloud_endpoint;
104
105            if let Some(host) = admin_endpoint_url.host_str().as_mut() {
106                if let Some(host) = host.strip_prefix("api.") {
107                    admin_endpoint_url.set_host(Some(&format!("admin.{}", host)))?;
108                }
109                return Ok(Some(admin_endpoint_url));
110            }
111        }
112
113        Ok(None)
114    }
115
116    /// Returns the global profile option.
117    pub fn get_global_profile(&self) -> Option<String> {
118        self.profile.clone()
119    }
120
121    /// Converts this context into a [`ProfileContext`].
122    ///
123    /// If a profile is not specified, the default profile is activated.
124    pub fn activate_profile(self) -> Result<ProfileContext, Error> {
125        let profile_name = self
126            .profile
127            .clone()
128            .unwrap_or_else(|| self.config_file.profile().into());
129        let config_file = self.config_file.clone();
130
131        let profile = config_file.load_profile(&profile_name)?;
132
133        // Parse the endpoint form the string in the config to URL.
134        let cloud_endpoint = match profile.cloud_endpoint() {
135            Some(endpoint) => Some(endpoint.parse::<Url>()?),
136            None => None,
137        };
138
139        // Build clients
140        let mut admin_client_builder = AdminClientBuilder::default();
141
142        if let Ok(Some(admin_endpoint)) =
143            self.get_admin_endpoint(cloud_endpoint.clone(), profile.admin_endpoint())
144        {
145            admin_client_builder = admin_client_builder.endpoint(admin_endpoint);
146        }
147
148        let admin_client: Arc<AdminClient> =
149            Arc::new(admin_client_builder.build(AdminClientConfig {
150                authentication: Authentication::AppPassword(
151                    profile.app_password(config_file.vault())?.parse()?,
152                ),
153            }));
154
155        let mut cloud_client_builder = CloudClientBuilder::default();
156
157        if let Some(cloud_endpoint) = cloud_endpoint {
158            cloud_client_builder = cloud_client_builder.endpoint(cloud_endpoint);
159        }
160
161        let cloud_client = cloud_client_builder.build(CloudClientConfig {
162            auth_client: Arc::clone(&admin_client),
163        });
164
165        // The sql client is created here to avoid having to handle the config around. E.g. reading config from config_file
166        // this happens because profile is 'a static, and adding it to the profile context would make also the context 'a, etc.
167        let sql_client = SqlClient::new(SqlClientConfig {
168            app_password: profile.app_password(config_file.vault())?.parse()?,
169        });
170
171        Ok(ProfileContext {
172            context: self,
173            profile_name,
174            admin_client,
175            cloud_client,
176            sql_client,
177        })
178    }
179
180    /// Returns the configuration file loaded by this context.
181    pub fn config_file(&self) -> &ConfigFile {
182        &self.config_file
183    }
184
185    /// Returns the output_formatter associated with this context.
186    pub fn output_formatter(&self) -> &OutputFormatter {
187        &self.output_formatter
188    }
189}
190
191/// Context for a command that requires a valid authentication profile.
192pub struct ProfileContext {
193    context: Context,
194    profile_name: String,
195    admin_client: Arc<AdminClient>,
196    cloud_client: CloudClient,
197    sql_client: SqlClient,
198}
199
200impl ProfileContext {
201    /// Loads the profile and returns a region context.
202    pub fn activate_region(self) -> Result<RegionContext, Error> {
203        let profile = self
204            .context
205            .config_file
206            .load_profile(&self.profile_name)
207            .unwrap();
208
209        // Region must be lower case.
210        // Cloud's API response returns the region in
211        // lower case.
212        let region_name = self
213            .context
214            .region
215            .clone()
216            .or(profile.region().map(|r| r.to_string()))
217            .ok_or_else(|| panic!("no region configured"))
218            .unwrap()
219            .to_lowercase();
220        Ok(RegionContext {
221            context: self,
222            region_name,
223        })
224    }
225
226    /// Returns the admin API client associated with this context.
227    pub fn admin_client(&self) -> &AdminClient {
228        &self.admin_client
229    }
230
231    /// Returns the cloud API client associated with this context.
232    pub fn cloud_client(&self) -> &CloudClient {
233        &self.cloud_client
234    }
235
236    /// Returns the configuration file loaded by this context.
237    pub fn config_file(&self) -> &ConfigFile {
238        &self.context.config_file
239    }
240
241    /// Returns the output_formatter associated with this context.
242    pub fn output_formatter(&self) -> &OutputFormatter {
243        &self.context.output_formatter
244    }
245
246    /// Returns the context profile.
247    /// If a global profile has been set, it will return the global profile.
248    /// Otherwise returns the config's profile.
249    pub fn get_profile(&self) -> String {
250        self.context
251            .get_global_profile()
252            .unwrap_or(self.config_file().profile().to_string())
253    }
254}
255
256/// Context for a command that requires a valid authentication profile
257/// and an active region.
258pub struct RegionContext {
259    context: ProfileContext,
260    region_name: String,
261}
262
263impl RegionContext {
264    /// Returns the admin API client associated with this context.
265    pub fn admin_client(&self) -> &AdminClient {
266        &self.context.admin_client
267    }
268
269    /// Returns the admin API client associated with this context.
270    pub fn cloud_client(&self) -> &CloudClient {
271        &self.context.cloud_client
272    }
273
274    /// Returns a SQL client connected to region associated with this context.
275    pub fn sql_client(&self) -> &SqlClient {
276        &self.context.sql_client
277    }
278
279    /// Returns the cloud provider from the profile context.
280    pub async fn get_cloud_provider(&self) -> Result<CloudProvider, Error> {
281        let client = &self.context.cloud_client;
282        let cloud_providers = client.list_cloud_regions().await?;
283
284        let provider = cloud_providers
285            .into_iter()
286            .find(|x| x.id == self.region_name)
287            .ok_or(Error::CloudRegionMissing)?;
288
289        Ok(provider)
290    }
291
292    /// Returns the cloud provider region of the context.
293    pub async fn get_region(&self) -> Result<Region, Error> {
294        let client = self.cloud_client();
295        let cloud_provider = self.get_cloud_provider().await?;
296        let region = client.get_region(cloud_provider).await?;
297
298        Ok(region)
299    }
300
301    /// Returns the cloud provider region of the context.
302    pub async fn get_region_info(&self) -> Result<RegionInfo, Error> {
303        let client = self.cloud_client();
304        let cloud_provider = self.get_cloud_provider().await?;
305        let region = client.get_region(cloud_provider).await?;
306
307        region.region_info.ok_or_else(|| Error::NotReadyRegion)
308    }
309
310    /// Returns the configuration file loaded by this context.
311    pub fn config_file(&self) -> &ConfigFile {
312        self.context.config_file()
313    }
314
315    /// Returns the region profile.
316    /// As in the context, if a global profile has been set,
317    /// it will return the global profile.
318    /// Otherwise returns the config's profile.
319    pub fn get_profile(&self) -> String {
320        self.context.get_profile()
321    }
322
323    /// Returns the output_formatter associated with this context.
324    pub fn output_formatter(&self) -> &OutputFormatter {
325        self.context.output_formatter()
326    }
327}