Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8802521
[V1] DP scale-out (2/N): Decouple engine process management and comms
njhill Apr 2, 2025
e869380
Headless mode
njhill Apr 3, 2025
1ca3d15
Wire data_parallel_address arg
njhill Apr 4, 2025
a551183
Some code cleanup
njhill Apr 4, 2025
a662169
Fix offline DP compatibility
njhill Apr 4, 2025
b29dcf4
Merge remote-tracking branch 'refs/remotes/origin/main' into decouple…
njhill Apr 7, 2025
8126f72
Address some review comments
njhill Apr 7, 2025
8fdc6f5
Address other minor review comments
njhill Apr 7, 2025
9c90ad4
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill Apr 17, 2025
80f9c98
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill Apr 17, 2025
efa8ad8
Fix merge error, address @russellb's ipv6 review comment
njhill Apr 17, 2025
30ab14b
Hande ipv6 URIs in all places
njhill Apr 18, 2025
acc5af3
Fix head node with no engines, don't require dp size on other nodes
njhill Apr 19, 2025
1649d7d
Merge remote-tracking branch 'refs/remotes/origin/main' into decouple…
njhill Apr 23, 2025
4fbf90e
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill Apr 23, 2025
86a0453
Merge remote-tracking branch 'refs/remotes/origin/main' into decouple…
njhill Apr 26, 2025
e70545c
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill Apr 27, 2025
24b2e1e
Merge remote-tracking branch 'refs/remotes/origin/main' into decouple…
njhill May 1, 2025
f7a909e
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill May 11, 2025
42c30bf
Fix test_startup_failure
njhill May 12, 2025
3904d10
Fix mock config related test failure
njhill May 12, 2025
cece58a
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill May 12, 2025
02f7263
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill May 12, 2025
e1400f7
Merge remote-tracking branch 'refs/remotes/origin/main' into decouple…
njhill May 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
def pairwise(iterable):
"""
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise

Can be removed when Python 3.9 support is dropped.
"""
iterator = iter(iterable)
Expand Down Expand Up @@ -1564,12 +1564,16 @@ class ParallelConfig:
data_parallel_size: int = 1
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
data_parallel_size_local: int = 1
"""Number of local data parallel groups."""
data_parallel_rank: int = 0
"""Rank of the data parallel group."""
data_parallel_rank_local: Optional[int] = None
"""Local rank of the data parallel group, defaults to global rank."""
data_parallel_master_ip: str = "127.0.0.1"
"""IP of the data parallel master."""
data_parallel_rpc_port: int = 29550
"""Port for data parallel messaging."""
data_parallel_master_port: int = 29500
"""Port of the data parallel master."""
enable_expert_parallel: bool = False
Expand Down Expand Up @@ -1682,10 +1686,13 @@ def __post_init__(self) -> None:
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size

if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError(
"data_parallel_size_local must be <= data_parallel_size")

if self.data_parallel_size > 1:
# Data parallel was specified in the engine args.
self.data_parallel_master_port = get_open_port()
# TODO multi-node
else:
# Otherwise fall back to env vars (e.g. for offline SPMD case).
self.data_parallel_size = envs.VLLM_DP_SIZE
Expand Down
38 changes: 38 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ class EngineArgs:
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_size_local: Optional[int] = None
data_parallel_address: Optional[str] = None
data_parallel_rpc_port: Optional[int] = None
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
Expand Down Expand Up @@ -526,6 +529,21 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]:
**parallel_kwargs["tensor_parallel_size"])
parallel_group.add_argument('--data-parallel-size', '-dp',
**parallel_kwargs["data_parallel_size"])
parallel_group.add_argument('--data-parallel-size-local',
'-dpl',
type=int,
help='Number of data parallel replicas '
'to run on this node.')
parallel_group.add_argument('--data-parallel-address',
'-dpa',
type=str,
help='Address of data parallel cluster '
'head-node.')
parallel_group.add_argument('--data-parallel-rpc-port',
'-dpp',
type=int,
help='Port for data parallel RPC '
'communication.')
parallel_group.add_argument(
'--enable-expert-parallel',
**parallel_kwargs["enable_expert_parallel"])
Expand Down Expand Up @@ -1223,10 +1241,30 @@ def create_engine_config(
# but we should not do this here.
placement_group = ray.util.get_current_placement_group()

# Local DP size defaults to global DP size if not set.
data_parallel_size_local = self.data_parallel_size if (
self.data_parallel_size_local
is None) else self.data_parallel_size_local

# DP address, used in multi-node case for torch distributed group
# and ZMQ sockets.
data_parallel_address = self.data_parallel_address if (
self.data_parallel_address
is not None) else ParallelConfig.data_parallel_master_ip

# This port is only used when there are remote data parallel engines,
# otherwise the local IPC transport is used.
data_parallel_rpc_port = self.data_parallel_rpc_port if (
self.data_parallel_rpc_port
is not None) else ParallelConfig.data_parallel_rpc_port

parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
data_parallel_size=self.data_parallel_size,
data_parallel_size_local=data_parallel_size_local,
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
Expand Down
70 changes: 69 additions & 1 deletion vllm/entrypoints/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,20 @@

import uvloop

import vllm.envs as envs
from vllm import AsyncEngineArgs
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.core_client import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor

logger = init_logger(__name__)


class ServeSubcommand(CLISubcommand):
Expand All @@ -24,7 +33,10 @@ def cmd(args: argparse.Namespace) -> None:
if hasattr(args, 'model_tag') and args.model_tag is not None:
args.model = args.model_tag

uvloop.run(run_server(args))
if args.headless:
run_headless(args)
else:
uvloop.run(run_server(args))

def validate(self, args: argparse.Namespace) -> None:
validate_parsed_serve_args(args)
Expand All @@ -42,6 +54,19 @@ def subparser_init(
nargs='?',
help="The model tag to serve "
"(optional if specified in config)")
serve_parser.add_argument(
"--headless",
action='store_true',
default=False,
help="Run in headless mode. See multi-node data parallel "
"documentation for more details.")
serve_parser.add_argument(
'--data-parallel-start-rank',
'-dpr',
type=int,
default=0,
help='Starting data parallel rank for secondary '
'nodes.')
serve_parser.add_argument(
"--config",
type=str,
Expand All @@ -57,3 +82,46 @@ def subparser_init(

def cmd_init() -> list[CLISubcommand]:
return [ServeSubcommand()]


def run_headless(args: argparse.Namespace):

# Create the EngineConfig.
engine_args = AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)

if not envs.VLLM_USE_V1:
raise RuntimeError("Headless mode is only supported for V1")

parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
port = engine_args.data_parallel_rpc_port # add to config too
input_address = f"tcp://{host}:{port}"

if local_engine_count <= 0:
raise RuntimeError("data_parallel_size_local must be > 0 in "
"headless mode")

logger.info(
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s.", local_engine_count, input_address)

# Create the engines.
engine_manager = CoreEngineProcManager(
target_fn=EngineCoreProc.run_engine_core,
local_engine_count=local_engine_count,
start_index=args.data_parallel_start_rank,
local_start_index=0,
vllm_config=vllm_config,
on_head_node=False,
input_address=input_address,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
)

try:
engine_manager.join_first()
finally:
engine_manager.close()
4 changes: 4 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,10 @@ def is_valid_ipv6_address(address: str) -> bool:


def get_distributed_init_method(ip: str, port: int) -> str:
return get_tcp_uri(ip, port)


def get_tcp_uri(ip: str, port: int) -> str:
# Brackets are not permitted in ipv4 addresses,
# see https://github.com/python/cpython/issues/103848
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
Expand Down
Loading