diff --git a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py index 517aa0f81901..c7e10da61c73 100644 --- a/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py +++ b/python/ray/data/_internal/execution/backpressure_policy/concurrency_cap_backpressure_policy.py @@ -1,9 +1,10 @@ import logging import math -from collections import defaultdict, deque -from typing import TYPE_CHECKING, Deque, Dict +from collections import defaultdict +from typing import TYPE_CHECKING, Dict from .backpressure_policy import BackpressurePolicy +from ray._private.ray_constants import env_float from ray.data._internal.execution.operators.map_operator import MapOperator from ray.data._internal.execution.operators.task_pool_map_operator import ( TaskPoolMapOperator, @@ -20,90 +21,55 @@ class ConcurrencyCapBackpressurePolicy(BackpressurePolicy): """A backpressure policy that caps the concurrency of each operator. - This policy dynamically limits the number of concurrent tasks per operator - based on queue pressure. It combines: - - - Adaptive threshold built from EWMA of the queue level and its - absolute deviation: ``threshold = max(level + K * dev, current_queue_size_bytes)``. - The ``current_queue_size_bytes`` term enables immediate upward revision to avoid - throttling ramp-up throughput. - - Why this threshold works: - - level + K*dev: Sets threshold above typical queue size by K standard deviations - - max(..., current_queue_size_bytes): Prevents false throttling during legitimate ramp-up periods - - When queue grows faster than EWMA can track, current_queue_size_bytes > level + K*dev - - This allows the system to "catch up" to the new higher baseline before throttling - - - Quantized step controller that nudges running concurrency by - ``{-1, 0, +1, +2}`` using normalized pressure and trend signals. - - Key Concepts: - - Level (EWMA): Typical queue size; slowly tracks the central tendency. - - Deviation (EWMA): Typical absolute deviation around the level; acts as - a normalization scale for pressure and trend signals. - - Threshold: Dynamic limit derived from observed signal: if the current - queue exceeds the threshold, we consider backoff. The ``max(..., current_queue_size_bytes)`` - makes this instantaneously responsive upward. - - Instantaneous pressure: How far the current queue is from threshold, - normalized by deviation. - - Trend: Whether the queue is rising/falling over a short horizon (recent - vs older HISTORY_LEN/2 samples), normalized by deviation. - - Example: - Consider an operator with configured cap=10 and queue pressure over time: - - Queue samples: [100, 120, 140, 160, 180, 200, 220, 240, 260, 280] - Threshold: 150 (level=180, dev=20, K=4.0) - - Ramp-up scenario (queue growing, pressure < 0): - - pressure_signal = (100-150)/20 = -2.5, trend_signal = -1.0 - - Decision: step = +2 (strong growth, low pressure) - - Result: concurrency increases from 8 -> 10 (capped at configured max) - - Dial-down scenario (queue growing, pressure > 0): - - pressure_signal = (200-150)/20 = +2.5, trend_signal = +1.0 - - Decision: step = -1 (high pressure, growing trend) - - Result: concurrency decreases from 10 -> 9 - - Stable scenario (queue stable, pressure ~ 0): - - pressure_signal = (150-150)/20 = 0.0, trend_signal = 0.0 - - Decision: step = 0 (no change needed) - - Result: concurrency stays at 10 + based on the output queue growth rate. + + - Maintain asymmetric EWMA of total enqueued output bytes as the + typical level: `level`. + - Maintain asymmetric EWMA of absolute residual vs the *previous* level as a + scale proxy: `dev = EWMA(|q - level_prev|)`. + - Define deadband: Deadband is the acceptable range of the output queue size + around the typical level where the queue size is expected to stay stable. + deadband [lower, upper] = [level - K_DEV*dev, level + K_DEV*dev]. + - If q > upper -> target cap = running - BACKOFF_FACTOR (back off) + If q < lower -> target cap = running + RAMPUP_FACTOR (ramp up) + Else -> target cap = running (hold) + - Apply user-configured max concurrency cap, admit iff running < target cap. NOTE: Only support setting concurrency cap for `TaskPoolMapOperator` for now. TODO(chengsu): Consolidate with actor scaling logic of `ActorPoolMapOperator`. """ - # Queue history window for recent-trend estimation. Small window to capture recent trend. - HISTORY_LEN = 10 - # Smoothing factor for both level and dev. - EWMA_ALPHA = 0.2 - # Deviation multiplier to define "over-threshold". - K_DEV = 4.0 + # Smoothing factor for the asymmetric EWMA (slow fall, faster rise). + EWMA_ALPHA = env_float("RAY_DATA_CONCURRENCY_CAP_EWMA_ALPHA", 0.2) + # Deadband width in units of the EWMA absolute deviation estimate. + K_DEV = env_float("RAY_DATA_CONCURRENCY_CAP_K_DEV", 2.0) + # Factor to back off when the queue is too large. + BACKOFF_FACTOR = env_float("RAY_DATA_CONCURRENCY_CAP_BACKOFF_FACTOR", 1) + # Factor to ramp up when the queue is too small. + RAMPUP_FACTOR = env_float("RAY_DATA_CONCURRENCY_CAP_RAMPUP_FACTOR", 1) + # Threshold for per-Op object store budget (available) vs total usage (used) + # (available / used) ratio to enable dynamic output queue size backpressure. + OBJECT_STORE_USAGE_RATIO = env_float( + "RAY_DATA_CONCURRENCY_CAP_OBJECT_STORE_USAGE_RATIO", 0.1 + ) def __init__(self, *args, **kwargs): - """Initialize the ConcurrencyCapBackpressurePolicy.""" super().__init__(*args, **kwargs) - # Explicit concurrency caps for each operator. Infinite if not specified. + # Configured per-operator caps (+inf if unset). self._concurrency_caps: Dict["PhysicalOperator", float] = {} - # Queue history for recent-trend estimation. Small window to capture recent trend. - self._queue_history: Dict["PhysicalOperator", Deque[int]] = defaultdict( - lambda: deque(maxlen=self.HISTORY_LEN) - ) - - # Per-operator cached threshold (bootstrapped from first sample). - self._queue_thresholds: Dict["PhysicalOperator", int] = defaultdict(int) - - # EWMA state for level and absolute deviation. + # EWMA state for level self._q_level_nbytes: Dict["PhysicalOperator", float] = defaultdict(float) - # EWMA state for absolute deviation. + # EWMA state for dev self._q_level_dev: Dict["PhysicalOperator", float] = defaultdict(float) - # Track last effective cap per operator for change detection. + # Per-operator cached threshold (bootstrapped from first sample). + self._queue_level_thresholds: Dict["PhysicalOperator", int] = defaultdict(int) + + # Last effective cap for change logs. self._last_effective_caps: Dict["PhysicalOperator", int] = {} # Initialize caps from operators (infinite if unset) @@ -121,16 +87,21 @@ def __init__(self, *args, **kwargs): self._data_context.enable_dynamic_output_queue_size_backpressure ) + dynamic_output_queue_size_backpressure_configs = "" + if self.enable_dynamic_output_queue_size_backpressure: + dynamic_output_queue_size_backpressure_configs = ( + f", EWMA_ALPHA={self.EWMA_ALPHA}, K_DEV={self.K_DEV}, " + f"BACKOFF_FACTOR={self.BACKOFF_FACTOR}, RAMPUP_FACTOR={self.RAMPUP_FACTOR}, " + f"OBJECT_STORE_USAGE_RATIO={self.OBJECT_STORE_USAGE_RATIO}" + ) logger.debug( - "ConcurrencyCapBackpressurePolicy caps: %s, cap based on queue size: %s", - self._concurrency_caps, - self.enable_dynamic_output_queue_size_backpressure, + f"ConcurrencyCapBackpressurePolicy caps: {self._concurrency_caps}, " + f"enabled: {self.enable_dynamic_output_queue_size_backpressure}{dynamic_output_queue_size_backpressure_configs}" ) def _update_ewma_asymmetric(self, prev_value: float, sample: float) -> float: """ Update EWMA with asymmetric behavior: fast rise, slow fall. - Args: prev_value: Previous EWMA value sample: New sample value @@ -145,28 +116,57 @@ def _update_ewma_asymmetric(self, prev_value: float, sample: float) -> float: alpha = alpha_up if sample > prev_value else self.EWMA_ALPHA # slow fall return (1 - alpha) * prev_value + alpha * sample - def can_add_input(self, op: "PhysicalOperator") -> bool: - """Return whether `op` may accept another input now. + def _update_level_and_dev(self, op: "PhysicalOperator", q_bytes: int) -> None: + """Update EWMA level and dev (residual w.r.t. previous level).""" + q = float(q_bytes) - Admission control logic: - * Under threshold: Allow full concurrency up to configured cap - * Over threshold: Adjust concurrency using step controller (-1,0,+1,+2) - based on pressure and trend signals + level_prev = self._q_level_nbytes[op] + dev_prev = self._q_level_dev[op] - Args: - op: The operator under consideration. + # Deviation vs the previous level + dev_sample = abs(q - level_prev) if level_prev > 0 else 0.0 + dev = self._update_ewma_asymmetric(dev_prev, dev_sample) - Returns: - True if admitting one more input is allowed. - """ - running = op.metrics.num_tasks_running + # Now update the level itself + level = self._update_ewma_asymmetric(level_prev, q) + + self._q_level_nbytes[op] = level + self._q_level_dev[op] = dev + + # For visibility, store the integer center of the band + self._queue_level_thresholds[op] = max(1, int(level)) + + def can_add_input(self, op: "PhysicalOperator") -> bool: + """Return whether `op` may accept another input now.""" + num_tasks_running = op.metrics.num_tasks_running + + # If not a MapOperator or feature disabled, just enforce configured cap. if ( not isinstance(op, MapOperator) or not self.enable_dynamic_output_queue_size_backpressure ): - return running < self._concurrency_caps[op] + return num_tasks_running < self._concurrency_caps[op] - # Observe fresh queue size for this operator and its downstream. + # For this Op, if the objectstore budget (available) to total usage (used) + # ratio is below threshold (10%), skip dynamic output queue size backpressure. + op_usage = self._resource_manager.get_op_usage(op) + op_budget = self._resource_manager.get_budget(op) + if ( + op_usage is not None + and op_budget is not None + and op_budget.object_store_memory > 0 + and op_usage.object_store_memory > 0 + ): + if ( + op_budget.object_store_memory / op_usage.object_store_memory + > self.OBJECT_STORE_USAGE_RATIO + ): + # If the objectstore budget (available) to total usage (used) + # ratio is above threshold (10%), skip dynamic output queue size + # backpressure, but still enforce the configured cap. + return num_tasks_running < self._concurrency_caps[op] + + # Current total queued bytes (this op + downstream) current_queue_size_bytes = ( self._resource_manager.get_op_internal_object_store_usage(op) + self._resource_manager.get_op_outputs_object_store_usage_with_downstream( @@ -174,248 +174,58 @@ def can_add_input(self, op: "PhysicalOperator") -> bool: ) ) - # Update short history and refresh the adaptive threshold. - self._queue_history[op].append(current_queue_size_bytes) - threshold = self._update_queue_threshold(op, current_queue_size_bytes) - - # If configured to cap based on queue size, use the effective cap. - if current_queue_size_bytes > threshold: - # Over-threshold: potentially back off via effective cap. - effective_cap = self._effective_cap(op) - is_capped = running < effective_cap - - last_effective_cap = self._last_effective_caps.get(op, None) - if last_effective_cap != effective_cap: - logger.debug( - "Effective concurrency cap changed for operator %s: %d -> %d " - "running=%d tasks, queue=%d bytes, threshold=%d bytes", - op.name, - last_effective_cap, - effective_cap, - running, - current_queue_size_bytes, - threshold, - ) - self._last_effective_caps[op] = effective_cap - - return is_capped - else: - # Under-threshold: only enforce the configured cap. - return running < self._concurrency_caps[op] + # Update EWMA state (level & dev) and compute effective cap. Note that + # we don't update the EWMA state if the objectstore budget (available) vs total usage (used) + # ratio is above threshold (10%), because the level and dev adjusts quickly. + self._update_level_and_dev(op, current_queue_size_bytes) + effective_cap = self._effective_cap( + op, num_tasks_running, current_queue_size_bytes + ) - def _update_queue_threshold( - self, op: "PhysicalOperator", current_queue_size_bytes: int - ) -> int: - """Update and return the current adaptive threshold for `op`. + last = self._last_effective_caps.get(op, None) + if last != effective_cap: + logger.debug( + f"Cap change {op.name}: {last if last is not None else 'None'} -> " + f"{effective_cap} (running={num_tasks_running}, queue={current_queue_size_bytes}, " + f"thr={self._queue_level_thresholds[op]})" + ) + self._last_effective_caps[op] = effective_cap - Motivation: Adaptive thresholds prevent both over-throttling (too aggressive) and - under-throttling (too permissive). The logic balances responsiveness with stability: - - Fast upward response to pressure spikes (immediate threshold increase) - - Thresholds only increase, never decrease, since we don't know if low pressure is steady state + return num_tasks_running < effective_cap + def _effective_cap( + self, + op: "PhysicalOperator", + num_tasks_running: int, + current_queue_size_bytes: int, + ) -> int: + """A simple controller around EWMA level. Args: - op: Operator whose threshold is being updated. + op: The operator to compute the effective cap for. + num_tasks_running: The number of tasks currently running. current_queue_size_bytes: Current total queued bytes for this operator + downstream. - - Returns: - The updated threshold in bytes. - - Examples: - # Bootstrap: first sample sets threshold - # Input: current_queue_size_bytes = 1000, level_prev = 0, dev_prev = 0 - # EWMA: level = 1000, dev = 0 - # Base: 1000 + 4*0 = 1000, Threshold: max(1000, 1000) = 1000 - # Result: 1000 (bootstrap) - - # Pressure increase: threshold updated immediately - # Input: current_queue_size_bytes = 1500, level_prev = 1000, dev_prev = 100 - # EWMA: level = 1100, dev = 180, Base: 1100 + 4*180 = 1820 - # Threshold: max(1820, 1500) = 1820, prev_threshold = 1000 - # Result: 1820 (threshold increased) - - # Pressure decrease: threshold maintained (no decrease) - # Input: current_queue_size_bytes = 100, level_prev = 200, dev_prev = 50 - # EWMA: level = 180, dev = 60, Base: 180 + 4*60 = 420 - # Threshold: max(420, 100) = 420, prev_threshold = 500 - # Result: 500 (threshold maintained, no decrease) - - """ - hist = self._queue_history[op] - if not hist: - return 0 - - q = float(current_queue_size_bytes) - - # Step 1: update EWMAs - level_prev = self._q_level_nbytes[op] - dev_prev = self._q_level_dev[op] - - # Update EWMA level (typical queue size) with asymmetric behavior - # Why asymmetric? Quick to detect problems, slow to recover (prevents oscillation) - # Example: queue grows 100->200->150, EWMA follows 100->180->170 - # (jumps up fast, drops down slow) - level = self._update_ewma_asymmetric(level_prev, q) - - # Update EWMA deviation (typical absolute deviation) with asymmetric behavior - # Same logic: quick to detect high variability, slow to recover (prevents noise) - # Example: deviation jumps 10->30->20, EWMA follows 10->25->23 (fast up, slow down) - dev_sample = abs(q - level) - dev = self._update_ewma_asymmetric(dev_prev, dev_sample) - - self._q_level_nbytes[op] = level - self._q_level_dev[op] = dev - - # Step 2: base threshold from level & dev - # Example: level=1000, dev=200, K_DEV=4.0 -> base = 1000 + 4*200 = 1800 - base = level + self.K_DEV * dev - - # Step 3: fast ramp-up - threshold = max(1, int(max(base, q))) - - # Step 4: cache & return - prev_threshold = self._queue_thresholds[op] - - # Bootstrap - if prev_threshold == 0: - self._queue_thresholds[op] = max(1, threshold) - return self._queue_thresholds[op] - - # Only increase threshold when there's clear pressure - if threshold > prev_threshold: - self._queue_thresholds[op] = max(1, threshold) - return self._queue_thresholds[op] - - # Keep existing threshold when pressure decreases - return prev_threshold - - def _effective_cap(self, op: "PhysicalOperator") -> int: - """Compute a reduced concurrency cap via a tiny {-1,0,+1,+2} controller. - - Pressure and trend signals: - - pressure_signal: How far current queue is above threshold (normalized by absolute deviation estimate) - Formula: (current_queue_size_bytes - threshold) / max(dev, 1) - - Examples: - * queue=200, threshold=150, dev=20 -> pressure = (200-150)/20 = +2.5 - Meaning: Queue is 2.5x absolute deviation estimate level above threshold (high pressure, throttle!) - * queue=100, threshold=150, dev=20 -> pressure = (100-150)/20 = -2.5 - Meaning: Queue is 2.5x absolute deviation estimate level below threshold (low pressure, safe) - * queue=150, threshold=150, dev=20 -> pressure = (150-150)/20 = 0.0 - Meaning: Queue exactly at threshold (neutral pressure) - - - trend_signal: Whether queue is growing or shrinking (normalized by absolute deviation estimate) - Formula: (avg(recent_window) - avg(older_window)) / max(dev, 1) - - Examples: - * recent_avg=180, older_avg=160, dev=20 -> trend = (180-160)/20 = +1.0 - Meaning: Queue growing at 1x absolute deviation estimate level (upward trend, getting worse) - * recent_avg=140, older_avg=160, dev=20 -> trend = (140-160)/20 = -1.0 - Meaning: Queue shrinking at 1x absolute deviation estimate level (downward trend, getting better) - * recent_avg=160, older_avg=160, dev=20 -> trend = (160-160)/20 = 0.0 - Meaning: Queue stable (no trend) - - Controller decision logic: - - Decides concurrency adjustment {-1,0,+1,+2} based on pressure and trend signals - - Decision rules table: - +----------+----------+----------+--------------------------------+------------------+ - | Pressure | Trend | Step | Action | Example | - +----------+----------+----------+--------------------------------+------------------+ - | >= +2.0 | >= +1.0 | -1 | Emergency backoff | +2.5, +1.0 -> -1 | - | | | | (immediate reduction to | | - | | | | prevent overload) | | - +----------+----------+----------+--------------------------------+------------------+ - | >= +1.0 | >= 0.0 | 0 | Wait and see | +1.5, +0.5 -> 0 | - | | | | (let current level stabilize) | | - +----------+----------+----------+--------------------------------+------------------+ - | <= -1.0 | <= -1.0 | +1 | Conservative growth | -1.5, -1.0 -> +1 | - | | | | (safe to increase when | | - | | | | improving) | | - +----------+----------+----------+--------------------------------+------------------+ - | <= -2.0 | <= -2.0 | +2 | Aggressive growth | -2.5, -2.0 -> +2 | - | | | | (underutilized and improving | | - | | | | rapidly) | | - +----------+----------+----------+--------------------------------+------------------+ - | Other | Other | 0 | Hold | +0.5, -0.5 -> 0 | - | | | | (moderate signals, no clear | | - | | | | direction) | | - +----------+----------+----------+--------------------------------+------------------+ - - Logic summary: - - High pressure + growing trend = emergency backoff - - High pressure + stable trend = wait and see - - Low pressure + shrinking trend = safe to grow - - Very low pressure + strong improvement = aggressive growth - - Moderate signals = maintain current concurrency - - Args: - op: Operator whose effective cap we compute. - Returns: - An integer cap in [1, configured_cap]. + The effective cap. """ - hist = self._queue_history[op] - running = op.metrics.num_tasks_running - - # Need enough samples to evaluate short trend (recent + older windows). - recent_window = self.HISTORY_LEN // 2 - older_window = self.HISTORY_LEN // 2 - min_samples = recent_window + older_window - - if len(hist) < min_samples: - return max(1, running) - - # Trend windows and normalized signals - h = list(hist) - recent_avg = sum(h[-recent_window:]) / float(recent_window) - older_avg = sum(h[-(recent_window + older_window) : -recent_window]) / float( - older_window - ) - dev = max(1.0, self._q_level_dev[op]) - threshold = float(max(1, self._queue_thresholds[op])) - current_queue_size_bytes = float(hist[-1]) - - # Calculate normalized pressure and trend signals - scale = max(1.0, float(dev)) - pressure_signal = (current_queue_size_bytes - threshold) / scale - trend_signal = (recent_avg - older_avg) / scale + cap_cfg = self._concurrency_caps[op] - # Quantized controller decision - step = self._quantized_controller_step(pressure_signal, trend_signal) + level = float(self._q_level_nbytes[op]) + dev = max(1.0, float(self._q_level_dev[op])) + upper = level + self.K_DEV * dev + lower = level - self.K_DEV * dev + + if current_queue_size_bytes > upper: + # back off + target = num_tasks_running - self.BACKOFF_FACTOR + elif current_queue_size_bytes < lower: + # ramp up + target = num_tasks_running + self.RAMPUP_FACTOR + else: + # hold + target = num_tasks_running - # Apply step to current running concurrency, clamp by configured cap. - target = max(1, running + step) - cap_cfg = self._concurrency_caps[op] + # Clamp to [1, configured_cap] + target = max(1, target) if not math.isinf(cap_cfg): target = min(target, int(cap_cfg)) - return target - - def _quantized_controller_step( - self, pressure_signal: float, trend_signal: float - ) -> int: - """Compute the quantized controller step based on pressure and trend signals. - - This method implements the decision logic for the quantized controller: - - High pressure + growing trend = emergency backoff (-1) - - High pressure + stable/declining trend = wait and see (0) - - Low pressure + declining trend = safe to grow (+1) - - Very low pressure + strong improvement = aggressive growth (+2) - - Moderate signals = maintain current concurrency (0) - - Args: - pressure_signal: Normalized pressure signal (queue vs threshold) - trend_signal: Normalized trend signal (recent vs older average) - - Returns: - Step adjustment: -1, 0, +1, or +2 - """ - if pressure_signal >= 2.0 and trend_signal >= 1.0: - return -1 - elif pressure_signal >= 1.0 and trend_signal >= 0.0: - return 0 - elif pressure_signal <= -2.0 and trend_signal <= -2.0: - return +2 - elif pressure_signal <= -1.0 and trend_signal <= -1.0: - return +1 - else: - return 0 + return int(target) diff --git a/python/ray/data/tests/test_backpressure_policies.py b/python/ray/data/tests/test_backpressure_policies.py index cb2b6a5173d1..e1372158ad51 100644 --- a/python/ray/data/tests/test_backpressure_policies.py +++ b/python/ray/data/tests/test_backpressure_policies.py @@ -2,8 +2,8 @@ import math import time import unittest -from collections import defaultdict, deque -from unittest.mock import MagicMock, patch +from collections import defaultdict +from unittest.mock import MagicMock import pytest @@ -134,517 +134,224 @@ def test_e2e_normal(self): start2, end2 = ray.get(actor.get_start_and_end_time_for_op.remote(2)) assert start1 < start2 < end1 < end2, (start1, start2, end1, end2) - def test_can_add_input_with_normal_concurrency_cap(self): - """Test can_add_input when using normal concurrency cap (queue size disabled).""" - mock_op = MagicMock() - mock_op.name = "TestOperator" - mock_op.metrics.num_tasks_running = 3 - mock_op.throttling_disabled.return_value = False - mock_op.execution_finished.return_value = False - mock_op.output_dependencies = [] + def test_can_add_input_with_dynamic_output_queue_size_backpressure_disabled(self): + """Test can_add_input when dynamic output queue size backpressure is disabled.""" + input_op = InputDataBuffer(DataContext.get_current(), input_data=[MagicMock()]) + map_op = TaskPoolMapOperator( + map_transformer=MagicMock(), + data_context=DataContext.get_current(), + input_op=input_op, + max_concurrency=5, + ) + map_op.metrics.num_tasks_running = 3 + + topology = {map_op: MagicMock(), input_op: MagicMock()} + # Create policy with dynamic output queue size backpressure disabled policy = ConcurrencyCapBackpressurePolicy( DataContext.get_current(), - {mock_op: MagicMock()}, - MagicMock(), + topology, + MagicMock(), # resource_manager ) - - # Disable queue size based backpressure policy.enable_dynamic_output_queue_size_backpressure = False - policy._concurrency_caps[mock_op] = 5 - # Should allow input when running < cap - result = policy.can_add_input(mock_op) - self.assertTrue(result) + # Should only check against configured concurrency cap + self.assertTrue(policy.can_add_input(map_op)) # 3 < 5 - # Should deny input when running >= cap - mock_op.metrics.num_tasks_running = 5 - result = policy.can_add_input(mock_op) - self.assertFalse(result) - - def test_update_queue_threshold_bootstrap(self): - """Test threshold update for first sample (bootstrap).""" - mock_op = MagicMock() - policy = ConcurrencyCapBackpressurePolicy( - DataContext.get_current(), - {mock_op: MagicMock()}, - MagicMock(), - ) + map_op.metrics.num_tasks_running = 5 + self.assertFalse(policy.can_add_input(map_op)) # 5 >= 5 - # Add sample to history first (required for threshold calculation) - policy._queue_history[mock_op].append(1000) + def test_can_add_input_with_non_map_operator(self): + """Test can_add_input with non-MapOperator (should use basic cap check).""" + input_op = InputDataBuffer(DataContext.get_current(), input_data=[MagicMock()]) + input_op.metrics.num_tasks_running = 1 - # First sample should bootstrap threshold - # The threshold will be calculated as max(level + K_DEV * dev, q_now) - # where level=q_now=1000, dev=0 (first sample), so threshold = max(1000 + 4*0, 1000) = 1000 - threshold = policy._update_queue_threshold(mock_op, 1000) - self.assertEqual(threshold, 1000) - self.assertEqual(policy._queue_thresholds[mock_op], 1000) + topology = {input_op: MagicMock()} - # Test bootstrap with zero queue (should set threshold to 1 due to rounding) - fresh_mock_op = MagicMock() - fresh_policy = ConcurrencyCapBackpressurePolicy( - DataContext.get_current(), - {fresh_mock_op: MagicMock()}, - MagicMock(), - ) - fresh_policy._queue_thresholds[fresh_mock_op] = 0 # Reset to idle state - fresh_policy._queue_history[fresh_mock_op] = deque([0]) - # Fresh policy starts with clean EWMA state - - threshold_zero = fresh_policy._update_queue_threshold(fresh_mock_op, 0) - # When q_now=0, level=0, dev=0, threshold = max(1, max(0 + 4*0, 0)) = 1 - self.assertEqual(threshold_zero, 1) - self.assertEqual(fresh_policy._queue_thresholds[fresh_mock_op], 1) - - def test_update_queue_threshold_asymmetric_ewma(self): - """Test threshold update with asymmetric EWMA behavior.""" - mock_op = MagicMock() policy = ConcurrencyCapBackpressurePolicy( DataContext.get_current(), - {mock_op: MagicMock()}, - MagicMock(), + topology, + MagicMock(), # resource_manager ) - # Set up initial state - policy._q_level_nbytes[mock_op] = 100.0 - policy._q_level_dev[mock_op] = 20.0 - policy._queue_history[mock_op] = deque([100, 120, 140, 160, 180, 200]) + # InputDataBuffer has infinite concurrency cap, so should always allow + self.assertTrue(policy.can_add_input(input_op)) - # Test with growing queue (should use faster alpha_up) - threshold = policy._update_queue_threshold(mock_op, 300) + def test_can_add_input_with_object_store_memory_usage_ratio_above_threshold(self): + """Test can_add_input when object store memory usage ratio is above threshold.""" + input_op = InputDataBuffer(DataContext.get_current(), input_data=[MagicMock()]) + map_op = TaskPoolMapOperator( + map_transformer=MagicMock(), + data_context=DataContext.get_current(), + input_op=input_op, + max_concurrency=5, + ) + map_op.metrics.num_tasks_running = 3 - # Threshold should be at least as high as current queue - self.assertGreaterEqual(threshold, 300) + topology = {map_op: MagicMock(), input_op: MagicMock()} - # Level should have moved toward the new sample using alpha_up - self.assertGreater(policy._q_level_nbytes[mock_op], 100.0) + mock_resource_manager = MagicMock() - # Test with declining queue (should use slower EWMA_ALPHA) - policy._q_level_nbytes[mock_op] = 200.0 - policy._q_level_dev[mock_op] = 30.0 - policy._update_queue_threshold(mock_op, 150) + # Mock object store memory usage ratio above threshold + threshold = ConcurrencyCapBackpressurePolicy.OBJECT_STORE_USAGE_RATIO + mock_usage = MagicMock() + mock_usage.object_store_memory = 1000 # usage + mock_budget = MagicMock() + mock_budget.object_store_memory = int( + 1000 * (threshold + 0.1) + ) # budget above threshold - # Level should have moved less aggressively downward - self.assertGreater(policy._q_level_nbytes[mock_op], 150.0) - self.assertLess(policy._q_level_nbytes[mock_op], 200.0) + mock_resource_manager.get_op_usage.return_value = mock_usage + mock_resource_manager.get_budget.return_value = mock_budget - def test_update_queue_threshold_no_decrease(self): - """Test that thresholds are never decreased, only maintained or increased.""" - mock_op = MagicMock() policy = ConcurrencyCapBackpressurePolicy( DataContext.get_current(), - {mock_op: MagicMock()}, - MagicMock(), + topology, + mock_resource_manager, ) + policy.enable_dynamic_output_queue_size_backpressure = True - # Set up initial state with high threshold - policy._queue_thresholds[mock_op] = 200 - policy._q_level_nbytes[mock_op] = 10.0 # Very low level - policy._q_level_dev[mock_op] = 1.0 # Very low deviation - policy._queue_history[mock_op] = deque([10, 11, 12, 13, 14, 15]) + # Should skip dynamic backpressure and use basic cap check + self.assertTrue(policy.can_add_input(map_op)) # 3 < 5 - # Test that threshold is maintained when calculated threshold is lower - threshold = policy._update_queue_threshold(mock_op, 150) + map_op.metrics.num_tasks_running = 5 + self.assertFalse(policy.can_add_input(map_op)) # 5 >= 5 - # Should maintain the existing threshold (no decrease) - self.assertEqual(threshold, 200) - self.assertEqual(policy._queue_thresholds[mock_op], 200) - - # Test with even lower queue size - threshold_small = policy._update_queue_threshold(mock_op, 50) - self.assertEqual(threshold_small, 200) # Still maintained - self.assertEqual(policy._queue_thresholds[mock_op], 200) - - def test_update_queue_threshold_increase(self): - """Test that thresholds are increased when calculated threshold is higher.""" - mock_op = MagicMock() - policy = ConcurrencyCapBackpressurePolicy( - DataContext.get_current(), - {mock_op: MagicMock()}, - MagicMock(), + def test_can_add_input_with_object_store_memory_usage_ratio_below_threshold(self): + """Test can_add_input when object store memory usage ratio is below threshold.""" + input_op = InputDataBuffer(DataContext.get_current(), input_data=[MagicMock()]) + map_op = TaskPoolMapOperator( + map_transformer=MagicMock(), + data_context=DataContext.get_current(), + input_op=input_op, + max_concurrency=5, ) + map_op.metrics.num_tasks_running = 3 - # Set up initial state with moderate threshold - policy._queue_thresholds[mock_op] = 100 - policy._q_level_nbytes[mock_op] = 50.0 - policy._q_level_dev[mock_op] = 20.0 - policy._queue_history[mock_op] = deque([50, 60, 70, 80, 90, 100]) - - # Test that threshold is increased when calculated threshold is higher - threshold = policy._update_queue_threshold(mock_op, 200) + topology = {map_op: MagicMock(), input_op: MagicMock()} - # Should increase the threshold - self.assertGreaterEqual(threshold, 200) - self.assertGreaterEqual(policy._queue_thresholds[mock_op], 200) + mock_resource_manager = MagicMock() - def test_effective_cap_calculation_with_trend(self): - """Test effective cap calculation with different trend scenarios.""" - mock_op = MagicMock() - mock_op.metrics.num_tasks_running = 5 + # Mock object store memory usage ratio below threshold + threshold = ConcurrencyCapBackpressurePolicy.OBJECT_STORE_USAGE_RATIO + mock_usage = MagicMock() + mock_usage.object_store_memory = 1000 # usage + mock_budget = MagicMock() + mock_budget.object_store_memory = int( + 1000 * (threshold - 0.05) + ) # below threshold - policy = ConcurrencyCapBackpressurePolicy( - DataContext.get_current(), - {mock_op: MagicMock()}, - MagicMock(), - ) + mock_resource_manager.get_op_usage.return_value = mock_usage + mock_resource_manager.get_budget.return_value = mock_budget - # Set up queue history for trend calculation - policy._queue_history[mock_op] = deque( - [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] + # Mock queue size methods + mock_resource_manager.get_op_internal_object_store_usage.return_value = 100 + mock_resource_manager.get_op_outputs_object_store_usage_with_downstream.return_value = ( + 200 ) - policy._q_level_dev[mock_op] = 100.0 - policy._queue_thresholds[mock_op] = 500 - policy._concurrency_caps[mock_op] = 10 - - # Test with high pressure (queue > threshold) - with patch.object( - policy._resource_manager, - "get_op_internal_object_store_usage", - return_value=1000, - ), patch.object( - policy._resource_manager, - "get_op_outputs_object_store_usage_with_downstream", - return_value=1000, - ): - effective_cap = policy._effective_cap(mock_op) - # Should be reduced due to high pressure - self.assertLess(effective_cap, 10) - self.assertGreaterEqual(effective_cap, 1) # Should be at least 1 - - def test_effective_cap_insufficient_history(self): - """Test effective cap when there's insufficient history for trend calculation.""" - mock_op = MagicMock() - mock_op.metrics.num_tasks_running = 5 policy = ConcurrencyCapBackpressurePolicy( DataContext.get_current(), - {mock_op: MagicMock()}, - MagicMock(), + topology, + mock_resource_manager, ) + policy.enable_dynamic_output_queue_size_backpressure = True - # Set up insufficient history (less than 6 samples) - policy._queue_history[mock_op] = deque([100, 200, 300]) - policy._concurrency_caps[mock_op] = 10 - - effective_cap = policy._effective_cap(mock_op) - # Should return max(1, running) when insufficient history - self.assertEqual(effective_cap, 5) - - def test_signal_calculation_formulas(self): - """Test pressure_signal and trend_signal calculation formulas.""" - # Test pressure_signal formula: (q_now - threshold) / max(1.0, dev) - pressure_cases = [ - (1000, 500, 100, 5.0, "High pressure"), - (500, 500, 100, 0.0, "Neutral pressure"), - (200, 500, 100, -3.0, "Low pressure"), - (500, 500, 0, 0.0, "Zero deviation uses scale=1.0"), - ] - - for q_now, threshold, dev, expected, description in pressure_cases: - with self.subTest(signal="pressure", description=description): - scale = max(1.0, float(dev)) - pressure_signal = (q_now - threshold) / scale - self.assertAlmostEqual(pressure_signal, expected, places=5) - - # Test trend_signal formula: (recent_avg - older_avg) / max(1.0, dev) - trend_cases = [ - (1000, 500, 100, 5.0, "Strong growth"), - (500, 500, 100, 0.0, "No trend"), - (200, 500, 100, -3.0, "Strong decline"), - (500, 500, 0, 0.0, "Zero deviation uses scale=1.0"), - ] - - for recent_avg, older_avg, dev, expected, description in trend_cases: - with self.subTest(signal="trend", description=description): - scale = max(1.0, float(dev)) - trend_signal = (recent_avg - older_avg) / scale - self.assertAlmostEqual(trend_signal, expected, places=5) + # Should proceed with dynamic backpressure logic + # Initialize EWMA state for the operator + policy._q_level_nbytes[map_op] = 300.0 + policy._q_level_dev[map_op] = 50.0 - def test_decision_rules_table_comprehensive(self): - """Test all decision rules from the table comprehensively.""" - test_cases = [ - # (pressure_signal, trend_signal, expected_step, description) - # High pressure scenarios - (2.5, 1.5, -1, "High pressure + growing trend -> backoff"), - (2.0, 1.0, -1, "High pressure + growing trend (boundary) -> backoff"), - (2.5, 0.5, 0, "High pressure + mild growth -> hold"), - (2.5, 0.0, 0, "High pressure + no trend -> hold"), - (2.5, -0.5, 0, "High pressure + mild decline -> hold"), - (2.5, -1.0, 0, "High pressure + declining trend -> hold"), - # Moderate pressure scenarios - (1.5, 1.5, 0, "Moderate pressure + growing trend -> hold"), - (1.0, 1.0, 0, "Moderate pressure + growing trend (boundary) -> hold"), - (1.5, 0.5, 0, "Moderate pressure + mild growth -> hold"), - (1.5, 0.0, 0, "Moderate pressure + no trend -> hold"), - (1.5, -0.5, 0, "Moderate pressure + mild decline -> hold"), - (1.5, -1.0, 0, "Moderate pressure + declining trend -> hold"), - # Low pressure scenarios - (-1.5, -1.5, 1, "Low pressure + declining trend -> increase"), - (-1.0, -1.0, 1, "Low pressure + declining trend (boundary) -> increase"), - (-1.5, -0.5, 0, "Low pressure + mild decline -> hold"), - (-1.5, 0.0, 0, "Low pressure + no trend -> hold"), - (-1.5, 0.5, 0, "Low pressure + mild growth -> hold"), - (-1.5, 1.0, 0, "Low pressure + growing trend -> hold"), - # Very low pressure scenarios - (-2.5, -2.5, 2, "Very low pressure + declining trend -> increase by 2"), - ( - -2.0, - -2.0, - 2, - "Very low pressure + declining trend (boundary) -> increase by 2", - ), - (-2.5, -1.5, 1, "Very low pressure + mild decline -> increase by 1"), - (-2.5, -1.0, 1, "Very low pressure + mild decline -> increase by 1"), - (-2.5, 0.0, 0, "Very low pressure + no trend -> hold"), - (-2.5, 0.5, 0, "Very low pressure + mild growth -> hold"), - (-2.5, 1.0, 0, "Very low pressure + growing trend -> hold"), - # Neutral scenarios - (0.5, 0.5, 0, "Low pressure + mild growth -> hold"), - (0.0, 0.0, 0, "Neutral pressure + no trend -> hold"), - (-0.5, -0.5, 0, "Low pressure + mild decline -> hold"), - # Edge cases - (1.0, 0.0, 0, "Moderate pressure + no trend (boundary) -> hold"), - (0.0, 1.0, 0, "Neutral pressure + growing trend -> hold"), - (0.0, -1.0, 0, "Neutral pressure + declining trend -> hold"), - ] + result = policy.can_add_input(map_op) + # With queue size 300 in hold region (level=300, dev=50, bounds=[200, 400]), + # should hold current level, so running=3 < effective_cap=3 should be False + self.assertFalse(result) - # Create a policy instance to access the helper method - mock_op = MagicMock() - policy = ConcurrencyCapBackpressurePolicy( - DataContext.get_current(), - {mock_op: MagicMock()}, - MagicMock(), + def test_can_add_input_effective_cap_calculation(self): + """Test that effective cap calculation works correctly with different queue sizes.""" + input_op = InputDataBuffer(DataContext.get_current(), input_data=[MagicMock()]) + map_op = TaskPoolMapOperator( + map_transformer=MagicMock(), + data_context=DataContext.get_current(), + input_op=input_op, + max_concurrency=8, ) + map_op.metrics.num_tasks_running = 4 - for pressure_signal, trend_signal, expected_step, description in test_cases: - with self.subTest(description=description): - # Use the actual helper method from the policy - step = policy._quantized_controller_step(pressure_signal, trend_signal) + topology = {map_op: MagicMock(), input_op: MagicMock()} - self.assertEqual( - step, - expected_step, - f"Failed for pressure={pressure_signal}, trend={trend_signal}", - ) - - def test_ewma_calculation_formulas(self): - """Test EWMA level, deviation, and alpha calculation formulas.""" - # Test EWMA level formula: (1 - alpha) * prev + alpha * sample - level_cases = [ - (100.0, 120.0, 0.2, 104.0, "Normal alpha"), - (100.0, 80.0, 0.2, 96.0, "Normal alpha down"), - (100.0, 100.0, 0.2, 100.0, "Stable"), - (0.0, 100.0, 0.2, 20.0, "Bootstrap"), - ] - - for prev_level, sample, alpha, expected, description in level_cases: - with self.subTest(formula="level", description=description): - new_level = (1 - alpha) * prev_level + alpha * sample - self.assertAlmostEqual(new_level, expected, places=5) - - # Test EWMA deviation formula: (1 - alpha) * prev_dev + alpha * abs(sample - prev_level) - dev_cases = [ - (20.0, 120.0, 100.0, 0.2, 20.0, "Growing"), - (20.0, 100.0, 100.0, 0.2, 16.0, "Stable"), - (0.0, 100.0, 0.0, 0.2, 20.0, "Bootstrap"), - ] - - for prev_dev, sample, prev_level, alpha, expected, description in dev_cases: - with self.subTest(formula="deviation", description=description): - new_dev = (1 - alpha) * prev_dev + alpha * abs(sample - prev_level) - self.assertAlmostEqual(new_dev, expected, places=5) - - # Test alpha_up calculation: 1.0 - (1.0 - EWMA_ALPHA) ** 2 - alpha_cases = [ - (0.2, 0.36, "Normal EWMA_ALPHA"), - (0.1, 0.19, "Low EWMA_ALPHA"), - (0.5, 0.75, "High EWMA_ALPHA"), - ] - - for EWMA_ALPHA, expected, description in alpha_cases: - with self.subTest(formula="alpha_up", description=description): - alpha_up = 1.0 - (1.0 - EWMA_ALPHA) ** 2 - self.assertAlmostEqual(alpha_up, expected, places=5) - - def test_threshold_calculation_formula(self): - """Test threshold calculation: max(level + K_DEV * dev, q_now).""" - test_cases = [ - (100.0, 20.0, 150.0, 4.0, 180.0, "Normal case"), - (100.0, 20.0, 200.0, 4.0, 200.0, "High queue"), - (100.0, 0.0, 150.0, 4.0, 150.0, "Zero deviation"), - (100.0, 20.0, 150.0, 2.0, 150.0, "Lower K_DEV"), - ] + mock_resource_manager = MagicMock() + threshold = ConcurrencyCapBackpressurePolicy.OBJECT_STORE_USAGE_RATIO + mock_usage = MagicMock() + mock_usage.object_store_memory = 1000 + mock_budget = MagicMock() + mock_budget.object_store_memory = int( + 1000 * (threshold - 0.05) + ) # below threshold - for level, dev, q_now, K_DEV, expected, description in test_cases: - with self.subTest(description=description): - threshold = max(level + K_DEV * dev, q_now) - self.assertAlmostEqual(threshold, expected, places=5) + mock_resource_manager.get_op_usage.return_value = mock_usage + mock_resource_manager.get_budget.return_value = mock_budget - def test_threshold_update_logic_comprehensive(self): - """Test comprehensive threshold update logic including bootstrap, upward, and no-decrease cases.""" - mock_op = MagicMock() policy = ConcurrencyCapBackpressurePolicy( DataContext.get_current(), - {mock_op: MagicMock()}, - MagicMock(), + topology, + mock_resource_manager, ) + policy.enable_dynamic_output_queue_size_backpressure = True - # Test 1: Bootstrap case (prev_threshold = 0) - policy._queue_thresholds[mock_op] = 0 - policy._queue_history[mock_op] = deque([100]) - threshold1 = policy._update_queue_threshold(mock_op, 100) - # Bootstrap: threshold = max(level + K_DEV * dev, q_now) = max(100 + 4*0, 100) = 100 - self.assertEqual(threshold1, 100) - - # Test 2: Upward adjustment (threshold > prev_threshold) - policy._queue_thresholds[mock_op] = 100 - policy._q_level_nbytes[mock_op] = 50.0 - policy._q_level_dev[mock_op] = 10.0 - policy._queue_history[mock_op] = deque([50, 60, 70, 80, 90, 100]) - threshold2 = policy._update_queue_threshold(mock_op, 200) - # The EWMA will update level and dev, so we can't predict exact value - # Just verify it's >= 200 (upward adjustment) - self.assertGreaterEqual(threshold2, 200) - - # Test 3: No decrease (threshold < prev_threshold, should maintain existing) - policy._queue_thresholds[mock_op] = 200 - policy._q_level_nbytes[mock_op] = 10.0 # Very low level - policy._q_level_dev[mock_op] = 1.0 # Very low deviation - policy._queue_history[mock_op] = deque([10, 11, 12, 13, 14, 15]) - threshold3 = policy._update_queue_threshold(mock_op, 150) - self.assertEqual(threshold3, 200) - - # Test 4: Zero threshold case - fresh_mock_op = MagicMock() - fresh_policy = ConcurrencyCapBackpressurePolicy( - DataContext.get_current(), - {fresh_mock_op: MagicMock()}, - MagicMock(), - ) - fresh_policy._queue_thresholds[fresh_mock_op] = 0 - fresh_policy._queue_history[fresh_mock_op] = deque([0]) - # Fresh policy starts with clean EWMA state - threshold4 = fresh_policy._update_queue_threshold(fresh_mock_op, 0) - self.assertEqual(threshold4, 1) # Should round up to 1 - - def test_trend_and_effective_cap_formulas(self): - """Test trend calculation and effective cap formulas.""" - # Test trend calculation: recent_avg - older_avg - trend_cases = [ - ([100, 200, 300, 400, 500, 600], 500.0, 200.0, 300.0, "6 samples"), - ([100, 200, 300, 400, 500, 600, 700], 600.0, 300.0, 300.0, "7 samples"), + # Test different queue sizes using policy constants + test_cases = [ + # (internal_usage, downstream_usage, level, dev, expected_result, description) + ( + 50, + 50, + 5000.0, + 200.0, + True, + "low_queue_below_lower_bound", + ), # 100 < 5000 - 2*200 = 4600, ramp up + ( + 200, + 200, + 400.0, + 50.0, + False, + "medium_queue_in_hold_region", + ), # 400 in [300, 500], hold + ( + 300, + 300, + 200.0, + 50.0, + False, + "high_queue_above_upper_bound", + ), # 600 > 200 + 2*50 = 300, backoff ] for ( - history, - expected_recent, - expected_older, - expected_trend, + internal_usage, + downstream_usage, + level, + dev, + expected_result, description, - ) in trend_cases: - with self.subTest(formula="trend", description=description): - h = list(history) - recent_window = len(h) // 2 - older_window = len(h) // 2 - - recent_avg = sum(h[-recent_window:]) / float(recent_window) - older_avg = sum( - h[-(recent_window + older_window) : -recent_window] - ) / float(older_window) - trend = recent_avg - older_avg - - self.assertAlmostEqual(recent_avg, expected_recent, places=5) - self.assertAlmostEqual(older_avg, expected_older, places=5) - self.assertAlmostEqual(trend, expected_trend, places=5) - - # Test effective cap formula: max(1, running + step) - cap_cases = [ - (5, -1, 4, "Reduce by 1"), - (5, 0, 5, "No change"), - (5, 1, 6, "Increase by 1"), - (1, -1, 1, "Min cap"), - ] - - for running, step, expected, description in cap_cases: - with self.subTest(formula="effective_cap", description=description): - effective_cap = max(1, running + step) - self.assertEqual(effective_cap, expected) - - def test_ewma_asymmetric_behavior(self): - """Test EWMA asymmetric behavior and level calculation.""" - # Test alpha selection: alpha_up if sample > prev else EWMA_ALPHA - alpha_cases = [ - (100.0, 150.0, 0.2, 0.36, "Rising uses alpha_up"), - (100.0, 50.0, 0.2, 0.2, "Falling uses EWMA_ALPHA"), - (100.0, 100.0, 0.2, 0.2, "Stable uses EWMA_ALPHA"), - ] - - for prev_level, sample, EWMA_ALPHA, expected, description in alpha_cases: - with self.subTest(behavior="alpha_selection", description=description): - alpha_up = 1.0 - (1.0 - EWMA_ALPHA) ** 2 - alpha = alpha_up if sample > prev_level else EWMA_ALPHA - self.assertAlmostEqual(alpha, expected, places=5) - - # Test level calculation with asymmetric alpha - level_cases = [ - (100.0, 150.0, 0.2, 118.0, "Rising with alpha_up"), - (100.0, 50.0, 0.2, 90.0, "Falling with EWMA_ALPHA"), - (0.0, 100.0, 0.2, 100.0, "Bootstrap uses sample"), - ] - - for prev_level, sample, EWMA_ALPHA, expected, description in level_cases: - with self.subTest(behavior="level_calculation", description=description): - if prev_level <= 0: - level = sample - else: - alpha_up = 1.0 - (1.0 - EWMA_ALPHA) ** 2 - alpha = alpha_up if sample > prev_level else EWMA_ALPHA - level = (1 - alpha) * prev_level + alpha * sample - self.assertAlmostEqual(level, expected, places=5) - - def test_simple_calculation_formulas(self): - """Test simple calculation formulas: scale, min_samples, and windows.""" - # Test scale calculation: max(1.0, float(dev)) - scale_cases = [ - (100.0, 100.0, "Normal deviation"), - (0.0, 1.0, "Zero deviation"), - (0.5, 1.0, "Small deviation"), - (1.1, 1.1, "Just above unit"), - ] - - for dev, expected, description in scale_cases: - with self.subTest(formula="scale", description=description): - scale = max(1.0, float(dev)) - self.assertAlmostEqual(scale, expected, places=5) - - # Test min_samples calculation: recent_window + older_window - min_samples_cases = [ - (10, 10, "HISTORY_LEN=10"), - (6, 6, "HISTORY_LEN=6"), - (12, 12, "HISTORY_LEN=12"), - ] + ) in test_cases: + with self.subTest(description=description): + mock_resource_manager.get_op_internal_object_store_usage.return_value = ( + internal_usage + ) + mock_resource_manager.get_op_outputs_object_store_usage_with_downstream.return_value = ( + downstream_usage + ) - for HISTORY_LEN, expected, description in min_samples_cases: - with self.subTest(formula="min_samples", description=description): - recent_window = HISTORY_LEN // 2 - older_window = HISTORY_LEN // 2 - min_samples = recent_window + older_window - self.assertEqual(min_samples, expected) - - # Test window calculation: recent_window = older_window = HISTORY_LEN // 2 - window_cases = [ - (10, 5, 5, "HISTORY_LEN=10"), - (6, 3, 3, "HISTORY_LEN=6"), - (9, 4, 4, "HISTORY_LEN=9 (integer division)"), - ] + # Initialize EWMA state + policy._q_level_nbytes[map_op] = level + policy._q_level_dev[map_op] = dev - for HISTORY_LEN, expected_recent, expected_older, description in window_cases: - with self.subTest(formula="windows", description=description): - recent_window = HISTORY_LEN // 2 - older_window = HISTORY_LEN // 2 - self.assertEqual(recent_window, expected_recent) - self.assertEqual(older_window, expected_older) + result = policy.can_add_input(map_op) + assert ( + result == expected_result + ), f"Expected {expected_result} for {description}" if __name__ == "__main__":