Skip to content
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
14f13ed
added debug logging
Jul 19, 2025
b90d331
updated
Jul 20, 2025
aefeeed
updated
Jul 20, 2025
59a9583
updated
Jul 20, 2025
48cf09b
updated
Jul 20, 2025
2fd0587
updated
Jul 20, 2025
14cf3c4
updated
Jul 20, 2025
4f5d3ea
updated
Jul 20, 2025
14db660
updated
Jul 20, 2025
2aa4975
updated
Jul 20, 2025
b142571
cleanup
Jul 20, 2025
e1843b7
updated
Jul 20, 2025
d2d54e9
updated
Jul 20, 2025
4438796
fix lb issues
Jul 20, 2025
2a68433
updated
Jul 20, 2025
1ced153
updatedd
Jul 20, 2025
b9c0f65
nits
Jul 20, 2025
dbc51d6
nits
Jul 20, 2025
471fa4a
updated
Jul 20, 2025
6569fac
stash
Jul 20, 2025
1e5303a
stash
Jul 20, 2025
a69edca
convert to use only one prometheus stat logger per async llm
Jul 20, 2025
de91a3c
convert to use only one prometheus stat logger per async llm
Jul 20, 2025
e08e1e9
cleanup prometheus logging
Jul 20, 2025
d39cf93
updated
Jul 20, 2025
9a2e26d
updated
Jul 20, 2025
3956d8c
updated
Jul 20, 2025
cad9670
updated
Jul 20, 2025
fd0650f
updated
Jul 20, 2025
896b0a2
updated
Jul 20, 2025
54e405b
updated
Jul 20, 2025
02ecfa8
updated
Jul 20, 2025
1358836
updated
Jul 20, 2025
4eae5cb
updated
Jul 20, 2025
5e6114d
Merge pull request #19 from robertgshaw2-redhat/fix-prometheus-logging
robertgshaw2-redhat Jul 20, 2025
c08fb6d
updated
Jul 20, 2025
d9291f9
cleanup
Jul 20, 2025
876c864
updated
Jul 20, 2025
f477b50
updated
Jul 20, 2025
5ea4fa2
updated
Jul 20, 2025
e9e180d
cleanup
Jul 20, 2025
3f4ae35
updated
Jul 20, 2025
7b53f0e
revert arg utils change
Jul 20, 2025
753061b
reset
Jul 20, 2025
381d7a6
add comment
Jul 20, 2025
d2baf53
updated
Jul 20, 2025
ebbc432
updated
Jul 20, 2025
4b50833
cleanup
Jul 20, 2025
eb5b84e
stash
Jul 20, 2025
efdeb01
merged
Jul 20, 2025
c54c17e
fixing tests
Jul 21, 2025
4be985d
passing
Jul 21, 2025
20e7f17
get other failing tst to pass
Jul 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,10 +1358,10 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
and not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if current_platform.is_rocm() or (
current_platform.is_cuda()
and current_platform.is_device_capability(100)
): # handle hpu also for OOT platform
if (current_platform.is_rocm()
or (current_platform.is_cuda()
and current_platform.is_device_capability(100))
or current_platform.is_tpu()):
supported = True
elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import (
Expand Down
71 changes: 30 additions & 41 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory,
setup_default_loggers)
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.metrics.stats import IterationStats

logger = init_logger(__name__)

Expand Down Expand Up @@ -95,14 +94,6 @@ def __init__(
self.log_requests = log_requests
self.log_stats = log_stats

# Set up stat loggers; independent set for each DP rank.
self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
vllm_config=vllm_config,
log_stats=self.log_stats,
engine_num=vllm_config.parallel_config.data_parallel_size,
custom_stat_loggers=stat_loggers,
)

# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
Expand All @@ -121,17 +112,24 @@ def __init__(
log_stats=self.log_stats)

# EngineCore (starts the engine in background process).

self.engine_core = EngineCoreClient.make_async_mp_client(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
client_addresses=client_addresses,
client_index=client_index,
)
if self.stat_loggers:
for stat_logger in self.stat_loggers[0]:
stat_logger.log_engine_initialized()

# Loggers.
self.logger_manager: Optional[StatLoggerManager] = None
if self.log_stats:
self.logger_manager = StatLoggerManager(
vllm_config=vllm_config,
engine_idxs=self.engine_core.engine_ranks,
custom_stat_loggers=stat_loggers,
)
self.logger_manager.log_engine_initialized()

self.output_handler: Optional[asyncio.Task] = None
try:
# Start output handler eagerly if we are in the asyncio eventloop.
Expand Down Expand Up @@ -370,7 +368,7 @@ def _run_output_handler(self):
engine_core = self.engine_core
output_processor = self.output_processor
log_stats = self.log_stats
stat_loggers = self.stat_loggers if log_stats else None
logger_manager = self.logger_manager

async def output_handler():
try:
Expand Down Expand Up @@ -410,11 +408,12 @@ async def output_handler():
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
if stat_loggers:
AsyncLLM._record_stats(
stat_loggers[outputs.engine_index],
# NOTE: we do not use self.log
if logger_manager:
logger_manager.record(
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
engine_idx=outputs.engine_index,
)
except Exception as e:
logger.exception("AsyncLLM output_handler failed.")
Expand All @@ -431,18 +430,6 @@ async def abort(self, request_id: str) -> None:
if self.log_requests:
logger.info("Aborted request %s.", request_id)

@staticmethod
def _record_stats(
stat_loggers: list[StatLoggerBase],
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
):
"""static so that it can be used from the output_handler task
without a circular ref to AsyncLLM."""
for stat_logger in stat_loggers:
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)

async def encode(
self,
prompt: PromptType,
Expand Down Expand Up @@ -547,7 +534,11 @@ async def do_log_stats(
scheduler_outputs=None,
model_output=None,
) -> None:
for loggers in self.stat_loggers:
if self.stat_loggers is None:
return
# loggers, prom_logger
per_engine_loggers, _ = self.stat_loggers
for loggers in per_engine_loggers.values():
for stat_logger in loggers:
stat_logger.log()

Expand Down Expand Up @@ -653,18 +644,16 @@ async def scale_elastic_ep(self,
new_data_parallel_size

# recreate stat loggers
if new_data_parallel_size > old_data_parallel_size:
stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers(
if new_data_parallel_size > old_data_parallel_size and self.log_stats:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ruisearch42 @kouroshHakha FYI - need to think through this in more detail. The current implementation is completely broken so I will save fixing the Elastic EP case for a follow up PR

Copy link
Collaborator

@ruisearch42 ruisearch42 Jul 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, thanks for the heads up. cc @libertyeagle

# TODO(rob): fix this after talking with Ray team.
# This resets all the prometheus metrics since we
# unregister during initialization. Need to understand
# the intended behavior here better.
self.logger_manager = StatLoggerManager(
vllm_config=self.vllm_config,
log_stats=self.log_stats,
engine_num=new_data_parallel_size,
engine_idxs=list(range(new_data_parallel_size)),
custom_stat_loggers=None,
)
num_new_engines = len(stat_loggers) - len(self.stat_loggers)
self.stat_loggers.extend(stat_loggers[-num_new_engines:])
else:
for _ in range(old_data_parallel_size - new_data_parallel_size):
self.stat_loggers.pop()

@property
def is_running(self) -> bool:
Expand Down
9 changes: 5 additions & 4 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,14 +432,15 @@ def __init__(
external_dp_lb = parallel_config.data_parallel_external_lb

offline_mode = parallel_config.data_parallel_rank_local is not None
engine_ranks = [dp_rank] if (offline_mode
or external_dp_lb) else range(dp_size)
self.engine_ranks = ([dp_rank] if
(offline_mode or external_dp_lb) else list(
range(dp_size)))
assert parallel_config.data_parallel_size_local <= len(
engine_ranks)
self.engine_ranks)

# ZMQ identity of each engine that this client will talk to.
self.core_engines: list[EngineIdentity] = [
index.to_bytes(2, "little") for index in engine_ranks
index.to_bytes(2, "little") for index in self.engine_ranks
]

# Wait for ready messages from each engine on the input socket.
Expand Down
Loading