diff --git a/core/Cargo.lock b/core/Cargo.lock index 852e81b01d52..e21e517d84f4 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -5773,6 +5773,7 @@ dependencies = [ "http 1.4.0", "mea", "opendal-core", + "tokio", ] [[package]] diff --git a/core/layers/concurrent-limit/Cargo.toml b/core/layers/concurrent-limit/Cargo.toml index da80d3e45d2a..d57f8a2a0b66 100644 --- a/core/layers/concurrent-limit/Cargo.toml +++ b/core/layers/concurrent-limit/Cargo.toml @@ -38,3 +38,4 @@ opendal-core = { path = "../../core", version = "0.55.0", default-features = fal [dev-dependencies] opendal-core = { path = "../../core", version = "0.55.0" } +tokio = { workspace = true, features = ["macros", "rt", "time"] } diff --git a/core/layers/concurrent-limit/src/lib.rs b/core/layers/concurrent-limit/src/lib.rs index bde32113e6c5..d5a3f4f06759 100644 --- a/core/layers/concurrent-limit/src/lib.rs +++ b/core/layers/concurrent-limit/src/lib.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::fmt::Debug; +use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::task::Context; @@ -29,6 +29,25 @@ use mea::semaphore::Semaphore; use opendal_core::raw::*; use opendal_core::*; +/// ConcurrentLimitSemaphore abstracts a semaphore-like concurrency primitive +/// that yields an owned permit released on drop. +pub trait ConcurrentLimitSemaphore: Send + Sync + Clone + Unpin + 'static { + /// The owned permit type associated with the semaphore. Dropping it + /// must release the permit back to the semaphore. + type Permit: Send + Sync + 'static; + + /// Acquire an owned permit asynchronously. + fn acquire(&self) -> impl Future + MaybeSend; +} + +impl ConcurrentLimitSemaphore for Arc { + type Permit = OwnedSemaphorePermit; + + async fn acquire(&self) -> Self::Permit { + self.clone().acquire_owned(1).await + } +} + /// Add concurrent request limit. /// /// # Notes @@ -50,7 +69,6 @@ use opendal_core::*; /// # use opendal_core::services; /// # use opendal_core::Operator; /// # use opendal_core::Result; -/// /// # fn main() -> Result<()> { /// let _ = Operator::new(services::Memory::default())? /// .layer(ConcurrentLimitLayer::new(1024)) @@ -66,7 +84,6 @@ use opendal_core::*; /// # use opendal_core::services; /// # use opendal_core::Operator; /// # use opendal_core::Result; -/// /// # fn main() -> Result<()> { /// let limit = ConcurrentLimitLayer::new(1024); /// @@ -81,41 +98,67 @@ use opendal_core::*; /// # } /// ``` #[derive(Clone)] -pub struct ConcurrentLimitLayer { - operation_semaphore: Arc, - http_semaphore: Option>, +pub struct ConcurrentLimitLayer> { + operation_semaphore: S, + http_semaphore: Option, } -impl ConcurrentLimitLayer { - /// Create a new ConcurrentLimitLayer will specify permits. +impl ConcurrentLimitLayer> { + /// Create a new `ConcurrentLimitLayer` with the specified number of + /// permits. /// - /// This permits will applied to all operations. + /// These permits will be applied to all operations. pub fn new(permits: usize) -> Self { + Self::with_semaphore(Arc::new(Semaphore::new(permits))) + } + + /// Set a concurrent limit for HTTP requests. + /// + /// This convenience helper constructs a new semaphore with the specified + /// number of permits and calls [`ConcurrentLimitLayer::with_http_semaphore`]. + /// Use [`ConcurrentLimitLayer::with_http_semaphore`] directly when reusing + /// a shared semaphore. + pub fn with_http_concurrent_limit(self, permits: usize) -> Self { + self.with_http_semaphore(Arc::new(Semaphore::new(permits))) + } +} + +impl ConcurrentLimitLayer { + /// Create a layer with any ConcurrentLimitSemaphore implementation. + /// + /// ``` + /// # use std::sync::Arc; + /// # use mea::semaphore::Semaphore; + /// # use opendal_layer_concurrent_limit::ConcurrentLimitLayer; + /// let semaphore = Arc::new(Semaphore::new(1024)); + /// let _layer = ConcurrentLimitLayer::with_semaphore(semaphore); + /// ``` + pub fn with_semaphore(operation_semaphore: S) -> Self { Self { - operation_semaphore: Arc::new(Semaphore::new(permits)), + operation_semaphore, http_semaphore: None, } } - /// Set a concurrent limit for HTTP requests. - /// - /// This will limit the number of concurrent HTTP requests made by the - /// operator. - pub fn with_http_concurrent_limit(mut self, permits: usize) -> Self { - self.http_semaphore = Some(Arc::new(Semaphore::new(permits))); + /// Provide a custom HTTP concurrency semaphore instance. + pub fn with_http_semaphore(mut self, semaphore: S) -> Self { + self.http_semaphore = Some(semaphore); self } } -impl Layer for ConcurrentLimitLayer { - type LayeredAccess = ConcurrentLimitAccessor; +impl Layer for ConcurrentLimitLayer +where + S::Permit: Unpin, +{ + type LayeredAccess = ConcurrentLimitAccessor; fn layer(&self, inner: A) -> Self::LayeredAccess { let info = inner.info(); - // Update http client with metrics http fetcher. + // Update http client with concurrent limit http fetcher. info.update_http_client(|client| { - HttpClient::with(ConcurrentLimitHttpFetcher { + HttpClient::with(ConcurrentLimitHttpFetcher:: { inner: client.into_inner(), http_semaphore: self.http_semaphore.clone(), }) @@ -128,23 +171,26 @@ impl Layer for ConcurrentLimitLayer { } } -pub struct ConcurrentLimitHttpFetcher { +pub struct ConcurrentLimitHttpFetcher { inner: HttpFetcher, - http_semaphore: Option>, + http_semaphore: Option, } -impl HttpFetch for ConcurrentLimitHttpFetcher { +impl HttpFetch for ConcurrentLimitHttpFetcher +where + S::Permit: Unpin, +{ async fn fetch(&self, req: http::Request) -> Result> { let Some(semaphore) = self.http_semaphore.clone() else { return self.inner.fetch(req).await; }; - let permit = semaphore.acquire_owned(1).await; + let permit = semaphore.acquire().await; let resp = self.inner.fetch(req).await?; let (parts, body) = resp.into_parts(); let body = body.map_inner(|s| { - Box::new(ConcurrentLimitStream { + Box::new(ConcurrentLimitStream::<_, S::Permit> { inner: s, _permit: permit, }) @@ -153,48 +199,62 @@ impl HttpFetch for ConcurrentLimitHttpFetcher { } } -pub struct ConcurrentLimitStream { +pub struct ConcurrentLimitStream { inner: S, // Hold on this permit until this reader has been dropped. - _permit: OwnedSemaphorePermit, + _permit: P, } -impl Stream for ConcurrentLimitStream +impl Stream for ConcurrentLimitStream where S: Stream> + Unpin + 'static, + P: Unpin, { type Item = Result; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_next_unpin(cx) + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Safe due to Unpin bounds on S and P (thus on Self). + let this = self.get_mut(); + this.inner.poll_next_unpin(cx) } } -#[derive(Debug, Clone)] -pub struct ConcurrentLimitAccessor { +#[derive(Clone)] +pub struct ConcurrentLimitAccessor { inner: A, - semaphore: Arc, + semaphore: S, +} + +impl std::fmt::Debug for ConcurrentLimitAccessor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConcurrentLimitAccessor") + .field("inner", &self.inner) + .finish_non_exhaustive() + } } -impl LayeredAccess for ConcurrentLimitAccessor { +impl LayeredAccess for ConcurrentLimitAccessor +where + S::Permit: Unpin, +{ type Inner = A; - type Reader = ConcurrentLimitWrapper; - type Writer = ConcurrentLimitWrapper; - type Lister = ConcurrentLimitWrapper; - type Deleter = ConcurrentLimitWrapper; + type Reader = ConcurrentLimitWrapper; + type Writer = ConcurrentLimitWrapper; + type Lister = ConcurrentLimitWrapper; + type Deleter = ConcurrentLimitWrapper; fn inner(&self) -> &Self::Inner { &self.inner } async fn create_dir(&self, path: &str, args: OpCreateDir) -> Result { - let _permit = self.semaphore.acquire(1).await; + let _permit = self.semaphore.acquire().await; self.inner.create_dir(path, args).await } async fn read(&self, path: &str, args: OpRead) -> Result<(RpRead, Self::Reader)> { - let permit = self.semaphore.clone().acquire_owned(1).await; + let permit = self.semaphore.acquire().await; self.inner .read(path, args) @@ -203,7 +263,7 @@ impl LayeredAccess for ConcurrentLimitAccessor { } async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> { - let permit = self.semaphore.clone().acquire_owned(1).await; + let permit = self.semaphore.acquire().await; self.inner .write(path, args) @@ -212,13 +272,13 @@ impl LayeredAccess for ConcurrentLimitAccessor { } async fn stat(&self, path: &str, args: OpStat) -> Result { - let _permit = self.semaphore.acquire(1).await; + let _permit = self.semaphore.acquire().await; self.inner.stat(path, args).await } async fn delete(&self) -> Result<(RpDelete, Self::Deleter)> { - let permit = self.semaphore.clone().acquire_owned(1).await; + let permit = self.semaphore.acquire().await; self.inner .delete() @@ -227,7 +287,7 @@ impl LayeredAccess for ConcurrentLimitAccessor { } async fn list(&self, path: &str, args: OpList) -> Result<(RpList, Self::Lister)> { - let permit = self.semaphore.clone().acquire_owned(1).await; + let permit = self.semaphore.acquire().await; self.inner .list(path, args) @@ -236,15 +296,15 @@ impl LayeredAccess for ConcurrentLimitAccessor { } } -pub struct ConcurrentLimitWrapper { +pub struct ConcurrentLimitWrapper { inner: R, // Hold on this permit until this reader has been dropped. - _permit: OwnedSemaphorePermit, + _permit: P, } -impl ConcurrentLimitWrapper { - fn new(inner: R, permit: OwnedSemaphorePermit) -> Self { +impl ConcurrentLimitWrapper { + fn new(inner: R, permit: P) -> Self { Self { inner, _permit: permit, @@ -252,13 +312,13 @@ impl ConcurrentLimitWrapper { } } -impl oio::Read for ConcurrentLimitWrapper { +impl oio::Read for ConcurrentLimitWrapper { async fn read(&mut self) -> Result { self.inner.read().await } } -impl oio::Write for ConcurrentLimitWrapper { +impl oio::Write for ConcurrentLimitWrapper { async fn write(&mut self, bs: Buffer) -> Result<()> { self.inner.write(bs).await } @@ -272,13 +332,15 @@ impl oio::Write for ConcurrentLimitWrapper { } } -impl oio::List for ConcurrentLimitWrapper { +impl oio::List for ConcurrentLimitWrapper { async fn next(&mut self) -> Result> { self.inner.next().await } } -impl oio::Delete for ConcurrentLimitWrapper { +impl oio::Delete + for ConcurrentLimitWrapper +{ async fn delete(&mut self, path: &str, args: OpDelete) -> Result<()> { self.inner.delete(path, args).await } @@ -287,3 +349,84 @@ impl oio::Delete for ConcurrentLimitWrapper { self.inner.close().await } } + +#[cfg(test)] +mod tests { + use super::*; + use opendal_core::Operator; + use opendal_core::services; + use std::sync::Arc; + use std::time::Duration; + use tokio::time::timeout; + + use futures::stream; + use http::Response; + + #[tokio::test] + async fn operation_semaphore_can_be_shared() { + let semaphore = Arc::new(Semaphore::new(1)); + let layer = ConcurrentLimitLayer::with_semaphore(semaphore.clone()); + + let permit = semaphore.clone().acquire_owned(1).await; + + let op = Operator::new(services::Memory::default()) + .expect("operator must build") + .layer(layer) + .finish(); + + let blocked = timeout(Duration::from_millis(50), op.stat("any")).await; + assert!( + blocked.is_err(), + "operation should be limited by shared semaphore" + ); + + drop(permit); + + let completed = timeout(Duration::from_millis(50), op.stat("any")).await; + assert!( + completed.is_ok(), + "operation should proceed once permit is released" + ); + } + + #[tokio::test] + async fn http_semaphore_holds_until_body_dropped() { + struct DummyFetcher; + + impl HttpFetch for DummyFetcher { + async fn fetch(&self, _req: http::Request) -> Result> { + let body = HttpBody::new(stream::empty(), None); + Ok(Response::builder() + .status(http::StatusCode::OK) + .body(body) + .expect("response must build")) + } + } + + let semaphore = Arc::new(Semaphore::new(1)); + let layer = ConcurrentLimitLayer::new(1).with_http_semaphore(semaphore.clone()); + let fetcher = ConcurrentLimitHttpFetcher::> { + inner: HttpClient::with(DummyFetcher).into_inner(), + http_semaphore: layer.http_semaphore.clone(), + }; + + let request = http::Request::builder() + .uri("http://example.invalid/") + .body(Buffer::new()) + .expect("request must build"); + let _resp = fetcher + .fetch(request) + .await + .expect("first fetch should succeed"); + + let request = http::Request::builder() + .uri("http://example.invalid/") + .body(Buffer::new()) + .expect("request must build"); + let blocked = timeout(Duration::from_millis(50), fetcher.fetch(request)).await; + assert!( + blocked.is_err(), + "http fetch should block while the body holds the permit" + ); + } +}