diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f25db40a4efa..c81e64e7b05c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -957,11 +957,11 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, c_strides, per_act_token, per_out_ch) -def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, - a_scales: torch.Tensor, b_scales: torch.Tensor, - alphas: torch.Tensor, problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, sf_offsets: torch.Tensor, - out_dtype: torch.dtype, device: torch.device): +def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, + b_tensors: torch.Tensor, a_scales: torch.Tensor, + b_scales: torch.Tensor, alphas: torch.Tensor, + problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, sf_offsets: torch.Tensor): """ An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs the gemms for each combination based on the specified problem sizes. @@ -978,14 +978,10 @@ def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped MMs used in the fused MoE operation. """ - m_topk = a_tensors.shape[0] - n = b_tensors.shape[1] - c_shape = (m_topk, n) - c = torch.empty(c_shape, device=device, dtype=out_dtype) - torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales, - b_scales, alphas, problem_sizes, - expert_offsets, sf_offsets) - return c.to(out_dtype) + return torch.ops._C.cutlass_fp4_group_mm(out_tensors, a_tensors, b_tensors, + a_scales, b_scales, alphas, + problem_sizes, expert_offsets, + sf_offsets) # aqlm diff --git a/vllm/envs.py b/vllm/envs.py index 502978c76851..df6e9b11a094 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -121,6 +121,7 @@ VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None VLLM_USE_DEEP_GEMM: bool = False + VLLM_USE_FLASHINFER_MOE: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False @@ -868,6 +869,10 @@ def get_vllm_port() -> Optional[int]: "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), + # Allow use of FlashInfer CUTLASS kernels for fused moe ops. + "VLLM_USE_FLASHINFER_MOE": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE", "0"))), + # Control the cache sized used by the xgrammar compiler. The default # of 512 MB should be enough for roughly 1000 JSON schemas. # It can be changed with this variable if needed for some reason. diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index e61d350388ea..628aa5c7bb06 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import torch @@ -255,28 +255,18 @@ def workspace_shapes( output = (num_experts, max_num_tokens * num_dispatchers, K) return (workspace13, workspace2, output, a.dtype) - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - ): + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 1a63b3237343..fc30e84e6656 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import torch @@ -142,7 +142,8 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool): + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): experts = (self.batched_deep_gemm_experts if self.allow_deep_gemm else self.batched_triton_experts) assert experts is not None @@ -150,4 +151,4 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, activation, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, workspace2, expert_tokens_meta, - apply_router_weight_on_input) + apply_router_weight_on_input, extra_expert_args) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index def1c2b4556b..9bebb6a65fce 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.utils import cdiv +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe logger = init_logger(__name__) @@ -188,6 +189,11 @@ def use_deepep_ll_kernels(self): return (self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") + @property + def use_flashinfer_cutlass_kernels(self): + return (envs.VLLM_USE_FLASHINFER_MOE + and has_flashinfer_cutlass_fused_moe()) + @staticmethod def make(tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": @@ -392,6 +398,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_flashinfer_cutlass_kernels(self): + return self.moe_parallel_config.use_flashinfer_cutlass_kernels + @staticmethod def make( num_experts: int, @@ -435,6 +445,12 @@ def make( if quant_dtype is None and isinstance(quant_config, Fp8Config): quant_dtype = torch.float8_e4m3fn + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4Config) + if quant_dtype is None and isinstance(quant_config, + ModelOptNvFp4Config): + quant_dtype = torch.uint8 + if weight_quant is not None: per_out_ch_quant = ( weight_quant.strategy == QuantizationStrategy.CHANNEL) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 978c53223625..484e36a2cf0b 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ CUTLASS based Fused MoE kernels.""" -from typing import Callable, Optional +from typing import Any, Callable, Optional import torch @@ -14,7 +14,8 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, - _resize_cache) + _resize_cache, + extract_required_args) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -298,7 +299,8 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool): + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" @@ -431,23 +433,28 @@ def cutlass_moe_fp8( FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max -def cutlass_moe_fp4(a: torch.Tensor, - a1_gscale: torch.Tensor, - w1_fp4: torch.Tensor, - w1_blockscale: torch.Tensor, - w1_alphas: torch.Tensor, - a2_gscale: torch.Tensor, - w2_fp4: torch.Tensor, - w2_blockscale: torch.Tensor, - w2_alphas: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - m: int, - n: int, - k: int, - e: int, - device: torch.device, - apply_router_weight_on_input: bool = False): +def run_cutlass_moe_fp4( + output: torch.Tensor, + a: torch.Tensor, + a1_gscale: torch.Tensor, + w1_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w1_alphas: torch.Tensor, + a2_gscale: torch.Tensor, + w2_fp4: torch.Tensor, + w2_blockscale: torch.Tensor, + w2_alphas: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, + apply_router_weight_on_input: bool = False, +) -> None: """ MoE implementation for FP4 Inputs @@ -487,16 +494,16 @@ def cutlass_moe_fp4(a: torch.Tensor, assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match", " between weights.") - assert (k_a // 2 == half_k_w1 + assert (k_a == half_k_w1 * 2 and k == k_w2), ("Hidden size mismatch between a, w1 and w2") - assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in " - "expected `n`") + assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in " + "expected `n`") assert (m == m_a), "input shape mismatch" assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" assert (topk_weights.size(0) == m and topk_ids.size(0) == m), ("topk must be provided for each row of a") - + topk = topk_ids.size(1) out_dtype = a.dtype num_topk = topk_ids.size(1) @@ -523,7 +530,6 @@ def cutlass_moe_fp4(a: torch.Tensor, blockscale_offsets) a = ops.shuffle_rows(a, a_map) - rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant( a, a1_gscale, @@ -531,34 +537,220 @@ def cutlass_moe_fp4(a: torch.Tensor, blockscale_offsets, num_topk, ) - - c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale, - w1_blockscale, w1_alphas, problem_sizes1, - expert_offsets[:-1], blockscale_offsets[:-1], - out_dtype, device) + c1 = _resize_cache(workspace13, (m * topk, n * 2)) + c2 = _resize_cache(workspace2, (m * topk, n)) + c3 = _resize_cache(workspace13, (m * topk, k)) + ops.cutlass_fp4_moe_mm(c1, rep_a_fp4, w1_fp4, rep_a_blockscale, + w1_blockscale, w1_alphas, problem_sizes1, + expert_offsets[:-1], blockscale_offsets[:-1]) del rep_a_fp4, rep_a_blockscale - # hidden size dimension is split to one halfpytho sized tensor. - intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2), - device=device, - dtype=out_dtype) - - torch.ops._C.silu_and_mul(intermediate, c1) - + torch.ops._C.silu_and_mul(c2, c1) int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( - intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk) + c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk) - c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale, - w2_alphas, problem_sizes2, expert_offsets[:-1], - blockscale_offsets[:-1], out_dtype, device) + ops.cutlass_fp4_moe_mm(c3, int_fp4, w2_fp4, int_blockscale, w2_blockscale, + w2_alphas, problem_sizes2, expert_offsets[:-1], + blockscale_offsets[:-1]) del int_fp4, int_blockscale - c2 = ops.shuffle_rows(c2, c_map) + c3 = ops.shuffle_rows(c3, c_map) + + assert output.dtype == out_dtype if not apply_router_weight_on_input: - out = (c2.view(m, num_topk, k) * - topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1) + output.copy_( + (c3.view(m, num_topk, k) * + topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1), + non_blocking=True) else: - out = c2.view(m, num_topk, k).sum(dim=1) - return out.to(dtype=out_dtype) + output.copy_(c3.view(m, num_topk, k).sum(dim=1), non_blocking=True) + return + + +class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + max_experts_per_worker: int, + out_dtype: torch.dtype, + per_act_token_quant: bool, + per_out_ch_quant: bool, + block_shape: Optional[list[int]] = None, + use_batched_format: bool = False, + ): + super().__init__( + FusedMoEQuantConfig( + quant_dtype=torch.uint8, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape, + )) + self.max_experts_per_worker = max_experts_per_worker + self.out_dtype = out_dtype + self.use_batched_format = use_batched_format + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + if self.use_batched_format: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + else: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) + + def supports_expert_map(self) -> bool: + return False + + def supports_chunking(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + workspace1: tuple[int, ...] = () + workspace2: tuple[int, ...] = () + output: tuple[int, ...] = () + if self.use_batched_format: + padded_M = aq.size(1) + workspace1 = (self.max_experts_per_worker, padded_M, max(N, K)) + workspace2 = (self.max_experts_per_worker, padded_M, (N // 2)) + output = (self.max_experts_per_worker, padded_M, K) + else: + workspace1 = (M * topk, max(2 * N, K)) + workspace2 = (M * topk, N) + output = (M, K) + return (workspace1, workspace2, output, + self.out_dtype if self.out_dtype is not None else a.dtype) + + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor, + w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor], + workspace2: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): + required_keys = [ + "g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k", + "e", "device" + ] + (g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e, + device) = extract_required_args(extra_expert_args, required_keys) + run_cutlass_moe_fp4( + output=output, + a=hidden_states, + a1_gscale=a1_gscale, + w1_fp4=w1, + w1_blockscale=w1_scale, + w1_alphas=g1_alphas, + a2_gscale=a2_gscale, + w2_fp4=w2, + w2_blockscale=w2_scale, + w2_alphas=g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + workspace13=workspace13, + workspace2=workspace2, + m=m, + n=n, + k=k, + e=e, + device=device, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + +def cutlass_moe_fp4( + a: torch.Tensor, + w1_fp4: torch.Tensor, + w2_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w2_blockscale: torch.Tensor, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False) -> torch.Tensor: + assert expert_map is None, ("Expert Parallelism / expert_map " + "is currently not supported for " + "ModelOptNvFp4FusedMoE's cutlass_moe_fp4.") + fn = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + CutlassExpertsFp4( + max_experts_per_worker=e, + out_dtype=a.dtype, + per_act_token_quant=False, + per_out_ch_quant=False, + use_batched_format=False, + ), + ) + extra_expert_args = { + 'g1_alphas': g1_alphas, + 'g2_alphas': g2_alphas, + 'a1_gscale': a1_gscale, + 'a2_gscale': a2_gscale, + 'm': m, + 'n': n, + 'k': k, + 'e': e, + 'device': device, + } + + # NVFP4 requires two levels of quantization, which involves computing some + # scaling factors dynamically. This makes it incompatible with the typical + # prepare -> MoE -> finalize pipeline. Move the quantization logic into the + # MoE body. + extra_prepare_args = { + 'skip_quant': True, + } + # Similar reason as above. + extra_finalize_args = { + 'skip_weight_reduce': True, + } + return fn( + hidden_states=a, + w1=w1_fp4, + w2=w2_fp4, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + activation="silu", + global_num_experts=e, + expert_map=None, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + a1_scale=None, + a2_scale=None, + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=extra_expert_args, + extra_prepare_args=extra_prepare_args, + extra_finalize_args=extra_finalize_args, + ) def _valid_cutlass_block_scaled_grouped_gemm( diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index bb462938a392..dee6ad138a81 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -from typing import Optional +from typing import Any, Optional import torch @@ -136,6 +136,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], ): assert self.block_shape is not None assert a1q_scale is not None diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index e10927c4dce5..7016ff34c3a8 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import deep_ep import torch @@ -127,16 +127,12 @@ def _do_dispatch(self, tokens: torch.Tensor, expert_topk_weights) def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -191,7 +187,8 @@ def prepare( def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None: + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: assert self.handle is not None diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index b04f01975849..57871ca250ae 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Any, Optional, Union import deep_ep import torch @@ -111,16 +111,12 @@ def _do_quant( return x, x_scales def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -169,7 +165,8 @@ def prepare( def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None: + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py new file mode 100644 index 000000000000..1753c4f6e238 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) +from vllm.model_executor.layers.fused_moe.utils import extract_required_args +from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe, + has_flashinfer_cutlass_fused_moe) + +logger = init_logger(__name__) + + +def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor) -> bool: + """ + Check if the given problem size is supported by the FlashInfer CUTLASS MoE + kernel. + """ + if not has_flashinfer_cutlass_fused_moe(): + logger.debug_once("FlashInferExperts disabled: " + "flashinfer_cutlass_fused_moe not available.") + return False + # Data type checks + if (w1.dtype != torch.uint8 or w2.dtype != torch.uint8 + or hidden_states.dtype + not in [torch.float32, torch.float16, torch.bfloat16]): + logger.debug_once( + "FlashInferExperts disabled: w1/w2 must be torch.uint8 " + f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be " + f"float32, float16, or bfloat16 (got {hidden_states.dtype}).") + return False + return True + + +class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + use_nvfp4_w4a4: bool = False, + use_fp8_w8a8: bool = False, + use_dp: bool = False, + ep_rank: int = 0, + ep_size: int = 1, + tp_rank: int = 0, + tp_size: int = 1, + num_dispatchers: Optional[int] = None, + use_batched_format: bool = False, + ): + super().__init__( + FusedMoEQuantConfig( + quant_dtype=torch.uint8, + per_act_token_quant=False, + block_shape=None, + )) + self.use_nvfp4_w4a4 = use_nvfp4_w4a4 + self.use_fp8_w8a8 = use_fp8_w8a8 + self.ep_rank = ep_rank + self.ep_size = ep_size + self.tp_rank = tp_rank + self.tp_size = tp_size + self.use_dp = use_dp + assert not use_batched_format or num_dispatchers is not None + self.num_dispatchers = num_dispatchers + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) + + def supports_expert_map(self) -> bool: + return False + + def supports_chunking(self) -> bool: + # This refers to TP chunking; DP chunking is handled separately. + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # We use global_num_experts due to how moe_align_block_size handles + # expert_maps. + """ + Compute the shapes for the temporary and final outputs of the two gemms + and activation in the fused expert function. Since the gemms are + independent, the workspace for the first gemm can be shared with the + workspace for the last gemm. + + Returns a tuple of: + - workspace13 shape tuple: must be large enough to hold the + result of either expert gemm. + - workspace2 shape tuple: must be large enough to hold the + result of the activation function. + - output shape tuple: must be exact size of the final gemm output. + - Workspace type: The dtype to use for the workspace tensors. + - Note: in order for activation chunking to work, the first dimension + of each tuple must be the number of tokens. + """ + assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " + "currently supported.") + aq_m, aq_n = aq.shape + workspace2 = () + output_shape = (aq_m, aq_n * 2) + workspace_dtype = a.dtype + workspace1 = output_shape + # The workspace is determined by `aq`, since it comes after any + # potential communication op and is involved in the expert computation. + return (workspace1, workspace2, output_shape, workspace_dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], # Not used + workspace13: Optional[torch.Tensor], + workspace2: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: Optional[bool], + extra_expert_args: Optional[dict[str, Any]], + ): + assert extra_expert_args is not None, \ + "extra_expert_args must be provided" + required_keys = [ + 'g1_alphas', 'g2_alphas', 'a1_gscale', 'a2_gscale', 'out_dtype' + ] + + g1_alphas, g2_alphas, a1_gscale, a2_gscale, out_dtype = ( + extract_required_args(extra_expert_args, required_keys)) + + # Flashinfer CUTLASS kernel takes scalar global scales, + # min because inv_scale. + assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " + "currently supported.") + + # Ensure w1_scale and w2_scale are not None before calling view + assert w1_scale is not None and w2_scale is not None, ( + "w1_scale and w2_scale must not " + "be None for FlashInferExperts") + + assert not apply_router_weight_on_input + + quant_scales = [ + a1_gscale, + w1_scale.view(torch.int32), + g1_alphas, + a2_gscale, + w2_scale.view(torch.int32), + g2_alphas, + ] + _ = flashinfer_cutlass_fused_moe( + hidden_states, + topk_ids.to(torch.int), + topk_weights, + # FlashInfer API requires weight to be long for nvfp4 + w1.view(torch.long), + w2.view(torch.long), + output_dtype=out_dtype, + quant_scales=quant_scales, + input_sf=a1q_scale, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + ep_size=self.ep_size, + ep_rank=self.ep_rank, + output=output, + ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py new file mode 100644 index 000000000000..49819504c8ec --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import torch + +import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.distributed import get_dp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.utils import ( + extract_required_args, moe_kernel_quantize_input) +from vllm.utils.flashinfer import fp4_swizzle_blockscale + + +def get_local_sizes(local_tokens): + cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu + sizes = [cu_sizes[0].item()] + for i in range(1, len(cu_sizes)): + sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item()) + max_num_tokens = envs.VLLM_MOE_DP_CHUNK_SIZE + sizes_chunked = [max_num_tokens] * len(sizes) + if local_tokens < max_num_tokens: + # When the number of local tokens is less than max_num_tokens, all other + # ranks will also have fewer than max_num_tokens. The remaining tokens + # are accounted for as residual. + sizes_chunked = [x % max_num_tokens for x in sizes] + + return sizes_chunked + + +class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + + def __init__( + self, + quant_dtype: Optional[torch.dtype] = None, + per_channel_quant: bool = False, + block_shape: Optional[list[int]] = None, + num_dispatchers: int = 1, + ): + super().__init__() + self.per_channel_quant = per_channel_quant + self.block_shape = block_shape + self.quant_dtype = quant_dtype + self.num_dispatchers_ = num_dispatchers + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> Optional[int]: + return None + + def topk_indices_dtype(self) -> Optional[torch.dtype]: + return None + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], # Not used + a2_scale: Optional[torch.Tensor], # Not used + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + extra_prepare_args: Optional[dict[str, Any]] + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: + + assert not apply_router_weight_on_input + + (a1_gscale, use_dp, local_tokens) = extract_required_args( + extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens']) + + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + a1_gscale, + quant_config.quant_dtype, + self.per_channel_quant, + self.block_shape, + is_fp4_scale_swizzled=not use_dp, # Swizzling after communication + ) + if use_dp: + topk_weights, topk_ids, a1q, a1q_scale = \ + get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501 + dim=0, + sizes=get_local_sizes(local_tokens)) + a1_m, a1_n = a1q.shape + a1q_scale = fp4_swizzle_blockscale(a1q_scale, a1_m, a1_n * 2) + + return a1q, a1q_scale, None, topk_ids, topk_weights + + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: + + (use_dp, + local_tokens) = extract_required_args(extra_finalize_args, + ['use_dp', 'local_tokens']) + if use_dp: + fused_expert_output = get_dp_group().reduce_scatterv( + fused_expert_output, + dim=0, + sizes=get_local_sizes(local_tokens), + ) + output.copy_(fused_expert_output) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index ab8a281b3901..9a5c85e120cc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused batched MoE kernel.""" -from typing import Optional +from typing import Any, Optional import torch @@ -496,16 +496,12 @@ def num_dispatchers(self) -> int: return self.num_dispatchers_ def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -594,15 +590,11 @@ def prepare( return b_a1, b_a1_scale, expert_tokens_meta, None, None - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank) weight_and_reduce_impl.apply( @@ -706,7 +698,8 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool): + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): assert hidden_states.dim() == 3 assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens @@ -911,7 +904,8 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool): + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): # Check constraints. if self.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ddda87c441b7..459360260073 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1646,6 +1646,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], ): # Check constraints. if self.use_int4_w4a16: diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index da772c111559..93adfb81cb1b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -34,6 +34,7 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx +from vllm.utils.flashinfer import has_flashinfer if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts @@ -45,6 +46,9 @@ from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, DeepEPLLPrepareAndFinalize) + if has_flashinfer(): + from .flashinfer_cutlass_prepare_finalize import ( + FlashInferCutlassMoEPrepareAndFinalize) else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore @@ -99,6 +103,9 @@ def maybe_make_prepare_finalize( prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None + if moe.use_flashinfer_cutlass_kernels: + prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize( + quant_dtype=moe.quant_dtype, ) if moe.use_pplx_kernels: hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( moe.max_num_tokens, @@ -204,6 +211,12 @@ def select_gemm_impl( f"{self.__class__.__name__} must select appropriate gemm " "implementation based on the prepare_finalize") + def maybe_swap_experts_impl( + self, + moe_parallel_config: FusedMoEParallelConfig, + ): + pass + @abstractmethod def apply( self, @@ -780,12 +793,15 @@ def __init__( moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) + if isinstance(self.quant_method, FusedMoEMethodBase): + self.quant_method.maybe_swap_experts_impl(self.moe_parallel_config) # Chunked all2all staging tensor self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None if (self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels): + or self.moe_parallel_config.use_deepep_ll_kernels + or self.moe_parallel_config.use_flashinfer_cutlass_kernels): self.batched_hidden_states = torch.zeros( (moe.max_num_tokens, self.hidden_size), dtype=moe.in_dtype, @@ -837,6 +853,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_flashinfer_cutlass_kernels(self): + return self.moe_parallel_config.use_flashinfer_cutlass_kernels + def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, @@ -1438,9 +1458,9 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): final_hidden_states, non_blocking=True) ctx = get_forward_context() + # flashinfer_cutlass_kernels can handle: optional DP + TP/EP max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens - num_tokens = full_hidden_states.size(0) for chunk_start_ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): @@ -1460,13 +1480,20 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None + # Route to the chunked forward path using the FlashInfer Cutlass kernel + # only when data parallelism (DP) is enabled. + use_flashinfer_cutlass_kernels = ( + self.dp_size > 1 + and self.moe_parallel_config.use_flashinfer_cutlass_kernels) if (self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels): + or self.moe_parallel_config.use_deepep_ll_kernels + or use_flashinfer_cutlass_kernels): return self.forward_impl_chunked(hidden_states, router_logits) do_naive_dispatch_combine: bool = ( self.dp_size > 1 - and not self.moe_parallel_config.use_deepep_ht_kernels) + and not self.moe_parallel_config.use_deepep_ht_kernels + and not self.moe_parallel_config.use_flashinfer_cutlass_kernels) if do_naive_dispatch_combine: hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) @@ -1496,7 +1523,6 @@ def forward_impl(self, hidden_states: torch.Tensor, if do_naive_dispatch_combine: final_hidden_states = get_ep_group().combine(final_hidden_states) - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs. final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index bc4eb3b1932a..6262904e4dca 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from enum import Enum from math import prod -from typing import Optional, final +from typing import Any, Optional, final import torch @@ -150,16 +150,12 @@ class FusedMoEPrepareAndFinalize(ABC): @abstractmethod def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -190,15 +186,11 @@ def prepare( raise NotImplementedError @abstractmethod - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: TopKWeightAndReduce, - ) -> None: + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: """ Perform any combine plus apply weights and perform a reduction on the fused experts output. @@ -376,6 +368,7 @@ def apply( workspace2: torch.Tensor, expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], ): """ This function computes the intermediate result of a Mixture of Experts @@ -460,21 +453,19 @@ def __init__( f"{fused_experts.__class__.__name__}." f"{fused_experts.activation_formats[0]}") - def _do_fused_experts(self, fused_out: Optional[torch.Tensor], - a1: torch.Tensor, a1q: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - activation: str, global_num_experts: int, - local_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - expert_tokens_meta: Optional[ExpertTokensMetadata], - apply_router_weight_on_input: bool) -> torch.Tensor: + def _do_fused_experts( + self, fused_out: Optional[torch.Tensor], a1: torch.Tensor, + a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str, global_num_experts: int, local_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -517,7 +508,8 @@ def _do_fused_experts(self, fused_out: Optional[torch.Tensor], workspace13=workspace13, workspace2=workspace2, expert_tokens_meta=expert_tokens_meta, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=extra_expert_args) return fused_out @@ -541,6 +533,7 @@ def _maybe_chunk_fused_experts( a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], ) -> torch.Tensor: _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) @@ -568,7 +561,8 @@ def _maybe_chunk_fused_experts( a1q_scale=a1q_scale, a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=extra_expert_args) # Chunking required case assert num_chunks > 1 @@ -624,6 +618,15 @@ def slice_expert_tokens_metadata( expert_num_tokens=c_expert_num_tokens, expert_num_tokens_cpu=c_expert_num_tokens_cpu) + m = None + if extra_expert_args is not None and 'm' in extra_expert_args: + m = extra_expert_args.get('m') + + if extra_expert_args is not None: + chunked_extra_expert_args = extra_expert_args + else: + chunked_extra_expert_args = {} + for chunk_idx in range(num_chunks): c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = ( slice_input_tensors(chunk_idx)) @@ -634,6 +637,11 @@ def slice_expert_tokens_metadata( expert_tokens_meta, c_topk_ids, local_num_experts, expert_map) + s = chunk_idx * CHUNK_SIZE + e = min(s + CHUNK_SIZE, M) + + if m is not None: + chunked_extra_expert_args['m'] = e - s self._do_fused_experts( fused_out=slice_output_tensor(chunk_idx), a1=a1, @@ -653,7 +661,8 @@ def slice_expert_tokens_metadata( a1q_scale=c_a1q_scale, a2_scale=c_a2_scale, expert_tokens_meta=c_expert_tokens_meta, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=chunked_extra_expert_args) return fused_out @@ -675,6 +684,9 @@ def forward( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, + extra_expert_args: Optional[dict] = None, + extra_prepare_args: Optional[dict] = None, + extra_finalize_args: Optional[dict] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets @@ -707,6 +719,12 @@ def forward( - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1. + - extra_expert_args (Optional[dict]): Extra keyword arguments to pass to + fused_experts.apply. + - extra_prepare_args (Optional[dict]): Extra keyword arguments to pass + to prepare. + - extra_finalize_args (Optional[dict]): Extra keyword arguments to pass + to finalize. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -730,6 +748,7 @@ def forward( expert_map, apply_router_weight_on_input, self.fused_experts.quant_config, + extra_prepare_args, ) # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. @@ -766,11 +785,13 @@ def forward( a1q_scale=a1q_scale, a2_scale=a2_scale, expert_tokens_meta=expert_tokens_meta, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=extra_expert_args) self.prepare_finalize.finalize( output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input, - self.fused_experts.finalize_weight_and_reduce_impl()) + self.fused_experts.finalize_weight_and_reduce_impl(), + extra_finalize_args) return output diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 5a23a9f1ab09..46931f2dd7c7 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import pplx_kernels as pplx import torch @@ -89,16 +89,12 @@ def num_dispatchers(self) -> int: return self.num_dispatchers_ def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -217,15 +213,11 @@ def prepare( return expert_x, expert_x_scale, expert_tokens_meta, None, None - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: assert isinstance( weight_and_reduce_impl, TopKWeightAndReduceDelegate ), ("Weight application and reduction happens in the combine kernel.") diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index b15c00c44b5d..696c7cdba9a7 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import torch @@ -38,6 +38,7 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, + extra_prepare_args: Optional[dict[str, Any]], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -48,26 +49,33 @@ def prepare( assert topk == 1, \ "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) + + if (extra_prepare_args is not None + and extra_prepare_args.get("skip_quant", True)): + # Skip quantization if explicitly requested + return a1, None, None, None, None + a1q, a1q_scale = moe_kernel_quantize_input( a1, a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) return a1q, a1q_scale, None, None, None - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: - if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): - weight_and_reduce_impl = TopKWeightAndReduceContiguous() - weight_and_reduce_impl.apply( - output=output, - fused_expert_output=fused_expert_output, - topk_weights=topk_weights, - topk_ids=topk_ids, - apply_router_weight_on_input=apply_router_weight_on_input) + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: + if (extra_finalize_args is not None + and extra_finalize_args.get("skip_weight_reduce", True)): + assert output.shape == fused_expert_output.shape + output.copy_(fused_expert_output) + else: + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + weight_and_reduce_impl.apply( + output=output, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 51b95c9aa922..1b31368c79cd 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import torch @@ -119,28 +119,18 @@ def workspace_shapes( local_num_experts, expert_tokens_meta) - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: bool, - ): + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): use_deep_gemm = (self.allow_deep_gemm and (_valid_deep_gemm(hidden_states, w1, w2) or is_blackwell_deep_gemm_used())) @@ -168,4 +158,5 @@ def apply( workspace2, expert_tokens_meta, apply_router_weight_on_input, + extra_expert_args, ) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index c120d964b3cd..966471b5c59b 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from math import prod -from typing import Optional, Union +from typing import Any, Optional, Union import torch @@ -15,6 +15,7 @@ from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv +from vllm.utils.flashinfer import fp4_quantize @triton.jit @@ -98,6 +99,16 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: return x.flatten()[:prod(v)].view(*v) +def _fp4_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + is_sf_swizzled_layout: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + return fp4_quantize(A, + A_scale, + is_sf_swizzled_layout=is_sf_swizzled_layout) + + def _fp8_quantize( A: torch.Tensor, A_scale: Optional[torch.Tensor], @@ -172,11 +183,16 @@ def moe_kernel_quantize_input( quant_dtype: Union[None, torch.dtype, str], per_act_token_quant: bool, block_shape: Optional[list[int]] = None, + is_fp4_scale_swizzled: bool = True, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if quant_dtype == torch.float8_e4m3fn: return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.int8: return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) + elif quant_dtype == torch.uint8: # nvfp4 + return _fp4_quantize(A, + A_scale, + is_sf_swizzled_layout=is_fp4_scale_swizzled) elif quant_dtype == "mxfp4": return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape) else: @@ -236,3 +252,17 @@ def _validate_scale_shape( assert block_shape is not None expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" + + +def extract_required_args( + extra_args: Optional[dict[str, Any]], + required_keys: list[str], +) -> tuple[Any, ...]: + if extra_args is None: + raise ValueError("`extra_args` must be provided.") + + missing_keys = [k for k in required_keys if k not in extra_args] + if missing_keys: + raise ValueError(f"Missing keys in `extra_args`: {missing_keys}") + + return tuple(extra_args[k] for k in required_keys) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index fcf8ea023f63..1a31410c3385 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -339,19 +339,19 @@ def apply( return cutlass_moe_fp4( a=x, w1_fp4=layer.w13_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w1_alphas=layer.g1_alphas, w2_fp4=layer.w2_weight, + w1_blockscale=layer.w13_blockscale_swizzled, w2_blockscale=layer.w2_blockscale_swizzled, - w2_alphas=layer.g2_alphas, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, topk_weights=topk_weights, topk_ids=topk_ids, m=x.shape[0], n=layer.w2_weight.shape[2] * 2, k=x.shape[1], e=layer.w13_weight.shape[0], - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, device=x.device, apply_router_weight_on_input=apply_router_weight_on_input).to( x.dtype) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 788f0a9116f8..3807899fc3e5 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -7,9 +7,15 @@ from torch.nn import Module from torch.nn.parameter import Parameter +import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._custom_ops import (cutlass_scaled_fp4_mm, cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) +from vllm.distributed import get_ep_group from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 + FlashInferCutlassMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, @@ -713,6 +719,18 @@ def __init__(self, quant_config: ModelOptNvFp4Config): self.quant_config = quant_config self.cutlass_nvfp4_supported = cutlass_fp4_supported() self.use_marlin = False + self.allow_flashinfer_cutlass = False + + if envs.VLLM_USE_FLASHINFER_MOE: + if self.cutlass_nvfp4_supported and current_platform.is_cuda() \ + and current_platform.is_device_capability(100): + logger.info_once( + "Using FlashInfer kernels for ModelOptNvFp4FusedMoE.") + self.allow_flashinfer_cutlass = True + else: + logger.warning_once( + "Flashinfer CUTLASS Fused MoE not supported " + "or found on the current platform.") if not self.cutlass_nvfp4_supported: if is_fp4_marlin_supported(): @@ -722,6 +740,73 @@ def __init__(self, quant_config: ModelOptNvFp4Config): " quantization. Please use Blackwell and" " above.") + self.fused_experts = None # type: ignore + + def maybe_swap_experts_impl( + self, + moe_parallel_config: FusedMoEParallelConfig, + ): + if not self.allow_flashinfer_cutlass: + return + + logger.debug_once("FlashInferExperts") + # default to TP/EP case only + + experts_kwargs: dict[str, Any] = { + "use_nvfp4_w4a4": True, + "use_dp": moe_parallel_config.dp_size > 1, + "ep_rank": moe_parallel_config.ep_rank, + "ep_size": moe_parallel_config.ep_size, + "tp_rank": moe_parallel_config.tp_rank, + "tp_size": moe_parallel_config.tp_size, + } + + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + FlashInferExperts) + experts = FlashInferExperts(**experts_kwargs) + self.fused_experts = mk.FusedMoEModularKernel( + FlashInferCutlassMoEPrepareAndFinalize( + quant_dtype=torch.uint8, + #meaning 2x e2m1 packed in one, kernel requirement + ), + experts, + ) + + # This method update self.fused_experts + # only prepare_finalize is not None call select_gemm_impl + # so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert + # when it's not called(TP case), we still have 2 kernels to use. + def select_gemm_impl(self, prepare_finalize, + moe) -> mk.FusedMoEPermuteExpertsUnpermute: + + assert moe is not None + assert prepare_finalize is not None + experts = None + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None + if self.allow_flashinfer_cutlass: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + FlashInferExperts) + logger.debug_once("Using FlashInferExperts") + experts = FlashInferExperts( + use_nvfp4_w4a4=True, + use_dp=moe.moe_parallel_config.dp_size > 1, + ep_rank=moe.moe_parallel_config.ep_rank, + ep_size=moe.moe_parallel_config.ep_size, + tp_rank=moe.moe_parallel_config.tp_rank, + tp_size=moe.moe_parallel_config.tp_size, + ) + else: + assert moe.dp_size > 1 + logger.debug_once("Using CutlassExpertsFp4") + # Currently CutlassExpertsFp4 doesn't support DP + raise ValueError( + "CutlassExpertsFp4 doesn't support DP. " + "Use flashinfer CUTLASS FusedMoE(VLLM_USE_FLASHINFER_MOE)" + " backend instead.") + + return experts + def uses_weight_scale_2_pattern(self) -> bool: """ FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales. @@ -842,8 +927,30 @@ def swizzle_blockscale(self, scale: torch.tensor): if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # GEMM 1 + # The FlashInfer Cutlass fused MoE kernel expects the combined weights + # to be ordered as [w3, w1], unlike the standard [w1, w3] layout. + gemm1_weight = layer.w13_weight.data + gemm1_weight_scale = layer.w13_weight_scale.data + + if self.allow_flashinfer_cutlass: + dim = -2 + size = gemm1_weight.size(dim) + assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}" + half = size // 2 + + # Reorder weight + w1, w3 = gemm1_weight.split(half, dim=dim) + gemm1_weight = torch.cat([w3, w1], dim=dim).contiguous() + + # Reorder scale + s1, s3 = gemm1_weight_scale.split(half, dim=dim) + gemm1_weight_scale = torch.cat([s3, s1], dim=dim).contiguous() + + layer.w13_weight = Parameter(gemm1_weight, requires_grad=False) + layer.w13_weight_scale = Parameter(gemm1_weight_scale, + requires_grad=False) + if not torch.allclose(layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]): logger.warning_once( @@ -874,9 +981,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_input_scale_quant = Parameter( (1 / w13_input_scale).to(torch.float32), requires_grad=False) - layer.w13_weight = Parameter(layer.w13_weight.data, - requires_grad=False) - # GEMM 2 layer.g2_alphas = Parameter( (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), @@ -961,31 +1065,74 @@ def apply( global_num_experts=global_num_experts, expert_map=expert_map) - assert expert_map is None, ("Expert Parallelism / expert_map " - "is currently not supported for " - "ModelOptNvFp4FusedMoE.") - - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp4) - - # Cutlass moe takes in activations in BF16/Half precision - # and fp4 quantized weights loaded from the checkpoint - return cutlass_moe_fp4( - a=x, - w1_fp4=layer.w13_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w1_alphas=layer.g1_alphas, - w2_fp4=layer.w2_weight, - w2_blockscale=layer.w2_blockscale_swizzled, - w2_alphas=layer.g2_alphas, - topk_weights=topk_weights, - topk_ids=topk_ids, - m=x.shape[0], - n=layer.w2_weight.shape[2] * 2, - k=x.shape[1], - e=layer.w13_weight.shape[0], - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, - device=x.device, - apply_router_weight_on_input=apply_router_weight_on_input).to( - x.dtype) + if self.fused_experts is None: + # If no modular kernel is provided, use cutlass_moe_fp4 for TP case + # only (no EP). + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp4) + out = cutlass_moe_fp4( + a=x, + w1_fp4=layer.w13_weight, + w2_fp4=layer.w2_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w2_blockscale=layer.w2_blockscale_swizzled, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + device=x.device, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input) + else: + # TP or DP case + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + is_valid_flashinfer_cutlass_fused_moe) + assert is_valid_flashinfer_cutlass_fused_moe( + x, layer.w13_weight, layer.w2_weight), ( + "Flashinfer CUTLASS Fused MoE not applicable!") + + a1_gscale = torch.min(layer.w13_input_scale_quant) + a2_gscale = torch.min(layer.w2_input_scale_quant) + extra_expert_args = { + 'g1_alphas': layer.g1_alphas, + 'g2_alphas': layer.g2_alphas, + 'out_dtype': x.dtype, + # Avoid confusion with a1_scale and a2_scale + # where are batch size related. + 'a1_gscale': a1_gscale, + 'a2_gscale': a2_gscale, + } + extra_prepare_args = { + 'use_dp': layer.dp_size > 1, + 'local_tokens': x.shape[0], + 'a1_gscale': a1_gscale, + } + extra_finalize_args = { + 'use_dp': layer.dp_size > 1, + 'local_tokens': x.shape[0], + } + + out = self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_blockscale_swizzled, + w2_scale=layer.w2_blockscale_swizzled, + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=extra_expert_args, + extra_prepare_args=extra_prepare_args, + extra_finalize_args=extra_finalize_args, + ) + return out diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py new file mode 100644 index 000000000000..dbd2dc393046 --- /dev/null +++ b/vllm/utils/flashinfer.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Compatibility wrapper for FlashInfer API changes. + +Users of vLLM should always import **only** these wrappers. +""" +from __future__ import annotations + +import contextlib +import functools +import importlib +import importlib.util +from typing import Any, Callable, NoReturn + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@functools.cache +def has_flashinfer() -> bool: + """Return ``True`` if FlashInfer is available.""" + # Use find_spec to check if the module exists without importing it + # This avoids potential CUDA initialization side effects + return importlib.util.find_spec("flashinfer") is not None + + +def _missing(*_: Any, **__: Any) -> NoReturn: + """Placeholder for unavailable FlashInfer backend.""" + raise RuntimeError( + "FlashInfer backend is not available. Please install the package " + "to enable FlashInfer kernels: " + "https://github.com/flashinfer-ai/flashinfer") + + +def _get_submodule(module_name: str) -> Any | None: + """Safely import a submodule and return it, or None if not available.""" + try: + return importlib.import_module(module_name) + except (ImportError, ModuleNotFoundError): + return None + + +# General lazy import wrapper +def _lazy_import_wrapper(module_name: str, + attr_name: str, + fallback_fn: Callable[..., Any] = _missing): + """Create a lazy import wrapper for a specific function.""" + + @functools.cache + def _get_impl(): + if not has_flashinfer(): + return None + mod = _get_submodule(module_name) + return getattr(mod, attr_name, None) if mod else None + + def wrapper(*args, **kwargs): + impl = _get_impl() + if impl is None: + return fallback_fn(*args, **kwargs) + return impl(*args, **kwargs) + + return wrapper + + +# Create lazy wrappers for each function +flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe", + "cutlass_fused_moe") +fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") +fp4_swizzle_blockscale = _lazy_import_wrapper("flashinfer", + "fp4_swizzle_blockscale") + +# Special case for autotune since it returns a context manager +autotune = _lazy_import_wrapper( + "flashinfer.autotuner", + "autotune", + fallback_fn=lambda *args, **kwargs: contextlib.nullcontext()) + + +@functools.cache +def has_flashinfer_cutlass_fused_moe() -> bool: + """Return ``True`` if FlashInfer CUTLASS fused MoE is available.""" + if not has_flashinfer(): + return False + + # Check if all required functions are available + required_functions = [ + ("flashinfer.fused_moe", "cutlass_fused_moe"), + ("flashinfer", "fp4_quantize"), + ("flashinfer", "fp4_swizzle_blockscale"), + ] + + for module_name, attr_name in required_functions: + mod = _get_submodule(module_name) + if not mod or not hasattr(mod, attr_name): + return False + return True + + +__all__ = [ + "has_flashinfer", + "has_flashinfer_cutlass_fused_moe", + "flashinfer_cutlass_fused_moe", + "fp4_quantize", + "fp4_swizzle_blockscale", + "autotune", +]