mz_persist/
azure.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An Azure Blob Storage implementation of [Blob] storage.
11
12use anyhow::{Context, anyhow};
13use async_trait::async_trait;
14use azure_core::{ExponentialRetryOptions, RetryOptions, StatusCode, TransportOptions};
15use azure_identity::create_default_credential;
16use azure_storage::{CloudLocation, EMULATOR_ACCOUNT, prelude::*};
17use azure_storage_blobs::blob::operations::GetBlobResponse;
18use azure_storage_blobs::prelude::*;
19use bytes::Bytes;
20use futures_util::StreamExt;
21use futures_util::stream::FuturesOrdered;
22use std::fmt::{Debug, Formatter};
23use std::sync::Arc;
24use std::time::Duration;
25use tracing::{info, warn};
26use url::Url;
27use uuid::Uuid;
28
29use mz_dyncfg::ConfigSet;
30use mz_ore::bytes::SegmentedBytes;
31use mz_ore::cast::CastFrom;
32use mz_ore::lgbytes::MetricsRegion;
33use mz_ore::metrics::MetricsRegistry;
34
35use crate::cfg::BlobKnobs;
36use crate::error::Error;
37use crate::location::{Blob, BlobMetadata, Determinate, ExternalError};
38use crate::metrics::S3BlobMetrics;
39
40/// Configuration for opening an [AzureBlob].
41#[derive(Clone, Debug)]
42pub struct AzureBlobConfig {
43    // The metrics struct here is a bit of a misnomer. We only need access
44    // to the LgBytes metrics, which has an Azure-specific field. For now,
45    // it saves considerable plumbing to reuse [S3BlobMetrics].
46    //
47    // TODO: spin up an AzureBlobMetrics and do the plumbing.
48    metrics: S3BlobMetrics,
49    client: ContainerClient,
50    prefix: String,
51    cfg: Arc<ConfigSet>,
52}
53
54impl AzureBlobConfig {
55    const EXTERNAL_TESTS_AZURE_CONTAINER: &'static str =
56        "MZ_PERSIST_EXTERNAL_STORAGE_TEST_AZURE_CONTAINER";
57
58    /// Returns a new [AzureBlobConfig] for use in production.
59    ///
60    /// Stores objects in the given container prepended with the (possibly empty)
61    /// prefix. Azure credentials must be available in the process or environment.
62    pub fn new(
63        account: String,
64        container: String,
65        prefix: String,
66        metrics: S3BlobMetrics,
67        url: Url,
68        knobs: Box<dyn BlobKnobs>,
69        cfg: Arc<ConfigSet>,
70    ) -> Result<Self, Error> {
71        let client = if account == EMULATOR_ACCOUNT {
72            info!("Connecting to Azure emulator");
73            ClientBuilder::with_location(
74                CloudLocation::Emulator {
75                    address: url.domain().expect("domain for Azure emulator").to_string(),
76                    port: url.port().expect("port for Azure emulator"),
77                },
78                StorageCredentials::emulator(),
79            )
80            .transport({
81                // Azure uses reqwest / hyper internally, but we specify a client explicitly to
82                // plumb through our timeouts.
83                TransportOptions::new(Arc::new(
84                    reqwest::ClientBuilder::new()
85                        .timeout(knobs.operation_attempt_timeout())
86                        .read_timeout(knobs.read_timeout())
87                        .connect_timeout(knobs.connect_timeout())
88                        .build()
89                        .expect("valid config for azure HTTP client"),
90                ))
91            })
92            .retry(RetryOptions::exponential(
93                ExponentialRetryOptions::default().max_total_elapsed(knobs.operation_timeout()),
94            ))
95            .blob_service_client()
96            .container_client(container)
97        } else {
98            let sas_credentials = match url.query() {
99                Some(query) => Some(StorageCredentials::sas_token(query)),
100                None => None,
101            };
102
103            let credentials = match sas_credentials {
104                Some(Ok(credentials)) => credentials,
105                Some(Err(err)) => {
106                    warn!("Failed to parse SAS token: {err}");
107                    // TODO: should we fallback here? Or can we fully rely on query params
108                    // to determine whether a SAS token was provided?
109                    StorageCredentials::token_credential(
110                        create_default_credential().expect("Azure default credentials"),
111                    )
112                }
113                None => {
114                    // Fall back to default credential stack to support auth modes like
115                    // workload identity that are injected into the environment
116                    StorageCredentials::token_credential(
117                        create_default_credential().expect("Azure default credentials"),
118                    )
119                }
120            };
121
122            let service_client = BlobServiceClient::new(account, credentials);
123            service_client.container_client(container)
124        };
125
126        // TODO: some auth modes like user-delegated SAS tokens are time-limited
127        // and need to be refreshed. This can be done through `service_client.update_credentials`
128        // but there'll be a fair bit of plumbing needed to make each mode work
129
130        Ok(AzureBlobConfig {
131            metrics,
132            client,
133            cfg,
134            prefix,
135        })
136    }
137
138    /// Returns a new [AzureBlobConfig] for use in unit tests.
139    pub fn new_for_test() -> Result<Option<Self>, Error> {
140        struct TestBlobKnobs;
141        impl Debug for TestBlobKnobs {
142            fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
143                f.debug_struct("TestBlobKnobs").finish_non_exhaustive()
144            }
145        }
146        impl BlobKnobs for TestBlobKnobs {
147            fn operation_timeout(&self) -> Duration {
148                Duration::from_secs(30)
149            }
150
151            fn operation_attempt_timeout(&self) -> Duration {
152                Duration::from_secs(10)
153            }
154
155            fn connect_timeout(&self) -> Duration {
156                Duration::from_secs(5)
157            }
158
159            fn read_timeout(&self) -> Duration {
160                Duration::from_secs(5)
161            }
162
163            fn is_cc_active(&self) -> bool {
164                false
165            }
166        }
167
168        let container_name = match std::env::var(Self::EXTERNAL_TESTS_AZURE_CONTAINER) {
169            Ok(container) => container,
170            Err(_) => {
171                if mz_ore::env::is_var_truthy("CI") {
172                    panic!("CI is supposed to run this test but something has gone wrong!");
173                }
174                return Ok(None);
175            }
176        };
177
178        let prefix = Uuid::new_v4().to_string();
179        let metrics = S3BlobMetrics::new(&MetricsRegistry::new());
180
181        let config = AzureBlobConfig::new(
182            EMULATOR_ACCOUNT.to_string(),
183            container_name.clone(),
184            prefix,
185            metrics,
186            Url::parse(&format!("http://localhost:40111/{}", container_name)).expect("valid url"),
187            Box::new(TestBlobKnobs),
188            Arc::new(ConfigSet::default()),
189        )?;
190
191        Ok(Some(config))
192    }
193}
194
195/// Implementation of [Blob] backed by Azure Blob Storage.
196#[derive(Debug)]
197pub struct AzureBlob {
198    metrics: S3BlobMetrics,
199    client: ContainerClient,
200    prefix: String,
201    _cfg: Arc<ConfigSet>,
202}
203
204impl AzureBlob {
205    /// Opens the given location for non-exclusive read-write access.
206    pub async fn open(config: AzureBlobConfig) -> Result<Self, ExternalError> {
207        if config.client.service_client().account() == EMULATOR_ACCOUNT {
208            // TODO: we could move this logic into the test harness.
209            // it's currently here because it's surprisingly annoying to
210            // create the container out-of-band
211            if let Err(e) = config.client.create().await {
212                warn!("Failed to create container: {e}");
213            }
214        }
215
216        let ret = AzureBlob {
217            metrics: config.metrics,
218            client: config.client,
219            prefix: config.prefix,
220            _cfg: config.cfg,
221        };
222
223        Ok(ret)
224    }
225
226    fn get_path(&self, key: &str) -> String {
227        format!("{}/{}", self.prefix, key)
228    }
229}
230
231#[async_trait]
232impl Blob for AzureBlob {
233    async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError> {
234        let path = self.get_path(key);
235        let blob = self.client.blob_client(path);
236
237        /// Fetch a the body of a single [`GetBlobResponse`].
238        async fn fetch_chunk(
239            response: GetBlobResponse,
240            metrics: S3BlobMetrics,
241        ) -> Result<Bytes, ExternalError> {
242            let content_length = response.blob.properties.content_length;
243
244            // Here we're being quite defensive. If `content_length` comes back
245            // as 0 it's most likely incorrect. In that case we'll copy bytes
246            // of the network into a growable buffer, then copy the entire
247            // buffer into lgalloc.
248            let mut buffer = match content_length {
249                1.. => {
250                    let region = metrics
251                        .lgbytes
252                        .persist_azure
253                        .new_region(usize::cast_from(content_length));
254                    PreSizedBuffer::Sized(region)
255                }
256                0 => PreSizedBuffer::Unknown(SegmentedBytes::new()),
257            };
258
259            let mut body = response.data;
260            while let Some(value) = body.next().await {
261                let value = value
262                    .map_err(|e| ExternalError::from(e.context("azure blob get body error")))?;
263
264                match &mut buffer {
265                    PreSizedBuffer::Sized(region) => region.extend_from_slice(&value),
266                    PreSizedBuffer::Unknown(segments) => segments.push(value),
267                }
268            }
269
270            // Spill our bytes to lgalloc, if they aren't already.
271            let lgbytes: Bytes = match buffer {
272                PreSizedBuffer::Sized(region) => region.into(),
273                // Now that we've collected all of the segments, we know the size of our region.
274                PreSizedBuffer::Unknown(segments) => {
275                    let mut region = metrics.lgbytes.persist_azure.new_region(segments.len());
276                    for segment in segments.into_segments() {
277                        region.extend_from_slice(segment.as_ref());
278                    }
279                    region.into()
280                }
281            };
282
283            // Report if the content-length header didn't match the number of
284            // bytes we read from the network.
285            if content_length != u64::cast_from(lgbytes.len()) {
286                metrics.get_invalid_resp.inc();
287            }
288
289            Ok(lgbytes)
290        }
291
292        let mut requests = FuturesOrdered::new();
293        // TODO: the default chunk size is 1MB. We have not tried tuning it,
294        // but making this configurable / running some benchmarks could be
295        // valuable.
296        let mut stream = blob.get().into_stream();
297
298        while let Some(value) = stream.next().await {
299            // Return early if any of the individual fetch requests return an error.
300            let response = match value {
301                Ok(v) => v,
302                Err(e) => {
303                    if let Some(e) = e.as_http_error() {
304                        if e.status() == StatusCode::NotFound {
305                            return Ok(None);
306                        }
307                    }
308
309                    return Err(ExternalError::from(e.context("azure blob get error")));
310                }
311            };
312
313            // Drive all of the fetch requests concurrently.
314            let metrics = self.metrics.clone();
315            requests.push_back(fetch_chunk(response, metrics));
316        }
317
318        // Await on all of our chunks.
319        let mut segments = SegmentedBytes::with_capacity(requests.len());
320        while let Some(body) = requests.next().await {
321            let segment = body.context("azure blob get body err")?;
322            segments.push(segment);
323        }
324
325        Ok(Some(segments))
326    }
327
328    async fn list_keys_and_metadata(
329        &self,
330        key_prefix: &str,
331        f: &mut (dyn FnMut(BlobMetadata) + Send + Sync),
332    ) -> Result<(), ExternalError> {
333        let blob_key_prefix = self.get_path(key_prefix);
334        let strippable_root_prefix = format!("{}/", self.prefix);
335
336        let mut stream = self
337            .client
338            .list_blobs()
339            .prefix(blob_key_prefix.clone())
340            .into_stream();
341
342        while let Some(response) = stream.next().await {
343            let response =
344                response.map_err(|e| ExternalError::from(e.context("azure blob list error")))?;
345
346            for blob in response.blobs.items {
347                let azure_storage_blobs::container::operations::list_blobs::BlobItem::Blob(blob) =
348                    blob
349                else {
350                    continue;
351                };
352
353                if let Some(key) = blob.name.strip_prefix(&strippable_root_prefix) {
354                    let size_in_bytes = blob.properties.content_length;
355                    f(BlobMetadata { key, size_in_bytes });
356                }
357            }
358        }
359
360        Ok(())
361    }
362
363    async fn set(&self, key: &str, value: Bytes) -> Result<(), ExternalError> {
364        let path = self.get_path(key);
365        let blob = self.client.blob_client(path);
366
367        blob.put_block_blob(value)
368            .await
369            .map_err(|e| ExternalError::from(e.context("azure blob put error")))?;
370
371        Ok(())
372    }
373
374    async fn delete(&self, key: &str) -> Result<Option<usize>, ExternalError> {
375        let path = self.get_path(key);
376        let blob = self.client.blob_client(path);
377
378        match blob.get_properties().await {
379            Ok(props) => {
380                let size = usize::cast_from(props.blob.properties.content_length);
381                blob.delete()
382                    .await
383                    .map_err(|e| ExternalError::from(e.context("azure blob delete error")))?;
384                Ok(Some(size))
385            }
386            Err(e) => {
387                if let Some(e) = e.as_http_error() {
388                    if e.status() == StatusCode::NotFound {
389                        return Ok(None);
390                    }
391                }
392
393                Err(ExternalError::from(e.context("azure blob error")))
394            }
395        }
396    }
397
398    async fn restore(&self, key: &str) -> Result<(), ExternalError> {
399        let path = self.get_path(key);
400        let blob = self.client.blob_client(&path);
401
402        match blob.get_properties().await {
403            Ok(_) => Ok(()),
404            Err(e) => {
405                if let Some(e) = e.as_http_error() {
406                    if e.status() == StatusCode::NotFound {
407                        return Err(Determinate::new(anyhow!(
408                            "azure blob error: unable to restore non-existent key {key}"
409                        ))
410                        .into());
411                    }
412                }
413
414                Err(ExternalError::from(e.context("azure blob error")))
415            }
416        }
417    }
418}
419
420/// If possible we'll pre-allocate a chunk of memory in lgalloc and write into
421/// that as we read bytes off the network.
422enum PreSizedBuffer {
423    Sized(MetricsRegion<u8>),
424    Unknown(SegmentedBytes),
425}
426
427#[cfg(test)]
428mod tests {
429    use tracing::info;
430
431    use crate::location::tests::blob_impl_test;
432
433    use super::*;
434
435    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `TLS_method` on OS `linux`
436    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
437    async fn azure_blob() -> Result<(), ExternalError> {
438        let config = match AzureBlobConfig::new_for_test()? {
439            Some(client) => client,
440            None => {
441                info!(
442                    "{} env not set: skipping test that uses external service",
443                    AzureBlobConfig::EXTERNAL_TESTS_AZURE_CONTAINER
444                );
445                return Ok(());
446            }
447        };
448
449        blob_impl_test(move |_path| {
450            let config = config.clone();
451            async move {
452                let config = AzureBlobConfig {
453                    metrics: config.metrics.clone(),
454                    client: config.client.clone(),
455                    cfg: Arc::new(ConfigSet::default()),
456                    prefix: config.prefix.clone(),
457                };
458                AzureBlob::open(config).await
459            }
460        })
461        .await
462    }
463}