Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
92d647c
feat(ep_moe): integrate deepgemm into origin ep moe
TianQiLin666666 Apr 22, 2025
e057acb
fix(ep_moe): group_gemm_mask bug
TianQiLin666666 Apr 22, 2025
19ec50e
fix bugs
TianQiLin666666 Apr 22, 2025
3ce1a91
fix bugs
TianQiLin666666 Apr 23, 2025
3d51a71
fix(em_moe): offset bugs
TianQiLin666666 Apr 24, 2025
c80fc3c
fix(deepgemm): bugfix
TianQiLin666666 Apr 24, 2025
af94a8b
fix: remove redundant code
TianQiLin666666 Apr 28, 2025
2022070
fix: clang-format
TianQiLin666666 Apr 28, 2025
55ea483
fix: remove print
TianQiLin666666 Apr 28, 2025
988a522
fix(ep_moe): replace EPMOE_USE_DEEPGEMM with _ENABLE_JIT_DEEPGEMM
TianQiLin666666 Apr 28, 2025
1f81f01
merge main
xutizhou Jun 2, 2025
b4ae984
Refactor moe_ep_deepgemm_preprocess to remove CUDA-specific handling …
xutizhou Jun 2, 2025
c6d51d2
Fix condition for expert fusion by updating the check for 'enable_dee…
xutizhou Jun 2, 2025
0397c25
Fix typo in function name from 'moe_ep_deepgemm_preproess' to 'moe_ep…
xutizhou Jun 2, 2025
aeb437d
Update moe_ep_deepgemm_preprocess to adjust m_max calculation for mas…
xutizhou Jun 3, 2025
d2c19bf
Refactor compute_masked_m_triton_kernel to remove num_experts paramet…
xutizhou Jun 3, 2025
0d3e793
Refactor moe_ep_deepgemm_preprocess to improve assertions and streaml…
xutizhou Jun 3, 2025
ee1a2b1
Refactor kernel functions in kernels.py to improve variable naming an…
xutizhou Jun 3, 2025
d530ee7
Enhance EPMoE layer by capturing hidden states' shape, dtype, and dev…
xutizhou Jun 3, 2025
cb6ea36
Refactor EPMoE layer to improve memory management by explicitly delet…
xutizhou Jun 3, 2025
2769865
Update EPMoE layer to use hidden_states_device for tensor creation, e…
xutizhou Jun 3, 2025
a7a6235
Refactor deepgemm_post_reorder_triton_kernel to improve variable hand…
xutizhou Jun 3, 2025
a1493e5
Optimize EPMoE layer by explicitly deleting intermediate tensor varia…
xutizhou Jun 3, 2025
b1edec5
fix(ep_moe_deepgemm): use dispose_tensor to really free tensor mem
TianQiLin666666 Jun 3, 2025
773f151
Merge pull request #1 from TianQiLin666666/feat/ep_moe_deepgemm_zxt
xutizhou Jun 3, 2025
4975723
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 4, 2025
6584c21
Enhance EPMoE layer to conditionally use deep GEMM based on FP8 setti…
xutizhou Jun 5, 2025
c54a141
Optimize memory management in EPMoE layer by removing unnecessary ten…
xutizhou Jun 5, 2025
86f06b6
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 5, 2025
15c9514
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 5, 2025
afd55c2
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 5, 2025
b596fa2
fix: remove 'del gateup_input_scale' to avoid H20*8 OOM, and move ann…
TianQiLin666666 Jun 6, 2025
fe97d1f
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 7, 2025
431f47a
Merge pull request #2 from TianQiLin666666/feat/ep_moe_deepgemm_zxt
xutizhou Jun 9, 2025
1422da6
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 10, 2025
bbbd98a
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 10, 2025
35cc918
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 11, 2025
1a26425
dispose hidden_states
ch-wan Jun 11, 2025
a902338
Merge branch 'main' into feat/ep_moe_deepgemm
TianQiLin666666 Jun 16, 2025
d6b3afd
fix: call deep_gemm_wrapper APIs in epmoe forward_deepgemm
TianQiLin666666 Jun 16, 2025
b51f324
assert act=="silu" in epmoe forward_deepgemm
TianQiLin666666 Jun 16, 2025
ad9abb2
fix(epmoe): remove _ENABLE_JIT_DEEPGEMM
TianQiLin666666 Jun 16, 2025
1869b18
fix(fill_gateup_input_triton_kernel): pre-define a tl.arange() outsid…
TianQiLin666666 Jun 16, 2025
77123c6
replace deepgemm_post_reorder_triton_kernel with post_reorder_triton_…
TianQiLin666666 Jun 16, 2025
ddc0c33
fix args of post_reorder_triton_kernel in all tests and benchmarks
TianQiLin666666 Jun 16, 2025
8e71771
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 17, 2025
a6d61f6
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 18, 2025
e30b3ab
Merge branch 'main' into feat/ep_moe_deepgemm
zhyncs Jun 18, 2025
aafbe4e
add num_fused_shared_experts
TianQiLin666666 Jun 19, 2025
0ea5bd9
fix(moe_deepgemm): convert per-tensor weight quant to per-block quant…
TianQiLin666666 Jun 19, 2025
87d68e9
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 19, 2025
26040f4
Merge branch 'main' into feat/ep_moe_deepgemm
xutizhou Jun 20, 2025
bbb127d
Merge branch 'main' into feat/ep_moe_deepgemm
ch-wan Jun 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,3 +1085,210 @@ def tma_align_input_scale(input_scale: torch.Tensor):
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
return output.t()[:m]


@triton.jit
def compute_masked_m_triton_kernel(seg_indptr, masked_m, num_experts, N):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter N in the compute_masked_m_triton_kernel function signature appears to be unused within the kernel's body. If it's not required for any logic (e.g., boundary checks that might be missing or planned), could it be removed to simplify the signature and avoid confusion? The kernel is launched with a grid size of (num_experts,) and uses tl.program_id(0) to get expert_id, suggesting N might not be directly used for indexing in its current form.

expert_id = tl.program_id(0)
start = tl.load(seg_indptr + expert_id)
end = tl.load(seg_indptr + expert_id + 1)
tl.store(masked_m + expert_id, (end - start))


@triton.jit
def deepgemm_compute_src2dst_triton_kernel(
topk_ids,
reorder_ids,
seg_indptr,
src2dst,
max_m,
num_toks,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = dst_id < num_toks
src_id = tl.load(reorder_ids + dst_id, mask=mask)
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
expert_dst_start = tl.load(seg_indptr + expert_id)
expert_dst_offset = dst_id - expert_dst_start
dst_id = expert_id * max_m + expert_dst_offset
tl.store(src2dst + src_id, dst_id, mask=mask)


@triton.jit
def fill_gateup_input_triton_kernel(
input_ptr,
scale_ptr,
gateup_input_ptr,
gateup_input_scale_ptr,
src2dst_ptr,
topk_ids_ptr,
start_expert_id,
end_expert_id,
topk,
m_max,
hidden_size,
scale_size,
BLOCK_SIZE: tl.constexpr,
):

src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
src_ptr = input_ptr + src_idx * hidden_size
scale_src_ptr = scale_ptr + src_idx * scale_size

for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
dst_idx = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx - start_expert_id * m_max
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we pre-define a tl.arange() outside this for-loop and reuse it, can this kernel be faster?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

mask = offset < hidden_size
in_data = tl.load(src_ptr + offset, mask=mask)
tl.store(dst_ptr + offset, in_data, mask=mask)
scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size
for start_offset in tl.range(0, scale_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < scale_size
in_scale = tl.load(scale_src_ptr + offset, mask=mask)
tl.store(scale_dst_ptr + offset, in_scale, mask=mask)


def exp2_upper(num: int) -> int:
for i in range(2, 31):
value = pow(2, i)
if num <= value:
return value
return num
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The exp2_upper function calculates powers of 2 from 2^2 up to 2^30.

  1. If num is less than or equal to 2^1 (e.g., 1 or 2), the loop condition num <= value might not behave as expected for the first few powers of 2, potentially returning num itself. For instance, if num=1, it returns 1. If num=2, it returns 2.
  2. If num is greater than 2^30, the function returns num, which wouldn't be a power of 2.

Is this behavior intended for these edge cases? If the goal is to always find the smallest power of 2 greater than or equal to num (for num > 0), an alternative like 1 << (num - 1).bit_length() could be considered. If the current behavior is correct, adding a comment to clarify these boundary conditions would be helpful.

def exp2_upper(num: int) -> int:
    if num <= 0:
        # Or raise an error, or return a defined value like 1 or 2
        # depending on how non-positive inputs should be handled.
        # Assuming num is expected to be positive for this context.
        return 1 # Smallest power of 2, or handle error
    if num == 1:
        return 1 # 2^0, or 2 if strictly greater power of 2 is needed
    # Smallest power of 2 greater than or equal to num
    return 1 << (num - 1).bit_length()



def moe_ep_deepgemm_preproess(
topk_ids: torch.Tensor,
num_experts: int,
hidden_states: torch.Tensor,
top_k: int,
start_expert_id,
end_expert_id,
block_shape,
output_dtype: torch.dtype = torch.float8_e4m3fn,
):
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)

compute_seg_indptr_triton_kernel[(num_experts,)](
reorder_topk_ids, seg_indptr, topk_ids.numel()
)

grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
compute_masked_m_triton_kernel[(num_experts,)](
seg_indptr, masked_m, num_experts, reorder_topk_ids.numel()
)

m_max = exp2_upper(hidden_states.size(0))
expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
gateup_input = torch.empty(
(int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
device=hidden_states.device,
dtype=output_dtype,
)

deepgemm_compute_src2dst_triton_kernel[grid](
topk_ids,
reorder_ids,
seg_indptr,
src2dst,
m_max,
topk_ids.numel(),
BLOCK_SIZE=256,
)

if block_shape is not None:
assert len(block_shape) == 2
block_n, block_k = block_shape[0], block_shape[1]
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)

gateup_input_scale = torch.empty(
(gateup_input.size(0), gateup_input.size(1), scale.size(1)),
device=hidden_states.device,
dtype=scale.dtype,
)

fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
scale,
gateup_input,
gateup_input_scale,
src2dst,
topk_ids,
start_expert_id,
end_expert_id,
top_k,
m_max,
hidden_states.size(1),
scale.size(1),
BLOCK_SIZE=1024,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential UnboundLocalError for scale and gateup_input_scale and incorrect arguments passed to fill_gateup_input_triton_kernel if block_shape is None.

  1. If block_shape is None, scale (from per_token_group_quant_fp8) and gateup_input_scale are not defined within the if block_shape is not None: block (lines 1211-1220).
  2. However, scale is unconditionally passed to fill_gateup_input_triton_kernel (line 1224), and gateup_input_scale is returned by the function (line 1244).
  3. Additionally, hidden_states.size(1) and scale.size(1) are passed as hidden_size and scale_size to the kernel. If block_shape is None, hidden_states refers to the original unquantized tensor, but scale would be undefined, causing an error when scale.size(1) is accessed.

These variables need to be correctly defined and the appropriate (possibly quantized) hidden states and their sizes should be passed to the kernel regardless of the block_shape condition. Consider initializing scale and gateup_input_scale (e.g., to None or dummy tensors if the kernel requires them) and using a separate variable for the hidden states that are actually passed to the kernel.

    # Initialize variables that might be conditionally defined or modified
    kernel_passed_hidden_states = hidden_states
    kernel_passed_scale = None
    # gateup_input_scale is returned, so it must be defined.
    # Initialize to None or a default based on expected behavior when block_shape is None.
    # If the kernel always expects a tensor, a dummy tensor should be created here.
    # For now, let's assume it can be None if not used by the kernel in that path.
    # This needs careful verification against the kernel's expectations.
    returnable_gateup_input_scale = None

    # Determine sizes for the kernel call, these will be updated if quantization occurs
    final_kernel_hidden_size = hidden_states.size(1)
    final_kernel_scale_size = 0  # Default if no scale quantization

    if block_shape is not None:
        assert len(block_shape) == 2
        # block_n is not used in this part of the preprocessing for input quantization
        block_k = block_shape[1]
        quantized_hidden_states, scale_values = per_token_group_quant_fp8(hidden_states, block_k)

        kernel_passed_hidden_states = quantized_hidden_states
        kernel_passed_scale = scale_values

        returnable_gateup_input_scale = torch.empty(
            (gateup_input.size(0), gateup_input.size(1), kernel_passed_scale.size(1)),
            device=kernel_passed_hidden_states.device,  # Use device of the tensor being processed
            dtype=kernel_passed_scale.dtype,
        )
        final_kernel_hidden_size = kernel_passed_hidden_states.size(1)
        final_kernel_scale_size = kernel_passed_scale.size(1)
    else:
        # If block_shape is None, quantization via per_token_group_quant_fp8 is skipped.
        # The kernel fill_gateup_input_triton_kernel still expects scale_ptr and scale_size.
        # If the kernel requires valid tensors for scale and gateup_input_scale even when no
        # block_shape is provided (e.g., for a non-quantized path or different quant type),
        # dummy tensors should be created here. Otherwise, if it can handle None or scale_size=0,
        # current_scale = None and current_gateup_input_scale = None would be appropriate.
        # This example creates minimal dummy scales if the kernel cannot handle None.
        # This part needs to align with the Triton kernel's expectation for a no-quantization path.
        # For safety, providing dummy tensors if the kernel always reads them:
        kernel_passed_scale = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.float32)
        returnable_gateup_input_scale = torch.ones((gateup_input.size(0), gateup_input.size(1), 1), device=hidden_states.device, dtype=torch.float32)
        final_kernel_scale_size = 1 # For the dummy scale

    fill_gateup_input_triton_kernel[(kernel_passed_hidden_states.shape[0],)](
        kernel_passed_hidden_states,
        kernel_passed_scale, 
        gateup_input,
        returnable_gateup_input_scale,
        src2dst,
        topk_ids,
        start_expert_id,
        end_expert_id,
        top_k,
        m_max,
        final_kernel_hidden_size, # Size of kernel_passed_hidden_states
        final_kernel_scale_size,  # Size of kernel_passed_scale
        BLOCK_SIZE=1024,
    )

    return (
        m_max,
        masked_m[start_expert_id : (end_expert_id + 1)],
        expected_m,
        src2dst,
        gateup_input,
        returnable_gateup_input_scale, # Return the (potentially dummy or None) gateup_input_scale
    )


return (
m_max,
masked_m[start_expert_id : (end_expert_id + 1)],
expected_m,
src2dst,
gateup_input,
gateup_input_scale,
)


@triton.jit
def deepgemm_post_reorder_triton_kernel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this kernel is very similar with post_reorder_triton_kernel. Can we simply add an additional argument to post_reorder_triton_kernel to control the workflow? This can avoid repetitive implementation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

down_output_ptr,
output_ptr,
src2dst_ptr,
topk_ids_ptr,
topk_weights_ptr,
start_expert_id,
end_expert_id,
topk,
hidden_size,
max_m,
BLOCK_SIZE: tl.constexpr,
):
InDtype = down_output_ptr.dtype.element_ty

src_idx = tl.program_id(0)
src2dst_ptr = src2dst_ptr + src_idx * topk
topk_ids_ptr = topk_ids_ptr + src_idx * topk
topk_weights_ptr = topk_weights_ptr + src_idx * topk

computed = False
store_ptr = output_ptr + src_idx * hidden_size
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size

sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
for idx in range(topk):
expert_id = tl.load(topk_ids_ptr + idx)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
computed = True
dst_idx = tl.load(src2dst_ptr + idx)
dst_idx = dst_idx - start_expert_id * max_m
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
load_ptr = down_output_ptr + dst_idx * hidden_size
in_data = tl.load(load_ptr + offset, mask=mask)
sum_vec += in_data * weigh_scale
tl.store(store_ptr + offset, sum_vec, mask=mask)

if computed == False:
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < hidden_size
tl.store(
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
)
139 changes: 135 additions & 4 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.moe.ep_moe.kernels import (
deepgemm_post_reorder_triton_kernel,
ep_gather,
ep_scatter,
gelu_and_mul_triton_kernel,
grouped_gemm_triton,
moe_ep_deepgemm_preproess,
Copy link

Copilot AI Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The function name moe_ep_deepgemm_preproess seems to have a typo (preproess); consider renaming it to moe_ep_deepgemm_preprocess for clarity.

Suggested change
moe_ep_deepgemm_preproess,
moe_ep_deepgemm_preprocess,

Copilot uses AI. Check for mistakes.
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
run_moe_ep_preproess,
Expand All @@ -49,13 +51,21 @@
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import (
scaled_fp8_quant,
sglang_per_token_quant_fp8,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
from sglang.srt.utils import (
DeepEPMode,
dispose_tensor,
get_bool_env_var,
is_cuda,
is_hip,
set_weight_attrs,
)

_is_hip = is_hip()

Expand Down Expand Up @@ -228,13 +238,134 @@ def __init__(

self.grouped_gemm_runner = None

self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)

def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
if use_deep_gemm and _ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm(hidden_states, router_logits)
else:
return self.forward_normal(hidden_states, router_logits)

def forward_deepgemm(
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
):
assert self.quant_method is not None

topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
routed_scaling_factor=self.routed_scaling_factor,
)

# PreReorder
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
moe_ep_deepgemm_preproess(
topk_ids,
self.num_experts,
hidden_states,
self.top_k,
self.start_expert_id,
self.end_expert_id,
self.block_shape,
)
)
gateup_input_fp8 = (
gateup_input,
get_col_major_tma_aligned_tensor(gateup_input_scale),
)

# GroupGemm-0
num_groups, m, k = gateup_input_fp8[0].size()
n = self.w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states.device, dtype=torch.bfloat16
)

m_grouped_gemm_fp8_fp8_bf16_nt_masked(
gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
)

# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=gateup_output.device,
dtype=self.fp8_dtype,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=gateup_output.device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
)

# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
get_col_major_tma_aligned_tensor(down_input_scale),
)
down_output = torch.empty(
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
)

# PostReorder
output = torch.empty_like(hidden_states)
deepgemm_post_reorder_triton_kernel[(hidden_states.size(0),)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states.size(1),
m_max,
BLOCK_SIZE=512,
)
return output

def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner(
hidden_states.device,
Expand Down
Loading