Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def test_default_pooling_type(model_id, default_pooling_type, pooling_type):
)
def test_moe_model_detection(model_id, expected_is_moe_model):
model_config = ModelConfig(model_id)
# Just check that is_moe_model field exists and is a boolean
assert model_config.is_model_moe() == expected_is_moe_model
# Just check that is_moe field exists and is a boolean
assert model_config.is_moe == expected_is_moe_model


@pytest.mark.parametrize(
Expand All @@ -224,7 +224,7 @@ def test_moe_model_detection(model_id, expected_is_moe_model):
def test_is_quantized(model_id, quantized):
model_config = ModelConfig(model_id)
# Just check that quantized field exists and is a boolean
assert model_config.is_quantized() == quantized
assert model_config.is_quantized == quantized


@pytest.mark.skipif(
Expand Down Expand Up @@ -925,7 +925,7 @@ def test_vllm_config_callable_defaults():
model_config=quantized_model, optimization_level=OptimizationLevel.O2
)
enable_if_quantized = lambda cfg: (
cfg.model_config is not None and cfg.model_config.is_quantized()
cfg.model_config is not None and cfg.model_config.is_quantized
)
assert enable_if_quantized(config_quantized) is True
assert enable_if_quantized(config_no_model) is False
Expand All @@ -936,7 +936,7 @@ def test_vllm_config_callable_defaults():
model_config=moe_model, optimization_level=OptimizationLevel.O2
)
enable_if_sequential = lambda cfg: (
cfg.model_config is not None and not cfg.model_config.is_model_moe()
cfg.model_config is not None and not cfg.model_config.is_moe
)
assert enable_if_sequential(config_moe) is False
assert enable_if_sequential(config_quantized) is True
Expand Down Expand Up @@ -1050,3 +1050,46 @@ def test_scheduler_config_init():
with pytest.raises(AttributeError):
# InitVar does not become an attribute
print(SchedulerConfig.default_factory().max_model_len)


@pytest.mark.parametrize(
(
"model_id",
"data_parallel_size",
"external_lb",
"expected_needs_coordinator",
),
[
# Non-MoE model with DP=1 should not need coordinator
("facebook/opt-125m", 1, False, False),
# Non-MoE model with DP>1 internal LB should need coordinator
("facebook/opt-125m", 2, False, True),
# Non-MoE model with DP>1 external LB should not need coordinator
("facebook/opt-125m", 2, True, False),
# MoE model with DP=1 should not need coordinator
("mistralai/Mixtral-8x7B-Instruct-v0.1", 1, False, False),
# MoE model with DP>1 internal LB should need both coordinator
# and wave coordination
("mistralai/Mixtral-8x7B-Instruct-v0.1", 2, False, True),
# MoE model with DP>1 external LB needs coordinator for wave coordination
# (wave coordination runs in coordinator process)
("mistralai/Mixtral-8x7B-Instruct-v0.1", 2, True, True),
],
)
def test_needs_dp_coordination(
model_id,
data_parallel_size,
external_lb,
expected_needs_coordinator,
):
"""Test that DP coordinator and wave coordination are configured correctly."""
from vllm.config import ParallelConfig

model_config = ModelConfig(model_id)
parallel_config = ParallelConfig(
data_parallel_size=data_parallel_size,
data_parallel_external_lb=external_lb,
)
vllm_config = VllmConfig(model_config=model_config, parallel_config=parallel_config)

assert vllm_config.needs_dp_coordinator == expected_needs_coordinator
1 change: 1 addition & 0 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def setsockopt(self, *_args, **_kwargs):
parallel_config = SimpleNamespace(
data_parallel_size=1,
data_parallel_rank=0,
data_parallel_index=0,
data_parallel_size_local=1,
data_parallel_rank_local=None,
data_parallel_hybrid_lb=False,
Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def __call__(
os.makedirs(cache_dir, exist_ok=True)
self.compilation_config.cache_dir = cache_dir
rank = vllm_config.parallel_config.rank
dp_rank = vllm_config.parallel_config.data_parallel_rank
dp_rank = vllm_config.parallel_config.data_parallel_index
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
os.makedirs(local_cache_dir, exist_ok=True)
self.compilation_config.local_cache_dir = local_cache_dir
Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def __call__(self, *args, **kwargs):
)

rank = self.vllm_config.parallel_config.rank
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
dp_rank = self.vllm_config.parallel_config.data_parallel_index
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
aot_compilation_path = os.path.join(cache_dir, "model")
try:
Expand Down
13 changes: 6 additions & 7 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def _get_transformers_backend_cls(self) -> str:
cls = "Transformers"
# If 'hf_config != hf_text_config' it's a nested config, i.e. multimodal
cls += "MultiModal" if self.hf_config != self.hf_text_config else ""
cls += "MoE" if self.get_num_experts() > 1 else ""
cls += "MoE" if self.is_moe else ""
# Check if the architecture we're wrapping has defaults
runner = None
task = None
Expand Down Expand Up @@ -1001,8 +1001,7 @@ def _verify_bnb_config(self) -> None:
self.enforce_eager = True

def _verify_with_expert_parallelism(self) -> None:
num_experts = self.get_num_experts()
if num_experts < 1:
if not self.is_moe:
raise ValueError(
"Number of experts in the model must be greater than 0 "
"when expert parallelism is enabled."
Expand Down Expand Up @@ -1797,11 +1796,11 @@ def is_prefix_caching_supported(self) -> bool:
logger.debug("Generative models support prefix caching.")
return True

def is_model_moe(
self,
) -> bool:
return self.get_num_experts() > 1
@property
def is_moe(self) -> bool:
return self.get_num_experts() > 0

@property
def is_quantized(self) -> bool:
return getattr(self.hf_config, "quantization_config", None) is not None

Expand Down
15 changes: 15 additions & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ class ParallelConfig:
between local data parallel ranks, but an external LB balances
between vLLM nodes/replicas. Set explicitly in conjunction with
--data-parallel-start-rank."""
is_moe_model: bool | None = None
"""Whether the deployed model is MoE (if known)."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False
Expand Down Expand Up @@ -255,6 +257,10 @@ class is dynamically inherited by the worker class. This is used to inject
Block_size should be divisible by cp_kv_cache_interleave_size.
"""

data_parallel_index: int = Field(init=False)
"""Equal to the data parallel rank but not used for torch process groups
and not overridden for dense models."""

_api_process_count: int = Field(default=1, gt=0)
"""
The number of API processes initialized.
Expand Down Expand Up @@ -466,6 +472,7 @@ def compute_hash(self):
"data_parallel_rank",
"data_parallel_rank_local",
"data_parallel_size_local",
"data_parallel_index",
"data_parallel_backend",
"data_parallel_external_lb",
"data_parallel_hybrid_lb",
Expand Down Expand Up @@ -546,6 +553,14 @@ def __post_init__(self) -> None:
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT

if self.data_parallel_size > 1 and self.is_moe_model is False:
raise ValueError(
"Offline data parallel mode is not supported/useful"
" for dense models."
)

self.data_parallel_index = self.data_parallel_rank

if self.distributed_executor_backend == "external_launcher":
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.")
Expand Down
41 changes: 33 additions & 8 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,29 @@ def pad_for_cudagraph(self, batch_size: int) -> int:
# i.e., batch_size <= self.compilation_config.max_cudagraph_capture_size
return self.compilation_config.bs_to_padded_graph_size[batch_size]

@property
def needs_dp_coordinator(self) -> bool:
"""
Determine if the DPCoordinator process is needed.

The DPCoordinator is needed in two cases:
1. For MoE models with DP > 1: to handle wave coordination
(even in external LB mode, since wave coordination runs in the coordinator)
2. For non-MoE models in internal/hybrid LB mode: to collect and publish
queue stats for load balancing across DP ranks

Returns:
True if DPCoordinator process is needed, False otherwise.
"""

# For non-MoE models, only need coordinator in internal/hybrid LB mode
# (for stats collection).
return self.parallel_config.data_parallel_size > 1 and (
self.model_config is None
or self.model_config.is_moe
or not self.parallel_config.data_parallel_external_lb
)

def enable_trace_function_call_for_thread(self) -> None:
"""
Set up function tracing for the current thread,
Expand Down Expand Up @@ -522,6 +545,8 @@ def __post_init__(self):
self.model_config.verify_with_parallel_config(self.parallel_config)
self.model_config.verify_dual_chunk_attention_config(self.load_config)

self.parallel_config.is_moe_model = self.model_config.is_moe

self.cache_config.verify_with_parallel_config(self.parallel_config)

if self.lora_config is not None:
Expand Down Expand Up @@ -811,9 +836,14 @@ def has_blocked_weights():
)

# Do this after all the updates to compilation_config.mode
effective_dp_size = (
self.parallel_config.data_parallel_size
if self.model_config is None or self.model_config.is_moe
else 1
)
self.compilation_config.set_splitting_ops_for_v1(
all2all_backend=self.parallel_config.all2all_backend,
data_parallel_size=self.parallel_config.data_parallel_size,
data_parallel_size=effective_dp_size,
)

if self.compilation_config.pass_config.enable_sp:
Expand Down Expand Up @@ -1281,13 +1311,8 @@ def compile_debug_dump_path(self) -> Path | None:
if self.compilation_config.debug_dump_path is None:
return None
tp_rank = self.parallel_config.rank
dp_rank = self.parallel_config.data_parallel_rank
data_parallel_size = self.parallel_config.data_parallel_size
append_path = (
f"rank_{tp_rank}"
if data_parallel_size == 1
else f"rank_{tp_rank}_dp_{dp_rank}"
)
dp_rank = self.parallel_config.data_parallel_index
append_path = f"rank_{tp_rank}_dp_{dp_rank}"
path = self.compilation_config.debug_dump_path / append_path
return path

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,6 @@ def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int:
# This logic is now centralized
return (
envs.VLLM_MOONCAKE_BOOTSTRAP_PORT
+ vllm_config.parallel_config.data_parallel_rank
+ vllm_config.parallel_config.data_parallel_index
* vllm_config.parallel_config.tensor_parallel_size
)
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
self.side_channel_port = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT
+ vllm_config.parallel_config.data_parallel_rank
+ vllm_config.parallel_config.data_parallel_index
)
assert vllm_config.kv_transfer_config is not None
if current_platform.device_type == "cpu":
Expand Down
37 changes: 22 additions & 15 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,7 +1115,11 @@ def get_dp_group() -> GroupCoordinator:


def get_ep_group() -> GroupCoordinator:
assert _EP is not None, "expert parallel group is not initialized"
assert _EP is not None, (
"expert parallel group is not initialized. "
"EP group is only created for MoE models with num_experts > 0. "
"This function should only be called for MoE models."
)
return _EP


Expand Down Expand Up @@ -1400,20 +1404,23 @@ def initialize_model_parallel(

global _EP
assert _EP is None, "expert parallel group is already initialized"
group_ranks = (
all_ranks.transpose(1, 2)
.reshape(
-1,
data_parallel_size
* prefill_context_model_parallel_size
* tensor_model_parallel_size,
# Don't create EP group for dense models.
if config is None or config.model_config is None or config.model_config.is_moe:
group_ranks = (
all_ranks.transpose(1, 2)
.reshape(
-1,
data_parallel_size
* prefill_context_model_parallel_size
* tensor_model_parallel_size,
)
.unbind(0)
)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_EP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="ep"
)
group_ranks = [x.tolist() for x in group_ranks]
_EP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="ep"
)
# If no EP group needed, _EP remains None

logger.info_once(
"rank %s in world size %s is assigned as "
Expand All @@ -1425,7 +1432,7 @@ def initialize_model_parallel(
_PP.rank_in_group,
_PCP.rank_in_group,
_TP.rank_in_group,
_EP.rank_in_group,
_EP.rank_in_group if _EP is not None else "N/A",
)


Expand Down
1 change: 1 addition & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,7 @@ def create_engine_config(
data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=self.data_parallel_backend,
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
is_moe_model=model_config.is_moe,
enable_expert_parallel=self.enable_expert_parallel,
all2all_backend=self.all2all_backend,
enable_dbo=self.enable_dbo,
Expand Down
1 change: 1 addition & 0 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def make(
) -> "DPMetadata":
assert num_tokens_across_dp_cpu is not None
assert parallel_config.data_parallel_size > 1
assert parallel_config.is_moe_model is not False
dp_rank = parallel_config.data_parallel_rank
batchsize = num_tokens

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(

self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
self.parallel_config.data_parallel_rank,
self.parallel_config.data_parallel_index,
)
self.ec_connector = None
if self.vllm_config.ec_transfer_config is not None:
Expand Down
Loading