Skip to content

Commit 3f1ba30

Browse files
committed
[Metrics] Move LoRA request counts to SchedulerStats
SchedulerStats is the right place for this really, just like the regular running/waiting counts. Make sure to call LoRARequestStates.update_scheduler_stats() even where there was no engine core outputs. Signed-off-by: Mark McLoughlin <[email protected]>
1 parent b3dd9a5 commit 3f1ba30

File tree

6 files changed

+85
-49
lines changed

6 files changed

+85
-49
lines changed

tests/v1/engine/test_output_processor.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
from vllm.v1.engine import (
2323
EngineCoreEvent,
2424
EngineCoreEventType,
25+
EngineCoreOutputs,
2526
EngineCoreRequest,
2627
FinishReason,
2728
)
2829
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
29-
from vllm.v1.metrics.stats import IterationStats
30+
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
3031

3132

3233
def _ref_convert_id_to_token(
@@ -940,21 +941,26 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
940941
output_processor.add_request(request, None)
941942

942943
# First iteration: process outputs with QUEUED events
943-
outputs = engine_core.get_outputs()
944-
for output in outputs:
944+
outputs = EngineCoreOutputs(
945+
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
946+
)
947+
for output in outputs.outputs:
945948
output.events = [
946949
EngineCoreEvent.new_event(EngineCoreEventType.QUEUED, engine_core_timestamp)
947950
]
948951

949952
iteration_stats = IterationStats() if log_stats else None
950-
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
953+
output_processor.process_outputs(
954+
outputs.outputs, engine_core_timestamp, iteration_stats
955+
)
956+
output_processor.update_scheduler_stats(outputs.scheduler_stats)
951957

952958
if log_stats:
953959
# Verify waiting counts
954-
assert iteration_stats.waiting_lora_adapters.get("lora-1") == 1
955-
assert iteration_stats.waiting_lora_adapters.get("lora-2") == 1
956-
assert iteration_stats.running_lora_adapters.get("lora-1") == 0
957-
assert iteration_stats.running_lora_adapters.get("lora-2") == 0
960+
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 1
961+
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 1
962+
assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 0
963+
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 0
958964
# Verify internal state
959965
assert len(output_processor.lora_states.requests) == 2
960966
assert "lora-1" in output_processor.lora_states.requests
@@ -965,76 +971,96 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
965971
assert len(output_processor.lora_states.requests) == 0
966972

967973
# Second iteration: process outputs with SCHEDULED events
968-
outputs = engine_core.get_outputs()
969-
for output in outputs:
974+
outputs = EngineCoreOutputs(
975+
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
976+
)
977+
for output in outputs.outputs:
970978
output.events = [
971979
EngineCoreEvent.new_event(
972980
EngineCoreEventType.SCHEDULED, engine_core_timestamp
973981
)
974982
]
975983

976984
iteration_stats = IterationStats() if log_stats else None
977-
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
985+
output_processor.process_outputs(
986+
outputs.outputs, engine_core_timestamp, iteration_stats
987+
)
988+
output_processor.update_scheduler_stats(outputs.scheduler_stats)
978989

979990
if log_stats:
980991
# Verify running counts
981-
assert iteration_stats.waiting_lora_adapters.get("lora-1") == 0
982-
assert iteration_stats.waiting_lora_adapters.get("lora-2") == 0
983-
assert iteration_stats.running_lora_adapters.get("lora-1") == 1
984-
assert iteration_stats.running_lora_adapters.get("lora-2") == 1
992+
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 0
993+
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 0
994+
assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 1
995+
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
985996
else:
986997
assert iteration_stats is None
987998
assert len(output_processor.lora_states.requests) == 0
988999

9891000
# Third iteration: finish request-0 (lora-1)
990-
outputs = engine_core.get_outputs()
1001+
outputs = EngineCoreOutputs(
1002+
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
1003+
)
9911004
# Find and mark request-0 as finished (it uses lora-1)
992-
for output in outputs:
1005+
for output in outputs.outputs:
9931006
if output.request_id == "request-0":
9941007
output.finish_reason = FinishReason.LENGTH
9951008
break
9961009

9971010
iteration_stats = IterationStats() if log_stats else None
998-
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
1011+
output_processor.process_outputs(
1012+
outputs.outputs, engine_core_timestamp, iteration_stats
1013+
)
1014+
output_processor.update_scheduler_stats(outputs.scheduler_stats)
9991015

10001016
if log_stats:
10011017
# lora-1 should be removed since no requests remain
10021018
assert "lora-1" not in output_processor.lora_states.requests
10031019
# lora-2 should still be running
1004-
assert iteration_stats.running_lora_adapters.get("lora-2") == 1
1020+
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
10051021
assert len(output_processor.lora_states.requests) == 1
10061022
else:
10071023
assert len(output_processor.lora_states.requests) == 0
10081024

10091025
# Fourth iteration: finish request-1 (lora-2)
1010-
outputs = engine_core.get_outputs()
1026+
outputs = EngineCoreOutputs(
1027+
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
1028+
)
10111029
# Find and mark request-1 as finished (it uses lora-2)
1012-
for output in outputs:
1030+
for output in outputs.outputs:
10131031
if output.request_id == "request-1":
10141032
output.finish_reason = FinishReason.LENGTH
10151033
break
10161034

10171035
iteration_stats = IterationStats() if log_stats else None
1018-
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
1036+
output_processor.process_outputs(
1037+
outputs.outputs, engine_core_timestamp, iteration_stats
1038+
)
1039+
output_processor.update_scheduler_stats(outputs.scheduler_stats)
10191040

10201041
if log_stats:
10211042
# lora-2 should be removed since no requests remain
10221043
assert "lora-2" not in output_processor.lora_states.requests
1023-
assert len(iteration_stats.running_lora_adapters) == 0
1044+
assert len(outputs.scheduler_stats.running_lora_adapters) == 0
10241045
assert len(output_processor.lora_states.requests) == 0
10251046
else:
10261047
assert len(output_processor.lora_states.requests) == 0
10271048

10281049
# Finish the last request (no LoRA)
1029-
outputs = engine_core.get_outputs()
1050+
outputs = EngineCoreOutputs(
1051+
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
1052+
)
10301053
# Find and mark request-2 as finished (it has no LoRA)
1031-
for output in outputs:
1054+
for output in outputs.outputs:
10321055
if output.request_id == "request-2":
10331056
output.finish_reason = FinishReason.LENGTH
10341057
break
10351058

10361059
iteration_stats = IterationStats() if log_stats else None
1037-
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
1060+
output_processor.process_outputs(
1061+
outputs.outputs, engine_core_timestamp, iteration_stats
1062+
)
1063+
output_processor.update_scheduler_stats(outputs.scheduler_stats)
10381064

10391065
# Verify all requests are finished
10401066
assert output_processor.get_num_unfinished_requests() == 0

vllm/v1/engine/async_llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,8 @@ async def output_handler():
519519
processed_outputs.reqs_to_abort
520520
)
521521

522+
output_processor.update_scheduler_stats(outputs.scheduler_stats)
523+
522524
# 4) Logging.
523525
# TODO(rob): make into a coroutine and launch it in
524526
# background thread once Prometheus overhead is non-trivial.

vllm/v1/engine/llm_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def step(self) -> list[RequestOutput] | list[PoolingRequestOutput]:
301301
engine_core_timestamp=outputs.timestamp,
302302
iteration_stats=iteration_stats,
303303
)
304+
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
304305

305306
# 3) Abort any reqs that finished due to stop strings.
306307
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)

vllm/v1/engine/output_processor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
2323
from vllm.v1.engine.logprobs import LogprobsProcessor
2424
from vllm.v1.engine.parallel_sampling import ParentRequest
25-
from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats
25+
from vllm.v1.metrics.stats import (
26+
IterationStats,
27+
LoRARequestStates,
28+
RequestStateStats,
29+
SchedulerStats,
30+
)
2631

2732

2833
class RequestOutputCollector:
@@ -477,13 +482,15 @@ def process_outputs(
477482
)
478483
if self.tracer:
479484
self.do_tracing(engine_core_output, req_state, iteration_stats)
480-
self.lora_states.update_iteration_stats(iteration_stats)
481485

482486
return OutputProcessorOutput(
483487
request_outputs=request_outputs,
484488
reqs_to_abort=reqs_to_abort,
485489
)
486490

491+
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
492+
self.lora_states.update_scheduler_stats(scheduler_stats)
493+
487494
def do_tracing(
488495
self,
489496
engine_core_output: EngineCoreOutput,

vllm/v1/metrics/loggers.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,20 @@ def record(
974974
scheduler_stats.kv_connector_stats, engine_idx
975975
)
976976

977+
if self.gauge_lora_info is not None:
978+
running_lora_adapters = ",".join(
979+
scheduler_stats.running_lora_adapters.keys()
980+
)
981+
waiting_lora_adapters = ",".join(
982+
scheduler_stats.waiting_lora_adapters.keys()
983+
)
984+
lora_info_labels = {
985+
self.labelname_running_lora_adapters: running_lora_adapters,
986+
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
987+
self.labelname_max_lora: self.max_lora,
988+
}
989+
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
990+
977991
if mm_cache_stats is not None:
978992
self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries)
979993
self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits)
@@ -1037,20 +1051,6 @@ def record(
10371051
finished_request.max_tokens_param
10381052
)
10391053

1040-
if self.gauge_lora_info is not None:
1041-
running_lora_adapters = ",".join(
1042-
iteration_stats.running_lora_adapters.keys()
1043-
)
1044-
waiting_lora_adapters = ",".join(
1045-
iteration_stats.waiting_lora_adapters.keys()
1046-
)
1047-
lora_info_labels = {
1048-
self.labelname_running_lora_adapters: running_lora_adapters,
1049-
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
1050-
self.labelname_max_lora: self.max_lora,
1051-
}
1052-
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
1053-
10541054
def record_sleep_state(self, sleep: int = 0, level: int = 0):
10551055
awake = 1
10561056
discard_all = 0

vllm/v1/metrics/stats.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ class SchedulerStats:
170170

171171
num_corrupted_reqs: int = 0
172172

173+
waiting_lora_adapters: dict[str, int] = field(default_factory=dict)
174+
running_lora_adapters: dict[str, int] = field(default_factory=dict)
175+
173176

174177
@dataclass
175178
class RequestStateStats:
@@ -219,8 +222,6 @@ def __init__(self):
219222
self.n_params_iter: list[int] = []
220223
self.time_to_first_tokens_iter: list[float] = []
221224
self.inter_token_latencies_iter: list[float] = []
222-
self.waiting_lora_adapters: dict[str, int] = {}
223-
self.running_lora_adapters: dict[str, int] = {}
224225

225226
def __repr__(self) -> str:
226227
field_to_value_str = ", ".join(f"{k}={v}" for k, v in vars(self).items())
@@ -393,10 +394,9 @@ def request_running(self, req_id: str, lora_name: str | None):
393394
def request_finished(self, req_id: str, lora_name: str | None):
394395
self._request_update(req_id, lora_name, waiting=False, running=False)
395396

396-
def update_iteration_stats(self, iteration_stats: IterationStats | None):
397-
if not self.log_stats:
397+
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
398+
if not self.log_stats or scheduler_stats is None:
398399
return
399-
assert iteration_stats is not None
400400
for lora_name, stats in self.requests.items():
401-
iteration_stats.waiting_lora_adapters[lora_name] = len(stats.waiting)
402-
iteration_stats.running_lora_adapters[lora_name] = len(stats.running)
401+
scheduler_stats.waiting_lora_adapters[lora_name] = len(stats.waiting)
402+
scheduler_stats.running_lora_adapters[lora_name] = len(stats.running)

0 commit comments

Comments
 (0)