Skip to content

Commit c855106

Browse files
abrarsheikhdstrodtman
authored andcommitted
Aggregate autoscaling metrics on controller (#56306)
## Controller Metrics Aggregation + Code Refactoring ### What Changed - **New Feature**: Added `RAY_SERVE_AGGREGATE_METRICS_AT_CONTROLLER` flag to enable metrics aggregation at controller level using timeseries merging - **Code Cleanup**: Refactored `get_total_num_requests()` as it was starting to get complicated - **Enhanced Testing**: Added multi-environment test variants to cover different metrics collection modes **Fully backward compatible** - existing behavior unchanged when flag is disabled. ### Changed the merge algorithm i think the problem is the current algorithm that uses latest in bucket during merge is not robust and is lossy. Which make it highly susceptible to the choice of bucket width. I am going to rewrite that algorithm as follows to see if it helps Interpret each replica’s gauge as **right-continuous, last-observation-carried-forward (LOCF)**. Then: 1. **Turn each replica into “delta events.”** For a sorted series $(t_0,v_0),(t_1,v_1),…$ emit: * at $t_0$: $+\;v_0$ * at $t_j$: $+\;(v_j - v_{j-1})$ for $j\ge1$ 2. **K-way merge all events by time.** Maintain `current_sum`. At each event time $t$, apply the sum of all deltas at $t$, update `current_sum`, and **record a point** $(t, current_sum)$. The result is an **event-driven, piecewise-constant series** $S(t)$: between event timestamps it holds constant and represents the instantaneous total across replicas. 3. **Resample to a regular grid** To get the instantaneous value at grid time $g_k$, take the last event at or before $g_k$ (LOCF on the merged step series). ## Next PR #56311 --------- Signed-off-by: abrar <[email protected]> Signed-off-by: Douglas Strodtman <[email protected]>
1 parent 324d823 commit c855106

File tree

9 files changed

+834
-403
lines changed

9 files changed

+834
-403
lines changed

bazel/python.bzl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,22 @@ def py_test_run_all_notebooks(include, exclude, allow_empty=False, **kwargs):
114114
args = ["--find-recursively", "--path", file],
115115
**kwargs
116116
)
117+
118+
def py_test_module_list_with_env_variants(files, env_variants, size="medium", **kwargs):
119+
"""Create multiple py_test_module_list targets with different environment variable configurations.
120+
121+
Args:
122+
files: List of test files to run
123+
env_variants: Dict where keys are variant names and values are dicts containing
124+
'env' and 'name_suffix' keys
125+
size: Test size
126+
**kwargs: Additional arguments passed to py_test_module_list
127+
"""
128+
for variant_name, variant_config in env_variants.items():
129+
py_test_module_list(
130+
size = size,
131+
files = files,
132+
env = variant_config.get("env", {}),
133+
name_suffix = variant_config.get("name_suffix", "_{}".format(variant_name)),
134+
**kwargs
135+
)

python/ray/serve/_private/autoscaling_state.py

Lines changed: 296 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,15 @@
1313
TargetCapacityDirection,
1414
)
1515
from ray.serve._private.constants import (
16+
RAY_SERVE_AGGREGATE_METRICS_AT_CONTROLLER,
1617
RAY_SERVE_MIN_HANDLE_METRICS_TIMEOUT_S,
1718
SERVE_LOGGER_NAME,
1819
)
1920
from ray.serve._private.deployment_info import DeploymentInfo
21+
from ray.serve._private.metrics_utils import (
22+
merge_timeseries_dicts,
23+
time_weighted_average,
24+
)
2025
from ray.serve._private.utils import get_capacity_adjusted_num_replicas
2126

2227
logger = logging.getLogger(SERVE_LOGGER_NAME)
@@ -162,7 +167,6 @@ def record_request_metrics_for_replica(
162167
self, replica_metric_report: ReplicaMetricReport
163168
) -> None:
164169
"""Records average number of ongoing requests at a replica."""
165-
166170
replica_id = replica_metric_report.replica_id
167171
send_timestamp = replica_metric_report.timestamp
168172

@@ -268,18 +272,281 @@ def get_decision_num_replicas(
268272

269273
return self.apply_bounds(decision_num_replicas)
270274

271-
def get_total_num_requests(self) -> float:
272-
"""Get average total number of requests aggregated over the past
273-
`look_back_period_s` number of seconds.
275+
def _collect_replica_running_requests(self) -> List[Dict[str, List]]:
276+
"""Collect running requests metrics from replicas for aggregation."""
277+
metrics_timeseries_dicts = []
274278

275-
If there are 0 running replicas, then returns the total number
276-
of requests queued at handles
279+
for replica_id in self._running_replicas:
280+
replica_metric_report = self._replica_metrics.get(replica_id, None)
281+
if (
282+
replica_metric_report is not None
283+
and RUNNING_REQUESTS_KEY in replica_metric_report.metrics
284+
):
285+
metrics_timeseries_dicts.append(
286+
{
287+
RUNNING_REQUESTS_KEY: replica_metric_report.metrics[
288+
RUNNING_REQUESTS_KEY
289+
]
290+
}
291+
)
277292

278-
This code assumes that the metrics are either emmited on handles
279-
or on replicas, but not both. Its the responsibility of the writer
280-
to ensure enclusivity of the metrics.
293+
return metrics_timeseries_dicts
294+
295+
def _collect_handle_queued_requests(self) -> float:
296+
"""Collect total queued requests from all handles."""
297+
total_queued_requests = 0
298+
for handle_metric_report in self._handle_requests.values():
299+
total_queued_requests += handle_metric_report.queued_requests
300+
return total_queued_requests
301+
302+
def _collect_handle_running_requests(self) -> List[Dict[str, List]]:
303+
"""Collect running requests metrics from handles when not collected on replicas.
304+
305+
Returns:
306+
A list of dictionaries, each containing a key-value pair:
307+
- The key is the name of the metric (RUNNING_REQUESTS_KEY)
308+
- The value is a list of TimeStampedValue objects, each representing a single measurement of the metric
309+
this list is sorted by timestamp ascending
310+
- The TimeStampedValue object contains a timestamp and a value
311+
- The timestamp is the time at which the measurement was taken
312+
- The value is the measurement of the metric
313+
314+
Example:
315+
If there are 2 handles, each managing 2 replicas, and the running requests metrics are:
316+
- Handle 1: Replica 1: 5, Replica 2: 7
317+
- Handle 2: Replica 1: 3, Replica 2: 1
318+
and the timestamp is 0.1 and 0.2 respectively
319+
Then the returned list will be:
320+
[
321+
{
322+
"running_requests": [
323+
TimeStampedValue(timestamp=0.1, value=5.0),
324+
]
325+
},
326+
{
327+
"running_requests": [
328+
TimeStampedValue(timestamp=0.2, value=7.0),
329+
]
330+
},
331+
{
332+
"running_requests": [
333+
TimeStampedValue(timestamp=0.1, value=3.0),
334+
]
335+
},
336+
{
337+
"running_requests": [
338+
TimeStampedValue(timestamp=0.2, value=1.0),
339+
]
340+
}
341+
]
281342
"""
343+
metrics_timeseries_dicts = []
344+
345+
for handle_metric in self._handle_requests.values():
346+
for replica_id in self._running_replicas:
347+
if (
348+
RUNNING_REQUESTS_KEY not in handle_metric.metrics
349+
or replica_id not in handle_metric.metrics[RUNNING_REQUESTS_KEY]
350+
):
351+
continue
352+
metrics_timeseries_dicts.append(
353+
{
354+
RUNNING_REQUESTS_KEY: handle_metric.metrics[
355+
RUNNING_REQUESTS_KEY
356+
][replica_id]
357+
}
358+
)
359+
360+
return metrics_timeseries_dicts
361+
362+
def _aggregate_running_requests(
363+
self, metrics_timeseries_dicts: List[Dict[str, List]]
364+
) -> float:
365+
"""Aggregate and average running requests from timeseries data using instantaneous merge.
366+
367+
Args:
368+
metrics_timeseries_dicts: A list of dictionaries, each containing a key-value pair:
369+
- The key is the name of the metric (RUNNING_REQUESTS_KEY)
370+
- The value is a list of TimeStampedValue objects, each representing a single measurement of the metric
371+
this list is sorted by timestamp ascending
372+
373+
Returns:
374+
The time-weighted average of the running requests
375+
376+
Example:
377+
If the metrics_timeseries_dicts is:
378+
[
379+
{
380+
"running_requests": [
381+
TimeStampedValue(timestamp=0.1, value=5.0),
382+
TimeStampedValue(timestamp=0.2, value=7.0),
383+
]
384+
},
385+
{
386+
"running_requests": [
387+
TimeStampedValue(timestamp=0.2, value=3.0),
388+
TimeStampedValue(timestamp=0.3, value=1.0),
389+
]
390+
}
391+
]
392+
Then the returned value will be:
393+
(5.0*0.1 + 7.0*0.2 + 3.0*0.2 + 1.0*0.3) / (0.1 + 0.2 + 0.2 + 0.3) = 4.5 / 0.8 = 5.625
394+
"""
395+
396+
if not metrics_timeseries_dicts:
397+
return 0.0
398+
399+
# Use instantaneous merge approach - no arbitrary windowing needed
400+
aggregated_metrics = merge_timeseries_dicts(*metrics_timeseries_dicts)
401+
running_requests_timeseries = aggregated_metrics.get(RUNNING_REQUESTS_KEY, [])
402+
if running_requests_timeseries:
403+
404+
# assume that the last recorded metric is valid for last_window_s seconds
405+
last_metric_time = running_requests_timeseries[-1].timestamp
406+
# we dont want to make any assumption about how long the last metric will be valid
407+
# only conclude that the last metric is valid for last_window_s seconds that is the
408+
# difference between the current time and the last metric recorded time
409+
last_window_s = time.time() - last_metric_time
410+
# adding a check to negative values caused by clock skew
411+
# between replicas and controller. Also add a small epsilon to avoid division by zero
412+
if last_window_s <= 0:
413+
last_window_s = 1e-3
414+
# Calculate the time-weighted average of the running requests
415+
avg_running = time_weighted_average(
416+
running_requests_timeseries, last_window_s=last_window_s
417+
)
418+
return avg_running if avg_running is not None else 0.0
419+
420+
return 0.0
421+
422+
def _calculate_total_requests_aggregate_mode(self) -> float:
423+
"""Calculate total requests using aggregate metrics mode with timeseries data.
424+
425+
This method works with raw timeseries metrics data and performs aggregation
426+
at the controller level, providing more accurate and stable metrics compared
427+
to simple mode.
428+
429+
Processing Steps:
430+
1. Collect raw timeseries data (eg: running request) from replicas (if available)
431+
2. Collect queued requests from handles (always tracked at handle level)
432+
3. Collect raw timeseries data (eg: running request) from handles (if not available from replicas)
433+
4. Merge timeseries using instantaneous approach for mathematically correct totals
434+
5. Calculate time-weighted average running requests from the merged timeseries
435+
436+
Key Differences from Simple Mode:
437+
- Uses raw timeseries data instead of pre-aggregated metrics
438+
- Performs instantaneous merging for exact gauge semantics
439+
- Aggregates at the controller level rather than using pre-computed averages
440+
- Uses time-weighted averaging over the look_back_period_s interval for accurate calculations
441+
442+
Metrics Collection:
443+
Running requests are collected with either replica-level or handle-level metrics.
444+
445+
Queued requests are always collected from handles regardless of where
446+
running requests are collected.
447+
448+
Timeseries Aggregation:
449+
Raw timeseries data from multiple sources is merged using an instantaneous
450+
approach that treats gauges as right-continuous step functions. This provides
451+
mathematically correct totals without arbitrary windowing bias.
452+
453+
Example with Numbers:
454+
Assume metrics_interval_s = 0.5s, current time = 2.0s
282455
456+
Step 1: Collect raw timeseries from 2 replicas (r1, r2)
457+
replica_metrics = [
458+
{"running_requests": [(t=0.2, val=5), (t=0.8, val=7), (t=1.5, val=6)]}, # r1
459+
{"running_requests": [(t=0.1, val=3), (t=0.9, val=4), (t=1.4, val=8)]} # r2
460+
]
461+
462+
Step 2: Collect queued requests from handles
463+
handle_queued = 2 + 3 = 5 # total from all handles
464+
465+
Step 3: No handle metrics needed (replica metrics available)
466+
handle_metrics = []
467+
468+
Step 4: Merge timeseries using instantaneous approach
469+
# Create delta events: r1 starts at 5 (t=0.2), changes to 7 (t=0.8), then 6 (t=1.5)
470+
# r2 starts at 3 (t=0.1), changes to 4 (t=0.9), then 8 (t=1.4)
471+
# Merged instantaneous total: [(t=0.1, val=3), (t=0.2, val=8), (t=0.8, val=10), (t=0.9, val=11), (t=1.4, val=15), (t=1.5, val=14)]
472+
merged_timeseries = {"running_requests": [(0.1, 3), (0.2, 8), (0.8, 10), (0.9, 11), (1.4, 15), (1.5, 14)]}
473+
474+
Step 5: Calculate time-weighted average over full timeseries (t=0.1 to t=1.5+0.5=2.0)
475+
# Time-weighted calculation: (3*0.1 + 8*0.6 + 10*0.1 + 11*0.5 + 15*0.1 + 14*0.5) / 2.0 = 10.05
476+
avg_running = 10.05
477+
478+
Final result: total_requests = avg_running + queued = 10.05 + 5 = 15.05
479+
480+
Returns:
481+
Total number of requests (average running + queued) calculated from
482+
timeseries data aggregation.
483+
"""
484+
# Collect replica-based running requests
485+
replica_metrics = self._collect_replica_running_requests()
486+
metrics_collected_on_replicas = len(replica_metrics) > 0
487+
488+
# Collect queued requests from handles
489+
total_requests = self._collect_handle_queued_requests()
490+
491+
if not metrics_collected_on_replicas:
492+
# Collect handle-based running requests if not collected on replicas
493+
handle_metrics = self._collect_handle_running_requests()
494+
else:
495+
handle_metrics = []
496+
497+
# Combine all running requests metrics
498+
all_running_metrics = replica_metrics + handle_metrics
499+
500+
# Aggregate and add running requests to total
501+
total_requests += self._aggregate_running_requests(all_running_metrics)
502+
503+
return total_requests
504+
505+
def _calculate_total_requests_simple_mode(self) -> float:
506+
"""Calculate total requests using simple aggregated metrics mode.
507+
508+
This method works with pre-aggregated metrics that are computed by averaging
509+
(or other functions) over the past look_back_period_s seconds.
510+
511+
Metrics Collection:
512+
Metrics can be collected at two levels:
513+
1. Replica level: Each replica reports one aggregated metric value
514+
2. Handle level: Each handle reports metrics for multiple replicas
515+
516+
Replica-Level Metrics Example:
517+
For 3 replicas (r1, r2, r3), metrics might look like:
518+
{
519+
"r1": 10,
520+
"r2": 20,
521+
"r3": 30
522+
}
523+
Total requests = 10 + 20 + 30 = 60
524+
525+
Handle-Level Metrics Example:
526+
For 3 handles (h1, h2, h3), each managing 2 replicas:
527+
- h1 manages r1, r2
528+
- h2 manages r2, r3
529+
- h3 manages r3, r1
530+
531+
Metrics structure:
532+
{
533+
"h1": {"r1": 10, "r2": 20},
534+
"h2": {"r2": 20, "r3": 30},
535+
"h3": {"r3": 30, "r1": 10}
536+
}
537+
538+
Total requests = 10 + 20 + 20 + 30 + 30 + 10 = 120
539+
540+
Note: We can safely sum all handle metrics because each unique request
541+
is counted only once across all handles (no double-counting).
542+
543+
Queued Requests:
544+
Queued request metrics are always tracked at the handle level, regardless
545+
of whether running request metrics are collected at replicas or handles.
546+
547+
Returns:
548+
Total number of requests (running + queued) across all replicas/handles.
549+
"""
283550
total_requests = 0
284551

285552
for id in self._running_replicas:
@@ -289,20 +556,39 @@ def get_total_num_requests(self) -> float:
289556
)
290557

291558
metrics_collected_on_replicas = total_requests > 0
559+
560+
# Add handle metrics
292561
for handle_metric in self._handle_requests.values():
293562
total_requests += handle_metric.queued_requests
294563

564+
# Add running requests from handles if not collected on replicas
295565
if not metrics_collected_on_replicas:
296566
for replica_id in self._running_replicas:
297567
if replica_id in handle_metric.aggregated_metrics.get(
298-
RUNNING_REQUESTS_KEY
568+
RUNNING_REQUESTS_KEY, {}
299569
):
300570
total_requests += handle_metric.aggregated_metrics.get(
301571
RUNNING_REQUESTS_KEY
302572
).get(replica_id)
303573

304574
return total_requests
305575

576+
def get_total_num_requests(self) -> float:
577+
"""Get average total number of requests aggregated over the past
578+
`look_back_period_s` number of seconds.
579+
580+
If there are 0 running replicas, then returns the total number
581+
of requests queued at handles
582+
583+
This code assumes that the metrics are either emmited on handles
584+
or on replicas, but not both. Its the responsibility of the writer
585+
to ensure enclusivity of the metrics.
586+
"""
587+
if RAY_SERVE_AGGREGATE_METRICS_AT_CONTROLLER:
588+
return self._calculate_total_requests_aggregate_mode()
589+
else:
590+
return self._calculate_total_requests_simple_mode()
591+
306592
def get_replica_metrics(self, agg_func: str) -> Dict[ReplicaID, List[Any]]:
307593
"""Get the raw replica metrics dict."""
308594
# arcyleung TODO: pass agg_func from autoscaling policy https://github.com/ray-project/ray/pull/51905

python/ray/serve/_private/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,3 +506,8 @@
506506
# This is used to detect and warn about long RPC latencies
507507
# between the controller and the replicas.
508508
RAY_SERVE_RPC_LATENCY_WARNING_THRESHOLD_MS = 2000
509+
510+
# Feature flag to aggregate metrics at the controller instead of the replicas or handles.
511+
RAY_SERVE_AGGREGATE_METRICS_AT_CONTROLLER = get_env_bool(
512+
"RAY_SERVE_AGGREGATE_METRICS_AT_CONTROLLER", "0"
513+
)

0 commit comments

Comments
 (0)