From dc45855f9598a8b6f19377f1dcca66c081070622 Mon Sep 17 00:00:00 2001 From: Zhewen Li Date: Tue, 14 Apr 2026 19:07:54 -0700 Subject: [PATCH] Revert "[MoE Refactor] Remove MoE DP chunking (#39107)" This reverts commit e1e318af010b4f92d39324d03231cfd409766bf9. --- .buildkite/test_areas/kernels.yaml | 18 +- .../moe/modular_kernel_tools/common.py | 8 +- .../modular_kernel_tools/parallel_utils.py | 2 +- tests/kernels/moe/test_deepep_deepgemm_moe.py | 9 +- tests/kernels/moe/test_flashinfer.py | 3 - tests/kernels/moe/test_flashinfer_moe.py | 2 - tests/kernels/moe/test_moe.py | 3 +- tests/kernels/moe/test_moe_layer.py | 23 +- tests/kernels/moe/utils.py | 1 - vllm/config/parallel.py | 12 - vllm/config/scheduler.py | 1 - vllm/engine/arg_utils.py | 23 +- vllm/envs.py | 11 + vllm/forward_context.py | 61 ++++- .../model_executor/layers/fused_moe/config.py | 14 +- vllm/model_executor/layers/fused_moe/layer.py | 3 +- .../fused_moe/runner/chunking_moe_runner.py | 243 ++++++++++++++++++ .../fused_moe/runner/moe_runner_factory.py | 8 +- .../layers/fused_moe/runner/shared_experts.py | 16 ++ .../layers/quantization/quark/quark_moe.py | 4 +- 20 files changed, 389 insertions(+), 76 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/runner/chunking_moe_runner.py diff --git a/.buildkite/test_areas/kernels.yaml b/.buildkite/test_areas/kernels.yaml index 8b9765130aee..e7f73b496e16 100644 --- a/.buildkite/test_areas/kernels.yaml +++ b/.buildkite/test_areas/kernels.yaml @@ -200,14 +200,7 @@ steps: timeout_in_minutes: 90 device: h100 num_devices: 2 - source_file_dependencies: - - csrc/quantization/cutlass_w8a8/moe/ - - csrc/moe/ - - tests/kernels/moe - - vllm/model_executor/layers/fused_moe/ - - vllm/model_executor/layers/quantization/ - - vllm/distributed/device_communicators/ - - vllm/config + optional: true commands: - pytest -v -s kernels/moe/test_moe_layer.py @@ -216,13 +209,6 @@ steps: timeout_in_minutes: 90 device: b200 num_devices: 2 - source_file_dependencies: - - csrc/quantization/cutlass_w8a8/moe/ - - csrc/moe/ - - tests/kernels/moe - - vllm/model_executor/layers/fused_moe/ - - vllm/model_executor/layers/quantization/ - - vllm/distributed/device_communicators/ - - vllm/config + optional: true commands: - pytest -v -s kernels/moe/test_moe_layer.py diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index f07d4c75e752..a6f3bc35a0b6 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -46,7 +46,6 @@ has_deep_gemm, has_mori, ) -from vllm.utils.math_utils import next_power_of_2 from .mk_objects import ( TestMoEQuantConfig, @@ -605,6 +604,13 @@ def make_modular_kernel( vllm_config: VllmConfig, quant_config: FusedMoEQuantConfig, ) -> mk.FusedMoEKernel: + def next_power_of_2(x): + import math + + if x == 0: + return 1 + return 2 ** math.ceil(math.log2(x)) + # make moe config moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( tp_size_=get_tensor_model_parallel_world_size(), diff --git a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py index 95004fa0ab4d..10a226bcd977 100644 --- a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py +++ b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py @@ -126,7 +126,7 @@ def parallel_launch_with_config( world_size: int, worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig, Any, P], None], vllm_config: VllmConfig, - env_dict: dict[Any, Any] | None, + env_dict: dict[Any, Any], *args: P.args, **kwargs: P.kwargs, ) -> None: diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 6caa9d8c0687..6bde13e0ecf1 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -29,7 +29,6 @@ is_deep_gemm_supported, ) from vllm.utils.import_utils import has_deep_ep, has_deep_gemm -from vllm.utils.math_utils import next_power_of_2 from vllm.utils.torch_utils import set_random_seed from vllm.v1.worker.workspace import init_workspace_manager @@ -85,6 +84,14 @@ def with_dp_metadata(M: int, world_size: int): yield +def next_power_of_2(x): + import math + + if x == 0: + return 1 + return 2 ** math.ceil(math.log2(x)) + + def make_block_quant_fp8_weights( e: int, n: int, diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index dad25bc31959..db499b68843f 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -32,7 +32,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8 from vllm.model_executor.models.llama4 import Llama4MoE from vllm.platforms import current_platform -from vllm.utils.math_utils import next_power_of_2 from vllm.utils.torch_utils import set_random_seed try: @@ -175,7 +174,6 @@ def make_moe_tensors_8bit( routing_method=layer.routing_method_type, activation=activation, device=w13_quantized.device, - max_num_tokens=next_power_of_2(m), ) return TestData( @@ -350,7 +348,6 @@ def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig: in_dtype=torch.bfloat16, is_act_and_mul=activation.is_gated, routing_method=RoutingMethodType.TopK, - max_num_tokens=next_power_of_2(m), ) kernel = mk.FusedMoEKernel( diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index d116a96f58bc..a3fb474f1517 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -29,7 +29,6 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -from vllm.utils.math_utils import next_power_of_2 from vllm.utils.torch_utils import set_random_seed if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability( @@ -106,7 +105,6 @@ def test_flashinfer_fp4_moe_no_graph( in_dtype=dtype, is_act_and_mul=is_gated_act, routing_method=RoutingMethodType.TopK, - max_num_tokens=next_power_of_2(m), ) flashinfer_experts = FusedMoEKernel( diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 35b21320f826..def545977e9c 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -59,7 +59,6 @@ from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types -from vllm.utils.math_utils import next_power_of_2 from vllm.utils.torch_utils import set_random_seed from vllm.v1.worker.workspace import init_workspace_manager @@ -1677,7 +1676,7 @@ def test_unquantized_bf16_flashinfer_trtllm_backend( in_dtype=dtype, is_act_and_mul=True, routing_method=RoutingMethodType.Renormalize, - max_num_tokens=next_power_of_2(m), + max_num_tokens=m, ) with set_current_vllm_config(vllm_config): diff --git a/tests/kernels/moe/test_moe_layer.py b/tests/kernels/moe/test_moe_layer.py index aa2948b8e989..7b31edd3360d 100644 --- a/tests/kernels/moe/test_moe_layer.py +++ b/tests/kernels/moe/test_moe_layer.py @@ -26,7 +26,6 @@ from vllm.config import ( CompilationConfig, ParallelConfig, - SchedulerConfig, VllmConfig, set_current_vllm_config, ) @@ -54,7 +53,7 @@ has_flashinfer_nvlink_two_sided, ) from vllm.utils.import_utils import has_deep_ep, has_mori, has_nixl_ep -from vllm.utils.math_utils import cdiv, next_power_of_2 +from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import set_random_seed from vllm.v1.worker.workspace import ( init_workspace_manager, @@ -66,9 +65,8 @@ SHAPE_COMBOS = [ (1, 128, 256), (32, 1024, 512), - (222, 2048, 2048), + (222, 2048, 2048), # should be big enough to exercise DP chunking ] -MAX_M = max([x[0] for x in SHAPE_COMBOS]) NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] @@ -114,7 +112,7 @@ "mori": {None, "fp8", "modelopt_fp8"}, "flashinfer_nvlink_two_sided": {None, "modelopt_fp8", "modelopt_fp4"}, "flashinfer_nvlink_one_sided": {None, "modelopt_fp8", "modelopt_fp4"}, - "deepep_low_latency": {None, "modelopt_fp8", "modelopt_fp4"}, + "deepep_low_latency": {None, "fp8", "modelopt_fp8", "modelopt_fp4"}, "deepep_high_throughput": {None, "fp8", "modelopt_fp8", "modelopt_fp4"}, "nixl_ep": {None, "fp8", "modelopt_fp8"}, } @@ -365,9 +363,9 @@ def is_valid_config(config: MoETestConfig) -> tuple[bool, str | None]: ) # routed_input_transform + quantization + high hidden dimensions - # TODO: Disable >= 2048 for now due to insane errors. + # TODO: Disable >= 2048 w/fp8 + deepep LL for now due to insane errors. if ( - config.use_routed_input_transform + (config.use_routed_input_transform or config.backend == "deepep_low_latency") and config.quantization is not None and config.k >= 2048 ): @@ -1665,6 +1663,9 @@ def test_moe_layer( verbosity = pytestconfig.getoption("verbose") + test_env = dict() + test_env["VLLM_MOE_DP_CHUNK_SIZE"] = "128" + monkeypatch.setenv("VLLM_MOE_DP_CHUNK_SIZE", "128") if os.environ.get("VLLM_LOGGING_LEVEL") is None: monkeypatch.setenv("VLLM_LOGGING_LEVEL", "ERROR") @@ -1689,11 +1690,7 @@ def test_moe_layer( compilation_config.pass_config.fuse_allreduce_rms = False # for now vllm_config = VllmConfig( - parallel_config=parallel_config, - compilation_config=compilation_config, - scheduler_config=SchedulerConfig.default_factory( - max_num_batched_tokens=next_power_of_2(MAX_M) - ), + parallel_config=parallel_config, compilation_config=compilation_config ) test_configs = generate_valid_test_configs( @@ -1721,7 +1718,7 @@ def test_moe_layer( world_size, _parallel_worker, vllm_config, - None, + test_env, test_configs, verbosity, ) diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index c9c5c97b26d5..8763ad683517 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -69,7 +69,6 @@ def make_dummy_moe_config( in_dtype=in_dtype, device="cuda", routing_method=RoutingMethodType.TopK, - max_num_tokens=512, ) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index a42b8422ef3c..0b5c97ba0639 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -622,18 +622,6 @@ def use_sequence_parallel_moe(self) -> bool: and self.data_parallel_size > 1 ) - @property - def use_batched_dp_moe(self) -> bool: - return ( - self.all2all_backend - in ( - "deepep_low_latency", - "nixl_ep", - ) - and self.enable_expert_parallel - and self.data_parallel_size > 1 - ) - @property def node_rank_within_dp(self) -> int: return self.node_rank % self.nnodes_within_dp diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index b9a48144ded4..3cd99bb082eb 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -40,7 +40,6 @@ class SchedulerConfig: """ DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048 - DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP: ClassVar[int] = 256 DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128 runner_type: RunnerType = "generate" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 03a460fbe95a..9d4a12c343e9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1604,6 +1604,9 @@ def create_engine_config( self._check_feature_supported() self._set_default_chunked_prefill_and_prefix_caching_args(model_config) + self._set_default_max_num_seqs_and_batched_tokens_args( + usage_context, model_config + ) self._set_default_reasoning_config_args() sliding_window: int | None = None if not is_interleaved(model_config.hf_text_config): @@ -1857,12 +1860,6 @@ def create_engine_config( target_parallel_config=parallel_config, ) - self._set_default_max_num_seqs_and_batched_tokens_args( - usage_context, - model_config, - parallel_config, - ) - assert self.max_num_batched_tokens is not None, ( "max_num_batched_tokens must be set by this point" ) @@ -2278,7 +2275,6 @@ def _set_default_max_num_seqs_and_batched_tokens_args( self, usage_context: UsageContext | None, model_config: ModelConfig, - parallel_config: ParallelConfig, ): world_size = self.pipeline_parallel_size * self.tensor_parallel_size ( @@ -2290,15 +2286,10 @@ def _set_default_max_num_seqs_and_batched_tokens_args( orig_max_num_seqs = self.max_num_seqs if self.max_num_batched_tokens is None: - if parallel_config.use_batched_dp_moe: - self.max_num_batched_tokens = ( - SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP - ) - else: - self.max_num_batched_tokens = default_max_num_batched_tokens.get( - usage_context, - SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS, - ) + self.max_num_batched_tokens = default_max_num_batched_tokens.get( + usage_context, + SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS, + ) if self.max_num_seqs is None: self.max_num_seqs = default_max_num_seqs.get( diff --git a/vllm/envs.py b/vllm/envs.py index 8ed1d33434cb..ee9d006aa987 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -146,6 +146,8 @@ VLLM_ENABLE_PREGRAD_PASSES: bool = False VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 + VLLM_MOE_DP_CHUNK_SIZE: int = 256 + VLLM_ENABLE_MOE_DP_CHUNK: bool = True VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict" VLLM_RAY_EXTRA_ENV_VAR_PREFIXES_TO_COPY: str = "" @@ -1138,6 +1140,15 @@ def _get_or_set_default() -> str: "VLLM_DP_MASTER_IP": lambda: os.getenv("VLLM_DP_MASTER_IP", "127.0.0.1"), # Port of the master node in the data parallel setting "VLLM_DP_MASTER_PORT": lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")), + # In the context of executing MoE models with Data-Parallel, Expert-Parallel + # and Batched All-to-All dispatch/combine kernels, VLLM_MOE_DP_CHUNK_SIZE + # dictates the quantum of tokens that can be dispatched from a DP + # rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE + # units. + "VLLM_MOE_DP_CHUNK_SIZE": lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")), + "VLLM_ENABLE_MOE_DP_CHUNK": lambda: bool( + int(os.getenv("VLLM_ENABLE_MOE_DP_CHUNK", "1")) + ), # Randomize inputs during dummy runs when using Data Parallel "VLLM_RANDOMIZE_DP_DUMMY_INPUTS": lambda: ( os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1" diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 537a28a42526..fa568c33f36d 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -70,8 +70,27 @@ def _compute_sp_num_tokens( return sp_tokens.tolist() +def _compute_chunked_local_num_tokens( + num_tokens_across_dp_cpu: torch.Tensor, + sequence_parallel_size: int, + max_num_tokens: int, + chunk_idx: int, +) -> list[int]: + sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu, sequence_parallel_size) + sp_size = len(sp_tokens) + + local_size = [-1] * sp_size + for i in range(sp_size): + # Take into account sharding if MoE activation is sequence parallel. + local_size[i] = min(max_num_tokens, sp_tokens[i] - (max_num_tokens * chunk_idx)) + if local_size[i] <= 0: + local_size[i] = 1 # ensure lockstep even if done + return local_size + + @dataclass class DPMetadata: + max_tokens_across_dp_cpu: torch.Tensor num_tokens_across_dp_cpu: torch.Tensor # NOTE: local_sizes should only be set by the chunked_sizes context manager @@ -94,7 +113,47 @@ def make( assert num_tokens_across_dp_cpu[dp_rank] == batchsize, ( f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}" ) - return DPMetadata(num_tokens_across_dp_cpu) + max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu) + return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu) + + @contextmanager + def chunked_sizes( + self, sequence_parallel_size: int, max_chunk_size_per_rank: int, chunk_idx: int + ): + """ + Context manager to compute and temporarily set the per-rank local token + sizes for a specific chunk during chunked forward execution. + + This is necessary to ensure each DP (data parallel) rank processes its + designated portion of tokens in lockstep with others, even when the + token counts are uneven or some ranks have completed their input early. + + For chunked execution, we break up the total tokens on each rank into + multiple chunks (of at most `max_chunk_size_per_rank`), and for a given + `chunk_idx`, this context manager sets `self.local_sizes` to the number + of tokens to process in that chunk on each rank. + + `self.local_sizes` is only valid inside the context. + + Args: + sequence_parallel_size: When Attn is TP and MoE layers are EP, + we use SP between the layers to avoid + redundant ops. We need this value to + compute the chunked sizes. + max_chunk_size_per_rank: The max number of tokens each rank is + allowed to process in this chunk. + chunk_idx: The index of the chunk to compute sizes for. + """ + self.local_sizes = _compute_chunked_local_num_tokens( + self.num_tokens_across_dp_cpu, + sequence_parallel_size, + max_chunk_size_per_rank, + chunk_idx, + ) + try: + yield self.local_sizes + finally: + self.local_sizes = None @contextmanager def sp_local_sizes(self, sequence_parallel_size: int): diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index a3b941dfa451..0c93dc6a76f9 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -6,7 +6,8 @@ import torch -from vllm.config import ParallelConfig, SchedulerConfig +import vllm.envs as envs +from vllm.config import ParallelConfig from vllm.distributed import get_dp_group, get_pcp_group, get_tensor_model_parallel_rank from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -936,6 +937,15 @@ class FusedMoEParallelConfig: all2all_backend: str # all2all backend for MoE communication enable_eplb: bool # whether to enable expert load balancing + @property + def use_dp_chunking(self) -> bool: + return ( + self.use_deepep_ll_kernels + or self.use_mori_kernels + or self.use_fi_nvl_two_sided_kernels + or self.use_nixl_ep_kernels + ) and envs.VLLM_ENABLE_MOE_DP_CHUNK + @property def is_sequence_parallel(self) -> bool: return self.sp_size > 1 @@ -1174,7 +1184,7 @@ class FusedMoEConfig: intermediate_size_per_partition_unpadded: int | None = None moe_backend: str = "auto" - max_num_tokens: int = SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP + max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE has_bias: bool = False is_act_and_mul: bool = True is_lora_enabled: bool = False diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 190a9cc3b5d7..ca5467f042a5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -8,6 +8,7 @@ import torch from torch.nn.parameter import UninitializedParameter +import vllm.envs as envs from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, get_current_vllm_config from vllm.config.parallel import ExpertPlacementStrategy @@ -481,7 +482,7 @@ def __init__( in_dtype=moe_in_dtype, moe_backend=vllm_config.kernel_config.moe_backend, router_logits_dtype=router_logits_dtype, - max_num_tokens=vllm_config.scheduler_config.max_num_batched_tokens, + max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, has_bias=has_bias, is_act_and_mul=is_act_and_mul, is_lora_enabled=vllm_config.lora_config is not None, diff --git a/vllm/model_executor/layers/fused_moe/runner/chunking_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/chunking_moe_runner.py new file mode 100644 index 000000000000..a8c75486d711 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/runner/chunking_moe_runner.py @@ -0,0 +1,243 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.forward_context import ( + get_forward_context, +) +from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, +) +from vllm.model_executor.layers.fused_moe.runner.moe_runner_base import MoERunnerBase +from vllm.model_executor.layers.fused_moe.runner.shared_experts import ( + SharedExperts, +) +from vllm.utils.math_utils import cdiv +from vllm.v1.worker.ubatching import dbo_current_ubatch_id +from vllm.v1.worker.workspace import current_workspace_manager + + +class ChunkingMoERunner(MoERunnerBase): + """ + MoE runner wrapper that adds chunked processing to any MoERunnerBase. + + This runner wraps an inner MoERunnerBase and overrides _forward_impl to + process large batches by breaking them into smaller chunks. Each chunk + is delegated to the inner runner's _forward_impl, making chunking + composable with any runner implementation. + + All MoERunnerBase state (moe_config, router, quant_method, etc.) is + transparently delegated to the inner runner via __getattr__. + ChunkingMoERunner only owns chunking-specific state: the pre-allocated + workspace buffers and the reduce_results override. + + Key behaviors: + - Pre-allocates workspace tensors for CUDA graph compatibility + - Processes chunks via inner._forward_impl per chunk + - Never reduces results (reduce_results always returns False) + """ + + def __init__(self, inner: MoERunnerBase): + # Assert that _maybe_dispatch/_maybe_combine will be nops. + assert inner.moe_config.pcp_size == 1 + + # Skip MoERunnerBase.__init__ — all state is delegated to inner + # via __getattr__. Only chunking-specific state lives here. + self._inner = inner + + # Pre-allocated staging buffers. These need to exist ahead of time + # due to CUDA graph construction needing fixed buffer addresses. + self.batched_hidden_states, self.batched_router_logits = ( + self._init_dp_chunking() + ) + + def __getattr__(self, name): + # Delegate attribute access to the inner runner. This is only + # called when normal lookup (instance __dict__, class MRO) fails, + # so ChunkingMoERunner's own attributes and methods take priority. + return getattr(self._inner, name) + + @property + def shared_experts(self) -> SharedExperts | None: + return self._inner.shared_experts + + # TODO(bnell): temporary hack, do not call this method. + def _replace_quant_method(self, quant_method: FusedMoEMethodBase): + self._inner._replace_quant_method(quant_method) + self.quant_method = quant_method + + def is_internal_router(self) -> bool: + return self._inner.gate is not None + + # Reducing results when chunking is handled by the MK finalize operations + # when DP chunking is enabled.. + # This will be removed by #35949 + @property + def reduce_results(self) -> bool: + return False + + def _init_dp_chunking(self) -> list[torch.Tensor]: + states_shape: tuple[int, ...] + logits_shape: tuple[int, ...] + + moe = self.moe_config + + if self.enable_dbo: + states_shape = (2, moe.max_num_tokens, self.moe_config.hidden_dim) + logits_shape = (2, moe.max_num_tokens, self.moe_config.num_logical_experts) + else: + states_shape = (moe.max_num_tokens, self.moe_config.hidden_dim) + logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts) + + # Does this need some kind of profiling run check like modular_kernel.py? + return current_workspace_manager().get_simultaneous( + (states_shape, moe.in_dtype), + (logits_shape, moe.router_logits_dtype), + ) + + def _allocate_dp_chunking_outputs( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> tuple[torch.Tensor | None, torch.Tensor]: + # Assert the inputs are of the proper type and shape. + assert self.batched_hidden_states is not None + assert self.batched_router_logits is not None + + assert self.batched_hidden_states.dtype == hidden_states.dtype, ( + f"{self.batched_hidden_states.dtype} == {hidden_states.dtype}" + ) + assert self.batched_router_logits.dtype == router_logits.dtype, ( + f"{self.batched_router_logits.dtype} == {router_logits.dtype}" + ) + + # Check size compatibility. + assert self.batched_hidden_states.size(-1) == hidden_states.size(-1) + assert self.batched_router_logits.size(-1) == router_logits.size(-1) + + final_fused_hidden_states = torch.empty_like(hidden_states) + if self.shared_experts is not None: + if shared_experts_input is not None: + final_shared_hidden_states = torch.empty_like(shared_experts_input) + else: + final_shared_hidden_states = torch.empty_like(hidden_states) + else: + final_shared_hidden_states = None + + return final_shared_hidden_states, final_fused_hidden_states + + def _slice_and_copy_input( + self, + out_slice: torch.Tensor, + orig: torch.Tensor | None, + start: int, + end: int, + ) -> torch.Tensor: + assert orig is not None + slice_size = end - start + orig_slice = orig[start:end, :] + if self.enable_dbo: + assert out_slice.dim() == 3 + batch_buffer_idx = dbo_current_ubatch_id() + out_slice = out_slice[batch_buffer_idx, :] + + assert out_slice.size(0) >= slice_size + out_slice = out_slice[:slice_size, :] + out_slice.copy_(orig_slice, non_blocking=True) + return out_slice + + def _forward_impl( + self, + layer: torch.nn.Module, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + final_shared_hidden_states, final_fused_hidden_states = ( + self._allocate_dp_chunking_outputs( + hidden_states, router_logits, shared_experts_input + ) + ) + + ctx = get_forward_context() + # flashinfer_cutlass_kernels can handle: optional DP + TP/EP + max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu + moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens + + # If the input to the MoE is sequence parallel then divide by sp_size + # to find the maximum number of tokens for any individual dispatcher. + if self.moe_config.is_sequence_parallel: + max_tokens_across_dispatchers = cdiv( + max_tokens_across_dispatchers, self.moe_config.sp_size + ) + + num_tokens = hidden_states.size(0) + for chunk_idx, chunk_start_ in enumerate( + range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank) + ): + chunk_start = chunk_start_ + chunk_end = min( + chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers + ) + # clamp start and end + chunk_start = min(chunk_start, num_tokens - 1) + chunk_end = min(chunk_end, num_tokens) + chunk_sizes = ctx.dp_metadata.chunked_sizes( + self.moe_config.sp_size, moe_dp_chunk_size_per_rank, chunk_idx + ) + with chunk_sizes: + hidden_states_chunk = self._slice_and_copy_input( + self.batched_hidden_states, + hidden_states, + chunk_start, + chunk_end, + ) + + router_logits_chunk = self._slice_and_copy_input( + self.batched_router_logits, + router_logits, + chunk_start, + chunk_end, + ) + + shared_experts_input_chunk = ( + shared_experts_input[chunk_start:chunk_end, :] + if shared_experts_input is not None + else None + ) + + # Delegate per-chunk computation to the inner runner. + chunk_result = self._inner._forward_impl( + layer=layer, + hidden_states=hidden_states_chunk, + router_logits=router_logits_chunk, + shared_experts_input=shared_experts_input_chunk, + ) + + # Store outputs + # TODO(bnell): document when chunk_start >= num_tokens + if chunk_start < num_tokens: + if self.shared_experts is not None: + assert isinstance(chunk_result, tuple) + shared_output_chunk, hidden_states_chunk = chunk_result + final_fused_hidden_states[chunk_start:chunk_end, :].copy_( + hidden_states_chunk, non_blocking=True + ) + assert shared_output_chunk is not None + assert final_shared_hidden_states is not None + final_shared_hidden_states[chunk_start:chunk_end, :].copy_( + shared_output_chunk, non_blocking=True + ) + else: + assert isinstance(chunk_result, torch.Tensor) + final_fused_hidden_states[chunk_start:chunk_end, :].copy_( + chunk_result, non_blocking=True + ) + + if self.shared_experts is None: + return final_fused_hidden_states + else: + assert final_shared_hidden_states is not None + return (final_shared_hidden_states, final_fused_hidden_states) diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner_factory.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner_factory.py index 2143fa3ce08a..da5068fa0912 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner_factory.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner_factory.py @@ -12,6 +12,9 @@ from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( FusedMoERouter, ) +from vllm.model_executor.layers.fused_moe.runner.chunking_moe_runner import ( + ChunkingMoERunner, +) from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import ( DefaultMoERunner, ) @@ -32,7 +35,7 @@ def create_moe_runner( reduce_results: bool, enable_dbo: bool, ) -> MoERunner: - return DefaultMoERunner( + runner = DefaultMoERunner( layer_name, moe_config, router, @@ -43,3 +46,6 @@ def create_moe_runner( reduce_results, enable_dbo, ) + if moe_config.moe_parallel_config.use_dp_chunking: + return ChunkingMoERunner(runner) + return runner diff --git a/vllm/model_executor/layers/fused_moe/runner/shared_experts.py b/vllm/model_executor/layers/fused_moe/runner/shared_experts.py index 827a6e6bd3ea..f5b07a6a51a4 100644 --- a/vllm/model_executor/layers/fused_moe/runner/shared_experts.py +++ b/vllm/model_executor/layers/fused_moe/runner/shared_experts.py @@ -69,6 +69,7 @@ def __init__( self._moe_config = moe_config self._quant_method = quant_method self._reduce_results = reduce_results + self._use_dp_chunking = moe_config.moe_parallel_config.use_dp_chunking # Allow disabling of the separate shared experts stream for # debug purposes. @@ -86,6 +87,20 @@ def __init__( "Enabled separate cuda stream for MoE shared_experts", scope="local" ) + @property + def _use_external_experts(self) -> bool: + if self._use_dp_chunking: + return False + + # Disable shared expert overlap if: + # - we are using eplb with non-default backend, because of correctness issues + # - we are using flashinfer with DP, since there nothing to gain + backend = self._moe_config.moe_parallel_config.all2all_backend + return ( + self._moe_config.moe_parallel_config.enable_eplb + and backend != "allgather_reducescatter" + ) or self._moe_config.moe_parallel_config.use_fi_nvl_two_sided_kernels + def _determine_shared_experts_order( self, hidden_states: torch.Tensor, @@ -95,6 +110,7 @@ def _determine_shared_experts_order( should_run_shared_in_aux_stream = ( current_platform.is_cuda() + and not self._use_dp_chunking and self._stream is not None and hidden_states.shape[0] <= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 2bab66709ddd..ff1e2700cdb4 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -1502,9 +1502,9 @@ def process_weights_after_loading(self, layer): layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False) # FIXME warp need to be adjusted based on batch size - # only apply to batched mode + # only apply to batched mode if self.moe.use_ep: - num_warps = 4 if self.moe.max_num_tokens <= 512 else 8 + num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 else: num_warps = 8