diff --git a/tests/models/test_deepseek_v4_mega_moe.py b/tests/models/test_deepseek_v4_mega_moe.py index 3d4657bba130..3daae242d459 100644 --- a/tests/models/test_deepseek_v4_mega_moe.py +++ b/tests/models/test_deepseek_v4_mega_moe.py @@ -8,9 +8,9 @@ from vllm.models.deepseek_v4.nvidia.model import ( DeepseekV4MegaMoEExperts, - _stage_deepseek_v4_mega_moe_inputs, make_deepseek_v4_expert_params_mapping, ) +from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs from vllm.platforms import current_platform pytestmark = pytest.mark.skipif( @@ -164,7 +164,7 @@ def test_deepseek_v4_mega_moe_fused_input_staging_is_bitwise_exact(): fused_topk_idx = torch.empty_like(ref_topk_idx) fused_topk_weights = torch.empty_like(ref_topk_weights) - _stage_deepseek_v4_mega_moe_inputs( + prepare_megamoe_inputs( hidden_states, topk_weights, topk_ids, diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 44c97715d848..0c0cee2a01e1 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -59,9 +59,9 @@ DeepseekV4MLAModules, DeepseekV4MultiHeadLatentAttentionWrapper, ) +from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op @@ -116,167 +116,6 @@ def forward(self, x): return x -@triton.jit -def _deepseek_v4_stage_mega_moe_inputs_kernel( - hidden_states, - x_fp8, - x_sf, - topk_ids, - topk_weights, - topk_idx_out, - topk_weights_out, - hidden_stride_m: tl.constexpr, - hidden_stride_k: tl.constexpr, - x_stride_m: tl.constexpr, - x_stride_k: tl.constexpr, - x_sf_stride_m: tl.constexpr, - x_sf_stride_k: tl.constexpr, - topk_ids_stride_m: tl.constexpr, - topk_ids_stride_k: tl.constexpr, - topk_weights_stride_m: tl.constexpr, - topk_weights_stride_k: tl.constexpr, - topk_idx_stride_m: tl.constexpr, - topk_idx_stride_k: tl.constexpr, - topk_weights_out_stride_m: tl.constexpr, - topk_weights_out_stride_k: tl.constexpr, - hidden_size: tl.constexpr, - top_k: tl.constexpr, - BLOCK_K: tl.constexpr, - GROUP_K: tl.constexpr, - BLOCK_TOPK: tl.constexpr, -) -> None: - token_id = tl.program_id(0) - k_block_id = tl.program_id(1) - - k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K) - k_mask = k_offsets < hidden_size - hidden = tl.load( - hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k, - mask=k_mask, - other=0.0, - ).to(tl.float32) - - num_groups: tl.constexpr = BLOCK_K // GROUP_K - hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) - amax = tl.max(hidden_groups, axis=1) - amax = tl.maximum(amax, 1.0e-4) - - scale = amax / 448.0 - scale_bits = scale.to(tl.uint32, bitcast=True) - scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to( - tl.uint32 - ) - scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254) - rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True) - - hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) - scaled = hidden_groups * (1.0 / rounded_scale)[:, None] - scaled = tl.reshape(scaled, [BLOCK_K]) - fp8 = scaled.to(tl.float8e4nv) - tl.store( - x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k, - fp8, - mask=k_mask, - ) - - scale_offsets = tl.arange(0, num_groups) - packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32) - tl.store( - x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k, - packed_scale, - ) - - if k_block_id == 0: - topk_offsets = tl.arange(0, BLOCK_TOPK) - topk_mask = topk_offsets < top_k - - ids = tl.load( - topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, - mask=topk_mask, - other=0, - ).to(tl.int64) - tl.store( - topk_idx_out - + token_id * topk_idx_stride_m - + topk_offsets * topk_idx_stride_k, - ids, - mask=topk_mask, - ) - - weights = tl.load( - topk_weights - + token_id * topk_weights_stride_m - + topk_offsets * topk_weights_stride_k, - mask=topk_mask, - other=0.0, - ) - tl.store( - topk_weights_out - + token_id * topk_weights_out_stride_m - + topk_offsets * topk_weights_out_stride_k, - weights, - mask=topk_mask, - ) - - -def _stage_deepseek_v4_mega_moe_inputs( - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - x_fp8: torch.Tensor, - x_sf: torch.Tensor, - topk_idx_out: torch.Tensor, - topk_weights_out: torch.Tensor, -) -> None: - num_tokens, hidden_size = hidden_states.shape - if num_tokens == 0: - return - if hidden_size % 128 != 0: - raise ValueError( - "DeepSeek V4 MegaMoE input staging requires hidden_size to be " - "a multiple of 128." - ) - top_k = topk_ids.shape[1] - if topk_weights.shape != topk_ids.shape: - raise ValueError( - "DeepSeek V4 MegaMoE input staging requires topk_weights and " - "topk_ids to have the same shape." - ) - - block_k = 128 - grid = (num_tokens, triton.cdiv(hidden_size, block_k)) - block_topk = triton.next_power_of_2(top_k) - _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( - hidden_states, - x_fp8, - x_sf, - topk_ids, - topk_weights, - topk_idx_out, - topk_weights_out, - hidden_states.stride(0), - hidden_states.stride(1), - x_fp8.stride(0), - x_fp8.stride(1), - x_sf.stride(0), - x_sf.stride(1), - topk_ids.stride(0), - topk_ids.stride(1), - topk_weights.stride(0), - topk_weights.stride(1), - topk_idx_out.stride(0), - topk_idx_out.stride(1), - topk_weights_out.stride(0), - topk_weights_out.stride(1), - hidden_size, - top_k, - BLOCK_K=block_k, - GROUP_K=32, - BLOCK_TOPK=block_topk, - num_warps=4, - ) - - def make_deepseek_v4_expert_params_mapping( num_experts: int, ) -> list[tuple[str, str, int, str]]: @@ -542,7 +381,7 @@ def _run_mega_moe( symm_buffer = self.get_symm_buffer() num_tokens = hidden_states.shape[0] - _stage_deepseek_v4_mega_moe_inputs( + prepare_megamoe_inputs( hidden_states, topk_weights, topk_ids, diff --git a/vllm/models/deepseek_v4/nvidia/ops/prepare_megamoe.py b/vllm/models/deepseek_v4/nvidia/ops/prepare_megamoe.py new file mode 100644 index 000000000000..7cdb39e9b686 --- /dev/null +++ b/vllm/models/deepseek_v4/nvidia/ops/prepare_megamoe.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Triton input-staging kernel for DeepSeek V4 MegaMoE. + +Quantizes hidden states to fp8 with E8M0 group scales and repacks the +routing top-k tensors into the int64/float32 layout that the DeepGEMM +MegaMoE kernels consume. +""" + +import torch + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _prepare_megamoe_inputs_kernel( + hidden_states, + x_fp8, + x_sf, + topk_ids, + topk_weights, + topk_idx_out, + topk_weights_out, + hidden_stride_m: tl.constexpr, + hidden_stride_k: tl.constexpr, + x_stride_m: tl.constexpr, + x_stride_k: tl.constexpr, + x_sf_stride_m: tl.constexpr, + x_sf_stride_k: tl.constexpr, + topk_ids_stride_m: tl.constexpr, + topk_ids_stride_k: tl.constexpr, + topk_weights_stride_m: tl.constexpr, + topk_weights_stride_k: tl.constexpr, + topk_idx_stride_m: tl.constexpr, + topk_idx_stride_k: tl.constexpr, + topk_weights_out_stride_m: tl.constexpr, + topk_weights_out_stride_k: tl.constexpr, + hidden_size: tl.constexpr, + top_k: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_K: tl.constexpr, + BLOCK_TOPK: tl.constexpr, +) -> None: + token_id = tl.program_id(0) + k_block_id = tl.program_id(1) + + k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + k_mask = k_offsets < hidden_size + hidden = tl.load( + hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k, + mask=k_mask, + other=0.0, + ).to(tl.float32) + + num_groups: tl.constexpr = BLOCK_K // GROUP_K + hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) + amax = tl.max(hidden_groups, axis=1) + amax = tl.maximum(amax, 1.0e-4) + + scale = amax / 448.0 + scale_bits = scale.to(tl.uint32, bitcast=True) + scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to( + tl.uint32 + ) + scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254) + rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True) + + hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) + scaled = hidden_groups * (1.0 / rounded_scale)[:, None] + scaled = tl.reshape(scaled, [BLOCK_K]) + fp8 = scaled.to(tl.float8e4nv) + tl.store( + x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k, + fp8, + mask=k_mask, + ) + + scale_offsets = tl.arange(0, num_groups) + packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32) + tl.store( + x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k, + packed_scale, + ) + + if k_block_id == 0: + topk_offsets = tl.arange(0, BLOCK_TOPK) + topk_mask = topk_offsets < top_k + + ids = tl.load( + topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, + mask=topk_mask, + other=0, + ).to(tl.int64) + tl.store( + topk_idx_out + + token_id * topk_idx_stride_m + + topk_offsets * topk_idx_stride_k, + ids, + mask=topk_mask, + ) + + weights = tl.load( + topk_weights + + token_id * topk_weights_stride_m + + topk_offsets * topk_weights_stride_k, + mask=topk_mask, + other=0.0, + ) + tl.store( + topk_weights_out + + token_id * topk_weights_out_stride_m + + topk_offsets * topk_weights_out_stride_k, + weights, + mask=topk_mask, + ) + + +def prepare_megamoe_inputs( + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + x_fp8: torch.Tensor, + x_sf: torch.Tensor, + topk_idx_out: torch.Tensor, + topk_weights_out: torch.Tensor, +) -> None: + num_tokens, hidden_size = hidden_states.shape + if num_tokens == 0: + return + if hidden_size % 128 != 0: + raise ValueError( + "DeepSeek V4 MegaMoE input staging requires hidden_size to be " + "a multiple of 128." + ) + top_k = topk_ids.shape[1] + if topk_weights.shape != topk_ids.shape: + raise ValueError( + "DeepSeek V4 MegaMoE input staging requires topk_weights and " + "topk_ids to have the same shape." + ) + + block_k = 128 + grid = (num_tokens, triton.cdiv(hidden_size, block_k)) + block_topk = triton.next_power_of_2(top_k) + _prepare_megamoe_inputs_kernel[grid]( + hidden_states, + x_fp8, + x_sf, + topk_ids, + topk_weights, + topk_idx_out, + topk_weights_out, + hidden_states.stride(0), + hidden_states.stride(1), + x_fp8.stride(0), + x_fp8.stride(1), + x_sf.stride(0), + x_sf.stride(1), + topk_ids.stride(0), + topk_ids.stride(1), + topk_weights.stride(0), + topk_weights.stride(1), + topk_idx_out.stride(0), + topk_idx_out.stride(1), + topk_weights_out.stride(0), + topk_weights_out.stride(1), + hidden_size, + top_k, + BLOCK_K=block_k, + GROUP_K=32, + BLOCK_TOPK=block_topk, + num_warps=4, + )