Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions examples/online_serving/multi_instance_data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import threading

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.loggers import AggregatedLoggingStatLogger

"""
To run this example, run the following commands simultaneously with
Expand All @@ -21,37 +23,64 @@
"""


def _do_background_logging(engine, interval, stop_event):
try:
while not stop_event.is_set():
asyncio.run(engine.do_log_stats())
stop_event.wait(interval)
except Exception as e:
print(f"vLLM background logging shutdown: {e}")
pass


async def main():
engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b",
data_parallel_size=2,
tensor_parallel_size=1,
dtype="auto",
max_model_len=2048,
data_parallel_address="127.0.0.1",
data_parallel_rpc_port=62300,
data_parallel_size_local=1,
enforce_eager=True,
enable_log_requests=True,
disable_custom_all_reduce=True,
)

engine_client = AsyncLLMEngine.from_engine_args(engine_args)

engine_client = AsyncLLMEngine.from_engine_args(
engine_args,
# Example: Using aggregated logger
stat_loggers=[AggregatedLoggingStatLogger],
)
stop_logging_event = threading.Event()
logging_thread = threading.Thread(
target=_do_background_logging,
args=(engine_client, 5, stop_logging_event),
daemon=True,
)
logging_thread.start()
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=100,
)
num_prompts = 10
for i in range(num_prompts):
prompt = "Who won the 2004 World Series?"
final_output: RequestOutput | None = None
async for output in engine_client.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=f"abcdef-{i}",
data_parallel_rank=1,
):
final_output = output
if final_output:
print(final_output.outputs[0].text)

prompt = "Who won the 2004 World Series?"
final_output: RequestOutput | None = None
async for output in engine_client.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id="abcdef",
data_parallel_rank=1,
):
final_output = output
if final_output:
print(final_output.outputs[0].text)
stop_logging_event.set()
logging_thread.join()


if __name__ == "__main__":
Expand Down
56 changes: 51 additions & 5 deletions tests/v1/engine/test_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import LoggingStatLogger
from vllm.v1.metrics.loggers import (
AggregatedLoggingStatLogger,
LoggingStatLogger,
PerEngineStatLoggerAdapter,
PrometheusStatLogger,
)

if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
Expand Down Expand Up @@ -384,6 +389,12 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
self.log = MagicMock()


class MockAggregatedStatLogger(AggregatedLoggingStatLogger):
def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]):
super().__init__(vllm_config, engine_indexes)
self.log = MagicMock()


@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
"""Test that we can customize the loggers.
Expand All @@ -401,10 +412,45 @@ async def test_customize_loggers(monkeypatch):

await engine.do_log_stats()

stat_loggers = engine.logger_manager.per_engine_logger_dict
assert len(stat_loggers) == 1
assert len(stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
stat_loggers[0][0].log.assert_called_once()
stat_loggers = engine.logger_manager.stat_loggers
assert (
len(stat_loggers) == 3
) # MockLoggingStatLogger + LoggingStatLogger + Promethus Logger
print(f"{stat_loggers=}")
stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
assert isinstance(stat_loggers[1], PerEngineStatLoggerAdapter)
assert isinstance(stat_loggers[1].per_engine_stat_loggers[0], LoggingStatLogger)
assert isinstance(stat_loggers[2], PrometheusStatLogger)


@pytest.mark.asyncio
async def test_customize_aggregated_loggers(monkeypatch):
"""Test that we can customize the aggregated loggers.
If a customized logger is provided at the init, it should
be added to the default loggers.
"""

with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
)
after.callback(engine.shutdown)

await engine.do_log_stats()

stat_loggers = engine.logger_manager.stat_loggers
assert len(stat_loggers) == 4
# MockLoggingStatLogger + MockAggregatedStatLogger
# + LoggingStatLogger + PrometheusStatLogger
stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
stat_loggers[1].log.assert_called_once()
assert isinstance(stat_loggers[2], PerEngineStatLoggerAdapter)
assert isinstance(stat_loggers[2].per_engine_stat_loggers[0], LoggingStatLogger)
assert isinstance(stat_loggers[3], PrometheusStatLogger)


@pytest.mark.asyncio(scope="module")
Expand Down
8 changes: 5 additions & 3 deletions tests/v1/metrics/test_engine_logger_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def test_async_llm_replace_default_loggers(log_stats_enabled_engine_args):
engine = AsyncLLM.from_engine_args(
log_stats_enabled_engine_args, stat_loggers=[RayPrometheusStatLogger]
)
assert isinstance(engine.logger_manager.prometheus_logger, RayPrometheusStatLogger)
assert isinstance(engine.logger_manager.stat_loggers[0], RayPrometheusStatLogger)
engine.shutdown()


Expand All @@ -73,9 +73,11 @@ async def test_async_llm_add_to_default_loggers(log_stats_enabled_engine_args):
disabled_log_engine_args, stat_loggers=[DummyStatLogger]
)

assert len(engine.logger_manager.per_engine_logger_dict[0]) == 1
assert len(engine.logger_manager.stat_loggers) == 2
assert len(engine.logger_manager.stat_loggers[0].per_engine_stat_loggers) == 1
assert isinstance(
engine.logger_manager.per_engine_logger_dict[0][0], DummyStatLogger
engine.logger_manager.stat_loggers[0].per_engine_stat_loggers[0],
DummyStatLogger,
)

# log_stats is still True, since custom stat loggers are used
Expand Down
7 changes: 7 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ class EngineArgs:
max_logprobs: int = ModelConfig.max_logprobs
logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode
disable_log_stats: bool = False
aggregate_engine_logging: bool = False
revision: str | None = ModelConfig.revision
code_revision: str | None = ModelConfig.code_revision
rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
Expand Down Expand Up @@ -1043,6 +1044,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="Disable logging statistics.",
)

parser.add_argument(
"--aggregate-engine-logging",
action="store_true",
help="Log aggregate rather than per-engine statistics "
"when using data parallelism.",
)
return parser

@classmethod
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ async def build_async_engine_client_from_engine_args(
vllm_config=vllm_config,
usage_context=usage_context,
enable_log_requests=engine_args.enable_log_requests,
aggregate_engine_logging=engine_args.aggregate_engine_logging,
disable_log_stats=engine_args.disable_log_stats,
client_addresses=client_config,
client_count=client_count,
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
log_requests: bool = True,
start_engine_loop: bool = True,
stat_loggers: list[StatLoggerFactory] | None = None,
aggregate_engine_logging: bool = False,
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
client_index: int = 0,
Expand Down Expand Up @@ -144,6 +145,7 @@ def __init__(
custom_stat_loggers=stat_loggers,
enable_default_loggers=log_stats,
client_count=client_count,
aggregate_engine_logging=aggregate_engine_logging,
)
self.logger_manager.log_engine_initialized()

Expand Down Expand Up @@ -187,6 +189,7 @@ def from_vllm_config(
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: list[StatLoggerFactory] | None = None,
enable_log_requests: bool = False,
aggregate_engine_logging: bool = False,
disable_log_stats: bool = False,
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
Expand All @@ -209,6 +212,7 @@ def from_vllm_config(
stat_loggers=stat_loggers,
log_requests=enable_log_requests,
log_stats=not disable_log_stats,
aggregate_engine_logging=aggregate_engine_logging,
usage_context=usage_context,
client_addresses=client_addresses,
client_count=client_count,
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
aggregate_engine_logging: bool = False,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: list[StatLoggerFactory] | None = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
Expand Down Expand Up @@ -132,6 +133,7 @@ def __init__(
vllm_config=vllm_config,
custom_stat_loggers=stat_loggers,
enable_default_loggers=log_stats,
aggregate_engine_logging=aggregate_engine_logging,
)
self.logger_manager.log_engine_initialized()

Expand Down
Loading