diff --git a/tests/kernels/moe/test_moe_layer.py b/tests/kernels/moe/test_moe_layer.py index 2b27202b6b6f..7942f738a1e5 100644 --- a/tests/kernels/moe/test_moe_layer.py +++ b/tests/kernels/moe/test_moe_layer.py @@ -31,12 +31,13 @@ VllmConfig, set_current_vllm_config, ) -from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator -from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace -from vllm.distributed.parallel_state import ( +from vllm.distributed import ( get_ep_group, get_eplb_group, + tensor_model_parallel_all_gather, ) +from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator +from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe import FusedMoE, fused_experts from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -49,6 +50,7 @@ ModelOptFp8Config, ModelOptNvFp4Config, ) +from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.utils.flashinfer import ( has_flashinfer_nvlink_one_sided, @@ -81,6 +83,9 @@ [1, 4, False], [2, 1, True], [4, 1, True], + # This combination indicates sequence parallel. + # See ParallelConfig.use_sequence_parallel. + [2, 2, True], ] # TODO: should this even be set manually? let oracles handle this @@ -112,24 +117,24 @@ # Which quantization methods each backend supports. # fmt: off BACKEND_SUPPORTED_QUANTS: dict[str, set[str | None]] = { - "allgather_reducescatter": {None, "fp8", "modelopt_fp8", "modelopt_fp4"}, - "mori": {None, "fp8", "modelopt_fp8"}, - "flashinfer_nvlink_two_sided": {None, "modelopt_fp8", "modelopt_fp4"}, - "flashinfer_nvlink_one_sided": {None, "modelopt_fp8", "modelopt_fp4"}, - "deepep_low_latency": {None, "fp8_blocked", "modelopt_fp4"}, + "allgather_reducescatter": {None, "fp8", "modelopt_fp8", "modelopt_fp4"}, # noqa: E501 + "mori": {None, "fp8", "modelopt_fp8"}, + "flashinfer_nvlink_two_sided": {None, "fp8_blocked", "modelopt_fp4"}, # noqa: E501 + "flashinfer_nvlink_one_sided": {None, "modelopt_fp4"}, # noqa: E501 + "deepep_low_latency": {None, "fp8_blocked", "modelopt_fp4"}, # noqa: E501 "deepep_high_throughput": {None, "fp8_blocked", "modelopt_fp8", "modelopt_fp4"}, # noqa: E501 - "nixl_ep": {None, "fp8", "modelopt_fp8"}, + "nixl_ep": {None, "fp8_blocked", "modelopt_fp8"}, } -# Map from backend -> (DP/EP support, DP support, TP support) -BACKEND_EP_DP_TP_SUPPORT: dict[str, tuple[bool, bool, bool]] = { - "allgather_reducescatter": (True, True, True), - "mori": (True, False, False), - "flashinfer_nvlink_two_sided": (False, True, False), - "flashinfer_nvlink_one_sided": (False, True, False), - "deepep_low_latency": (True, False, False), - "deepep_high_throughput": (True, False, False), - "nixl_ep": (True, False, False), +# Map from backend -> (DP/EP support, DP support, TP support, SP support) +BACKEND_EP_DP_TP_SUPPORT: dict[str, tuple[bool, bool, bool, bool]] = { + "allgather_reducescatter": (True, True, True, True), + "mori": (True, False, False, True), + "flashinfer_nvlink_two_sided": (False, True, False, False), + "flashinfer_nvlink_one_sided": (False, True, False, False), + "deepep_low_latency": (True, False, False, True), + "deepep_high_throughput": (True, False, False, True), + "nixl_ep": (True, False, False, True), } # fmt: on @@ -163,6 +168,45 @@ def override_normalize_e4m3fn_to_e4m3fnuz(): vllm.model_executor.layers.quantization.utils.w8a8_utils.normalize_e4m3fn_to_e4m3fnuz = mock_normalize_e4m3fn_to_e4m3fnuz # noqa: E501 +def sp_wrapper( + fn: Callable | FusedMoE, is_sequence_parallel: bool | None = None +) -> Callable: + """Wrapper to handle sequence parallelism chunking and gathering. + + For SP with EP: + - The TP group is created with the original tensor_parallel_size (e.g., 2) + - get_tp_group() has the correct world_size for SP operations + - sequence_parallel_chunk() uses get_tensor_model_parallel_world_size() + - tensor_model_parallel_all_gather() uses get_tp_group() + - Both should work correctly even when EP is enabled + """ + if isinstance(fn, FusedMoE): + assert is_sequence_parallel is None + is_sequence_parallel = fn.is_sequence_parallel + else: + assert is_sequence_parallel is not None + + if is_sequence_parallel: + + def wrapper( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor: + # Split sequence across TP ranks + # Both hidden_states and router_logits have [num_tokens, ...] shape + hidden_states = sequence_parallel_chunk(hidden_states) + router_logits = sequence_parallel_chunk(router_logits) + # Run MoE on local chunk + result = fn(hidden_states, router_logits) + # Gather results from all TP ranks + result = tensor_model_parallel_all_gather(result, 0) + # Remove any padding added by SP. + return result[: hidden_states.shape[0]] + + return wrapper + return fn + + def maybe_roundup_layer_hidden_size( hidden_size: int, act_dtype: torch.dtype, @@ -272,6 +316,15 @@ class MoETestConfig: dp_size: int = 1 tp_size: int = 1 + @property + def is_sequence_parallel(self) -> bool: + # Sequence parallelism: EP enabled + TP dimension used for sequence splitting + # In test config: ep_size represents total expert parallel size + # tp_size represents the original TP dimension (becomes sp_size in FusedMoE) + # dp_size represents data parallel size + # For SP: we need EP enabled (ep_size > 1) and sequence splitting (tp_size > 1) + return self.ep_size > 1 and self.tp_size > 1 + # TODO: add more error messages def id(self) -> str: def proc(s: str) -> str: @@ -404,11 +457,6 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]: "leads to large differences.", ) - # gate requires shared_experts (use_overlapped mode) - # TODO: also not sure this is true - if config.use_gate and not config.use_shared_experts: - return False, "gate requires shared_experts (use_overlapped mode)" - # Skip modelopt_fp4 if not on B100+ (compute capability 10.0+) if ( config.quantization == "modelopt_fp4" @@ -445,7 +493,7 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]: ) if config.backend == "nixl_ep": - from vllm.model_executor.layers.fused_moe.nixl_ep_prepare_finalize import ( # noqa: E501 + from vllm.model_executor.layers.fused_moe.prepare_finalize.nixl_ep import ( # noqa: E501 NixlEPPrepareAndFinalize, ) @@ -456,11 +504,11 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]: ) if config.backend is not None: - supports_ep_dp, supports_dp, supports_tp = BACKEND_EP_DP_TP_SUPPORT[ - config.backend - ] + supports_ep_dp, supports_dp, supports_tp, supports_sp = ( + BACKEND_EP_DP_TP_SUPPORT[config.backend] + ) - if config.tp_size > 1 and not supports_tp: + if config.tp_size > 1 and not supports_tp and not config.is_sequence_parallel: return False, f"{config.backend} does not support TP." if config.dp_size > 1 and config.ep_size == 1 and not supports_dp: @@ -468,10 +516,34 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]: if config.dp_size > 1 and config.ep_size > 1 and not supports_ep_dp: return False, f"{config.backend} does not support EP/DP." + + if config.is_sequence_parallel and not supports_sp: + return False, f"{config.backend} does not support SP." else: if config.tp_size > 1 or config.ep_size > 1 or config.dp_size > 1: return False, "An all2all backend is required for parallelism." + # Sequence parallelism specific validations + if config.is_sequence_parallel: + if config.ep_size == 1: + return False, "Sequence parallelism requires EP to be enabled (ep_size > 1)" + + if config.tp_size == 1: + return ( + False, + "Sequence parallelism requires tp_size > 1 for sequence splitting", + ) + + # SP is essentially EP + sequence splitting + # Verify the relationship: ep_size should equal dp_size * tp_size + # (when pcp_size=1). + expected_ep_size = config.dp_size * config.tp_size + if config.ep_size != expected_ep_size: + return False, ( + f"For sequence parallelism: ep_size ({config.ep_size}) should equal " + f"dp_size * tp_size ({expected_ep_size})" + ) + if config.enable_eplb: if config.ep_size == 1: return False, "EPLB requires EP." @@ -485,14 +557,6 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]: if config.num_experts % config.dp_size != 0: return False, "EPLB requires num_experts divisible by ep_size" - # Disable fp4 tests until flashinfer is updated or the Dockerfile is - # modified to install cublasLt.h. See #39525. - if ( - config.quantization == "modelopt_fp4" - and current_platform.is_device_capability_family(100) - ): - return False, "Temporarily skip until #39525 is resolved" - return True, None @@ -751,7 +815,8 @@ def create_shared_experts_from_config( in_dtype: torch.dtype, tp_size: int = 1, tp_rank: int = 0, - device: torch.device | str | None = None, + is_sequence_parallel: bool = False, + device: torch.device | str | None = "cuda", ) -> TestMLP | None: """Create TestMLP for shared experts from config. @@ -772,7 +837,7 @@ def create_shared_experts_from_config( s_w2 = shared_experts_config.w2 # Apply TP chunking if needed - if tp_size > 1: + if tp_size > 1 and not is_sequence_parallel: s_w1 = tp_chunk_gate_up(s_w1, tp_rank, tp_size, dim=1, device=device) s_w2 = chunk_by_rank(s_w2, tp_rank, tp_size, dim=0, device=device) else: @@ -920,6 +985,7 @@ def make_fused_moe_layer( routed_input_transform: torch.nn.Module | None = None, routed_output_transform: torch.nn.Module | None = None, pcp_size: int | None = 1, + is_sequence_parallel: bool = False, ) -> FusedMoE: quant_config, qw = make_quant_config(quantization, w1, w2, global_num_experts) @@ -959,6 +1025,7 @@ def make_fused_moe_layer( enable_eplb=enable_eplb, num_redundant_experts=num_redundant_experts, has_bias=has_bias, + is_sequence_parallel=is_sequence_parallel, **kwargs, ) @@ -1014,6 +1081,7 @@ def make_fake_moe_layer( tp_size: int = 1, dp_size: int = 1, ep_size: int = 1, + is_sequence_parallel: bool = False, ) -> Callable: quant_dtype = None activation = MoEActivation.from_str(activation) @@ -1043,7 +1111,8 @@ def make_fake_moe_layer( w2_s = None shared_experts = create_shared_experts_from_config( - shared_experts_config, in_dtype, 1, 0, "cuda" + shared_experts_config, + in_dtype, ) quant_config = FusedMoEQuantConfig.make( @@ -1114,7 +1183,7 @@ def _moe( def _test_body_regular( - moe_layer: Callable, + moe_layer: FusedMoE, hidden_states: torch.Tensor, router_logits: torch.Tensor, vllm_config: VllmConfig, @@ -1131,7 +1200,7 @@ def _test_body_regular( num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, ): - output = moe_layer(hidden_states, router_logits) + output = sp_wrapper(moe_layer)(hidden_states, router_logits) return baseline_output, output @@ -1164,6 +1233,8 @@ def _test_body_eplb( ) -> tuple[torch.Tensor, torch.Tensor]: device = torch.accelerator.current_accelerator() + is_sequence_parallel = moe_layer.is_sequence_parallel + """EPLB test body: compare output before and after expert weight rearrangement.""" # Get "before" output with original weight arrangement with set_forward_context( @@ -1172,7 +1243,7 @@ def _test_body_eplb( num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, ): - output_before = moe_layer(hidden_states, router_logits) + output_before = sp_wrapper(moe_layer)(hidden_states, router_logits) # Create a fresh FusedMoE layer with enable_eplb=True # Delete the original layer's registration so the constructor can @@ -1203,6 +1274,7 @@ def _test_body_eplb( gate=gate, routed_input_transform=routed_input_transform, routed_output_transform=routed_output_transform, + is_sequence_parallel=is_sequence_parallel, ) if eplb_moe_layer._expert_map is not None: @@ -1264,7 +1336,7 @@ def _test_body_eplb( num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, ): - output_after = eplb_moe_layer(hidden_states, router_logits) + output_after = sp_wrapper(eplb_moe_layer)(hidden_states, router_logits) return output_before, output_after @@ -1272,11 +1344,12 @@ def _test_body_eplb( # TODO: make this take a MoETestConfig def _run_one_config( vllm_config: VllmConfig, - ep_size: int, - dp_size: int, - tp_size: int, - dp_rank: int, - tp_rank: int, + ep_size: int, # Expert parallel size (total across all ranks) + dp_size: int, # Data parallel size (number of DP groups) + tp_size: int, # Tensor parallel size OR sequence parallel size (when use_ep=True) + dp_rank: int, # Current rank in data parallel dimension + tp_rank: int, # Current rank in tensor/sequence parallel dimension + is_sequence_parallel: bool, # Whether to use sequence parallelism m: int, n: int, k: int, @@ -1290,15 +1363,22 @@ def _run_one_config( use_routed_input_transform: bool, **kwargs, ) -> None: - set_random_seed(7) - """Generic test loop that sets up environment and delegates to test_body_fn. - This function is called directly by test_moe_layer and test_moe_layer_eplb - via parallel_launch_with_config, passing either _test_body_regular or - _test_body_eplb as the test_body_fn parameter. + Parameter Interpretation: + - When is_sequence_parallel=False (standard TP or EP): + * ep_size: Number of expert parallel ranks (or 1 if no EP) + * tp_size: Number of tensor parallel ranks (or 1 if no TP) + * Weights are chunked by ep_size (experts) and tp_size (tensors) + + - When is_sequence_parallel=True (EP + sequence splitting): + * ep_size: Number of expert parallel ranks (equals dp_size * tp_size) + * tp_size: Number of ranks to split sequence across (becomes sp_size in FusedMoE) + * Weights are chunked by ep_size (experts) but NOT by tp_size + * Input sequences are chunked by tp_size (via sp_wrapper) """ - world_size = tp_size * dp_size + set_random_seed(7) + use_ep = ep_size > 1 assert vllm_config.parallel_config.enable_expert_parallel == use_ep @@ -1334,6 +1414,8 @@ def _run_one_config( routed_output_transform = test_data.routed_output_transform activation = "silu" + # Create baseline layer with FULL weights (no EP chunking) + # Baseline represents the expected output using full model baseline_layer = make_fake_moe_layer( w1=w1, w2=w2, @@ -1351,26 +1433,45 @@ def _run_one_config( ep_size=ep_size, dp_size=dp_size, activation=activation, + is_sequence_parallel=is_sequence_parallel, ) - baseline_output = baseline_layer(hidden_states, router_logits) + with set_current_vllm_config(vllm_config): + # Compute baseline output with SP wrapper if needed + # sp_wrapper handles sequence chunking/gathering for SP + baseline_output = sp_wrapper(baseline_layer, is_sequence_parallel)( + hidden_states, router_logits + ) del baseline_layer torch.accelerator.empty_cache() with set_current_vllm_config(vllm_config): - # Chunk weights for EP/TP (after baseline is created) + # Chunk weights for EP BEFORE creating FusedMoE + # FusedMoE uses EP-chunked weights and handles reductions internally if ep_size > 1: - w1 = chunk_by_rank(w1, dp_rank, dp_size, dim=0, device=device) - w2 = chunk_by_rank(w2, dp_rank, dp_size, dim=0, device=device) - - if tp_size > 1: + # Split experts across ranks (dimension 0 is the expert dimension) + # When EP is enabled, use EP group rank and ep_size for chunking + ep_rank = get_ep_group().rank_in_group + w1 = chunk_by_rank(w1, ep_rank, ep_size, dim=0, device=device) + w2 = chunk_by_rank(w2, ep_rank, ep_size, dim=0, device=device) + + # Chunk weights for TP (only if NOT doing sequence parallelism) + # Sequence parallelism splits tokens/sequences, not weight tensors + if tp_size > 1 and not is_sequence_parallel: w1 = tp_chunk_gate_up(w1, tp_rank, tp_size, dim=1, device=device) w2 = chunk_by_rank(w2, tp_rank, tp_size, dim=2, device=device) # Setup shared experts if needed + # In SP mode, shared experts should NOT be TP-chunked (same as routed experts) + # tp_size is used for sequence splitting, not weight splitting shared_experts = create_shared_experts_from_config( - shared_experts_config, in_dtype, tp_size, tp_rank, device + shared_experts_config, + in_dtype, + tp_size, + tp_rank, + is_sequence_parallel, + device, ) # Determine hidden size for MoE layer @@ -1396,14 +1497,17 @@ def _run_one_config( routed_input_transform=routed_input_transform, routed_output_transform=routed_output_transform, activation=activation, + is_sequence_parallel=is_sequence_parallel, ) if moe_layer._expert_map is not None: moe_layer._expert_map = moe_layer._expert_map.to(device) num_tokens = m + # num_tokens_across_dp should have one entry per DP group, not per total rank + # When EP is enabled, dp_size represents the number of DP groups num_tokens_across_dp = torch.tensor( - [num_tokens] * world_size, + [num_tokens] * dp_size, device=device, dtype=torch.int, ) @@ -1445,7 +1549,7 @@ def _run_one_config( else: atol, rtol = 3.5e-2, 3.5e-2 elif quantization in ("fp8", "fp8_blocked", "modelopt_fp8"): - atol, rtol = 6e-2, 6e-2 + atol, rtol = 6.5e-2, 6.5e-2 elif quantization == "modelopt_fp4": if k >= 2048: atol = rtol = 1e-1 + (k * 1e-4) @@ -1529,6 +1633,7 @@ def test_moe_layer_no_parallel( test_config.tp_size, 0, 0, + False, test_config.m, test_config.n, test_config.k, @@ -1589,6 +1694,7 @@ def _parallel_worker( test_config.tp_size, dp_rank, tp_rank, + test_config.is_sequence_parallel, test_config.m, test_config.n, test_config.k, @@ -1673,7 +1779,12 @@ def test_moe_layer( """ num_gpus = current_platform.device_count() world_size = tp_size * dp_size - ep_size = 1 if not use_ep else world_size # or dp_size? + # When use_ep=True: FusedMoEParallelConfig flattens tp_size across dp ranks + # Result: ep_size = dp_size * pcp_size * tp_size + # Since pcp_size=1 in these tests: ep_size = dp_size * tp_size = world_size + # When use_ep=False: no expert parallelism, ep_size = 1 + ep_size = 1 if not use_ep else world_size + assert world_size > 1 # Check if enough GPUs available diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 57ef6e9cf148..7fd02bce6153 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -335,6 +335,7 @@ class NixlEPAll2AllManager(All2AllManagerBase): _lock = threading.Lock() def __init__(self, cpu_group, tcp_store_group=None): + assert tcp_store_group is not None super().__init__(cpu_group, tcp_store_group) self.max_num_ep_ranks = envs.VLLM_NIXL_EP_MAX_NUM_RANKS diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 74056134c095..e46d57cf9230 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -270,7 +270,7 @@ def reduce_scatterv( input_tensor = input_.movedim(0, dim).contiguous() if sizes is not None: - assert len(sizes) == world_size + assert len(sizes) == world_size, f"{len(sizes)} == {world_size}" assert input_tensor.shape[0] == sum(sizes) chunk_size = sizes[self.rank_in_group] else: diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py index f89f375709e3..73083db139fa 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/flashinfer_nvlink_two_sided.py @@ -185,16 +185,21 @@ def flashinfer_alltoall_dispatch( ep_size, ) - x_sf = MnnvlMoe.mnnvl_moe_alltoallv( - x_sf, - alltoall_info, - all2all_manager.workspace_tensor, # type: ignore[attr-defined] - ep_rank, - ep_size, - ) + if x_sf is not None: + x_sf = MnnvlMoe.mnnvl_moe_alltoallv( + x_sf, + alltoall_info, + all2all_manager.workspace_tensor, # type: ignore[attr-defined] + ep_rank, + ep_size, + ) # Swizzle after the A2A if MoE kernel expects swizzled scales. - if quant_config.quant_dtype == "nvfp4" and quant_config.is_scale_swizzled: + if ( + x_sf is not None + and quant_config.quant_dtype == "nvfp4" + and quant_config.is_scale_swizzled + ): if x_sf.element_size() == 1: x_sf = x_sf.view(torch.uint8) x_sf = nvfp4_block_scale_interleave(x_sf)