diff --git a/.gitignore b/.gitignore index 96b97a552c54..5dc0f04b6fbc 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,9 @@ # vllm-flash-attn built from source vllm/vllm_flash_attn/* +# triton jit +.triton + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py new file mode 100644 index 000000000000..3f9b32ce5a36 --- /dev/null +++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py @@ -0,0 +1,375 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass, fields + +import pytest +import torch +import torch.nn.functional as F +import triton_kernels.swiglu +from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig +from triton_kernels.numerics import InFlexData +from triton_kernels.numerics_details.mxfp import (downcast_to_mxfp, + upcast_from_mxfp) +from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor +from triton_kernels.tensor_details import layout +from triton_kernels.testing import assert_close + +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedPrepareAndFinalize) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( + BatchedOAITritonExperts, triton_kernel_moe_forward) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.model_executor.layers.utils import shuffle_weight +from vllm.utils import round_up + + +def deshuffle(w: torch.Tensor): + first = w[..., ::2] + second = w[..., 1::2] + + deshuffled = torch.concat((first, second), dim=-1) + return deshuffled + + +def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int): + randbits = [torch.randperm(E) for _ in range(M)] + x_list = [ + (-1)**i * + ((16384 + + ((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16)) + for i, bits in enumerate(randbits) + ] + exp_data = torch.stack(x_list).to( + device="cuda") # simulating gate_output (M, E) + + # create input tensor + x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda") + w1 = torch.randn((E, 2 * N, K), dtype=torch.bfloat16, device="cuda") + w1_bias = torch.randn((E, 2 * N), dtype=torch.bfloat16, device="cuda") + + w2 = torch.randn((E, K, N), dtype=torch.bfloat16, device="cuda") + w2_bias = torch.randn((E, K), dtype=torch.bfloat16, device="cuda") + + exp_data_tri = exp_data.clone() + x_tri = x.clone() + w1_tri = w1.clone() + w2_tri = w2.clone() + + w1_bias_tri = w1_bias.clone() + w2_bias_tri = w2_bias.clone() + w1_bias_tri = w1_bias_tri.to(torch.float32) + w2_bias_tri = w2_bias_tri.to(torch.float32) + + dtype_dict = { + "bf16": torch.bfloat16, + "fp8_e4m3": torch.float8_e4m3fn, + "fp8_e5m2": torch.float8_e5m2 + } + + x = x.to(dtype_dict[a_dtype]).to(torch.bfloat16) + if w_dtype != "mx4": + # simulate quantization support on reference impl + w1 = w1.to(dtype_dict[w_dtype]).to(torch.bfloat16) + w2 = w2.to(dtype_dict[w_dtype]).to(torch.bfloat16) + + # triton moe kernel use transposed shape for matmul + w1_tri = w1_tri.transpose(-2, -1) + w2_tri = w2_tri.transpose(-2, -1) + + # shuffle weights + w1_tri = shuffle_weight(w1_tri) + w1_bias_tri = shuffle_weight(w1_bias_tri) + + # quant triton_weights + x_tri = x.to(dtype_dict[a_dtype]) + if w_dtype != "mx4": + pytest.skip("NYI") + else: # quantize to mx4 + # careful on the padding here, the activation padding need to be + # multiple of 64, the actual engine is not implemented + w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1] + w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2] + + w2_bottom_pad = w1_right_pad // 2 + w2_right_pad = w1_bottom_pad + + x_pad = w1_bottom_pad + + w1_tri = F.pad(w1_tri, (0, w1_right_pad, 0, w1_bottom_pad, 0, 0), + mode="constant", + value=0) + w2_tri = F.pad(w2_tri, (0, w2_right_pad, 0, w2_bottom_pad, 0, 0), + mode="constant", + value=0) + + w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0), + mode="constant", + value=0) + w2_bias_tri = F.pad(w2_bias_tri, (0, w2_right_pad, 0, 0), + mode="constant", + value=0) + + x_tri = F.pad(x_tri, (0, x_pad, 0, 0), mode="constant", value=0) + + w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout( + mx_axis=1) + w_scale_layout, w_scale_layout_opts = ( + layout.make_default_matmul_mxfp4_w_scale_layout( + mx_axis=1, num_warps=num_warps)) + + w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1) + w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, torch.bfloat16, axis=1) + + w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1) + w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, torch.bfloat16, axis=1) + + w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, + **w_layout_opts) + w1_scale_tri = convert_layout(wrap_torch_tensor(w1_scale_tri), + w_scale_layout, **w_scale_layout_opts) + + w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, + **w_layout_opts) + w2_scale_tri = convert_layout(wrap_torch_tensor(w2_scale_tri), + w_scale_layout, **w_scale_layout_opts) + + pc1 = PrecisionConfig(weight_scale=w1_scale_tri, + flex_ctx=FlexCtx(rhs_data=InFlexData())) + pc2 = PrecisionConfig(weight_scale=w2_scale_tri, + flex_ctx=FlexCtx(rhs_data=InFlexData())) + + # tucuate so the rest can run properly + w1 = w1[..., :K, :2 * N] + w2 = w2[..., :N, :K] + + w1 = deshuffle(w1) + + w1 = w1.transpose(-1, -2).contiguous() + w2 = w2.transpose(-1, -2).contiguous() + + return (x, w1, w1_bias, w2, w2_bias, exp_data, x_tri, w1_tri, w2_tri, + exp_data_tri, w1_bias_tri, w2_bias_tri, pc1, pc2) + + +@dataclass +class ModelConfig: + num_hidden_layers: int = 36 + num_experts: int = 128 + experts_per_token: int = 4 + vocab_size: int = 201088 + hidden_size: int = 2880 + intermediate_size: int = 2880 + head_dim: int = 64 + num_attention_heads: int = 64 + num_key_value_heads: int = 8 + sliding_window: int = 128 + initial_context_length: int = 4096 + rope_theta: float = 150000.0 + rope_scaling_factor: float = 32.0 + rope_ntk_alpha: float = 1.0 + rope_ntk_beta: float = 32.0 + + +def swiglu(x, alpha: float = 1.702, limit: float = 1.0): + # Note we add an extra bias of 1 to the linear layer + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + if limit is not None: + x_glu = x_glu.clamp(max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + if limit is not None: + x_linear = x_linear.clamp(min=-limit, max=limit) + return out_glu * (x_linear + 1) + + +def oai_moe_forward( + hidden_states: torch.Tensor, # (M, K) + w1: torch.Tensor, # (E, 2N) + w1_bias: torch.Tensor, # (E, 2N, K) + w2: torch.Tensor, # (E, K, N) + w2_bias: torch.Tensor, # (E, N) + gating_output: torch.Tensor, # (M, E) + topk: int): + # model.py 309:330, assuming gating and norm + t = hidden_states + experts = torch.topk(gating_output, k=topk, dim=-1, sorted=True) + expert_weights = torch.nn.functional.softmax(experts.values, dim=1) + expert_indices = experts.indices + + # MLP #1 + mlp1_weight = w1[expert_indices, ...] + mlp1_bias = w1_bias[expert_indices, ...] + t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias + t = swiglu(t, limit=7) + + # MLP #2 + mlp2_weight = w2[expert_indices, ...] + mlp2_bias = w2_bias[expert_indices, ...] + t = torch.einsum("beck,bek->bec", mlp2_weight, t) + t += mlp2_bias + + # Weighted sum of experts + t = torch.einsum("bec,be->bc", t, expert_weights) + + return t + + +@dataclass +class Case: + a_dtype: str + w_dtype: str + + +@pytest.mark.parametrize( + ", ".join(f.name for f in fields(Case)), + [ + tuple(getattr(case, f.name) for f in fields(Case)) for case in [ + # Case(a_dtype="bf16", w_dtype="bf16"), + # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"), + Case(a_dtype="bf16", w_dtype="mx4") + ] + ], +) +@pytest.mark.parametrize("num_token", [2]) +@pytest.mark.parametrize("tp", [1, 2, 4, 8]) +def test_equiv(num_token, a_dtype, w_dtype, tp): + M = num_token + E = ModelConfig.num_experts + K = ModelConfig.hidden_size + N = ModelConfig.intermediate_size // tp + topk = ModelConfig.experts_per_token + + x, w1, w1_bias, w2, w2_bias, exp_data, \ + x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri,\ + w2_bias_tri, pc1, pc2 = init_compute_data( + M, K, N, E, a_dtype, w_dtype, num_warps=8) + + out_triton_monolithic = triton_kernel_moe_forward( + hidden_states=x_tri, + w1=w1_tri, + w2=w2_tri, + gating_output=exp_data_tri, + topk=topk, + renormalize=True, + w1_bias=w1_bias_tri, + w2_bias=w2_bias_tri, + w1_precision=pc1, + w2_precision=pc2) + out_triton_monolithic = out_triton_monolithic[..., :K] + + out_ref = oai_moe_forward(hidden_states=x, + w1=w1, + w1_bias=w1_bias, + w2=w2, + w2_bias=w2_bias, + gating_output=exp_data, + topk=topk) + assert_close(ref=out_ref, + tri=out_triton_monolithic, + maxtol=0.025, + rmstol=0.005) + + +def batched_moe(a: torch.Tensor, w1, w2, gating_output: torch.Tensor, + topk: int, renormalize: bool, w1_bias: torch.Tensor, + w2_bias: torch.Tensor, w1_precision: PrecisionConfig, + w2_precision: PrecisionConfig) -> torch.Tensor: + max_num_tokens = round_up(a.shape[0], 64) + + fused_experts = FusedMoEModularKernel( + BatchedPrepareAndFinalize(max_num_tokens, + num_dispatchers=1, + num_local_experts=w1.shape[0], + rank=0), + BatchedOAITritonExperts( + None, + max_num_tokens=max_num_tokens, + num_dispatchers=1, + w1_precision=w1_precision, + w2_precision=w2_precision, + ), + ) + + extra_expert_args = { + "w1_bias": w1_bias, + "w2_bias": w2_bias, + } + + topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize) + + return fused_experts( + a, + w1, + w2, + topk_weight, + topk_ids, + extra_expert_args=extra_expert_args, + ) + + +@pytest.mark.parametrize( + ", ".join(f.name for f in fields(Case)), + [ + tuple(getattr(case, f.name) for f in fields(Case)) for case in [ + # Case(a_dtype="bf16", w_dtype="bf16"), + # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"), + Case(a_dtype="bf16", w_dtype="mx4") + ] + ], +) +@pytest.mark.parametrize("num_token", [64]) +@pytest.mark.parametrize("ep", [1, 2, 4, 8]) +def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep): + M = num_token + E = ModelConfig.num_experts // ep + K = ModelConfig.hidden_size + N = ModelConfig.intermediate_size + topk = ModelConfig.experts_per_token + + x, w1, w1_bias, w2, w2_bias, exp_data, \ + x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri, \ + w2_bias_tri, pc1, pc2 = init_compute_data( + M, K, N, E, a_dtype, w_dtype, num_warps=4) + + out_tri = batched_moe(a=x_tri, + w1=w1_tri, + w2=w2_tri, + gating_output=exp_data_tri, + topk=topk, + renormalize=True, + w1_bias=w1_bias_tri, + w2_bias=w2_bias_tri, + w1_precision=pc1, + w2_precision=pc2) + out_tri = out_tri[..., :K] + + out_ref = oai_moe_forward(hidden_states=x, + w1=w1, + w1_bias=w1_bias, + w2=w2, + w2_bias=w2_bias, + gating_output=exp_data, + topk=topk) + assert_close(ref=out_ref, tri=out_tri, maxtol=0.025, rmstol=0.005) + + +def test_unit_shuffle(): + N = ModelConfig.intermediate_size + K = ModelConfig.hidden_size + m = torch.randn((K, 2 * N), dtype=torch.bfloat16, device="cuda") + + x = torch.randn(K, dtype=torch.bfloat16, device="cuda") + + m_shuffled = shuffle_weight(m) + + out_ref = x @ m + out_ref = swiglu(out_ref, limit=1.0) + + out = x @ m_shuffled + out = triton_kernels.swiglu.swiglu_torch( + out, + alpha=1.702, + precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0)) + + assert_close(ref=out_ref, tri=out) \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py new file mode 100644 index 000000000000..4482029c16a8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -0,0 +1,230 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) +from vllm.model_executor.layers.fused_moe.utils import extract_required_args + +if True: + import triton_kernels.swiglu + from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation, + PrecisionConfig, matmul_ogs) + from triton_kernels.routing import routing + + +def triton_kernel_moe_forward( + hidden_states: torch.Tensor, + w1, # Tensor or triton_kernels.Tensor + w2, # Tensor or triton_kernels.Tensor + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + w1_precision=None, # PrecisionConfig or None + w2_precision=None, # PrecisionConfig or None + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + + routing_data, gather_idx, scatter_idx = routing(gating_output, + topk, + sm_first=not renormalize) + + return triton_kernel_fused_experts( + None, + hidden_states, + w1, + w2, + routing_data, + gather_idx, + scatter_idx, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=use_fp8_w8a8, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_precision=w1_precision, + w2_precision=w2_precision, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape) + + +# This is a triton implementation of the fused_experts function +def triton_kernel_fused_experts( + output_tensor: torch.Tensor, + hidden_states: torch.Tensor, + w1, # Tensor or triton_kernels.Tensor + w2, # Tensor or triton_kernels.Tensor + routing_data, # RoutingData + gather_indx, # GatherIndx + scatter_indx, # ScatterIndx + activation: str = "silu", + swiglu_alpha: float = 1.702, + swiglu_limit: float = 7.0, + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, + w1_precision=None, # PrecisionConfig or None + w2_precision=None, # PrecisionConfig or None + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + + # type check, uint8 means mxfp4 + assert hidden_states.dtype == torch.bfloat16 + assert w1_bias is None or w1_bias.dtype == torch.float32 + assert w2_bias is None or w2_bias.dtype == torch.float32 + + # Shape check, only check non-mxfp4 + assert hidden_states.shape[-1] == w1.shape[-2] + assert w2.shape[-1] == w1.shape[1] + + E, _, N = w1.shape + + if global_num_experts == -1: + global_num_experts = E + + act = FusedActivation( + FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), + (swiglu_alpha, swiglu_limit), 2) + gammas = routing_data.gate_scal if routing_data else None + + intermediate_cache1 = matmul_ogs( + hidden_states, + w1, + w1_bias, + routing_data, + gather_indx=gather_indx, + precision_config=w1_precision, + gammas=gammas if apply_router_weight_on_input else None, + fused_activation=act) + + intermediate_cache3 = matmul_ogs( + intermediate_cache1, + w2, + w2_bias, + routing_data, + scatter_indx=scatter_indx, + precision_config=w2_precision, + gammas=None if apply_router_weight_on_input else gammas, + y=output_tensor, + ) + return intermediate_cache3 + + +class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__(self, quant_config, max_num_tokens: int, num_dispatchers: int, + w1_precision: PrecisionConfig, w2_precision: PrecisionConfig): + super().__init__(quant_config) + self.max_num_tokens = max_num_tokens + self.num_dispatchers = num_dispatchers + self.w1_precision = w1_precision + self.w2_precision = w2_precision + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + + def workspace_shapes( + self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, + topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata] + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # workspace are allocated inside the kernel + assert a.dim() == 2 + num_dp = self.num_dispatchers + num_experts = local_num_experts + max_num_tokens = self.max_num_tokens + workspace2 = (0, 0, 0) + output = (num_experts, max_num_tokens * num_dp, N) + return (output, workspace2, output, a.dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], + ): + w1_bias, w2_bias = (extract_required_args(extra_expert_args, + ["w1_bias", "w2_bias"])) + + return triton_kernel_fused_experts( + output, + hidden_states, + w1, + w2, + None, + None, + None, + activation=activation, + apply_router_weight_on_input=False, + use_fp8_w8a8=False, + per_channel_quant=False, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_precision=self.w1_precision, + w2_precision=self.w2_precision, + a1_scale=a1q_scale, + a2_scale=a2_scale) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a4a6157fa4bf..fb9758f09031 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -34,7 +34,7 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, - round_up) + has_triton_kernels, is_torch_equal_or_newer, round_up) from vllm.utils.flashinfer import has_flashinfer if current_platform.is_cuda_alike(): @@ -721,10 +721,17 @@ def __init__( self.global_num_experts = num_experts + num_redundant_experts # we padding globally so EP buffer allocation works - if quant_config and quant_config.get_name() == "mxfp4" and ( - envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): - hidden_size = round_up(hidden_size, 256) + if quant_config and quant_config.get_name() == "mxfp4": + if not is_torch_equal_or_newer("2.8.0"): + raise RuntimeError("Mxfp4 on hopper requires torch >= 2.8.0") + if current_platform.is_device_capability( + 90) and not has_triton_kernels(): + raise NotImplementedError( + "Triton kernels must be installed for mxfp4 on hopper") + if (current_platform.is_rocm() + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + hidden_size = round_up(hidden_size, 256) # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index b6d7bc5d5ccc..b17b157ec155 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -8,16 +8,19 @@ from vllm import envs from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( + triton_kernel_moe_forward) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - _can_support_mxfp4) + _can_support_mxfp4, _swizzle_mxfp4) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.utils import next_power_of_2, round_up if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 @@ -39,7 +42,7 @@ def from_config(cls, config): @classmethod def get_min_capability(cls) -> int: - return 100 + return 90 @classmethod def get_name(cls) -> QuantizationMethods: @@ -100,11 +103,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, intermediate_size_per_partition # pad the intermediate size to be a multiple of 2 * mxfp4_block # for to hold non-uniform sharded tensor as well as swizzling + # other padding to increase performance if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, 256) hidden_size = round_up(hidden_size, 256) + elif current_platform.is_rocm(): + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 128) + else: + intermediate_size_per_partition_after_pad = round_up( + intermediate_size_per_partition, 64) self.intermediate_size = intermediate_size_per_partition_after_pad self.hidden_size = hidden_size @@ -284,7 +294,41 @@ def swap_every_two_rows(x, axis=-1): layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape( self.num_experts, -1), requires_grad=False) - return + else: + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + + w13_bias = layer.w13_bias.to(torch.float32) + w2_bias = layer.w2_bias.to(torch.float32) + + layer.w13_bias = Parameter(w13_bias, requires_grad=False) + layer.w2_bias = Parameter(w2_bias, requires_grad=False) + + # FIXME warp need to be adjusted based on batch size + # only apply to batched mode + if self.moe.use_ep: + num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 + else: + num_warps = 8 + + w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( + layer.w13_weight, layer.w13_weight_scale, num_warps) + w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( + layer.w2_weight, layer.w2_weight_scale, num_warps) + + self.w13_precision_config = PrecisionConfig( + weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)) + self.w2_precision_config = PrecisionConfig( + weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)) + + self.w13_weight_triton_tensor = w13_weight + self.w2_weight_triton_tensor = w2_weight + + # need to delete the original weights to save memory on single GPU + del layer.w13_weight + del layer.w2_weight + layer.w13_weight = None + layer.w2_weight = None + torch.cuda.empty_cache() def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): # Number of tokens in the input tensor. @@ -385,3 +429,19 @@ def apply( True, # do finalize )[0] return trtllm_gen_output + else: + return triton_kernel_moe_forward( + hidden_states=x, + w1=self.w13_weight_triton_tensor, + w2=self.w2_weight_triton_tensor, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + w1_precision=self.w13_precision_config, + w2_precision=self.w2_precision_config, + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 4a4e199e1318..4084dd837c08 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -4,11 +4,55 @@ import torch -from vllm.utils import direct_register_custom_op +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer + +logger = init_logger(__name__) OCP_MX_BLOCK_SIZE = 32 +def _swizzle_mxfp4(quant_tensor, scale, num_warps): + """ weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel + """ + import triton_kernels.matmul_ogs_details.opt_flags as opt_flags + from triton_kernels.numerics import InFlexData + from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor + from triton_kernels.tensor_details import layout + from triton_kernels.tensor_details.layout import StridedLayout + if (current_platform.is_cuda() + and current_platform.is_device_capability(90) + and not is_torch_equal_or_newer("2.8.1")): + logger.warning_once( + "Mxfp4 on hopper is running on torch < 2.8.1, " + "this cause swizling to be disabled, which may " + "cause performance degradation. Please upgrade to torch nightly") + value_layout, value_layout_opts = StridedLayout, dict() + scale_layout, scale_layout_opts = StridedLayout, dict() + else: + value_layout, value_layout_opts = \ + layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) + scale_layout, scale_layout_opts = ( + layout.make_default_matmul_mxfp4_w_scale_layout( + mx_axis=1, num_warps=num_warps)) + if current_platform.is_cuda() and \ + current_platform.is_device_capability(100): + constraints = { + "is_persistent": True, + "epilogue_subtile": 1, + } + opt_flags.update_opt_flags_constraints(constraints) + # transpose the tensor so that the quantization axis is on dim1 + quant_tensor = quant_tensor.transpose(-2, -1) + scale = scale.transpose(-2, -1) + quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4), + value_layout, **value_layout_opts) + scale = convert_layout(wrap_torch_tensor(scale), scale_layout, + **scale_layout_opts) + return quant_tensor, InFlexData(), scale + + def _can_support_mxfp4(use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index cd32f12f3c26..48a347a8f561 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -11,6 +11,27 @@ from vllm.utils import direct_register_custom_op +def shuffle_weight(w: torch.Tensor) -> torch.Tensor: + # Shuffle weight along the last dimension so that + # we folded the weights to adjance location + # Example: + # input: + # [[1, 2, 3, 4, 5, 6], + # [7, 8, 9, 10, 11, 12]] + # output: + # [[1, 4, 2, 5, 3, 6], + # [7, 10, 8, 11, 9, 12]] + # This will be used together with triton swiglu kernel + shape = w.shape + N = shape[-1] + first = w[..., :N // 2] + second = w[..., N // 2:] + + stacked = torch.stack((first, second), dim=-1) + w_shuffled = stacked.reshape(shape) + return w_shuffled + + def get_token_bin_counts_and_mask( tokens: torch.Tensor, vocab_size: int, diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index ce62282c2199..fc55a6857f20 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3243,6 +3243,12 @@ def has_deep_gemm() -> bool: return _has_module("deep_gemm") +def has_triton_kernels() -> bool: + """Whether the optional `triton_kernels` package is available.""" + + return _has_module("triton_kernels") + + def set_process_title(name: str, suffix: str = "", append: bool = False) -> None: