opendal/layers/
concurrent_limit.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::Debug;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::Context;
22use std::task::Poll;
23
24use futures::Stream;
25use futures::StreamExt;
26use tokio::sync::OwnedSemaphorePermit;
27use tokio::sync::Semaphore;
28
29use crate::raw::*;
30use crate::*;
31
32/// Add concurrent request limit.
33///
34/// # Notes
35///
36/// Users can control how many concurrent connections could be established
37/// between OpenDAL and underlying storage services.
38///
39/// All operators wrapped by this layer will share a common semaphore. This
40/// allows you to reuse the same layer across multiple operators, ensuring
41/// that the total number of concurrent requests across the entire
42/// application does not exceed the limit.
43///
44/// # Examples
45///
46/// Add a concurrent limit layer to the operator:
47///
48/// ```no_run
49/// # use opendal::layers::ConcurrentLimitLayer;
50/// # use opendal::services;
51/// # use opendal::Operator;
52/// # use opendal::Result;
53///
54/// # fn main() -> Result<()> {
55/// let _ = Operator::new(services::Memory::default())?
56///     .layer(ConcurrentLimitLayer::new(1024))
57///     .finish();
58/// Ok(())
59/// # }
60/// ```
61///
62/// Share a concurrent limit layer between the operators:
63///
64/// ```no_run
65/// # use opendal::layers::ConcurrentLimitLayer;
66/// # use opendal::services;
67/// # use opendal::Operator;
68/// # use opendal::Result;
69///
70/// # fn main() -> Result<()> {
71/// let limit = ConcurrentLimitLayer::new(1024);
72///
73/// let _operator_a = Operator::new(services::Memory::default())?
74///     .layer(limit.clone())
75///     .finish();
76/// let _operator_b = Operator::new(services::Memory::default())?
77///     .layer(limit.clone())
78///     .finish();
79///
80/// Ok(())
81/// # }
82/// ```
83#[derive(Clone)]
84pub struct ConcurrentLimitLayer {
85    operation_semaphore: Arc<Semaphore>,
86    http_semaphore: Option<Arc<Semaphore>>,
87}
88
89impl ConcurrentLimitLayer {
90    /// Create a new ConcurrentLimitLayer will specify permits.
91    ///
92    /// This permits will applied to all operations.
93    pub fn new(permits: usize) -> Self {
94        Self {
95            operation_semaphore: Arc::new(Semaphore::new(permits)),
96            http_semaphore: None,
97        }
98    }
99
100    /// Set a concurrent limit for HTTP requests.
101    ///
102    /// This will limit the number of concurrent HTTP requests made by the
103    /// operator.
104    pub fn with_http_concurrent_limit(mut self, permits: usize) -> Self {
105        self.http_semaphore = Some(Arc::new(Semaphore::new(permits)));
106        self
107    }
108}
109
110impl<A: Access> Layer<A> for ConcurrentLimitLayer {
111    type LayeredAccess = ConcurrentLimitAccessor<A>;
112
113    fn layer(&self, inner: A) -> Self::LayeredAccess {
114        let info = inner.info();
115
116        // Update http client with metrics http fetcher.
117        info.update_http_client(|client| {
118            HttpClient::with(ConcurrentLimitHttpFetcher {
119                inner: client.into_inner(),
120                http_semaphore: self.http_semaphore.clone(),
121            })
122        });
123
124        ConcurrentLimitAccessor {
125            inner,
126            semaphore: self.operation_semaphore.clone(),
127        }
128    }
129}
130
131pub struct ConcurrentLimitHttpFetcher {
132    inner: HttpFetcher,
133    http_semaphore: Option<Arc<Semaphore>>,
134}
135
136impl HttpFetch for ConcurrentLimitHttpFetcher {
137    async fn fetch(&self, req: http::Request<Buffer>) -> Result<http::Response<HttpBody>> {
138        let Some(semaphore) = self.http_semaphore.clone() else {
139            return self.inner.fetch(req).await;
140        };
141
142        let permit = semaphore
143            .acquire_owned()
144            .await
145            .expect("semaphore must be valid");
146
147        let resp = self.inner.fetch(req).await?;
148        let (parts, body) = resp.into_parts();
149        let body = body.map_inner(|s| {
150            Box::new(ConcurrentLimitStream {
151                inner: s,
152                _permit: permit,
153            })
154        });
155        Ok(http::Response::from_parts(parts, body))
156    }
157}
158
159pub struct ConcurrentLimitStream<S> {
160    inner: S,
161    // Hold on this permit until this reader has been dropped.
162    _permit: OwnedSemaphorePermit,
163}
164
165impl<S> Stream for ConcurrentLimitStream<S>
166where
167    S: Stream<Item = Result<Buffer>> + Unpin + 'static,
168{
169    type Item = Result<Buffer>;
170
171    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
172        self.inner.poll_next_unpin(cx)
173    }
174}
175
176#[derive(Debug, Clone)]
177pub struct ConcurrentLimitAccessor<A: Access> {
178    inner: A,
179    semaphore: Arc<Semaphore>,
180}
181
182impl<A: Access> LayeredAccess for ConcurrentLimitAccessor<A> {
183    type Inner = A;
184    type Reader = ConcurrentLimitWrapper<A::Reader>;
185    type Writer = ConcurrentLimitWrapper<A::Writer>;
186    type Lister = ConcurrentLimitWrapper<A::Lister>;
187    type Deleter = ConcurrentLimitWrapper<A::Deleter>;
188
189    fn inner(&self) -> &Self::Inner {
190        &self.inner
191    }
192
193    async fn create_dir(&self, path: &str, args: OpCreateDir) -> Result<RpCreateDir> {
194        let _permit = self
195            .semaphore
196            .acquire()
197            .await
198            .expect("semaphore must be valid");
199
200        self.inner.create_dir(path, args).await
201    }
202
203    async fn read(&self, path: &str, args: OpRead) -> Result<(RpRead, Self::Reader)> {
204        let permit = self
205            .semaphore
206            .clone()
207            .acquire_owned()
208            .await
209            .expect("semaphore must be valid");
210
211        self.inner
212            .read(path, args)
213            .await
214            .map(|(rp, r)| (rp, ConcurrentLimitWrapper::new(r, permit)))
215    }
216
217    async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> {
218        let permit = self
219            .semaphore
220            .clone()
221            .acquire_owned()
222            .await
223            .expect("semaphore must be valid");
224
225        self.inner
226            .write(path, args)
227            .await
228            .map(|(rp, w)| (rp, ConcurrentLimitWrapper::new(w, permit)))
229    }
230
231    async fn stat(&self, path: &str, args: OpStat) -> Result<RpStat> {
232        let _permit = self
233            .semaphore
234            .acquire()
235            .await
236            .expect("semaphore must be valid");
237
238        self.inner.stat(path, args).await
239    }
240
241    async fn delete(&self) -> Result<(RpDelete, Self::Deleter)> {
242        let permit = self
243            .semaphore
244            .clone()
245            .acquire_owned()
246            .await
247            .expect("semaphore must be valid");
248
249        self.inner
250            .delete()
251            .await
252            .map(|(rp, w)| (rp, ConcurrentLimitWrapper::new(w, permit)))
253    }
254
255    async fn list(&self, path: &str, args: OpList) -> Result<(RpList, Self::Lister)> {
256        let permit = self
257            .semaphore
258            .clone()
259            .acquire_owned()
260            .await
261            .expect("semaphore must be valid");
262
263        self.inner
264            .list(path, args)
265            .await
266            .map(|(rp, s)| (rp, ConcurrentLimitWrapper::new(s, permit)))
267    }
268}
269
270pub struct ConcurrentLimitWrapper<R> {
271    inner: R,
272
273    // Hold on this permit until this reader has been dropped.
274    _permit: OwnedSemaphorePermit,
275}
276
277impl<R> ConcurrentLimitWrapper<R> {
278    fn new(inner: R, permit: OwnedSemaphorePermit) -> Self {
279        Self {
280            inner,
281            _permit: permit,
282        }
283    }
284}
285
286impl<R: oio::Read> oio::Read for ConcurrentLimitWrapper<R> {
287    async fn read(&mut self) -> Result<Buffer> {
288        self.inner.read().await
289    }
290}
291
292impl<R: oio::Write> oio::Write for ConcurrentLimitWrapper<R> {
293    async fn write(&mut self, bs: Buffer) -> Result<()> {
294        self.inner.write(bs).await
295    }
296
297    async fn close(&mut self) -> Result<Metadata> {
298        self.inner.close().await
299    }
300
301    async fn abort(&mut self) -> Result<()> {
302        self.inner.abort().await
303    }
304}
305
306impl<R: oio::List> oio::List for ConcurrentLimitWrapper<R> {
307    async fn next(&mut self) -> Result<Option<oio::Entry>> {
308        self.inner.next().await
309    }
310}
311
312impl<R: oio::Delete> oio::Delete for ConcurrentLimitWrapper<R> {
313    fn delete(&mut self, path: &str, args: OpDelete) -> Result<()> {
314        self.inner.delete(path, args)
315    }
316
317    async fn flush(&mut self) -> Result<usize> {
318        self.inner.flush().await
319    }
320}