diff --git a/rust/lance-core/src/utils.rs b/rust/lance-core/src/utils.rs index 663454e001b..565036311f9 100644 --- a/rust/lance-core/src/utils.rs +++ b/rust/lance-core/src/utils.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors pub mod address; +pub mod aimd; pub mod assume; pub mod backoff; pub mod bit; diff --git a/rust/lance-core/src/utils/aimd.rs b/rust/lance-core/src/utils/aimd.rs new file mode 100644 index 00000000000..0cbae68ca71 --- /dev/null +++ b/rust/lance-core/src/utils/aimd.rs @@ -0,0 +1,623 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! AIMD (Additive Increase / Multiplicative Decrease) rate controller. +//! +//! This module provides a reusable AIMD algorithm for dynamically adjusting +//! request rates. On success windows, the rate increases additively. On +//! windows with throttle signals, the rate decreases multiplicatively. +//! +//! The algorithm operates in discrete time windows. At the end of each window, +//! the throttle ratio (throttled / total) is compared against a threshold: +//! - Above threshold: `rate = max(rate * decrease_factor, min_rate)` +//! - At or below threshold: `rate = min(rate + additive_increment, max_rate)` + +use std::sync::Mutex; +use std::time::Duration; + +use crate::Result; + +/// Configuration for the AIMD rate controller. +/// +/// Use builder methods to customize. Defaults are tuned for cloud object stores +/// and will start at about 40% of the max rate and require 10 seconds to reach +/// the max rate. +/// +/// - initial_rate: 2000 req/s +/// - min_rate: 1 req/s +/// - max_rate: 5000 req/s (0.0 disables ceiling) +/// - decrease_factor: 0.5 (halve on throttle) +/// - additive_increment: 300 req/s per success window +/// - window_duration: 1 second +/// - throttle_threshold: 0.0 (any throttle triggers decrease) +#[derive(Debug, Clone)] +pub struct AimdConfig { + pub initial_rate: f64, + pub min_rate: f64, + pub max_rate: f64, + pub decrease_factor: f64, + pub additive_increment: f64, + pub window_duration: Duration, + pub throttle_threshold: f64, +} + +impl Default for AimdConfig { + fn default() -> Self { + Self { + initial_rate: 2000.0, + min_rate: 1.0, + max_rate: 5000.0, + decrease_factor: 0.5, + additive_increment: 300.0, + window_duration: Duration::from_secs(1), + throttle_threshold: 0.0, + } + } +} + +impl AimdConfig { + pub fn with_initial_rate(self, initial_rate: f64) -> Self { + Self { + initial_rate, + ..self + } + } + + pub fn with_min_rate(self, min_rate: f64) -> Self { + Self { min_rate, ..self } + } + + pub fn with_max_rate(self, max_rate: f64) -> Self { + Self { max_rate, ..self } + } + + pub fn with_decrease_factor(self, decrease_factor: f64) -> Self { + Self { + decrease_factor, + ..self + } + } + + pub fn with_additive_increment(self, additive_increment: f64) -> Self { + Self { + additive_increment, + ..self + } + } + + pub fn with_window_duration(self, window_duration: Duration) -> Self { + Self { + window_duration, + ..self + } + } + + pub fn with_throttle_threshold(self, throttle_threshold: f64) -> Self { + Self { + throttle_threshold, + ..self + } + } + + /// Validate that the configuration values are sensible. + pub fn validate(&self) -> Result<()> { + if self.initial_rate <= 0.0 { + return Err(crate::Error::invalid_input(format!( + "initial_rate must be positive, got {}", + self.initial_rate + ))); + } + if self.min_rate <= 0.0 { + return Err(crate::Error::invalid_input(format!( + "min_rate must be positive, got {}", + self.min_rate + ))); + } + if self.max_rate < 0.0 { + return Err(crate::Error::invalid_input(format!( + "max_rate must be non-negative (0.0 = no ceiling), got {}", + self.max_rate + ))); + } + if self.max_rate > 0.0 && self.min_rate > self.max_rate { + return Err(crate::Error::invalid_input(format!( + "min_rate ({}) must not exceed max_rate ({})", + self.min_rate, self.max_rate + ))); + } + if self.decrease_factor <= 0.0 || self.decrease_factor >= 1.0 { + return Err(crate::Error::invalid_input(format!( + "decrease_factor must be in (0, 1), got {}", + self.decrease_factor + ))); + } + if self.additive_increment <= 0.0 { + return Err(crate::Error::invalid_input(format!( + "additive_increment must be positive, got {}", + self.additive_increment + ))); + } + if self.window_duration.is_zero() { + return Err(crate::Error::invalid_input( + "window_duration must be non-zero", + )); + } + if !(0.0..=1.0).contains(&self.throttle_threshold) { + return Err(crate::Error::invalid_input(format!( + "throttle_threshold must be in [0.0, 1.0], got {}", + self.throttle_threshold + ))); + } + if self.max_rate > 0.0 && self.initial_rate > self.max_rate { + return Err(crate::Error::invalid_input(format!( + "initial_rate ({}) must not exceed max_rate ({})", + self.initial_rate, self.max_rate + ))); + } + if self.initial_rate < self.min_rate { + return Err(crate::Error::invalid_input(format!( + "initial_rate ({}) must not be below min_rate ({})", + self.initial_rate, self.min_rate + ))); + } + Ok(()) + } +} + +/// Outcome of a single request, used to feed the AIMD controller. +/// +/// Non-throttle errors (e.g. 404, network timeout) should be mapped to +/// `Success` since they don't indicate capacity problems. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RequestOutcome { + Success, + Throttled, +} + +struct AimdState { + rate: f64, + window_start: std::time::Instant, + success_count: u64, + throttle_count: u64, +} + +/// AIMD rate controller. +/// +/// Thread-safe: uses an internal `Mutex` to protect state. The lock is held +/// only briefly during `record_outcome` and `current_rate`. +pub struct AimdController { + config: AimdConfig, + state: Mutex, +} + +impl std::fmt::Debug for AimdController { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AimdController") + .field("config", &self.config) + .field("rate", &self.current_rate()) + .finish() + } +} + +impl AimdController { + /// Create a new AIMD controller with the given configuration. + pub fn new(config: AimdConfig) -> Result { + config.validate()?; + let rate = config.initial_rate; + Ok(Self { + config, + state: Mutex::new(AimdState { + rate, + window_start: std::time::Instant::now(), + success_count: 0, + throttle_count: 0, + }), + }) + } + + /// Record a request outcome and return the current rate. + /// + /// If the current time window has expired, the rate is adjusted before + /// recording the new outcome in a fresh window. + pub fn record_outcome(&self, outcome: RequestOutcome) -> f64 { + let mut state = self.state.lock().unwrap(); + self.record_outcome_inner(&mut state, outcome, std::time::Instant::now()) + } + + fn record_outcome_inner( + &self, + state: &mut AimdState, + outcome: RequestOutcome, + now: std::time::Instant, + ) -> f64 { + // Check if the window has expired + let elapsed = now.duration_since(state.window_start); + if elapsed >= self.config.window_duration { + let total = state.success_count + state.throttle_count; + if total > 0 { + let throttle_ratio = state.throttle_count as f64 / total as f64; + if throttle_ratio > self.config.throttle_threshold { + // Multiplicative decrease + state.rate = + (state.rate * self.config.decrease_factor).max(self.config.min_rate); + } else { + // Additive increase + state.rate += self.config.additive_increment; + if self.config.max_rate > 0.0 { + state.rate = state.rate.min(self.config.max_rate); + } + } + } + // Reset window + state.window_start = now; + state.success_count = 0; + state.throttle_count = 0; + } + + // Record this outcome + match outcome { + RequestOutcome::Success => state.success_count += 1, + RequestOutcome::Throttled => state.throttle_count += 1, + } + + state.rate + } + + /// Get the current rate without recording an outcome. + pub fn current_rate(&self) -> f64 { + self.state.lock().unwrap().rate + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + + #[rstest] + #[case::zero_initial_rate( + AimdConfig::default().with_initial_rate(0.0), + "initial_rate must be positive" + )] + #[case::negative_min_rate( + AimdConfig::default().with_min_rate(-1.0), + "min_rate must be positive" + )] + #[case::negative_max_rate( + AimdConfig::default().with_max_rate(-1.0), + "max_rate must be non-negative" + )] + #[case::min_exceeds_max( + AimdConfig::default().with_min_rate(100.0).with_max_rate(10.0), + "min_rate (100) must not exceed max_rate (10)" + )] + #[case::decrease_factor_zero( + AimdConfig::default().with_decrease_factor(0.0), + "decrease_factor must be in (0, 1)" + )] + #[case::decrease_factor_one( + AimdConfig::default().with_decrease_factor(1.0), + "decrease_factor must be in (0, 1)" + )] + #[case::decrease_factor_over_one( + AimdConfig::default().with_decrease_factor(1.5), + "decrease_factor must be in (0, 1)" + )] + #[case::zero_additive_increment( + AimdConfig::default().with_additive_increment(0.0), + "additive_increment must be positive" + )] + #[case::zero_window_duration( + AimdConfig::default().with_window_duration(Duration::ZERO), + "window_duration must be non-zero" + )] + #[case::threshold_over_one( + AimdConfig::default().with_throttle_threshold(1.1), + "throttle_threshold must be in [0.0, 1.0]" + )] + #[case::threshold_negative( + AimdConfig::default().with_throttle_threshold(-0.1), + "throttle_threshold must be in [0.0, 1.0]" + )] + #[case::initial_exceeds_max( + AimdConfig::default().with_initial_rate(6000.0), + "initial_rate (6000) must not exceed max_rate (5000)" + )] + #[case::initial_below_min( + AimdConfig::default().with_initial_rate(0.5).with_min_rate(1.0), + "initial_rate (0.5) must not be below min_rate (1)" + )] + fn test_config_validation_rejects_invalid( + #[case] config: AimdConfig, + #[case] expected_msg: &str, + ) { + let err = config.validate().unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains(expected_msg), + "Expected error containing '{}', got: {}", + expected_msg, + msg + ); + } + + #[test] + fn test_default_config_is_valid() { + AimdConfig::default().validate().unwrap(); + } + + #[test] + fn test_no_ceiling_config_is_valid() { + AimdConfig::default().with_max_rate(0.0).validate().unwrap(); + } + + #[test] + fn test_additive_increase_on_success_window() { + let config = AimdConfig::default() + .with_initial_rate(100.0) + .with_additive_increment(10.0) + .with_window_duration(Duration::from_millis(100)); + let controller = AimdController::new(config).unwrap(); + + // Record some successes in the first window + let start = std::time::Instant::now(); + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Success, start); + } + + // Advance past the window boundary and record another success + let after_window = start + Duration::from_millis(150); + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Success, after_window); + } + + // Rate should have increased by additive_increment + assert_eq!(controller.current_rate(), 110.0); + } + + #[test] + fn test_multiplicative_decrease_on_throttle_window() { + let config = AimdConfig::default() + .with_initial_rate(100.0) + .with_decrease_factor(0.5) + .with_window_duration(Duration::from_millis(100)); + let controller = AimdController::new(config).unwrap(); + + let start = std::time::Instant::now(); + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start); + } + + // Advance past window + let after_window = start + Duration::from_millis(150); + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Success, after_window); + } + + assert_eq!(controller.current_rate(), 50.0); + } + + #[test] + fn test_floor_enforcement() { + let config = AimdConfig::default() + .with_initial_rate(2.0) + .with_min_rate(1.0) + .with_decrease_factor(0.5) + .with_window_duration(Duration::from_millis(100)); + let controller = AimdController::new(config).unwrap(); + + let start = std::time::Instant::now(); + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start); + } + + // After decrease: 2.0 * 0.5 = 1.0 (at floor) + let t1 = start + Duration::from_millis(150); + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, t1); + } + assert_eq!(controller.current_rate(), 1.0); + + // Another decrease should stay at floor + let t2 = t1 + Duration::from_millis(150); + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Success, t2); + } + assert_eq!(controller.current_rate(), 1.0); + } + + #[test] + fn test_ceiling_enforcement() { + let config = AimdConfig::default() + .with_initial_rate(4990.0) + .with_max_rate(5000.0) + .with_additive_increment(20.0) + .with_window_duration(Duration::from_millis(100)); + let controller = AimdController::new(config).unwrap(); + + let start = std::time::Instant::now(); + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Success, start); + } + + let t1 = start + Duration::from_millis(150); + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Success, t1); + } + // 4990 + 20 = 5010, clamped to 5000 + assert_eq!(controller.current_rate(), 5000.0); + } + + #[test] + fn test_no_ceiling_allows_unbounded_growth() { + let config = AimdConfig::default() + .with_initial_rate(100.0) + .with_max_rate(0.0) + .with_additive_increment(50.0) + .with_window_duration(Duration::from_millis(100)); + let controller = AimdController::new(config).unwrap(); + + let start = std::time::Instant::now(); + let mut t = start; + + for _ in 0..5 { + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Success, t); + } + t += Duration::from_millis(150); + } + + // Trigger final window evaluation + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Success, t); + } + + // 100 + 50*5 = 350 + assert_eq!(controller.current_rate(), 350.0); + } + + #[test] + fn test_empty_window_no_adjustment() { + let config = AimdConfig::default() + .with_initial_rate(100.0) + .with_window_duration(Duration::from_millis(100)); + let controller = AimdController::new(config).unwrap(); + + // Don't record anything in the first window, just advance time + let start = std::time::Instant::now(); + let after = start + Duration::from_millis(150); + { + let mut state = controller.state.lock().unwrap(); + // First outcome in a new window after empty window + controller.record_outcome_inner(&mut state, RequestOutcome::Success, after); + } + // No adjustment because the expired window had 0 total + assert_eq!(controller.current_rate(), 100.0); + } + + #[test] + fn test_throttle_threshold_filtering() { + // With threshold 0.5, less than 50% throttles should still increase + let config = AimdConfig::default() + .with_initial_rate(100.0) + .with_throttle_threshold(0.5) + .with_additive_increment(10.0) + .with_window_duration(Duration::from_millis(100)); + let controller = AimdController::new(config).unwrap(); + + let start = std::time::Instant::now(); + { + let mut state = controller.state.lock().unwrap(); + // 1 throttle out of 3 = 33% < 50% threshold + controller.record_outcome_inner(&mut state, RequestOutcome::Success, start); + controller.record_outcome_inner(&mut state, RequestOutcome::Success, start); + controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start); + } + + // Advance past window + let t1 = start + Duration::from_millis(150); + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Success, t1); + } + + // Should have increased because 33% <= 50% + assert_eq!(controller.current_rate(), 110.0); + } + + #[test] + fn test_throttle_threshold_triggers_decrease() { + // With threshold 0.5, >= 50% throttles should decrease + let config = AimdConfig::default() + .with_initial_rate(100.0) + .with_throttle_threshold(0.5) + .with_decrease_factor(0.5) + .with_window_duration(Duration::from_millis(100)); + let controller = AimdController::new(config).unwrap(); + + let start = std::time::Instant::now(); + { + let mut state = controller.state.lock().unwrap(); + // 2 throttle out of 3 = 67% > 50% threshold + controller.record_outcome_inner(&mut state, RequestOutcome::Success, start); + controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start); + controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start); + } + + let t1 = start + Duration::from_millis(150); + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Success, t1); + } + + assert_eq!(controller.current_rate(), 50.0); + } + + #[test] + fn test_recovery_after_decrease() { + let config = AimdConfig::default() + .with_initial_rate(100.0) + .with_decrease_factor(0.5) + .with_additive_increment(10.0) + .with_window_duration(Duration::from_millis(100)); + let controller = AimdController::new(config).unwrap(); + + let start = std::time::Instant::now(); + + // Window 1: throttle → decrease to 50 + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Throttled, start); + } + let t1 = start + Duration::from_millis(150); + + // Window 2: success → increase to 60 + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Success, t1); + } + let t2 = t1 + Duration::from_millis(150); + + // Window 3: success → increase to 70 + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Success, t2); + } + let t3 = t2 + Duration::from_millis(150); + + // Trigger final evaluation + { + let mut state = controller.state.lock().unwrap(); + controller.record_outcome_inner(&mut state, RequestOutcome::Success, t3); + } + + assert_eq!(controller.current_rate(), 70.0); + } + + #[test] + fn test_within_window_no_adjustment() { + let config = AimdConfig::default() + .with_initial_rate(100.0) + .with_window_duration(Duration::from_secs(10)); + let controller = AimdController::new(config).unwrap(); + + // Record many outcomes but all within the same window + for _ in 0..100 { + controller.record_outcome(RequestOutcome::Throttled); + } + + // Rate should still be initial since window hasn't expired + assert_eq!(controller.current_rate(), 100.0); + } +} diff --git a/rust/lance-io/src/object_store.rs b/rust/lance-io/src/object_store.rs index 02128143705..c0d0acf51a2 100644 --- a/rust/lance-io/src/object_store.rs +++ b/rust/lance-io/src/object_store.rs @@ -35,6 +35,7 @@ use super::local::LocalObjectReader; mod list_retry; pub mod providers; pub mod storage_options; +pub mod throttle; mod tracing; use crate::object_reader::SmallReader; use crate::object_writer::{LocalWriter, WriteResult}; diff --git a/rust/lance-io/src/object_store/throttle.rs b/rust/lance-io/src/object_store/throttle.rs new file mode 100644 index 00000000000..378bcff4c7d --- /dev/null +++ b/rust/lance-io/src/object_store/throttle.rs @@ -0,0 +1,1163 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! AIMD-controlled token bucket rate limiter for ObjectStore operations. +//! +//! Wraps any [`object_store::ObjectStore`] with per-category token buckets +//! whose fill rates are dynamically adjusted by AIMD controllers. When cloud +//! stores return HTTP 429/503, the fill rate decreases multiplicatively. During +//! sustained success windows, it increases additively. +//! +//! Operations are split into four independent categories — **read**, **write**, +//! **delete**, **list** — each with its own AIMD controller and token bucket. +//! This prevents a burst of reads from starving writes, and vice versa. +//! +//! # Example +//! +//! ```ignore +//! use lance_io::object_store::throttle::{AimdThrottleConfig, AimdThrottledStore}; +//! +//! let throttled = AimdThrottledStore::new(target, AimdThrottleConfig::default()).unwrap(); +//! ``` + +use std::fmt::{Debug, Display, Formatter}; +use std::ops::Range; +use std::sync::Arc; + +use async_trait::async_trait; +use bytes::Bytes; +use futures::StreamExt; +use futures::stream::BoxStream; +use lance_core::utils::aimd::{AimdConfig, AimdController, RequestOutcome}; +use object_store::path::Path; +use object_store::{ + GetOptions, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore, + PutMultipartOptions, PutOptions, PutPayload, PutResult, Result as OSResult, UploadPart, +}; +use tokio::sync::Mutex; +use tracing::debug; + +/// Check whether an `object_store::Error` represents a throttle response +/// (HTTP 429 / 503) from a cloud object store. +/// +/// Regrettably, this information is not fully exposed by the `object_store` crate. +/// There is no generic mechanism for a custom object store to return a throttle error. +/// +/// However, the builtin object stores all use RetryError when retries are configured and +/// throttle errors are returned. Sadly, RetryError is not a public type, so we have to +/// infer it from the error message. This is potentially dangerous because these errors +/// often include the URI itself and that URI could have any characters in it (e.g. if we +/// look for 429 then we might match a 429 in a UUID).These error messages currently look like: +/// +/// ", after ... retries, max_retries: ..., retry_timeout: ..." +/// +/// So, as a crude heuristic, which should work for the builtin object stores, but won't +/// work for custom object stores, we simply look for the string "retries, max_retries" +/// in the error message. +pub fn is_throttle_error(err: &object_store::Error) -> bool { + // Only Generic errors can carry throttle responses + if let object_store::Error::Generic { source, .. } = err { + source.to_string().contains("retries, max_retries") + } else { + false + } +} + +/// Configuration for the AIMD-throttled ObjectStore wrapper. +/// +/// Each operation category (read, write, delete, list) has its own AIMD config. +/// Use [`with_aimd`](AimdThrottleConfig::with_aimd) to set all categories at +/// once, or per-category methods like [`with_read_aimd`](AimdThrottleConfig::with_read_aimd) +/// for fine-grained control. +#[derive(Debug, Clone)] +pub struct AimdThrottleConfig { + /// AIMD configuration for read operations (get, get_opts, get_range, get_ranges, head). + pub read: AimdConfig, + /// AIMD configuration for write operations (put, put_opts, put_multipart, copy, rename, etc.). + pub write: AimdConfig, + /// AIMD configuration for delete operations. + pub delete: AimdConfig, + /// AIMD configuration for list operations (list_with_delimiter). + pub list: AimdConfig, + /// Maximum tokens that can accumulate for bursts (shared across all categories). + pub burst_capacity: u32, +} + +impl Default for AimdThrottleConfig { + fn default() -> Self { + let aimd = AimdConfig::default(); + Self { + read: aimd.clone(), + write: aimd.clone(), + delete: aimd.clone(), + list: aimd, + burst_capacity: 100, + } + } +} + +impl AimdThrottleConfig { + /// Set the AIMD configuration for all four operation categories at once. + pub fn with_aimd(self, aimd: AimdConfig) -> Self { + Self { + read: aimd.clone(), + write: aimd.clone(), + delete: aimd.clone(), + list: aimd, + ..self + } + } + + /// Set the AIMD configuration for read operations. + pub fn with_read_aimd(self, aimd: AimdConfig) -> Self { + Self { read: aimd, ..self } + } + + /// Set the AIMD configuration for write operations. + pub fn with_write_aimd(self, aimd: AimdConfig) -> Self { + Self { + write: aimd, + ..self + } + } + + /// Set the AIMD configuration for delete operations. + pub fn with_delete_aimd(self, aimd: AimdConfig) -> Self { + Self { + delete: aimd, + ..self + } + } + + /// Set the AIMD configuration for list operations. + pub fn with_list_aimd(self, aimd: AimdConfig) -> Self { + Self { list: aimd, ..self } + } + + pub fn with_burst_capacity(self, burst_capacity: u32) -> Self { + Self { + burst_capacity, + ..self + } + } +} + +struct TokenBucketState { + tokens: f64, + last_refill: std::time::Instant, + rate: f64, +} + +/// Per-category throttle state: an AIMD controller paired with a token bucket. +struct OperationThrottle { + controller: AimdController, + bucket: Mutex, + burst_capacity: f64, +} + +impl OperationThrottle { + fn new(aimd_config: AimdConfig, burst_capacity: f64) -> lance_core::Result { + let initial_rate = aimd_config.initial_rate; + let controller = AimdController::new(aimd_config)?; + Ok(Self { + controller, + bucket: Mutex::new(TokenBucketState { + tokens: burst_capacity, + last_refill: std::time::Instant::now(), + rate: initial_rate, + }), + burst_capacity, + }) + } + + /// Acquire a token from the bucket, sleeping if none are available. + /// + /// Each caller reserves a token immediately (allowing `tokens` to go + /// negative) so that concurrent waiters queue behind each other instead + /// of all waking at the same instant (thundering herd). + async fn acquire_token(&self) { + let sleep_duration = { + let mut bucket = self.bucket.lock().await; + let now = std::time::Instant::now(); + let elapsed = now.duration_since(bucket.last_refill).as_secs_f64(); + bucket.tokens = (bucket.tokens + elapsed * bucket.rate).min(self.burst_capacity); + bucket.last_refill = now; + + // Reserve a token (may go negative to queue behind other waiters) + bucket.tokens -= 1.0; + + if bucket.tokens >= 0.0 { + // Had a token available, no need to sleep + return; + } + + // Sleep proportional to our position in the queue + std::time::Duration::from_secs_f64(-bucket.tokens / bucket.rate) + }; + + tokio::time::sleep(sleep_duration).await; + } + + /// Update the bucket's fill rate from the controller. + async fn update_bucket_rate(&self, new_rate: f64) { + let mut bucket = self.bucket.lock().await; + bucket.rate = new_rate; + } + + /// Classify a result and feed it back to the AIMD controller without + /// acquiring a token. Uses `try_lock` for the bucket update so that if the + /// bucket lock is contended the rate update is deferred to the next + /// `throttled()` call. + fn observe_outcome(&self, result: &OSResult) { + let outcome = match result { + Ok(_) => RequestOutcome::Success, + Err(err) if is_throttle_error(err) => { + debug!("Throttle error detected in stream, decreasing rate"); + RequestOutcome::Throttled + } + Err(_) => RequestOutcome::Success, + }; + let new_rate = self.controller.record_outcome(outcome); + if let Ok(mut bucket) = self.bucket.try_lock() { + bucket.rate = new_rate; + } + } + + /// Execute an operation with throttling: acquire token, run, classify result. + async fn throttled(&self, f: F) -> OSResult + where + F: FnOnce() -> Fut, + Fut: std::future::Future>, + { + self.acquire_token().await; + let result = f().await; + let outcome = match &result { + Ok(_) => RequestOutcome::Success, + Err(err) if is_throttle_error(err) => { + debug!("Throttle error detected, decreasing rate"); + RequestOutcome::Throttled + } + Err(_) => RequestOutcome::Success, // Non-throttle errors don't indicate capacity problems + }; + let new_rate = self.controller.record_outcome(outcome); + self.update_bucket_rate(new_rate).await; + result + } +} + +impl Debug for OperationThrottle { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OperationThrottle") + .field("controller", &self.controller) + .field("burst_capacity", &self.burst_capacity) + .finish() + } +} + +/// A [`MultipartUpload`] wrapper that throttles `put_part` and observes +/// outcomes from `put_part` and `complete`, feeding them back to the write +/// AIMD controller. +struct ThrottledMultipartUpload { + target: Box, + write: Arc, +} + +impl Debug for ThrottledMultipartUpload { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ThrottledMultipartUpload").finish() + } +} + +#[async_trait] +impl MultipartUpload for ThrottledMultipartUpload { + fn put_part(&mut self, data: PutPayload) -> UploadPart { + let write = Arc::clone(&self.write); + let fut = self.target.put_part(data); + Box::pin(async move { + write.acquire_token().await; + let result = fut.await; + write.observe_outcome(&result); + result + }) + } + + async fn complete(&mut self) -> OSResult { + let result = self.target.complete().await; + self.write.observe_outcome(&result); + result + } + + async fn abort(&mut self) -> OSResult<()> { + self.target.abort().await + } +} + +/// An ObjectStore wrapper that rate-limits operations using per-category token +/// buckets whose fill rates are controlled by AIMD algorithms. +/// +/// Operations are split into four independent categories: +/// - **read**: `get`, `get_opts`, `get_range`, `get_ranges`, `head` +/// - **write**: `put`, `put_opts`, `put_multipart`, `put_multipart_opts`, `copy`, `copy_if_not_exists`, `rename`, `rename_if_not_exists` +/// - **delete**: `delete` +/// - **list**: `list_with_delimiter` +/// +/// Streaming operations (`list`, `list_with_offset`, `delete_stream`) do not acquire tokens, +/// but observe each yielded item and feed the result back to the AIMD controller so it can +/// adjust the rate for other operations in the same category. +/// +/// This is not perfect but probably as close as we can get without moving the throttle into +/// the object_store crate itself. +pub struct AimdThrottledStore { + target: Arc, + read: Arc, + write: Arc, + delete: Arc, + list: Arc, +} + +impl Debug for AimdThrottledStore { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AimdThrottledStore") + .field("target", &self.target) + .field("read", &self.read) + .field("write", &self.write) + .field("delete", &self.delete) + .field("list", &self.list) + .finish() + } +} + +impl Display for AimdThrottledStore { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "AimdThrottledStore({})", self.target) + } +} + +impl AimdThrottledStore { + pub fn new( + target: Arc, + config: AimdThrottleConfig, + ) -> lance_core::Result { + let burst = config.burst_capacity as f64; + Ok(Self { + target, + read: Arc::new(OperationThrottle::new(config.read, burst)?), + write: Arc::new(OperationThrottle::new(config.write, burst)?), + delete: Arc::new(OperationThrottle::new(config.delete, burst)?), + list: Arc::new(OperationThrottle::new(config.list, burst)?), + }) + } +} + +#[async_trait] +#[deny(clippy::missing_trait_methods)] +impl ObjectStore for AimdThrottledStore { + async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult { + self.write + .throttled(|| self.target.put(location, bytes)) + .await + } + + async fn put_opts( + &self, + location: &Path, + bytes: PutPayload, + opts: PutOptions, + ) -> OSResult { + self.write + .throttled(|| self.target.put_opts(location, bytes, opts)) + .await + } + + async fn put_multipart(&self, location: &Path) -> OSResult> { + let target = self + .write + .throttled(|| self.target.put_multipart(location)) + .await?; + Ok(Box::new(ThrottledMultipartUpload { + target, + write: Arc::clone(&self.write), + })) + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOptions, + ) -> OSResult> { + let target = self + .write + .throttled(|| self.target.put_multipart_opts(location, opts)) + .await?; + Ok(Box::new(ThrottledMultipartUpload { + target, + write: Arc::clone(&self.write), + })) + } + + async fn get(&self, location: &Path) -> OSResult { + self.read.throttled(|| self.target.get(location)).await + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult { + self.read + .throttled(|| self.target.get_opts(location, options)) + .await + } + + async fn get_range(&self, location: &Path, range: Range) -> OSResult { + self.read + .throttled(|| self.target.get_range(location, range.clone())) + .await + } + + async fn get_ranges(&self, location: &Path, ranges: &[Range]) -> OSResult> { + self.read + .throttled(|| self.target.get_ranges(location, ranges)) + .await + } + + async fn head(&self, location: &Path) -> OSResult { + self.read.throttled(|| self.target.head(location)).await + } + + async fn delete(&self, location: &Path) -> OSResult<()> { + self.delete.throttled(|| self.target.delete(location)).await + } + + fn delete_stream<'a>( + &'a self, + locations: BoxStream<'a, OSResult>, + ) -> BoxStream<'a, OSResult> { + self.target + .delete_stream(locations) + .map(|item| { + self.delete.observe_outcome(&item); + item + }) + .boxed() + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult> { + let throttle = Arc::clone(&self.list); + self.target + .list(prefix) + .map(move |item| { + throttle.observe_outcome(&item); + item + }) + .boxed() + } + + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, OSResult> { + let throttle = Arc::clone(&self.list); + self.target + .list_with_offset(prefix, offset) + .map(move |item| { + throttle.observe_outcome(&item); + item + }) + .boxed() + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult { + self.list + .throttled(|| self.target.list_with_delimiter(prefix)) + .await + } + + async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> { + self.write.throttled(|| self.target.copy(from, to)).await + } + + async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> { + self.write.throttled(|| self.target.rename(from, to)).await + } + + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { + self.write + .throttled(|| self.target.rename_if_not_exists(from, to)) + .await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { + self.write + .throttled(|| self.target.copy_if_not_exists(from, to)) + .await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use object_store::memory::InMemory; + use rstest::rstest; + use std::collections::VecDeque; + use std::sync::atomic::{AtomicU64, Ordering}; + + fn make_generic_error(msg: &str) -> object_store::Error { + object_store::Error::Generic { + store: "test", + source: msg.into(), + } + } + + #[rstest] + #[case::retry_error("Error after 10 retries, max_retries: 10, retry_timeout: 180s", true)] + #[case::retries_in_message( + "request failed, after 3 retries, max_retries: 5, retry_timeout: 60s", + true + )] + #[case::not_found("Object not found", false)] + #[case::permission_denied("Access denied", false)] + #[case::timeout("Connection timed out", false)] + #[case::http_429_without_retries("HTTP 429 Too Many Requests", false)] + #[case::slowdown_without_retries("SlowDown: Please reduce your request rate", false)] + fn test_is_throttle_error(#[case] msg: &str, #[case] expected: bool) { + let err = make_generic_error(msg); + assert_eq!( + is_throttle_error(&err), + expected, + "is_throttle_error for '{}' should be {}", + msg, + expected + ); + } + + #[test] + fn test_non_generic_errors_are_not_throttle() { + let err = object_store::Error::NotFound { + path: "test".to_string(), + source: "not found".into(), + }; + assert!(!is_throttle_error(&err)); + } + + #[tokio::test] + async fn test_basic_put_get_through_wrapper() { + let store = Arc::new(InMemory::new()); + let config = AimdThrottleConfig::default(); + let throttled = AimdThrottledStore::new(store, config).unwrap(); + + let path = Path::from("test/file.txt"); + let data = PutPayload::from_static(b"hello world"); + throttled.put(&path, data).await.unwrap(); + + let result = throttled.get(&path).await.unwrap(); + let bytes = result.bytes().await.unwrap(); + assert_eq!(bytes.as_ref(), b"hello world"); + } + + #[tokio::test] + async fn test_rate_decreases_on_throttle() { + let store = Arc::new(InMemory::new()); + let config = AimdThrottleConfig::default().with_aimd( + AimdConfig::default() + .with_initial_rate(100.0) + .with_decrease_factor(0.5) + .with_window_duration(std::time::Duration::from_millis(10)), + ); + let throttled = AimdThrottledStore::new(store, config).unwrap(); + + let initial_rate = throttled.read.controller.current_rate(); + assert_eq!(initial_rate, 100.0); + + // Simulate a throttle outcome directly + throttled + .read + .controller + .record_outcome(RequestOutcome::Throttled); + + // Wait for window to expire and trigger evaluation + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + throttled + .read + .controller + .record_outcome(RequestOutcome::Success); + + let new_rate = throttled.read.controller.current_rate(); + assert!( + new_rate < initial_rate, + "Rate should decrease after throttle: {} < {}", + new_rate, + initial_rate + ); + } + + #[tokio::test] + async fn test_rate_recovers_on_success() { + let store = Arc::new(InMemory::new()); + let config = AimdThrottleConfig::default().with_aimd( + AimdConfig::default() + .with_initial_rate(100.0) + .with_decrease_factor(0.5) + .with_additive_increment(10.0) + .with_window_duration(std::time::Duration::from_millis(10)), + ); + let throttled = AimdThrottledStore::new(store, config).unwrap(); + + // First decrease via throttle + throttled + .read + .controller + .record_outcome(RequestOutcome::Throttled); + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + throttled + .read + .controller + .record_outcome(RequestOutcome::Success); + let decreased_rate = throttled.read.controller.current_rate(); + assert_eq!(decreased_rate, 50.0); + + // Now recover via success + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + throttled + .read + .controller + .record_outcome(RequestOutcome::Success); + let recovered_rate = throttled.read.controller.current_rate(); + assert_eq!(recovered_rate, 60.0); + } + + #[tokio::test] + async fn test_as_dyn_object_store() { + let store: Arc = Arc::new(InMemory::new()); + let throttled: Arc = + Arc::new(AimdThrottledStore::new(store, AimdThrottleConfig::default()).unwrap()); + + let path = Path::from("test/data.bin"); + let data = PutPayload::from_static(b"test data"); + throttled.put(&path, data).await.unwrap(); + + let result = throttled.get(&path).await.unwrap(); + let bytes = result.bytes().await.unwrap(); + assert_eq!(bytes.as_ref(), b"test data"); + } + + #[tokio::test] + async fn test_token_bucket_delays_when_exhausted() { + let store = Arc::new(InMemory::new()); + // Very low rate and burst capacity to force waiting + let config = AimdThrottleConfig::default() + .with_burst_capacity(1) + .with_aimd(AimdConfig::default().with_initial_rate(10.0)); + let throttled = Arc::new(AimdThrottledStore::new(store, config).unwrap()); + + let path = Path::from("test/file.txt"); + let data = PutPayload::from_static(b"data"); + throttled.put(&path, data).await.unwrap(); + + // After consuming the burst token, the next request should take ~100ms + // (1 token / 10 tokens-per-sec). We verify it takes at least 50ms. + let start = std::time::Instant::now(); + let data2 = PutPayload::from_static(b"data2"); + throttled.put(&path, data2).await.unwrap(); + let elapsed = start.elapsed(); + + assert!( + elapsed >= std::time::Duration::from_millis(50), + "Expected delay for token refill, but elapsed was {:?}", + elapsed + ); + } + + #[tokio::test] + async fn test_list_observes_outcomes() { + let store = Arc::new(InMemory::new()); + let config = AimdThrottleConfig::default(); + let throttled = AimdThrottledStore::new(store.clone(), config).unwrap(); + + let path = Path::from("prefix/file.txt"); + let data = PutPayload::from_static(b"data"); + store.put(&path, data).await.unwrap(); + + let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await; + assert_eq!(items.len(), 1); + assert!(items[0].is_ok()); + } + + /// A mock store whose `list` stream yields a configurable sequence of + /// Ok / throttle-error items. Used to verify that the AIMD wrapper + /// observes errors surfaced inside list streams. + struct ThrottlingListMockStore { + inner: InMemory, + /// Number of throttle errors to inject at the start of each list call. + throttle_count: usize, + } + + impl Display for ThrottlingListMockStore { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "ThrottlingListMockStore") + } + } + + impl Debug for ThrottlingListMockStore { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ThrottlingListMockStore").finish() + } + } + + #[async_trait] + impl ObjectStore for ThrottlingListMockStore { + async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult { + self.inner.put(location, bytes).await + } + async fn put_opts( + &self, + location: &Path, + bytes: PutPayload, + opts: PutOptions, + ) -> OSResult { + self.inner.put_opts(location, bytes, opts).await + } + async fn put_multipart(&self, location: &Path) -> OSResult> { + self.inner.put_multipart(location).await + } + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOptions, + ) -> OSResult> { + self.inner.put_multipart_opts(location, opts).await + } + async fn get(&self, location: &Path) -> OSResult { + self.inner.get(location).await + } + async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult { + self.inner.get_opts(location, options).await + } + async fn get_range(&self, location: &Path, range: Range) -> OSResult { + self.inner.get_range(location, range).await + } + async fn get_ranges(&self, location: &Path, ranges: &[Range]) -> OSResult> { + self.inner.get_ranges(location, ranges).await + } + async fn head(&self, location: &Path) -> OSResult { + self.inner.head(location).await + } + async fn delete(&self, location: &Path) -> OSResult<()> { + self.inner.delete(location).await + } + fn delete_stream<'a>( + &'a self, + locations: BoxStream<'a, OSResult>, + ) -> BoxStream<'a, OSResult> { + self.inner.delete_stream(locations) + } + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult> { + let n = self.throttle_count; + let inner_stream = self.inner.list(prefix); + let errors = futures::stream::iter((0..n).map(|_| { + Err(object_store::Error::Generic { + store: "ThrottlingListMock", + source: "request failed, after 3 retries, max_retries: 5, retry_timeout: 60s" + .into(), + }) + })); + errors.chain(inner_stream).boxed() + } + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, OSResult> { + self.inner.list_with_offset(prefix, offset) + } + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult { + self.inner.list_with_delimiter(prefix).await + } + async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> { + self.inner.copy(from, to).await + } + async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> { + self.inner.rename(from, to).await + } + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { + self.inner.rename_if_not_exists(from, to).await + } + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { + self.inner.copy_if_not_exists(from, to).await + } + } + + #[tokio::test] + async fn test_list_stream_throttle_errors_decrease_rate() { + let mock = Arc::new(ThrottlingListMockStore { + inner: InMemory::new(), + throttle_count: 5, + }); + + // Seed a file so the real items come through after the errors. + mock.put( + &Path::from("prefix/file.txt"), + PutPayload::from_static(b"data"), + ) + .await + .unwrap(); + + let config = AimdThrottleConfig::default().with_list_aimd( + AimdConfig::default() + .with_initial_rate(100.0) + .with_decrease_factor(0.5) + .with_window_duration(std::time::Duration::from_millis(10)), + ); + let throttled = AimdThrottledStore::new(mock as Arc, config).unwrap(); + + let initial_rate = throttled.list.controller.current_rate(); + assert_eq!(initial_rate, 100.0); + + let items: Vec<_> = throttled.list(Some(&Path::from("prefix"))).collect().await; + + // 5 errors + 1 real item + assert_eq!(items.len(), 6); + assert!(items[0].is_err()); + assert!(items[5].is_ok()); + + // Wait for the AIMD window to expire and trigger evaluation. + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + throttled + .list + .controller + .record_outcome(RequestOutcome::Success); + + let new_rate = throttled.list.controller.current_rate(); + assert!( + new_rate < initial_rate, + "List rate should decrease after stream throttle errors: {} < {}", + new_rate, + initial_rate + ); + } + + #[tokio::test] + async fn test_per_category_independence() { + let store = Arc::new(InMemory::new()); + let config = AimdThrottleConfig::default().with_aimd( + AimdConfig::default() + .with_initial_rate(100.0) + .with_decrease_factor(0.5) + .with_window_duration(std::time::Duration::from_millis(10)), + ); + let throttled = AimdThrottledStore::new(store, config).unwrap(); + + // Push the read controller into a throttled state + throttled + .read + .controller + .record_outcome(RequestOutcome::Throttled); + tokio::time::sleep(std::time::Duration::from_millis(20)).await; + throttled + .read + .controller + .record_outcome(RequestOutcome::Success); + + let read_rate = throttled.read.controller.current_rate(); + let write_rate = throttled.write.controller.current_rate(); + let delete_rate = throttled.delete.controller.current_rate(); + let list_rate = throttled.list.controller.current_rate(); + + assert_eq!(read_rate, 50.0, "Read rate should have decreased"); + assert_eq!(write_rate, 100.0, "Write rate should be unaffected"); + assert_eq!(delete_rate, 100.0, "Delete rate should be unaffected"); + assert_eq!(list_rate, 100.0, "List rate should be unaffected"); + } + + #[tokio::test] + async fn test_per_category_config() { + let store = Arc::new(InMemory::new()); + let config = AimdThrottleConfig::default() + .with_read_aimd(AimdConfig::default().with_initial_rate(200.0)) + .with_write_aimd(AimdConfig::default().with_initial_rate(100.0)) + .with_delete_aimd(AimdConfig::default().with_initial_rate(50.0)) + .with_list_aimd(AimdConfig::default().with_initial_rate(25.0)); + let throttled = AimdThrottledStore::new(store, config).unwrap(); + + assert_eq!(throttled.read.controller.current_rate(), 200.0); + assert_eq!(throttled.write.controller.current_rate(), 100.0); + assert_eq!(throttled.delete.controller.current_rate(), 50.0); + assert_eq!(throttled.list.controller.current_rate(), 25.0); + } + + /// A mock [`ObjectStore`] that measures request rate over a sliding window + /// and returns 503 errors when the rate exceeds a configurable threshold. + /// Write and metadata-only operations are not rate-limited. + struct RateLimitingMockStore { + inner: InMemory, + /// Timestamps of recent successful (admitted) requests. + timestamps: std::sync::Mutex>, + /// Maximum requests allowed within `window`. + max_per_window: usize, + /// Sliding window duration. + window: std::time::Duration, + success_count: AtomicU64, + throttle_count: AtomicU64, + } + + impl RateLimitingMockStore { + fn new(max_per_window: usize, window: std::time::Duration) -> Self { + Self { + inner: InMemory::new(), + timestamps: std::sync::Mutex::new(VecDeque::new()), + max_per_window, + window, + success_count: AtomicU64::new(0), + throttle_count: AtomicU64::new(0), + } + } + + /// Returns `true` if the request is admitted, `false` if throttled. + fn check_rate(&self) -> bool { + let mut ts = self.timestamps.lock().unwrap(); + let now = std::time::Instant::now(); + while let Some(&front) = ts.front() { + if now.duration_since(front) > self.window { + ts.pop_front(); + } else { + break; + } + } + if ts.len() >= self.max_per_window { + self.throttle_count.fetch_add(1, Ordering::Relaxed); + false + } else { + ts.push_back(now); + self.success_count.fetch_add(1, Ordering::Relaxed); + true + } + } + + fn throttle_error() -> object_store::Error { + object_store::Error::Generic { + store: "RateLimitingMock", + source: "request failed, after 10 retries, max_retries: 10, retry_timeout: 180s" + .into(), + } + } + } + + impl Display for RateLimitingMockStore { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "RateLimitingMockStore") + } + } + + impl Debug for RateLimitingMockStore { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RateLimitingMockStore").finish() + } + } + + #[async_trait] + impl ObjectStore for RateLimitingMockStore { + async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult { + self.inner.put(location, bytes).await + } + + async fn put_opts( + &self, + location: &Path, + bytes: PutPayload, + opts: PutOptions, + ) -> OSResult { + self.inner.put_opts(location, bytes, opts).await + } + + async fn put_multipart(&self, location: &Path) -> OSResult> { + self.inner.put_multipart(location).await + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOptions, + ) -> OSResult> { + self.inner.put_multipart_opts(location, opts).await + } + + async fn get(&self, location: &Path) -> OSResult { + if self.check_rate() { + self.inner.get(location).await + } else { + Err(Self::throttle_error()) + } + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult { + if self.check_rate() { + self.inner.get_opts(location, options).await + } else { + Err(Self::throttle_error()) + } + } + + async fn get_range(&self, location: &Path, range: Range) -> OSResult { + if self.check_rate() { + self.inner.get_range(location, range).await + } else { + Err(Self::throttle_error()) + } + } + + async fn get_ranges(&self, location: &Path, ranges: &[Range]) -> OSResult> { + if self.check_rate() { + self.inner.get_ranges(location, ranges).await + } else { + Err(Self::throttle_error()) + } + } + + async fn head(&self, location: &Path) -> OSResult { + if self.check_rate() { + self.inner.head(location).await + } else { + Err(Self::throttle_error()) + } + } + + async fn delete(&self, location: &Path) -> OSResult<()> { + self.inner.delete(location).await + } + + fn delete_stream<'a>( + &'a self, + locations: BoxStream<'a, OSResult>, + ) -> BoxStream<'a, OSResult> { + self.inner.delete_stream(locations) + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult> { + self.inner.list(prefix) + } + + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, OSResult> { + self.inner.list_with_offset(prefix, offset) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult { + self.inner.list_with_delimiter(prefix).await + } + + async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> { + self.inner.copy(from, to).await + } + + async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> { + self.inner.rename(from, to).await + } + + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { + self.inner.rename_if_not_exists(from, to).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { + self.inner.copy_if_not_exists(from, to).await + } + } + + /// Verify that multiple concurrent readers sharing an AIMD-throttled store + /// converge to the backend's actual capacity. + /// + /// Setup: + /// - Mock backend allows 30 requests per 100ms (= 300 req/s). + /// - 5 reader tasks, each with their own [`AimdThrottledStore`] wrapping + /// the shared mock. + /// - AIMD: 100ms window, initial rate 100 req/s, decrease 0.5, increase 2. + /// - Readers issue `head()` requests as fast as the throttle allows for 2s. + /// + /// Expected behaviour: + /// - Initial burst (100 burst tokens × 5 readers) overshoots the mock + /// capacity, causing many 503s. Each reader's AIMD halves its rate. + /// - After the transient, each reader converges to ~60 req/s (300/5). + /// - Over 2 seconds, total successful requests should be in the range + /// [300, 900] (theoretical max ≈ 600). + #[tokio::test(flavor = "multi_thread", worker_threads = 8)] + async fn test_aimd_throttle_under_concurrent_load() { + let mock = Arc::new(RateLimitingMockStore::new( + 30, + std::time::Duration::from_millis(100), + )); + + // Seed a test file so head() succeeds when admitted. + let path = Path::from("test/data.bin"); + mock.put(&path, PutPayload::from_static(b"test data")) + .await + .unwrap(); + + let aimd = AimdConfig::default() + .with_initial_rate(100.0) + .with_decrease_factor(0.5) + .with_additive_increment(2.0) + .with_window_duration(std::time::Duration::from_millis(100)); + let throttle_config = AimdThrottleConfig::default() + .with_aimd(aimd) + .with_burst_capacity(100); + + let num_readers = 5; + let test_duration = std::time::Duration::from_secs(2); + let mut handles = Vec::new(); + + for _ in 0..num_readers { + let store = Arc::new( + AimdThrottledStore::new( + mock.clone() as Arc, + throttle_config.clone(), + ) + .unwrap(), + ); + let p = path.clone(); + handles.push(tokio::spawn(async move { + let deadline = std::time::Instant::now() + test_duration; + let mut count = 0u64; + while std::time::Instant::now() < deadline { + let _ = store.head(&p).await; + count += 1; + } + count + })); + } + + let mut total_reader_requests = 0u64; + for handle in handles { + total_reader_requests += handle.await.unwrap(); + } + + let successes = mock.success_count.load(Ordering::Relaxed); + let throttled = mock.throttle_count.load(Ordering::Relaxed); + let total_mock = successes + throttled; + + // Reader-side count must match mock-side count. + assert_eq!( + total_reader_requests, total_mock, + "Reader-side count ({total_reader_requests}) != mock-side count ({total_mock})" + ); + + // Mock capacity is 30/100ms = 300 req/s. Over 2s the theoretical max is + // ~600 successful requests. With AIMD ramp-up, expect somewhat fewer. + assert!( + successes >= 300, + "Expected >= 300 successful requests over 2s, got {successes}" + ); + assert!( + successes <= 900, + "Expected <= 900 successful requests, got {successes}" + ); + + // The initial burst exceeds mock capacity, so throttling must occur. + assert!(throttled > 0, "Expected some throttled requests but got 0"); + + // Without AIMD, raw tokio tasks against InMemory would fire 100k+ req/s. + // AIMD should keep the total well under 5000 over 2s. + assert!( + total_mock <= 5000, + "AIMD should limit total requests, got {total_mock}" + ); + } +}