Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
8802521
[V1] DP scale-out (2/N): Decouple engine process management and comms
njhill Apr 2, 2025
e869380
Headless mode
njhill Apr 3, 2025
1ca3d15
Wire data_parallel_address arg
njhill Apr 4, 2025
a551183
Some code cleanup
njhill Apr 4, 2025
a662169
Fix offline DP compatibility
njhill Apr 4, 2025
b29dcf4
Merge remote-tracking branch 'refs/remotes/origin/main' into decouple…
njhill Apr 7, 2025
8126f72
Address some review comments
njhill Apr 7, 2025
8fdc6f5
Address other minor review comments
njhill Apr 7, 2025
9c90ad4
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill Apr 17, 2025
80f9c98
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill Apr 17, 2025
efa8ad8
Fix merge error, address @russellb's ipv6 review comment
njhill Apr 17, 2025
30ab14b
Hande ipv6 URIs in all places
njhill Apr 18, 2025
acc5af3
Fix head node with no engines, don't require dp size on other nodes
njhill Apr 19, 2025
1649d7d
Merge remote-tracking branch 'refs/remotes/origin/main' into decouple…
njhill Apr 23, 2025
4fbf90e
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill Apr 23, 2025
86a0453
Merge remote-tracking branch 'refs/remotes/origin/main' into decouple…
njhill Apr 26, 2025
e70545c
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill Apr 27, 2025
24b2e1e
Merge remote-tracking branch 'refs/remotes/origin/main' into decouple…
njhill May 1, 2025
c76e8e5
[Perf] API-server scaleout with all-to-all server-engine comms
njhill Apr 4, 2025
742b532
Fix engine init num_gpu_blocks logging
njhill May 1, 2025
6340c87
Improve load balancing
njhill May 5, 2025
877f195
small fixes
njhill May 6, 2025
f7a909e
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill May 11, 2025
12da06b
Merge remote-tracking branch 'refs/remotes/njhill/decouple-engines' i…
njhill May 11, 2025
42c30bf
Fix test_startup_failure
njhill May 12, 2025
3904d10
Fix mock config related test failure
njhill May 12, 2025
cece58a
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill May 12, 2025
811d8f4
Merge remote-tracking branch 'njhill/decouple-engines' into all-to-all
njhill May 12, 2025
02f7263
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill May 12, 2025
77b7821
Merge remote-tracking branch 'njhill/decouple-engines' into all-to-all
njhill May 12, 2025
ee7cc58
wip
kouroshHakha May 13, 2025
e1400f7
Merge remote-tracking branch 'refs/remotes/origin/main' into decouple…
njhill May 13, 2025
1d89c90
Merge remote-tracking branch 'refs/remotes/njhill/decouple-engines' i…
njhill May 13, 2025
5403440
Merge remote-tracking branch 'origin/main' into all-to-all
njhill May 13, 2025
fe507f6
Wip
kouroshHakha May 13, 2025
70e2513
Merge branch 'all-to-all' into kh/fix-a2a-metrics
kouroshHakha May 13, 2025
b376674
lint
kouroshHakha May 13, 2025
0c0a1a5
wip
kouroshHakha May 15, 2025
0fa7aac
refactored prometheus non-sense into a separate self contained python…
kouroshHakha May 15, 2025
106265d
clean
kouroshHakha May 15, 2025
aace3a0
nit
kouroshHakha May 15, 2025
b876684
wip
kouroshHakha May 16, 2025
4cadec3
feedback
kouroshHakha May 16, 2025
e3b0dd2
wip
kouroshHakha May 16, 2025
aad9b52
wip
kouroshHakha May 16, 2025
f3f65bd
nits
kouroshHakha May 16, 2025
7260f4f
nit
kouroshHakha May 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def make_request(request_id,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100,
arrival_time=0,
lora_request=None,
cache_salt=cache_salt,
)
Expand Down
1 change: 0 additions & 1 deletion tests/v1/core/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def make_request(request_id,
sampling_params=SamplingParams(max_tokens=17,
prompt_logprobs=prompt_logprobs),
eos_token_id=100,
arrival_time=0,
lora_request=None,
cache_salt=cache_salt,
)
Expand Down
3 changes: 1 addition & 2 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def create_requests(num_requests: int,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=0,
)
requests.append(request)
return requests
Expand Down Expand Up @@ -732,7 +731,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
prompt_logprobs_dict={},
)
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
model_runner_output)[0]

for i in range(len(requests)):
running_req = scheduler.running[i]
Expand Down
162 changes: 156 additions & 6 deletions vllm/entrypoints/cli/serve.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,34 @@
# SPDX-License-Identifier: Apache-2.0

import argparse
import multiprocessing
import os
import signal
import sys
from multiprocessing.context import SpawnProcess
from typing import Any

import uvloop
import zmq

import vllm.envs as envs
from vllm import AsyncEngineArgs
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.api_server import (run_server, run_server_worker,
setup_server)
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.core_client import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from vllm.v1.utils import (CoreEngine, get_engine_client_zmq_addr,
wait_for_engine_startup)

logger = init_logger(__name__)

Expand All @@ -34,9 +46,12 @@ def cmd(args: argparse.Namespace) -> None:
if hasattr(args, 'model_tag') and args.model_tag is not None:
args.model = args.model_tag

if args.headless:
if args.headless or args.api_server_count < 1:
run_headless(args)
elif args.api_server_count > 1:
run_multi_api_server(args)
else:
# Single API server (this process).
uvloop.run(run_server(args))

def validate(self, args: argparse.Namespace) -> None:
Expand Down Expand Up @@ -67,6 +82,11 @@ def subparser_init(
type=int,
default=0,
help='Starting data parallel rank for secondary nodes.')
serve_parser.add_argument('--api-server-count',
'-asc',
type=int,
default=1,
help='How many API server processes to run.')
serve_parser.add_argument(
"--config",
type=str,
Expand All @@ -86,6 +106,9 @@ def cmd_init() -> list[CLISubcommand]:

def run_headless(args: argparse.Namespace):

if args.api_server_count > 1:
raise RuntimeError("api_server_count can't be set in headless mode")

# Create the EngineConfig.
engine_args = AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
Expand All @@ -98,7 +121,7 @@ def run_headless(args: argparse.Namespace):
local_engine_count = parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
port = engine_args.data_parallel_rpc_port # add to config too
input_address = get_tcp_uri(host, port)
handshake_address = get_tcp_uri(host, port)

if local_engine_count <= 0:
raise RuntimeError("data_parallel_size_local must be > 0 in "
Expand All @@ -114,7 +137,7 @@ def signal_handler(signum, frame):

logger.info(
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s.", local_engine_count, input_address)
"with head node address %s.", local_engine_count, handshake_address)

# Create the engines.
engine_manager = CoreEngineProcManager(
Expand All @@ -124,7 +147,7 @@ def signal_handler(signum, frame):
local_start_index=0,
vllm_config=vllm_config,
on_head_node=False,
input_address=input_address,
handshake_address=handshake_address,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
)
Expand All @@ -134,3 +157,130 @@ def signal_handler(signum, frame):
finally:
logger.info("Shutting down.")
engine_manager.close()


def run_multi_api_server(args: argparse.Namespace):

assert not args.headless
num_api_servers = args.api_server_count
assert num_api_servers > 1

setup_multiprocess_prometheus()

listen_address, sock = setup_server(args)

engine_args = AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
parallel_config = vllm_config.parallel_config

assert parallel_config.data_parallel_rank == 0

dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
local_only = local_engine_count == dp_size

# Set up input and output addresses.
input_addresses = [
get_engine_client_zmq_addr(local_only, host)
for _ in range(num_api_servers)
]
output_addresses = [
get_engine_client_zmq_addr(local_only, host)
for _ in range(num_api_servers)
]

addresses: dict[str, Any] = {
"input_addresses": input_addresses,
"output_addresses": output_addresses,
}

# Set up coordinator for dp > 1.
coordinator = None
stats_update_address = None
if dp_size > 1:
# TODO "ready" event for coordinator
coordinator = DPCoordinator(parallel_config)
addresses.update(coordinator.get_engine_socket_addresses())
stats_update_address = coordinator.get_stats_publish_address()

handshake_address = get_engine_client_zmq_addr(
local_only, host, parallel_config.data_parallel_rpc_port)

with zmq_socket_ctx(handshake_address, zmq.ROUTER,
bind=True) as handshake_socket:

# Start local engines.
if not local_engine_count:
local_engine_manager = None
else:
local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
handshake_address=handshake_address,
on_head_node=True,
local_engine_count=local_engine_count,
start_index=0,
local_start_index=0)

# Start API servers.
spawn_context = multiprocessing.get_context("spawn")
api_server_workers: list[SpawnProcess] = []
for i, in_addr, out_addr in zip(range(num_api_servers),
input_addresses, output_addresses):
client_config = {
"input_address": in_addr,
"output_address": out_addr,
"client_index": i
}
if stats_update_address is not None:
client_config["stats_update_address"] = stats_update_address

# TODO check signal propagation
proc = spawn_context.Process(target=run_api_server_worker,
name=f"ApiServer_{i}",
args=(listen_address, sock, args,
client_config))
api_server_workers.append(proc)
proc.start()

# Wait for engine handshakes to complete.
core_engines = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(dp_size)
]

wait_for_engine_startup(
handshake_socket,
addresses,
core_engines,
parallel_config,
vllm_config.cache_config,
local_engine_manager,
coordinator.proc if coordinator else None,
)

# TODO handle failures / clean shutdown here
for proc in api_server_workers:
proc.join()


def run_api_server_worker(listen_address,
sock,
args,
client_config=None,
**uvicorn_kwargs) -> None:

# Add process-specific prefix to stdout and stderr.
from multiprocessing import current_process
process_name = current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)

uvloop.run(
run_server_worker(listen_address, sock, args, client_config,
**uvicorn_kwargs))
Loading