Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/models/test_deepseek_v4_mega_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from vllm.models.deepseek_v4.nvidia.model import (
DeepseekV4MegaMoEExperts,
_stage_deepseek_v4_mega_moe_inputs,
make_deepseek_v4_expert_params_mapping,
)
from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs
from vllm.platforms import current_platform

pytestmark = pytest.mark.skipif(
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_deepseek_v4_mega_moe_fused_input_staging_is_bitwise_exact():
fused_topk_idx = torch.empty_like(ref_topk_idx)
fused_topk_weights = torch.empty_like(ref_topk_weights)

_stage_deepseek_v4_mega_moe_inputs(
prepare_megamoe_inputs(
hidden_states,
topk_weights,
topk_ids,
Expand Down
165 changes: 2 additions & 163 deletions vllm/models/deepseek_v4/nvidia/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@
DeepseekV4MLAModules,
DeepseekV4MultiHeadLatentAttentionWrapper,
)
from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op


Expand Down Expand Up @@ -116,167 +116,6 @@ def forward(self, x):
return x


@triton.jit
def _deepseek_v4_stage_mega_moe_inputs_kernel(
hidden_states,
x_fp8,
x_sf,
topk_ids,
topk_weights,
topk_idx_out,
topk_weights_out,
hidden_stride_m: tl.constexpr,
hidden_stride_k: tl.constexpr,
x_stride_m: tl.constexpr,
x_stride_k: tl.constexpr,
x_sf_stride_m: tl.constexpr,
x_sf_stride_k: tl.constexpr,
topk_ids_stride_m: tl.constexpr,
topk_ids_stride_k: tl.constexpr,
topk_weights_stride_m: tl.constexpr,
topk_weights_stride_k: tl.constexpr,
topk_idx_stride_m: tl.constexpr,
topk_idx_stride_k: tl.constexpr,
topk_weights_out_stride_m: tl.constexpr,
topk_weights_out_stride_k: tl.constexpr,
hidden_size: tl.constexpr,
top_k: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_K: tl.constexpr,
BLOCK_TOPK: tl.constexpr,
) -> None:
token_id = tl.program_id(0)
k_block_id = tl.program_id(1)

k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
k_mask = k_offsets < hidden_size
hidden = tl.load(
hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k,
mask=k_mask,
other=0.0,
).to(tl.float32)

num_groups: tl.constexpr = BLOCK_K // GROUP_K
hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
amax = tl.max(hidden_groups, axis=1)
amax = tl.maximum(amax, 1.0e-4)

scale = amax / 448.0
scale_bits = scale.to(tl.uint32, bitcast=True)
scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to(
tl.uint32
)
scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254)
rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True)

hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
scaled = hidden_groups * (1.0 / rounded_scale)[:, None]
scaled = tl.reshape(scaled, [BLOCK_K])
fp8 = scaled.to(tl.float8e4nv)
tl.store(
x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k,
fp8,
mask=k_mask,
)

scale_offsets = tl.arange(0, num_groups)
packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32)
tl.store(
x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k,
packed_scale,
)

if k_block_id == 0:
topk_offsets = tl.arange(0, BLOCK_TOPK)
topk_mask = topk_offsets < top_k

ids = tl.load(
topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k,
mask=topk_mask,
other=0,
).to(tl.int64)
tl.store(
topk_idx_out
+ token_id * topk_idx_stride_m
+ topk_offsets * topk_idx_stride_k,
ids,
mask=topk_mask,
)

weights = tl.load(
topk_weights
+ token_id * topk_weights_stride_m
+ topk_offsets * topk_weights_stride_k,
mask=topk_mask,
other=0.0,
)
tl.store(
topk_weights_out
+ token_id * topk_weights_out_stride_m
+ topk_offsets * topk_weights_out_stride_k,
weights,
mask=topk_mask,
)


def _stage_deepseek_v4_mega_moe_inputs(
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
x_fp8: torch.Tensor,
x_sf: torch.Tensor,
topk_idx_out: torch.Tensor,
topk_weights_out: torch.Tensor,
) -> None:
num_tokens, hidden_size = hidden_states.shape
if num_tokens == 0:
return
if hidden_size % 128 != 0:
raise ValueError(
"DeepSeek V4 MegaMoE input staging requires hidden_size to be "
"a multiple of 128."
)
top_k = topk_ids.shape[1]
if topk_weights.shape != topk_ids.shape:
raise ValueError(
"DeepSeek V4 MegaMoE input staging requires topk_weights and "
"topk_ids to have the same shape."
)

block_k = 128
grid = (num_tokens, triton.cdiv(hidden_size, block_k))
block_topk = triton.next_power_of_2(top_k)
_deepseek_v4_stage_mega_moe_inputs_kernel[grid](
hidden_states,
x_fp8,
x_sf,
topk_ids,
topk_weights,
topk_idx_out,
topk_weights_out,
hidden_states.stride(0),
hidden_states.stride(1),
x_fp8.stride(0),
x_fp8.stride(1),
x_sf.stride(0),
x_sf.stride(1),
topk_ids.stride(0),
topk_ids.stride(1),
topk_weights.stride(0),
topk_weights.stride(1),
topk_idx_out.stride(0),
topk_idx_out.stride(1),
topk_weights_out.stride(0),
topk_weights_out.stride(1),
hidden_size,
top_k,
BLOCK_K=block_k,
GROUP_K=32,
BLOCK_TOPK=block_topk,
num_warps=4,
)


def make_deepseek_v4_expert_params_mapping(
num_experts: int,
) -> list[tuple[str, str, int, str]]:
Expand Down Expand Up @@ -542,7 +381,7 @@ def _run_mega_moe(

symm_buffer = self.get_symm_buffer()
num_tokens = hidden_states.shape[0]
_stage_deepseek_v4_mega_moe_inputs(
prepare_megamoe_inputs(
hidden_states,
topk_weights,
topk_ids,
Expand Down
173 changes: 173 additions & 0 deletions vllm/models/deepseek_v4/nvidia/ops/prepare_megamoe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Triton input-staging kernel for DeepSeek V4 MegaMoE.

Quantizes hidden states to fp8 with E8M0 group scales and repacks the
routing top-k tensors into the int64/float32 layout that the DeepGEMM
MegaMoE kernels consume.
"""

import torch

from vllm.triton_utils import tl, triton


@triton.jit
def _prepare_megamoe_inputs_kernel(
hidden_states,
x_fp8,
x_sf,
topk_ids,
topk_weights,
topk_idx_out,
topk_weights_out,
hidden_stride_m: tl.constexpr,
hidden_stride_k: tl.constexpr,
x_stride_m: tl.constexpr,
x_stride_k: tl.constexpr,
x_sf_stride_m: tl.constexpr,
x_sf_stride_k: tl.constexpr,
topk_ids_stride_m: tl.constexpr,
topk_ids_stride_k: tl.constexpr,
topk_weights_stride_m: tl.constexpr,
topk_weights_stride_k: tl.constexpr,
topk_idx_stride_m: tl.constexpr,
topk_idx_stride_k: tl.constexpr,
topk_weights_out_stride_m: tl.constexpr,
topk_weights_out_stride_k: tl.constexpr,
hidden_size: tl.constexpr,
top_k: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_K: tl.constexpr,
BLOCK_TOPK: tl.constexpr,
) -> None:
token_id = tl.program_id(0)
k_block_id = tl.program_id(1)

k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
k_mask = k_offsets < hidden_size
hidden = tl.load(
hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k,
mask=k_mask,
other=0.0,
).to(tl.float32)

num_groups: tl.constexpr = BLOCK_K // GROUP_K
hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
amax = tl.max(hidden_groups, axis=1)
amax = tl.maximum(amax, 1.0e-4)

scale = amax / 448.0
scale_bits = scale.to(tl.uint32, bitcast=True)
scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to(
tl.uint32
)
scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254)
rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True)

hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
scaled = hidden_groups * (1.0 / rounded_scale)[:, None]
scaled = tl.reshape(scaled, [BLOCK_K])
fp8 = scaled.to(tl.float8e4nv)
tl.store(
x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k,
fp8,
mask=k_mask,
)

scale_offsets = tl.arange(0, num_groups)
packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32)
tl.store(
x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k,
packed_scale,
)

if k_block_id == 0:
topk_offsets = tl.arange(0, BLOCK_TOPK)
topk_mask = topk_offsets < top_k

ids = tl.load(
topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k,
mask=topk_mask,
other=0,
).to(tl.int64)
tl.store(
topk_idx_out
+ token_id * topk_idx_stride_m
+ topk_offsets * topk_idx_stride_k,
ids,
mask=topk_mask,
)

weights = tl.load(
topk_weights
+ token_id * topk_weights_stride_m
+ topk_offsets * topk_weights_stride_k,
mask=topk_mask,
other=0.0,
)
tl.store(
topk_weights_out
+ token_id * topk_weights_out_stride_m
+ topk_offsets * topk_weights_out_stride_k,
weights,
mask=topk_mask,
)


def prepare_megamoe_inputs(
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
x_fp8: torch.Tensor,
x_sf: torch.Tensor,
topk_idx_out: torch.Tensor,
topk_weights_out: torch.Tensor,
) -> None:
num_tokens, hidden_size = hidden_states.shape
if num_tokens == 0:
return
if hidden_size % 128 != 0:
raise ValueError(
"DeepSeek V4 MegaMoE input staging requires hidden_size to be "
"a multiple of 128."
)
top_k = topk_ids.shape[1]
if topk_weights.shape != topk_ids.shape:
raise ValueError(
"DeepSeek V4 MegaMoE input staging requires topk_weights and "
"topk_ids to have the same shape."
)

block_k = 128
grid = (num_tokens, triton.cdiv(hidden_size, block_k))
block_topk = triton.next_power_of_2(top_k)
_prepare_megamoe_inputs_kernel[grid](
hidden_states,
x_fp8,
x_sf,
topk_ids,
topk_weights,
topk_idx_out,
topk_weights_out,
hidden_states.stride(0),
hidden_states.stride(1),
x_fp8.stride(0),
x_fp8.stride(1),
x_sf.stride(0),
x_sf.stride(1),
topk_ids.stride(0),
topk_ids.stride(1),
topk_weights.stride(0),
topk_weights.stride(1),
topk_idx_out.stride(0),
topk_idx_out.stride(1),
topk_weights_out.stride(0),
topk_weights_out.stride(1),
hidden_size,
top_k,
BLOCK_K=block_k,
GROUP_K=32,
BLOCK_TOPK=block_topk,
num_warps=4,
)
Loading