Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm/config/observability.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ def show_hidden_metrics(self) -> bool:
enable_mfu_metrics: bool = False
"""Enable Model FLOPs Utilization (MFU) metrics."""

enable_logging_iteration_details: bool = False
"""Enable detailed logging of iteration details.
If set, vllm EngineCore will log iteration details
This includes number of context/generation requests and tokens
and the elapsed cpu time for the iteration."""

@cached_property
def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request."""
Expand Down
8 changes: 8 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,9 @@ class EngineArgs:
ObservabilityConfig.enable_layerwise_nvtx_tracing
)
enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics
enable_logging_iteration_details: bool = (
ObservabilityConfig.enable_logging_iteration_details
)
enable_mm_processor_stats: bool = ObservabilityConfig.enable_mm_processor_stats
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
Expand Down Expand Up @@ -1054,6 +1057,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--enable-mfu-metrics",
**observability_kwargs["enable_mfu_metrics"],
)
observability_group.add_argument(
"--enable-logging-iteration-details",
**observability_kwargs["enable_logging_iteration_details"],
)

# Scheduler arguments
scheduler_kwargs = get_kwargs(SchedulerConfig)
Expand Down Expand Up @@ -1707,6 +1714,7 @@ def create_engine_config(
enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
enable_mfu_metrics=self.enable_mfu_metrics,
enable_mm_processor_stats=self.enable_mm_processor_stats,
enable_logging_iteration_details=self.enable_logging_iteration_details,
)

# Compilation config overrides
Expand Down
15 changes: 15 additions & 0 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING

from vllm._bc_linter import bc_linter_include
Expand Down Expand Up @@ -151,6 +152,20 @@ def __repr__(self) -> str:
def num_reqs(self) -> int:
return len(self.req_ids)

@cached_property
def _req_id_to_num_output_tokens(self) -> dict[str, int]:
"""Cache mapping of req_id to num_output_tokens for O(1) lookup.

This cached property is safe because CachedRequestData instances
are created fresh each scheduling iteration and not mutated during
computation of iteration details.
"""
return dict(zip(self.req_ids, self.num_output_tokens))

def is_context_phase(self, req_id: str) -> bool:
num_output_tokens = self._req_id_to_num_output_tokens.get(req_id)
return num_output_tokens is not None and num_output_tokens == 0

@classmethod
def make_empty(cls) -> "CachedRequestData":
return cls(
Expand Down
42 changes: 39 additions & 3 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import compute_iteration_details
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -208,7 +209,6 @@ def __init__(
self.async_scheduling = vllm_config.scheduler_config.async_scheduling

self.aborts_queue = queue.Queue[list[str]]()

# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
freeze_gc_heap()
Expand Down Expand Up @@ -337,6 +337,36 @@ def log_error_detail(self, scheduler_output: SchedulerOutput):
)
raise err

@contextmanager
def log_iteration_details(self, scheduler_output: SchedulerOutput):
if not self.vllm_config.observability_config.enable_logging_iteration_details:
yield
return
self._iteration_index = getattr(self, "_iteration_index", 0)
iteration_details = compute_iteration_details(scheduler_output)
before = time.monotonic()
yield
logger.info(
"".join(
[
"Iteration(",
str(self._iteration_index),
"): ",
str(iteration_details.num_ctx_requests),
" context requests, ",
str(iteration_details.num_ctx_tokens),
" context tokens, ",
str(iteration_details.num_generation_requests),
" generation requests, ",
str(iteration_details.num_generation_tokens),
" generation tokens, iteration elapsed time: ",
format((time.monotonic() - before) * 1000, ".2f"),
" ms",
]
)
)
self._iteration_index += 1

def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output.

Expand All @@ -351,7 +381,10 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output):
with (
self.log_error_detail(scheduler_output),
self.log_iteration_details(scheduler_output),
):
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
Expand Down Expand Up @@ -447,7 +480,10 @@ def step_with_batch_queue(

# Block until the next result is available.
future, scheduler_output, exec_model_fut = batch_queue.pop()
with self.log_error_detail(scheduler_output):
with (
self.log_error_detail(scheduler_output),
self.log_iteration_details(scheduler_output),
):
model_output = future.result()
if model_output is None:
# None from sample_tokens() implies that the original execute_model()
Expand Down
52 changes: 52 additions & 0 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import weakref
from collections.abc import Callable, Sequence
from contextlib import AbstractContextManager
from dataclasses import dataclass
from multiprocessing import connection
from multiprocessing.process import BaseProcess
from typing import (
Expand All @@ -27,6 +28,7 @@
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message
from vllm.utils.network_utils import get_open_port, get_open_zmq_ipc_path, get_tcp_uri
from vllm.utils.system_utils import kill_process_tree
from vllm.v1.core.sched.output import SchedulerOutput

if TYPE_CHECKING:
import numpy as np
Expand Down Expand Up @@ -412,3 +414,53 @@ def tensor_data(tensor: torch.Tensor) -> memoryview:
A memoryview of the tensor data as uint8.
"""
return tensor.flatten().contiguous().view(torch.uint8).numpy().data


@dataclass
class IterationDetails:
num_ctx_requests: int
num_ctx_tokens: int
num_generation_requests: int
num_generation_tokens: int

def __repr__(self) -> str:
return f"IterationDetails(num_ctx_requests={self.num_ctx_requests},\
num_ctx_tokens={self.num_ctx_tokens}, \
num_generation_requests={self.num_generation_requests}, \
num_generation_tokens={self.num_generation_tokens})"


def compute_iteration_details(scheduler_output: SchedulerOutput) -> IterationDetails:
"""
Compute the number of context/generation requests and tokens
for the current iteration's scheduler output. A requests is regarded
as a context request if its output tokens are still 0, an extended chunk
of chunked prefill falls into this category.

Args:
scheduler_output: The scheduler output for the current iteration.

Returns:
An IterationDetails object containing the number of
context/generation requests and tokens.
"""
num_context_requests = 0
num_context_tokens = 0
num_generation_requests = 0
num_generation_tokens = 0
new_req_ids = {new_req.req_id for new_req in scheduler_output.scheduled_new_reqs}
for req_id, num_tokens in scheduler_output.num_scheduled_tokens.items():
if scheduler_output.scheduled_cached_reqs.is_context_phase(req_id) or (
req_id in new_req_ids
):
num_context_requests += 1
num_context_tokens += num_tokens
else:
num_generation_requests += 1
num_generation_tokens += num_tokens
return IterationDetails(
num_context_requests,
num_context_tokens,
num_generation_requests,
num_generation_tokens,
)
25 changes: 18 additions & 7 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
DraftTokenIds,
ModelRunnerOutput,
)
from vllm.v1.utils import report_usage_stats
from vllm.v1.utils import compute_iteration_details, report_usage_stats
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager
Expand Down Expand Up @@ -545,18 +545,29 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:

def annotate_profile(self, scheduler_output):
# add trace annotation so that we can easily distinguish
# new/cached request numbers in each iteration
# context/generation request numbers in each iteration.
# A context request is a request that has not yet generated any tokens
if not self.profiler:
return nullcontext()

self.profiler.step()

num_new = len(scheduler_output.scheduled_new_reqs)
num_cached = len(scheduler_output.scheduled_cached_reqs.req_ids)

return self.profiler.annotate_context_manager(
f"execute_new_{num_new}_cached_{num_cached}"
iteration_details = compute_iteration_details(scheduler_output)

annotation = "".join(
[
"execute_context_",
str(iteration_details.num_ctx_requests),
"(",
str(iteration_details.num_ctx_tokens),
")_generation_",
str(iteration_details.num_generation_requests),
"(",
str(iteration_details.num_generation_tokens),
")",
]
)
return self.profiler.annotate_context_manager(annotation)

@torch.inference_mode()
def sample_tokens(
Expand Down