Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/v1/metrics/test_engine_logger_apis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest

from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM
from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger

DEFAULT_ENGINE_ARGS = AsyncEngineArgs(
model="distilbert/distilgpt2",
dtype="half",
disable_log_stats=False,
enforce_eager=True,
)


@pytest.mark.asyncio
async def test_async_llm_replace_default_loggers():
# Empty stat_loggers removes default loggers
engine = AsyncLLM.from_engine_args(DEFAULT_ENGINE_ARGS, stat_loggers=[])
await engine.add_logger(RayPrometheusStatLogger)

# Verify that only this logger is present in shared loggers
assert len(engine.logger_manager.shared_loggers) == 1
assert isinstance(engine.logger_manager.shared_loggers[0],
RayPrometheusStatLogger)


@pytest.mark.asyncio
async def test_async_llm_add_to_default_loggers():
# Start with default loggers, including PrometheusStatLogger
engine = AsyncLLM.from_engine_args(DEFAULT_ENGINE_ARGS)

# Add another PrometheusStatLogger subclass
await engine.add_logger(RayPrometheusStatLogger)

assert len(engine.logger_manager.shared_loggers) == 2
23 changes: 19 additions & 4 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.loggers import (DpSharedStatLoggerFactory,
StatLoggerFactory, StatLoggerManager)
from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats

Expand All @@ -55,7 +56,8 @@ def __init__(
use_cached_outputs: bool = False,
log_requests: bool = True,
start_engine_loop: bool = True,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
stat_loggers: Optional[list[Union[StatLoggerFactory,
DpSharedStatLoggerFactory]]] = None,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
) -> None:
Expand Down Expand Up @@ -144,7 +146,8 @@ def from_vllm_config(
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
stat_loggers: Optional[list[Union[StatLoggerFactory,
DpSharedStatLoggerFactory]]] = None,
disable_log_requests: bool = False,
disable_log_stats: bool = False,
client_addresses: Optional[dict[str, str]] = None,
Expand Down Expand Up @@ -176,7 +179,8 @@ def from_engine_args(
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[list[StatLoggerFactory]] = None,
stat_loggers: Optional[list[Union[StatLoggerFactory,
DpSharedStatLoggerFactory]]] = None,
) -> "AsyncLLM":
"""Create an AsyncLLM from the EngineArgs."""

Expand Down Expand Up @@ -596,6 +600,17 @@ async def collective_rpc(self,
return await self.engine_core.collective_rpc_async(
method, timeout, args, kwargs)

async def add_logger(
self, logger_factory: Union[StatLoggerFactory,
DpSharedStatLoggerFactory]
) -> None:
if self.logger_manager is None:
raise RuntimeError(
"Stat logging is disabled. Set `disable_log_stats=False` "
"engine argument to enable.")

self.logger_manager.add_logger(logger_factory)

async def wait_for_requests_to_drain(self, drain_timeout: int = 300):
"""Wait for all requests to be drained."""
start_time = time.time()
Expand Down
71 changes: 52 additions & 19 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
logger = init_logger(__name__)

StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
DpSharedStatLoggerFactory = Callable[[VllmConfig, Optional[list[int]]],
"PrometheusStatLogger"]


class StatLoggerBase(ABC):
Expand Down Expand Up @@ -633,37 +635,67 @@ def __init__(
self,
vllm_config: VllmConfig,
engine_idxs: Optional[list[int]] = None,
custom_stat_loggers: Optional[list[StatLoggerFactory]] = None,
custom_stat_loggers: Optional[list[Union[
StatLoggerFactory, DpSharedStatLoggerFactory]]] = None,
):
"""
Initializes the StatLoggerManager.

Args:
vllm_config (VllmConfig): The configuration object for vLLM.
engine_idxs (Optional[list[int]]): List of engine indices. If None,
defaults to [0].
custom_stat_loggers (Optional[list[Union[
StatLoggerFactory, DpSharedStatLoggerFactory
]]]):
Optional list of custom stat logger factories to use. If None,
default loggers are used.
"""
self.engine_idxs = engine_idxs if engine_idxs else [0]
self.vllm_config = vllm_config

factories: list[StatLoggerFactory]
factories: list[StatLoggerFactory] = []
shared_logger_factories: list[DpSharedStatLoggerFactory] = []
if custom_stat_loggers is not None:
factories = custom_stat_loggers
for factory in custom_stat_loggers:
if isinstance(factory, type) and issubclass(
factory, PrometheusStatLogger):
shared_logger_factories.append(factory) # type: ignore
else:
factories.append(factory) # type: ignore
else:
factories = []
if logger.isEnabledFor(logging.INFO):
factories.append(LoggingStatLogger)

shared_logger_factories.append(PrometheusStatLogger)

self.shared_loggers = []
if len(shared_logger_factories) > 0:
for factory in shared_logger_factories:
self.shared_loggers.append(factory(vllm_config, engine_idxs))

# engine_idx: StatLogger
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
prometheus_factory = PrometheusStatLogger
for engine_idx in self.engine_idxs:
loggers: list[StatLoggerBase] = []
for logger_factory in factories:
# If we get a custom prometheus logger, use that
# instead. This is typically used for the ray case.
if (isinstance(logger_factory, type)
and issubclass(logger_factory, PrometheusStatLogger)):
prometheus_factory = logger_factory
continue
loggers.append(logger_factory(vllm_config,
engine_idx)) # type: ignore
self.per_engine_logger_dict[engine_idx] = loggers

# For Prometheus, need to share the metrics between EngineCores.
# Each EngineCore's metrics are expressed as a unique label.
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs)
def add_logger(
self, logger_factory: Union[StatLoggerFactory,
DpSharedStatLoggerFactory]
) -> None:
if (isinstance(logger_factory, type)
and issubclass(logger_factory, PrometheusStatLogger)):
self.shared_loggers.append(
logger_factory(self.vllm_config,
self.engine_idxs)) # type: ignore
else:
for engine_idx, logger_list in self.per_engine_logger_dict.items():
logger_list.append(logger_factory(self.vllm_config,
engine_idx)) # type: ignore

def record(
self,
Expand All @@ -678,17 +710,18 @@ def record(
for logger in per_engine_loggers:
logger.record(scheduler_stats, iteration_stats, engine_idx)

self.prometheus_logger.record(scheduler_stats, iteration_stats,
engine_idx)
for logger in self.shared_loggers:
logger.record(scheduler_stats, iteration_stats, engine_idx)

def log(self):
for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers:
logger.log()

def log_engine_initialized(self):
self.prometheus_logger.log_engine_initialized()
for shared_logger in self.shared_loggers:
shared_logger.log_engine_initialized()

for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers:
logger.log_engine_initialized()
for per_engine_logger in per_engine_loggers:
per_engine_logger.log_engine_initialized()