mz_persist/
azure.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An Azure Blob Storage implementation of [Blob] storage.
11
12use anyhow::{Context, anyhow};
13use async_trait::async_trait;
14use azure_core::{ExponentialRetryOptions, RetryOptions, StatusCode, TransportOptions};
15use azure_identity::create_default_credential;
16use azure_storage::{CloudLocation, EMULATOR_ACCOUNT, prelude::*};
17use azure_storage_blobs::blob::operations::GetBlobResponse;
18use azure_storage_blobs::prelude::*;
19use bytes::Bytes;
20use futures_util::StreamExt;
21use futures_util::stream::FuturesOrdered;
22use std::fmt::{Debug, Formatter};
23use std::sync::Arc;
24use std::time::Duration;
25use tracing::{info, warn};
26use url::Url;
27use uuid::Uuid;
28
29use mz_dyncfg::ConfigSet;
30use mz_ore::bytes::SegmentedBytes;
31use mz_ore::cast::CastFrom;
32use mz_ore::lgbytes::MetricsRegion;
33use mz_ore::metrics::MetricsRegistry;
34
35use crate::cfg::BlobKnobs;
36use crate::error::Error;
37use crate::location::{Blob, BlobMetadata, Determinate, ExternalError};
38use crate::metrics::S3BlobMetrics;
39
40/// Configuration for opening an [AzureBlob].
41#[derive(Clone, Debug)]
42pub struct AzureBlobConfig {
43    // The metrics struct here is a bit of a misnomer. We only need access
44    // to the LgBytes metrics, which has an Azure-specific field. For now,
45    // it saves considerable plumbing to reuse [S3BlobMetrics].
46    //
47    // TODO: spin up an AzureBlobMetrics and do the plumbing.
48    metrics: S3BlobMetrics,
49    client: ContainerClient,
50    prefix: String,
51    cfg: Arc<ConfigSet>,
52}
53
54impl AzureBlobConfig {
55    const EXTERNAL_TESTS_AZURE_CONTAINER: &'static str =
56        "MZ_PERSIST_EXTERNAL_STORAGE_TEST_AZURE_CONTAINER";
57
58    /// Returns a new [AzureBlobConfig] for use in production.
59    ///
60    /// Stores objects in the given container prepended with the (possibly empty)
61    /// prefix. Azure credentials must be available in the process or environment.
62    pub fn new(
63        account: String,
64        container: String,
65        prefix: String,
66        metrics: S3BlobMetrics,
67        url: Url,
68        knobs: Box<dyn BlobKnobs>,
69        cfg: Arc<ConfigSet>,
70    ) -> Result<Self, Error> {
71        let client = if account == EMULATOR_ACCOUNT {
72            info!("Connecting to Azure emulator");
73            ClientBuilder::with_location(
74                CloudLocation::Emulator {
75                    address: url.domain().expect("domain for Azure emulator").to_string(),
76                    port: url.port().expect("port for Azure emulator"),
77                },
78                StorageCredentials::emulator(),
79            )
80            .transport({
81                // Azure uses reqwest / hyper internally, but we specify a client explicitly to
82                // plumb through our timeouts.
83                TransportOptions::new(Arc::new(
84                    reqwest::ClientBuilder::new()
85                        .timeout(knobs.operation_attempt_timeout())
86                        .read_timeout(knobs.read_timeout())
87                        .connect_timeout(knobs.connect_timeout())
88                        .build()
89                        .expect("valid config for azure HTTP client"),
90                ))
91            })
92            .retry(RetryOptions::exponential(
93                ExponentialRetryOptions::default().max_total_elapsed(knobs.operation_timeout()),
94            ))
95            .blob_service_client()
96            .container_client(container)
97        } else {
98            let sas_credentials = match url.query() {
99                Some(query) => Some(StorageCredentials::sas_token(query)),
100                None => None,
101            };
102
103            let credentials = match sas_credentials {
104                Some(Ok(credentials)) => credentials,
105                Some(Err(err)) => {
106                    warn!("Failed to parse SAS token: {err}");
107                    // TODO: should we fallback here? Or can we fully rely on query params
108                    // to determine whether a SAS token was provided?
109                    StorageCredentials::token_credential(
110                        create_default_credential().expect("Azure default credentials"),
111                    )
112                }
113                None => {
114                    // Fall back to default credential stack to support auth modes like
115                    // workload identity that are injected into the environment
116                    StorageCredentials::token_credential(
117                        create_default_credential().expect("Azure default credentials"),
118                    )
119                }
120            };
121
122            let service_client = BlobServiceClient::new(account, credentials);
123            service_client.container_client(container)
124        };
125
126        // TODO: some auth modes like user-delegated SAS tokens are time-limited
127        // and need to be refreshed. This can be done through `service_client.update_credentials`
128        // but there'll be a fair bit of plumbing needed to make each mode work
129
130        Ok(AzureBlobConfig {
131            metrics,
132            client,
133            cfg,
134            prefix,
135        })
136    }
137
138    /// Returns a new [AzureBlobConfig] for use in unit tests.
139    pub fn new_for_test() -> Result<Option<Self>, Error> {
140        struct TestBlobKnobs;
141        impl Debug for TestBlobKnobs {
142            fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
143                f.debug_struct("TestBlobKnobs").finish_non_exhaustive()
144            }
145        }
146        impl BlobKnobs for TestBlobKnobs {
147            fn operation_timeout(&self) -> Duration {
148                Duration::from_secs(30)
149            }
150
151            fn operation_attempt_timeout(&self) -> Duration {
152                Duration::from_secs(10)
153            }
154
155            fn connect_timeout(&self) -> Duration {
156                Duration::from_secs(5)
157            }
158
159            fn read_timeout(&self) -> Duration {
160                Duration::from_secs(5)
161            }
162
163            fn is_cc_active(&self) -> bool {
164                false
165            }
166        }
167
168        let container_name = match std::env::var(Self::EXTERNAL_TESTS_AZURE_CONTAINER) {
169            Ok(container) => container,
170            Err(_) => {
171                if mz_ore::env::is_var_truthy("CI") {
172                    panic!("CI is supposed to run this test but something has gone wrong!");
173                }
174                return Ok(None);
175            }
176        };
177
178        let prefix = Uuid::new_v4().to_string();
179        let metrics = S3BlobMetrics::new(&MetricsRegistry::new());
180
181        let config = AzureBlobConfig::new(
182            EMULATOR_ACCOUNT.to_string(),
183            container_name.clone(),
184            prefix,
185            metrics,
186            Url::parse(&format!("http://localhost:40111/{}", container_name)).expect("valid url"),
187            Box::new(TestBlobKnobs),
188            Arc::new(ConfigSet::default()),
189        )?;
190
191        Ok(Some(config))
192    }
193}
194
195/// Implementation of [Blob] backed by Azure Blob Storage.
196#[derive(Debug)]
197pub struct AzureBlob {
198    metrics: S3BlobMetrics,
199    client: ContainerClient,
200    prefix: String,
201    _cfg: Arc<ConfigSet>,
202}
203
204impl AzureBlob {
205    /// Opens the given location for non-exclusive read-write access.
206    pub async fn open(config: AzureBlobConfig) -> Result<Self, ExternalError> {
207        if config.client.service_client().account() == EMULATOR_ACCOUNT {
208            // TODO: we could move this logic into the test harness.
209            // it's currently here because it's surprisingly annoying to
210            // create the container out-of-band
211            if let Err(error) = config.client.create().await {
212                info!(
213                    ?error,
214                    "failed to create emulator container; this is expected on repeat runs"
215                );
216            }
217        }
218
219        let ret = AzureBlob {
220            metrics: config.metrics,
221            client: config.client,
222            prefix: config.prefix,
223            _cfg: config.cfg,
224        };
225
226        Ok(ret)
227    }
228
229    fn get_path(&self, key: &str) -> String {
230        format!("{}/{}", self.prefix, key)
231    }
232}
233
234#[async_trait]
235impl Blob for AzureBlob {
236    async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError> {
237        let path = self.get_path(key);
238        let blob = self.client.blob_client(path);
239
240        /// Fetch the body of a single [`GetBlobResponse`].
241        async fn fetch_chunk(
242            response: GetBlobResponse,
243            metrics: S3BlobMetrics,
244        ) -> Result<Bytes, ExternalError> {
245            let content_length = response.blob.properties.content_length;
246
247            // Here we're being quite defensive. If `content_length` comes back
248            // as 0 it's most likely incorrect. In that case we'll copy bytes
249            // of the network into a growable buffer, then copy the entire
250            // buffer into lgalloc.
251            let mut buffer = match content_length {
252                1.. => {
253                    let region = metrics
254                        .lgbytes
255                        .persist_azure
256                        .new_region(usize::cast_from(content_length));
257                    PreSizedBuffer::Sized(region)
258                }
259                0 => PreSizedBuffer::Unknown(SegmentedBytes::new()),
260            };
261
262            let mut body = response.data;
263            while let Some(value) = body.next().await {
264                let value = value
265                    .map_err(|e| ExternalError::from(e.context("azure blob get body error")))?;
266
267                match &mut buffer {
268                    PreSizedBuffer::Sized(region) => region.extend_from_slice(&value),
269                    PreSizedBuffer::Unknown(segments) => segments.push(value),
270                }
271            }
272
273            // Spill our bytes to lgalloc, if they aren't already.
274            let lgbytes: Bytes = match buffer {
275                PreSizedBuffer::Sized(region) => region.into(),
276                // Now that we've collected all of the segments, we know the size of our region.
277                PreSizedBuffer::Unknown(segments) => {
278                    let mut region = metrics.lgbytes.persist_azure.new_region(segments.len());
279                    for segment in segments.into_segments() {
280                        region.extend_from_slice(segment.as_ref());
281                    }
282                    region.into()
283                }
284            };
285
286            // Report if the content-length header didn't match the number of
287            // bytes we read from the network.
288            if content_length != u64::cast_from(lgbytes.len()) {
289                metrics.get_invalid_resp.inc();
290            }
291
292            Ok(lgbytes)
293        }
294
295        let mut requests = FuturesOrdered::new();
296        // TODO: the default chunk size is 1MB. We have not tried tuning it,
297        // but making this configurable / running some benchmarks could be
298        // valuable.
299        let mut stream = blob.get().into_stream();
300
301        while let Some(value) = stream.next().await {
302            // Return early if any of the individual fetch requests return an error.
303            let response = match value {
304                Ok(v) => v,
305                Err(e) => {
306                    if let Some(e) = e.as_http_error() {
307                        if e.status() == StatusCode::NotFound {
308                            return Ok(None);
309                        }
310                    }
311
312                    return Err(ExternalError::from(e.context("azure blob get error")));
313                }
314            };
315
316            // Drive all of the fetch requests concurrently.
317            let metrics = self.metrics.clone();
318            requests.push_back(fetch_chunk(response, metrics));
319        }
320
321        // Await on all of our chunks.
322        let mut segments = SegmentedBytes::with_capacity(requests.len());
323        while let Some(body) = requests.next().await {
324            let segment = body.context("azure blob get body err")?;
325            segments.push(segment);
326        }
327
328        Ok(Some(segments))
329    }
330
331    async fn list_keys_and_metadata(
332        &self,
333        key_prefix: &str,
334        f: &mut (dyn FnMut(BlobMetadata) + Send + Sync),
335    ) -> Result<(), ExternalError> {
336        let blob_key_prefix = self.get_path(key_prefix);
337        let strippable_root_prefix = format!("{}/", self.prefix);
338
339        let mut stream = self
340            .client
341            .list_blobs()
342            .prefix(blob_key_prefix.clone())
343            .into_stream();
344
345        while let Some(response) = stream.next().await {
346            let response =
347                response.map_err(|e| ExternalError::from(e.context("azure blob list error")))?;
348
349            for blob in response.blobs.items {
350                let azure_storage_blobs::container::operations::list_blobs::BlobItem::Blob(blob) =
351                    blob
352                else {
353                    continue;
354                };
355
356                if let Some(key) = blob.name.strip_prefix(&strippable_root_prefix) {
357                    let size_in_bytes = blob.properties.content_length;
358                    f(BlobMetadata { key, size_in_bytes });
359                }
360            }
361        }
362
363        Ok(())
364    }
365
366    async fn set(&self, key: &str, value: Bytes) -> Result<(), ExternalError> {
367        let path = self.get_path(key);
368        let blob = self.client.blob_client(path);
369
370        blob.put_block_blob(value)
371            .await
372            .map_err(|e| ExternalError::from(e.context("azure blob put error")))?;
373
374        Ok(())
375    }
376
377    async fn delete(&self, key: &str) -> Result<Option<usize>, ExternalError> {
378        let path = self.get_path(key);
379        let blob = self.client.blob_client(path);
380
381        match blob.get_properties().await {
382            Ok(props) => {
383                let size = usize::cast_from(props.blob.properties.content_length);
384                blob.delete()
385                    .await
386                    .map_err(|e| ExternalError::from(e.context("azure blob delete error")))?;
387                Ok(Some(size))
388            }
389            Err(e) => {
390                if let Some(e) = e.as_http_error() {
391                    if e.status() == StatusCode::NotFound {
392                        return Ok(None);
393                    }
394                }
395
396                Err(ExternalError::from(e.context("azure blob error")))
397            }
398        }
399    }
400
401    async fn restore(&self, key: &str) -> Result<(), ExternalError> {
402        let path = self.get_path(key);
403        let blob = self.client.blob_client(&path);
404
405        match blob.get_properties().await {
406            Ok(_) => Ok(()),
407            Err(e) => {
408                if let Some(e) = e.as_http_error() {
409                    if e.status() == StatusCode::NotFound {
410                        return Err(Determinate::new(anyhow!(
411                            "azure blob error: unable to restore non-existent key {key}"
412                        ))
413                        .into());
414                    }
415                }
416
417                Err(ExternalError::from(e.context("azure blob error")))
418            }
419        }
420    }
421}
422
423/// If possible we'll pre-allocate a chunk of memory in lgalloc and write into
424/// that as we read bytes off the network.
425enum PreSizedBuffer {
426    Sized(MetricsRegion<u8>),
427    Unknown(SegmentedBytes),
428}
429
430#[cfg(test)]
431mod tests {
432    use tracing::info;
433
434    use crate::location::tests::blob_impl_test;
435
436    use super::*;
437
438    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `TLS_method` on OS `linux`
439    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
440    async fn azure_blob() -> Result<(), ExternalError> {
441        let config = match AzureBlobConfig::new_for_test()? {
442            Some(client) => client,
443            None => {
444                info!(
445                    "{} env not set: skipping test that uses external service",
446                    AzureBlobConfig::EXTERNAL_TESTS_AZURE_CONTAINER
447                );
448                return Ok(());
449            }
450        };
451
452        blob_impl_test(move |_path| {
453            let config = config.clone();
454            async move {
455                let config = AzureBlobConfig {
456                    metrics: config.metrics.clone(),
457                    client: config.client.clone(),
458                    cfg: Arc::new(ConfigSet::default()),
459                    prefix: config.prefix.clone(),
460                };
461                AzureBlob::open(config).await
462            }
463        })
464        .await
465    }
466}