Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 173 additions & 2 deletions tests/v1/engine/test_output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,19 @@
)
from vllm import PoolingParams
from vllm.logprobs import PromptLogprobs, SampleLogprobs
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine import (
EngineCoreEvent,
EngineCoreEventType,
EngineCoreOutputs,
EngineCoreRequest,
FinishReason,
)
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
from vllm.v1.metrics.stats import IterationStats
from vllm.v1.metrics.stats import IterationStats, SchedulerStats


def _ref_convert_id_to_token(
Expand Down Expand Up @@ -895,6 +902,170 @@ def test_iteration_stats(dummy_test_vectors):
assert iteration_stats.num_generation_tokens == num_active


@pytest.mark.parametrize("log_stats", [True, False])
def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
"""Test LoRA request lifecycle tracking through waiting -> running -> finished."""
output_processor = OutputProcessor(
dummy_test_vectors.tokenizer, log_stats=log_stats
)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
engine_core_timestamp = time.monotonic()

# Create LoRA requests
lora1 = LoRARequest(lora_name="lora-1", lora_int_id=1, lora_path="/path/to/lora1")
lora2 = LoRARequest(lora_name="lora-2", lora_int_id=2, lora_path="/path/to/lora2")

# Create requests with different LoRA adapters:
# - request-0: lora-1
# - request-1: lora-2
# - request-2: None (no LoRA)
lora_assignments = [lora1, lora2, None]
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
prompt_token_ids=prompt_tokens,
mm_features=None,
eos_token_id=None,
arrival_time=0,
lora_request=lora_assignments[idx],
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(),
pooling_params=None,
)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]

# Add all requests to the OutputProcessor
for request in requests:
output_processor.add_request(request, None)

# First iteration: process outputs with QUEUED events
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
for output in outputs.outputs:
output.events = [
EngineCoreEvent.new_event(EngineCoreEventType.QUEUED, engine_core_timestamp)
]

iteration_stats = IterationStats() if log_stats else None
output_processor.process_outputs(
outputs.outputs, engine_core_timestamp, iteration_stats
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)

if log_stats:
# Verify waiting counts
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 1
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 1
assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 0
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 0
# Verify internal state
assert len(output_processor.lora_states.requests) == 2
assert "lora-1" in output_processor.lora_states.requests
assert "lora-2" in output_processor.lora_states.requests
else:
# When log_stats=False, no tracking should occur
assert iteration_stats is None
assert len(output_processor.lora_states.requests) == 0

# Second iteration: process outputs with SCHEDULED events
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
for output in outputs.outputs:
output.events = [
EngineCoreEvent.new_event(
EngineCoreEventType.SCHEDULED, engine_core_timestamp
)
]

iteration_stats = IterationStats() if log_stats else None
output_processor.process_outputs(
outputs.outputs, engine_core_timestamp, iteration_stats
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)

if log_stats:
# Verify running counts
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 0
assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 0
assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 1
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
else:
assert iteration_stats is None
assert len(output_processor.lora_states.requests) == 0

# Third iteration: finish request-0 (lora-1)
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
# Find and mark request-0 as finished (it uses lora-1)
for output in outputs.outputs:
if output.request_id == "request-0":
output.finish_reason = FinishReason.LENGTH
break

iteration_stats = IterationStats() if log_stats else None
output_processor.process_outputs(
outputs.outputs, engine_core_timestamp, iteration_stats
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)

if log_stats:
# lora-1 should be removed since no requests remain
assert "lora-1" not in output_processor.lora_states.requests
# lora-2 should still be running
assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
assert len(output_processor.lora_states.requests) == 1
else:
assert len(output_processor.lora_states.requests) == 0

# Fourth iteration: finish request-1 (lora-2)
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
# Find and mark request-1 as finished (it uses lora-2)
for output in outputs.outputs:
if output.request_id == "request-1":
output.finish_reason = FinishReason.LENGTH
break

iteration_stats = IterationStats() if log_stats else None
output_processor.process_outputs(
outputs.outputs, engine_core_timestamp, iteration_stats
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)

if log_stats:
# lora-2 should be removed since no requests remain
assert "lora-2" not in output_processor.lora_states.requests
assert len(outputs.scheduler_stats.running_lora_adapters) == 0
assert len(output_processor.lora_states.requests) == 0
else:
assert len(output_processor.lora_states.requests) == 0

# Finish the last request (no LoRA)
outputs = EngineCoreOutputs(
outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
)
# Find and mark request-2 as finished (it has no LoRA)
for output in outputs.outputs:
if output.request_id == "request-2":
output.finish_reason = FinishReason.LENGTH
break

iteration_stats = IterationStats() if log_stats else None
output_processor.process_outputs(
outputs.outputs, engine_core_timestamp, iteration_stats
)
output_processor.update_scheduler_stats(outputs.scheduler_stats)

# Verify all requests are finished
assert output_processor.get_num_unfinished_requests() == 0


@pytest.mark.asyncio
async def test_request_output_collector():
NUM_REQS = 3
Expand Down
18 changes: 1 addition & 17 deletions tests/v1/metrics/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,4 @@

def test_iteration_stats_repr():
iteration_stats = IterationStats()
iteration_stats.iteration_timestamp = 0
expected_repr = (
"IterationStats("
"iteration_timestamp=0, "
"num_generation_tokens=0, "
"num_prompt_tokens=0, "
"num_preempted_reqs=0, "
"finished_requests=[], "
"max_num_generation_tokens_iter=[], "
"n_params_iter=[], "
"time_to_first_tokens_iter=[], "
"inter_token_latencies_iter=[], "
"waiting_lora_adapters={}, "
"running_lora_adapters={}, "
"num_corrupted_reqs=0)"
)
assert repr(iteration_stats) == expected_repr
assert repr(iteration_stats).startswith("IterationStats(")
2 changes: 2 additions & 0 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,8 @@ async def output_handler():
processed_outputs.reqs_to_abort
)

output_processor.update_scheduler_stats(outputs.scheduler_stats)

# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def step(self) -> list[RequestOutput | PoolingRequestOutput]:
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats,
)
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)

# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
Expand Down
23 changes: 14 additions & 9 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats
from vllm.v1.metrics.stats import (
IterationStats,
LoRARequestStates,
RequestStateStats,
SchedulerStats,
)


class RequestOutputCollector:
Expand Down Expand Up @@ -310,7 +315,7 @@ def __init__(self, tokenizer: AnyTokenizer, log_stats: bool):
self.tokenizer = tokenizer
self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates()
self.lora_states = LoRARequestStates(log_stats)
self.tracer: Tracer | None = None

def get_num_unfinished_requests(self):
Expand All @@ -334,7 +339,7 @@ def abort_requests(
for request_id in request_ids:
req_state = self.request_states.pop(request_id, None)
if req_state is not None:
self.lora_states.abort_request(req_state)
self.lora_states.request_finished(request_id, req_state.lora_name)
request_ids_to_abort.append(request_id)
# Produce final abort output.
if req_state.queue is not None and (
Expand Down Expand Up @@ -382,7 +387,6 @@ def add_request(
log_stats=self.log_stats,
)
self.request_states[request_id] = req_state
self.lora_states.add_request(req_state)
if parent_req:
self.parent_requests[parent_req.request_id] = parent_req

Expand Down Expand Up @@ -484,13 +488,15 @@ def process_outputs(
)
if self.tracer:
self.do_tracing(engine_core_output, req_state, iteration_stats)
self.lora_states.update_iteration_stats(iteration_stats)

return OutputProcessorOutput(
request_outputs=request_outputs,
reqs_to_abort=reqs_to_abort,
)

def update_scheduler_stats(self, scheduler_stats: SchedulerStats | None):
self.lora_states.update_scheduler_stats(scheduler_stats)

def do_tracing(
self,
engine_core_output: EngineCoreOutput,
Expand Down Expand Up @@ -564,8 +570,6 @@ def _update_stats_from_output(
if iteration_stats is None:
return

lora_stats = self.lora_states.get_stats(req_state)

assert engine_core_timestamp is not None
assert req_state.stats is not None
iteration_stats.update_from_output(
Expand All @@ -574,7 +578,8 @@ def _update_stats_from_output(
req_state.is_prefilling,
req_state.prompt_len,
req_state.stats,
lora_stats,
self.lora_states,
req_state.lora_name,
)

def _update_stats_from_finished(
Expand All @@ -596,7 +601,7 @@ def _update_stats_from_finished(
max_tokens_param=req_state.max_tokens_param,
req_stats=req_state.stats,
)
self.lora_states.finish_request(req_state)
self.lora_states.request_finished(req_state.request_id, req_state.lora_name)

ParentRequest.observe_finished_request(
req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens
Expand Down
28 changes: 14 additions & 14 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,20 @@ def record(
scheduler_stats.kv_connector_stats, engine_idx
)

if self.gauge_lora_info is not None:
running_lora_adapters = ",".join(
scheduler_stats.running_lora_adapters.keys()
)
waiting_lora_adapters = ",".join(
scheduler_stats.waiting_lora_adapters.keys()
)
lora_info_labels = {
self.labelname_running_lora_adapters: running_lora_adapters,
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
self.labelname_max_lora: self.max_lora,
}
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()

if mm_cache_stats is not None:
self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries)
self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits)
Expand Down Expand Up @@ -1055,20 +1069,6 @@ def record(
finished_request.max_tokens_param
)

if self.gauge_lora_info is not None:
running_lora_adapters = ",".join(
iteration_stats.running_lora_adapters.keys()
)
waiting_lora_adapters = ",".join(
iteration_stats.waiting_lora_adapters.keys()
)
lora_info_labels = {
self.labelname_running_lora_adapters: running_lora_adapters,
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
self.labelname_max_lora: self.max_lora,
}
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()

def record_sleep_state(self, sleep: int = 0, level: int = 0):
awake = 1
discard_all = 0
Expand Down
Loading