Skip to main content

mz_persist/
azure.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An Azure Blob Storage implementation of [Blob] storage.
11
12use anyhow::{Context, anyhow};
13use async_trait::async_trait;
14use azure_core::{ExponentialRetryOptions, RetryOptions, StatusCode, TransportOptions};
15use azure_identity::create_default_credential;
16use azure_storage::{CloudLocation, EMULATOR_ACCOUNT, prelude::*};
17use azure_storage_blobs::blob::operations::GetBlobResponse;
18use azure_storage_blobs::prelude::*;
19use bytes::Bytes;
20use futures_util::StreamExt;
21use futures_util::stream::FuturesOrdered;
22use std::fmt::{Debug, Formatter};
23use std::sync::Arc;
24use std::time::Duration;
25use tracing::{info, warn};
26use url::Url;
27use uuid::Uuid;
28
29use mz_dyncfg::ConfigSet;
30use mz_ore::bytes::SegmentedBytes;
31use mz_ore::cast::CastFrom;
32use mz_ore::lgbytes::MetricsRegion;
33use mz_ore::metrics::MetricsRegistry;
34
35use crate::cfg::BlobKnobs;
36use crate::error::Error;
37use crate::location::{Blob, BlobMetadata, Determinate, ExternalError};
38use crate::metrics::S3BlobMetrics;
39
40/// Configuration for opening an [AzureBlob].
41#[derive(Clone, Debug)]
42pub struct AzureBlobConfig {
43    // The metrics struct here is a bit of a misnomer. We only need access
44    // to the LgBytes metrics, which has an Azure-specific field. For now,
45    // it saves considerable plumbing to reuse [S3BlobMetrics].
46    //
47    // TODO: spin up an AzureBlobMetrics and do the plumbing.
48    metrics: S3BlobMetrics,
49    client: ContainerClient,
50    prefix: String,
51    cfg: Arc<ConfigSet>,
52}
53
54impl AzureBlobConfig {
55    const EXTERNAL_TESTS_AZURE_CONTAINER: &'static str =
56        "MZ_PERSIST_EXTERNAL_STORAGE_TEST_AZURE_CONTAINER";
57
58    /// Returns a new [AzureBlobConfig] for use in production.
59    ///
60    /// Stores objects in the given container prepended with the (possibly empty)
61    /// prefix. Azure credentials must be available in the process or environment.
62    pub fn new(
63        account: String,
64        container: String,
65        prefix: String,
66        metrics: S3BlobMetrics,
67        url: Url,
68        knobs: Box<dyn BlobKnobs>,
69        cfg: Arc<ConfigSet>,
70    ) -> Result<Self, Error> {
71        let client = if account == EMULATOR_ACCOUNT {
72            info!("Connecting to Azure emulator");
73            ClientBuilder::with_location(
74                CloudLocation::Emulator {
75                    address: url.domain().expect("domain for Azure emulator").to_string(),
76                    port: url.port().expect("port for Azure emulator"),
77                },
78                StorageCredentials::emulator(),
79            )
80            .transport({
81                // Azure uses reqwest / hyper internally, but we specify a client explicitly to
82                // plumb through our timeouts.
83                TransportOptions::new(Arc::new(
84                    reqwest::ClientBuilder::new()
85                        .timeout(knobs.operation_attempt_timeout())
86                        .read_timeout(knobs.read_timeout())
87                        .connect_timeout(knobs.connect_timeout())
88                        .build()
89                        .expect("valid config for azure HTTP client"),
90                ))
91            })
92            .retry(RetryOptions::exponential(
93                ExponentialRetryOptions::default().max_total_elapsed(knobs.operation_timeout()),
94            ))
95            .blob_service_client()
96            .container_client(container)
97        } else {
98            let sas_credentials = match url.query() {
99                Some(query) => Some(StorageCredentials::sas_token(query)),
100                None => None,
101            };
102
103            let credentials = match sas_credentials {
104                Some(Ok(credentials)) => credentials,
105                Some(Err(err)) => {
106                    warn!("Failed to parse SAS token: {err}");
107                    // TODO: should we fallback here? Or can we fully rely on query params
108                    // to determine whether a SAS token was provided?
109                    StorageCredentials::token_credential(
110                        create_default_credential().expect("Azure default credentials"),
111                    )
112                }
113                None => {
114                    // Fall back to default credential stack to support auth modes like
115                    // workload identity that are injected into the environment
116                    StorageCredentials::token_credential(
117                        create_default_credential().expect("Azure default credentials"),
118                    )
119                }
120            };
121
122            let service_client = BlobServiceClient::new(account, credentials);
123            service_client.container_client(container)
124        };
125
126        // TODO: some auth modes like user-delegated SAS tokens are time-limited
127        // and need to be refreshed. This can be done through `service_client.update_credentials`
128        // but there'll be a fair bit of plumbing needed to make each mode work
129
130        Ok(AzureBlobConfig {
131            metrics,
132            client,
133            cfg,
134            prefix,
135        })
136    }
137
138    /// Returns a new [AzureBlobConfig] for use in unit tests.
139    pub fn new_for_test() -> Result<Option<Self>, Error> {
140        struct TestBlobKnobs;
141        impl Debug for TestBlobKnobs {
142            fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
143                f.debug_struct("TestBlobKnobs").finish_non_exhaustive()
144            }
145        }
146        impl BlobKnobs for TestBlobKnobs {
147            fn operation_timeout(&self) -> Duration {
148                Duration::from_secs(30)
149            }
150
151            fn operation_attempt_timeout(&self) -> Duration {
152                Duration::from_secs(10)
153            }
154
155            fn connect_timeout(&self) -> Duration {
156                Duration::from_secs(5)
157            }
158
159            fn read_timeout(&self) -> Duration {
160                Duration::from_secs(5)
161            }
162
163            fn is_cc_active(&self) -> bool {
164                false
165            }
166        }
167
168        let container_name = match std::env::var(Self::EXTERNAL_TESTS_AZURE_CONTAINER) {
169            Ok(container) => container,
170            Err(_) => {
171                assert!(
172                    !mz_ore::env::is_var_truthy("CI"),
173                    "CI is supposed to run this test but something has gone wrong!"
174                );
175                return Ok(None);
176            }
177        };
178
179        let prefix = Uuid::new_v4().to_string();
180        let metrics = S3BlobMetrics::new(&MetricsRegistry::new());
181
182        let config = AzureBlobConfig::new(
183            EMULATOR_ACCOUNT.to_string(),
184            container_name.clone(),
185            prefix,
186            metrics,
187            Url::parse(&format!("http://localhost:40111/{}", container_name)).expect("valid url"),
188            Box::new(TestBlobKnobs),
189            Arc::new(ConfigSet::default()),
190        )?;
191
192        Ok(Some(config))
193    }
194}
195
196/// Implementation of [Blob] backed by Azure Blob Storage.
197#[derive(Debug)]
198pub struct AzureBlob {
199    metrics: S3BlobMetrics,
200    client: ContainerClient,
201    prefix: String,
202    _cfg: Arc<ConfigSet>,
203}
204
205impl AzureBlob {
206    /// Opens the given location for non-exclusive read-write access.
207    pub async fn open(config: AzureBlobConfig) -> Result<Self, ExternalError> {
208        if config.client.service_client().account() == EMULATOR_ACCOUNT {
209            // TODO: we could move this logic into the test harness.
210            // it's currently here because it's surprisingly annoying to
211            // create the container out-of-band
212            if let Err(error) = config.client.create().await {
213                info!(
214                    ?error,
215                    "failed to create emulator container; this is expected on repeat runs"
216                );
217            }
218        }
219
220        let ret = AzureBlob {
221            metrics: config.metrics,
222            client: config.client,
223            prefix: config.prefix,
224            _cfg: config.cfg,
225        };
226
227        Ok(ret)
228    }
229
230    fn get_path(&self, key: &str) -> String {
231        format!("{}/{}", self.prefix, key)
232    }
233}
234
235#[async_trait]
236impl Blob for AzureBlob {
237    async fn get(&self, key: &str) -> Result<Option<SegmentedBytes>, ExternalError> {
238        let path = self.get_path(key);
239        let blob = self.client.blob_client(path);
240
241        /// Fetch the body of a single [`GetBlobResponse`].
242        async fn fetch_chunk(
243            response: GetBlobResponse,
244            metrics: S3BlobMetrics,
245        ) -> Result<Bytes, ExternalError> {
246            let content_length = response.blob.properties.content_length;
247
248            // Here we're being quite defensive. If `content_length` comes back
249            // as 0 it's most likely incorrect. In that case we'll copy bytes
250            // of the network into a growable buffer, then copy the entire
251            // buffer into lgalloc.
252            let mut buffer = match content_length {
253                1.. => {
254                    let region = metrics
255                        .lgbytes
256                        .persist_azure
257                        .new_region(usize::cast_from(content_length));
258                    PreSizedBuffer::Sized(region)
259                }
260                0 => PreSizedBuffer::Unknown(SegmentedBytes::new()),
261            };
262
263            let mut body = response.data;
264            while let Some(value) = body.next().await {
265                let value = value
266                    .map_err(|e| ExternalError::from(e.context("azure blob get body error")))?;
267
268                match &mut buffer {
269                    PreSizedBuffer::Sized(region) => region.extend_from_slice(&value),
270                    PreSizedBuffer::Unknown(segments) => segments.push(value),
271                }
272            }
273
274            // Spill our bytes to lgalloc, if they aren't already.
275            let lgbytes: Bytes = match buffer {
276                PreSizedBuffer::Sized(region) => region.into(),
277                // Now that we've collected all of the segments, we know the size of our region.
278                PreSizedBuffer::Unknown(segments) => {
279                    let mut region = metrics.lgbytes.persist_azure.new_region(segments.len());
280                    for segment in segments.into_segments() {
281                        region.extend_from_slice(segment.as_ref());
282                    }
283                    region.into()
284                }
285            };
286
287            // Report if the content-length header didn't match the number of
288            // bytes we read from the network.
289            if content_length != u64::cast_from(lgbytes.len()) {
290                metrics.get_invalid_resp.inc();
291            }
292
293            Ok(lgbytes)
294        }
295
296        let mut requests = FuturesOrdered::new();
297        // TODO: the default chunk size is 1MB. We have not tried tuning it,
298        // but making this configurable / running some benchmarks could be
299        // valuable.
300        let mut stream = blob.get().into_stream();
301
302        while let Some(value) = stream.next().await {
303            // Return early if any of the individual fetch requests return an error.
304            let response = match value {
305                Ok(v) => v,
306                Err(e) => {
307                    if let Some(e) = e.as_http_error() {
308                        if e.status() == StatusCode::NotFound {
309                            return Ok(None);
310                        }
311                    }
312
313                    return Err(ExternalError::from(e.context("azure blob get error")));
314                }
315            };
316
317            // Drive all of the fetch requests concurrently.
318            let metrics = self.metrics.clone();
319            requests.push_back(fetch_chunk(response, metrics));
320        }
321
322        // Await on all of our chunks.
323        let mut segments = SegmentedBytes::with_capacity(requests.len());
324        while let Some(body) = requests.next().await {
325            let segment = body.context("azure blob get body err")?;
326            segments.push(segment);
327        }
328
329        Ok(Some(segments))
330    }
331
332    async fn list_keys_and_metadata(
333        &self,
334        key_prefix: &str,
335        f: &mut (dyn FnMut(BlobMetadata) + Send + Sync),
336    ) -> Result<(), ExternalError> {
337        let blob_key_prefix = self.get_path(key_prefix);
338        let strippable_root_prefix = format!("{}/", self.prefix);
339
340        let mut stream = self
341            .client
342            .list_blobs()
343            .prefix(blob_key_prefix.clone())
344            .into_stream();
345
346        while let Some(response) = stream.next().await {
347            let response =
348                response.map_err(|e| ExternalError::from(e.context("azure blob list error")))?;
349
350            for blob in response.blobs.items {
351                let azure_storage_blobs::container::operations::list_blobs::BlobItem::Blob(blob) =
352                    blob
353                else {
354                    continue;
355                };
356
357                if let Some(key) = blob.name.strip_prefix(&strippable_root_prefix) {
358                    let size_in_bytes = blob.properties.content_length;
359                    f(BlobMetadata { key, size_in_bytes });
360                }
361            }
362        }
363
364        Ok(())
365    }
366
367    async fn set(&self, key: &str, value: Bytes) -> Result<(), ExternalError> {
368        let path = self.get_path(key);
369        let blob = self.client.blob_client(path);
370
371        blob.put_block_blob(value)
372            .await
373            .map_err(|e| ExternalError::from(e.context("azure blob put error")))?;
374
375        Ok(())
376    }
377
378    async fn delete(&self, key: &str) -> Result<Option<usize>, ExternalError> {
379        let path = self.get_path(key);
380        let blob = self.client.blob_client(path);
381
382        match blob.get_properties().await {
383            Ok(props) => {
384                let size = usize::cast_from(props.blob.properties.content_length);
385                blob.delete()
386                    .await
387                    .map_err(|e| ExternalError::from(e.context("azure blob delete error")))?;
388                Ok(Some(size))
389            }
390            Err(e) => {
391                if let Some(e) = e.as_http_error() {
392                    if e.status() == StatusCode::NotFound {
393                        return Ok(None);
394                    }
395                }
396
397                Err(ExternalError::from(e.context("azure blob error")))
398            }
399        }
400    }
401
402    async fn restore(&self, key: &str) -> Result<(), ExternalError> {
403        let path = self.get_path(key);
404        let blob = self.client.blob_client(&path);
405
406        match blob.get_properties().await {
407            Ok(_) => Ok(()),
408            Err(e) => {
409                if let Some(e) = e.as_http_error() {
410                    if e.status() == StatusCode::NotFound {
411                        return Err(Determinate::new(anyhow!(
412                            "azure blob error: unable to restore non-existent key {key}"
413                        ))
414                        .into());
415                    }
416                }
417
418                Err(ExternalError::from(e.context("azure blob error")))
419            }
420        }
421    }
422}
423
424/// If possible we'll pre-allocate a chunk of memory in lgalloc and write into
425/// that as we read bytes off the network.
426enum PreSizedBuffer {
427    Sized(MetricsRegion<u8>),
428    Unknown(SegmentedBytes),
429}
430
431#[cfg(test)]
432mod tests {
433    use tracing::info;
434
435    use crate::location::tests::blob_impl_test;
436
437    use super::*;
438
439    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `TLS_method` on OS `linux`
440    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
441    async fn azure_blob() -> Result<(), ExternalError> {
442        let config = match AzureBlobConfig::new_for_test()? {
443            Some(client) => client,
444            None => {
445                info!(
446                    "{} env not set: skipping test that uses external service",
447                    AzureBlobConfig::EXTERNAL_TESTS_AZURE_CONTAINER
448                );
449                return Ok(());
450            }
451        };
452
453        blob_impl_test(move |_path| {
454            let config = config.clone();
455            async move {
456                let config = AzureBlobConfig {
457                    metrics: config.metrics.clone(),
458                    client: config.client.clone(),
459                    cfg: Arc::new(ConfigSet::default()),
460                    prefix: config.prefix.clone(),
461                };
462                AzureBlob::open(config).await
463            }
464        })
465        .await
466    }
467}