Skip to content

Commit 6f7de33

Browse files
authored
[Metrics] Refactor LoRA state tracking (#26801)
Signed-off-by: Mark McLoughlin <[email protected]>
1 parent a98cc35 commit 6f7de33

File tree

7 files changed

+268
-106
lines changed

7 files changed

+268
-106
lines changed

tests/v1/engine/test_output_processor.py

Lines changed: 173 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
)
1616
from vllm import PoolingParams
1717
from vllm.logprobs import PromptLogprobs, SampleLogprobs
18+
from vllm.lora.request import LoRARequest
1819
from vllm.outputs import CompletionOutput, RequestOutput
1920
from vllm.sampling_params import RequestOutputKind, SamplingParams
2021
from vllm.transformers_utils.tokenizer import AnyTokenizer
21-
from vllm.v1.engine import EngineCoreRequest
22+
from vllm.v1.engine import (
23+
EngineCoreEvent,
24+
EngineCoreEventType,
25+
EngineCoreOutputs,
26+
EngineCoreRequest,
27+
FinishReason,
28+
)
2229
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
23-
from vllm.v1.metrics.stats import IterationStats
30+
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
2431

2532

2633
def _ref_convert_id_to_token(
@@ -895,6 +902,170 @@ def test_iteration_stats(dummy_test_vectors):
895902
assert iteration_stats.num_generation_tokens == num_active
896903

897904

905+
@pytest.mark.parametrize("log_stats", [True, False])
906+
def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
907+
"""Test LoRA request lifecycle tracking through waiting -> running -> finished."""
908+
output_processor = OutputProcessor(
909+
dummy_test_vectors.tokenizer, log_stats=log_stats
910+
)
911+
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
912+
engine_core_timestamp = time.monotonic()
913+
914+
# Create LoRA requests
915+
lora1 = LoRARequest(lora_name="lora-1", lora_int_id=1, lora_path="/path/to/lora1")
916+
lora2 = LoRARequest(lora_name="lora-2", lora_int_id=2, lora_path="/path/to/lora2")
917+
918+
# Create requests with different LoRA adapters:
919+
# - request-0: lora-1
920+
# - request-1: lora-2
921+
# - request-2: None (no LoRA)
922+
lora_assignments = [lora1, lora2, None]
923+
requests = [
924+
EngineCoreRequest(
925+
request_id=f"request-{idx}",
926+
prompt_token_ids=prompt_tokens,
927+
mm_features=None,
928+
eos_token_id=None,
929+
arrival_time=0,
930+
lora_request=lora_assignments[idx],
931+
cache_salt=None,
932+
data_parallel_rank=None,
933+
sampling_params=SamplingParams(),
934+
pooling_params=None,
935+
)
936+
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
937+
]
938+
939+
# Add all requests to the OutputProcessor
940+
for request in requests:
941+
output_processor.add_request(request, None)
942+
943+
# First iteration: process outputs with QUEUED events
944+
outputs = EngineCoreOutputs(
945+
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
946+
)
947+
for output in outputs.outputs:
948+
output.events = [
949+
EngineCoreEvent.new_event(EngineCoreEventType.QUEUED, engine_core_timestamp)
950+
]
951+
952+
iteration_stats = IterationStats() if log_stats else None
953+
output_processor.process_outputs(
954+
outputs.outputs, engine_core_timestamp, iteration_stats
955+
)
956+
output_processor.update_scheduler_stats(outputs.scheduler_stats)
957+
958+
if log_stats:
959+
# Verify waiting counts
960+
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 1
961+
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 1
962+
assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 0
963+
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 0
964+
# Verify internal state
965+
assert len(output_processor.lora_states.requests) == 2
966+
assert "lora-1" in output_processor.lora_states.requests
967+
assert "lora-2" in output_processor.lora_states.requests
968+
else:
969+
# When log_stats=False, no tracking should occur
970+
assert iteration_stats is None
971+
assert len(output_processor.lora_states.requests) == 0
972+
973+
# Second iteration: process outputs with SCHEDULED events
974+
outputs = EngineCoreOutputs(
975+
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
976+
)
977+
for output in outputs.outputs:
978+
output.events = [
979+
EngineCoreEvent.new_event(
980+
EngineCoreEventType.SCHEDULED, engine_core_timestamp
981+
)
982+
]
983+
984+
iteration_stats = IterationStats() if log_stats else None
985+
output_processor.process_outputs(
986+
outputs.outputs, engine_core_timestamp, iteration_stats
987+
)
988+
output_processor.update_scheduler_stats(outputs.scheduler_stats)
989+
990+
if log_stats:
991+
# Verify running counts
992+
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 0
993+
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 0
994+
assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 1
995+
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
996+
else:
997+
assert iteration_stats is None
998+
assert len(output_processor.lora_states.requests) == 0
999+
1000+
# Third iteration: finish request-0 (lora-1)
1001+
outputs = EngineCoreOutputs(
1002+
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
1003+
)
1004+
# Find and mark request-0 as finished (it uses lora-1)
1005+
for output in outputs.outputs:
1006+
if output.request_id == "request-0":
1007+
output.finish_reason = FinishReason.LENGTH
1008+
break
1009+
1010+
iteration_stats = IterationStats() if log_stats else None
1011+
output_processor.process_outputs(
1012+
outputs.outputs, engine_core_timestamp, iteration_stats
1013+
)
1014+
output_processor.update_scheduler_stats(outputs.scheduler_stats)
1015+
1016+
if log_stats:
1017+
# lora-1 should be removed since no requests remain
1018+
assert "lora-1" not in output_processor.lora_states.requests
1019+
# lora-2 should still be running
1020+
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
1021+
assert len(output_processor.lora_states.requests) == 1
1022+
else:
1023+
assert len(output_processor.lora_states.requests) == 0
1024+
1025+
# Fourth iteration: finish request-1 (lora-2)
1026+
outputs = EngineCoreOutputs(
1027+
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
1028+
)
1029+
# Find and mark request-1 as finished (it uses lora-2)
1030+
for output in outputs.outputs:
1031+
if output.request_id == "request-1":
1032+
output.finish_reason = FinishReason.LENGTH
1033+
break
1034+
1035+
iteration_stats = IterationStats() if log_stats else None
1036+
output_processor.process_outputs(
1037+
outputs.outputs, engine_core_timestamp, iteration_stats
1038+
)
1039+
output_processor.update_scheduler_stats(outputs.scheduler_stats)
1040+
1041+
if log_stats:
1042+
# lora-2 should be removed since no requests remain
1043+
assert "lora-2" not in output_processor.lora_states.requests
1044+
assert len(outputs.scheduler_stats.running_lora_adapters) == 0
1045+
assert len(output_processor.lora_states.requests) == 0
1046+
else:
1047+
assert len(output_processor.lora_states.requests) == 0
1048+
1049+
# Finish the last request (no LoRA)
1050+
outputs = EngineCoreOutputs(
1051+
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
1052+
)
1053+
# Find and mark request-2 as finished (it has no LoRA)
1054+
for output in outputs.outputs:
1055+
if output.request_id == "request-2":
1056+
output.finish_reason = FinishReason.LENGTH
1057+
break
1058+
1059+
iteration_stats = IterationStats() if log_stats else None
1060+
output_processor.process_outputs(
1061+
outputs.outputs, engine_core_timestamp, iteration_stats
1062+
)
1063+
output_processor.update_scheduler_stats(outputs.scheduler_stats)
1064+
1065+
# Verify all requests are finished
1066+
assert output_processor.get_num_unfinished_requests() == 0
1067+
1068+
8981069
@pytest.mark.asyncio
8991070
async def test_request_output_collector():
9001071
NUM_REQS = 3

tests/v1/metrics/test_stats.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,4 @@
55

66
def test_iteration_stats_repr():
77
iteration_stats = IterationStats()
8-
iteration_stats.iteration_timestamp = 0
9-
expected_repr = (
10-
"IterationStats("
11-
"iteration_timestamp=0, "
12-
"num_generation_tokens=0, "
13-
"num_prompt_tokens=0, "
14-
"num_preempted_reqs=0, "
15-
"finished_requests=[], "
16-
"max_num_generation_tokens_iter=[], "
17-
"n_params_iter=[], "
18-
"time_to_first_tokens_iter=[], "
19-
"inter_token_latencies_iter=[], "
20-
"waiting_lora_adapters={}, "
21-
"running_lora_adapters={}, "
22-
"num_corrupted_reqs=0)"
23-
)
24-
assert repr(iteration_stats) == expected_repr
8+
assert repr(iteration_stats).startswith("IterationStats(")

vllm/v1/engine/async_llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,8 @@ async def output_handler():
508508
processed_outputs.reqs_to_abort
509509
)
510510

511+
output_processor.update_scheduler_stats(outputs.scheduler_stats)
512+
511513
# 4) Logging.
512514
# TODO(rob): make into a coroutine and launch it in
513515
# background thread once Prometheus overhead is non-trivial.

vllm/v1/engine/llm_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def step(self) -> list[RequestOutput | PoolingRequestOutput]:
289289
engine_core_timestamp=outputs.timestamp,
290290
iteration_stats=iteration_stats,
291291
)
292+
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
292293

293294
# 3) Abort any reqs that finished due to stop strings.
294295
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)

vllm/v1/engine/output_processor.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@
2222
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
2323
from vllm.v1.engine.logprobs import LogprobsProcessor
2424
from vllm.v1.engine.parallel_sampling import ParentRequest
25-
from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats
25+
from vllm.v1.metrics.stats import (
26+
IterationStats,
27+
LoRARequestStates,
28+
RequestStateStats,
29+
SchedulerStats,
30+
)
2631

2732

2833
class RequestOutputCollector:
@@ -310,7 +315,7 @@ def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
310315
self.tokenizer = tokenizer
311316
self.request_states: dict[str, RequestState] = {}
312317
self.parent_requests: dict[str, ParentRequest] = {}
313-
self.lora_states = LoRARequestStates()
318+
self.lora_states = LoRARequestStates(log_stats)
314319
self.tracer: Tracer | None = None
315320

316321
def get_num_unfinished_requests(self):
@@ -334,7 +339,7 @@ def abort_requests(
334339
for request_id in request_ids:
335340
req_state = self.request_states.pop(request_id, None)
336341
if req_state is not None:
337-
self.lora_states.abort_request(req_state)
342+
self.lora_states.request_finished(request_id, req_state.lora_name)
338343
request_ids_to_abort.append(request_id)
339344
# Produce final abort output.
340345
if req_state.queue is not None and (
@@ -382,7 +387,6 @@ def add_request(
382387
log_stats=self.log_stats,
383388
)
384389
self.request_states[request_id] = req_state
385-
self.lora_states.add_request(req_state)
386390
if parent_req:
387391
self.parent_requests[parent_req.request_id] = parent_req
388392

@@ -484,13 +488,15 @@ def process_outputs(
484488
)
485489
if self.tracer:
486490
self.do_tracing(engine_core_output, req_state, iteration_stats)
487-
self.lora_states.update_iteration_stats(iteration_stats)
488491

489492
return OutputProcessorOutput(
490493
request_outputs=request_outputs,
491494
reqs_to_abort=reqs_to_abort,
492495
)
493496

497+
def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
498+
self.lora_states.update_scheduler_stats(scheduler_stats)
499+
494500
def do_tracing(
495501
self,
496502
engine_core_output: EngineCoreOutput,
@@ -564,8 +570,6 @@ def _update_stats_from_output(
564570
if iteration_stats is None:
565571
return
566572

567-
lora_stats = self.lora_states.get_stats(req_state)
568-
569573
assert engine_core_timestamp is not None
570574
assert req_state.stats is not None
571575
iteration_stats.update_from_output(
@@ -574,7 +578,8 @@ def _update_stats_from_output(
574578
req_state.is_prefilling,
575579
req_state.prompt_len,
576580
req_state.stats,
577-
lora_stats,
581+
self.lora_states,
582+
req_state.lora_name,
578583
)
579584

580585
def _update_stats_from_finished(
@@ -596,7 +601,7 @@ def _update_stats_from_finished(
596601
max_tokens_param=req_state.max_tokens_param,
597602
req_stats=req_state.stats,
598603
)
599-
self.lora_states.finish_request(req_state)
604+
self.lora_states.request_finished(req_state.request_id, req_state.lora_name)
600605

601606
ParentRequest.observe_finished_request(
602607
req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens

vllm/v1/metrics/loggers.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,20 @@ def record(
989989
scheduler_stats.kv_connector_stats, engine_idx
990990
)
991991

992+
if self.gauge_lora_info is not None:
993+
running_lora_adapters = ",".join(
994+
scheduler_stats.running_lora_adapters.keys()
995+
)
996+
waiting_lora_adapters = ",".join(
997+
scheduler_stats.waiting_lora_adapters.keys()
998+
)
999+
lora_info_labels = {
1000+
self.labelname_running_lora_adapters: running_lora_adapters,
1001+
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
1002+
self.labelname_max_lora: self.max_lora,
1003+
}
1004+
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
1005+
9921006
if mm_cache_stats is not None:
9931007
self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries)
9941008
self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits)
@@ -1055,20 +1069,6 @@ def record(
10551069
finished_request.max_tokens_param
10561070
)
10571071

1058-
if self.gauge_lora_info is not None:
1059-
running_lora_adapters = ",".join(
1060-
iteration_stats.running_lora_adapters.keys()
1061-
)
1062-
waiting_lora_adapters = ",".join(
1063-
iteration_stats.waiting_lora_adapters.keys()
1064-
)
1065-
lora_info_labels = {
1066-
self.labelname_running_lora_adapters: running_lora_adapters,
1067-
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
1068-
self.labelname_max_lora: self.max_lora,
1069-
}
1070-
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
1071-
10721072
def record_sleep_state(self, sleep: int = 0, level: int = 0):
10731073
awake = 1
10741074
discard_all = 0

0 commit comments

Comments
 (0)