Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f7c007a
Create marlin_utils_fp4.py
zhangxiaolei123456 Apr 25, 2026
e6d566d
Update marlin_template.h
zhangxiaolei123456 Apr 25, 2026
2e9f5ad
Update server_args.py
zhangxiaolei123456 Apr 25, 2026
c20d661
Update mxfp4_deepseek.py
zhangxiaolei123456 Apr 25, 2026
30cd6a2
Update fused_marlin_moe.py
zhangxiaolei123456 Apr 25, 2026
fffc4fa
Update marlin.py
zhangxiaolei123456 Apr 25, 2026
b40572f
Update fp8.py
zhangxiaolei123456 Apr 25, 2026
ab45dd7
Update marlin.py
zhangxiaolei123456 Apr 25, 2026
af9b12f
Update fused_marlin_moe.py
zhangxiaolei123456 Apr 25, 2026
0e58b9d
Update fp8.py
zhangxiaolei123456 Apr 25, 2026
c196b1f
Update marlin.py
zhangxiaolei123456 Apr 25, 2026
bbc0856
Update marlin_utils_fp4.py
zhangxiaolei123456 Apr 25, 2026
1d34449
Update mxfp4_deepseek.py
zhangxiaolei123456 Apr 25, 2026
c10de75
Update fused_marlin_moe.py
zhangxiaolei123456 Apr 25, 2026
f67f765
Update marlin_template.h
zhangxiaolei123456 Apr 25, 2026
ca259d6
Merge branch 'deepseek_v4' into deepseek_v4_w4a16
zhangxiaolei123456 Apr 25, 2026
afb1778
Update ops.cu
zhangxiaolei123456 Apr 26, 2026
3196c2b
Update ops.cu
zhangxiaolei123456 Apr 26, 2026
dc458b5
Update ops.cu
zhangxiaolei123456 Apr 26, 2026
d491f0c
Update fused_marlin_moe.py
zhangxiaolei123456 Apr 26, 2026
92ae5f6
Update marlin.py
zhangxiaolei123456 Apr 26, 2026
f3dd1e6
Update fp8.py
zhangxiaolei123456 Apr 26, 2026
a6b05c8
Update mxfp4_deepseek.py
zhangxiaolei123456 Apr 26, 2026
0c4d0ea
Update marlin_utils_fp4.py
zhangxiaolei123456 Apr 26, 2026
c06086b
Update marlin_template.h
zhangxiaolei123456 Apr 26, 2026
2c3ab62
Update ops.cu
zhangxiaolei123456 Apr 26, 2026
1b4f142
Merge branch 'deepseek_v4' into deepseek_v4_w4a16
Fridge003 Apr 27, 2026
25159a3
small fix
Fridge003 Apr 28, 2026
0177192
Merge branch 'deepseek_v4' into deepseek_v4_w4a16
zhangxiaolei123456 Apr 28, 2026
ab3ae45
Merge branch 'deepseek_v4' into deepseek_v4_w4a16
zhangxiaolei123456 Apr 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 63 additions & 25 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,48 @@
from typing import Optional

import torch
import torch.nn.functional as F

from sglang.srt.utils import is_cuda
from sglang.srt.utils.custom_op import register_custom_op

_is_cuda = is_cuda()

if _is_cuda:
from sgl_kernel import moe_sum_reduce, silu_and_mul


def get_scalar_type(num_bits: int, has_zp: bool):
from sgl_kernel import silu_and_mul
from sgl_kernel.scalar_type import scalar_types


def get_scalar_type(num_bits: int, has_zp: bool, scales: Optional[torch.Tensor] = None):
if (
not has_zp
and num_bits == 4
and scales is not None
and scales.dtype == torch.float8_e8m0fnu
):
return scalar_types.float4_e2m1f
if has_zp:
assert num_bits == 4
return scalar_types.uint4
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128


def swiglu_limit_func(
output: torch.Tensor,
input: torch.Tensor, # first half is gate, second half is up
swiglu_limit: float = 0.0,
) -> None:
d = input.shape[1] // 2
gate = input[:, :d]
up = input[:, d:]

if swiglu_limit > 0:
gate = torch.clamp(gate, max=swiglu_limit)
up = torch.clamp(up, min=-swiglu_limit, max=swiglu_limit)

output.copy_(F.silu(gate) * up)


@register_custom_op(out_shape="hidden_states")
def fused_marlin_moe(
hidden_states: torch.Tensor,
Expand All @@ -44,6 +66,7 @@ def fused_marlin_moe(
is_k_full: bool = True,
inplace: bool = False,
routed_scaling_factor: Optional[float] = None,
clamp_limit: Optional[float] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
Expand Down Expand Up @@ -83,12 +106,29 @@ def fused_marlin_moe(
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert (
hidden_states.dtype == w1_scale.dtype
), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})"
assert (
hidden_states.dtype == w2_scale.dtype
), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})"
is_mxfp4_marlin = (
num_bits == 4
and w1_zeros is None
and w2_zeros is None
and w1_scale.dtype == torch.float8_e8m0fnu
and w2_scale.dtype == torch.float8_e8m0fnu
)
if is_mxfp4_marlin:
assert w1_scale.dtype == torch.float8_e8m0fnu, (
"MXFP4 Marlin expects w1_scale to be torch.float8_e8m0fnu, "
f"got {w1_scale.dtype}"
)
assert w2_scale.dtype == torch.float8_e8m0fnu, (
"MXFP4 Marlin expects w2_scale to be torch.float8_e8m0fnu, "
f"got {w2_scale.dtype}"
)
else:
assert (
hidden_states.dtype == w1_scale.dtype
), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})"
assert (
hidden_states.dtype == w2_scale.dtype
), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})"
assert num_bits in [4, 8]

M, K = hidden_states.shape
Expand Down Expand Up @@ -119,8 +159,8 @@ def fused_marlin_moe(
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
)

scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None, w1_scale)
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None, w2_scale)

intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
Expand All @@ -140,7 +180,7 @@ def fused_marlin_moe(
use_atomic_add = (
hidden_states.dtype == torch.half
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
)
) and (not is_mxfp4_marlin)

intermediate_cache1 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
hidden_states,
Expand Down Expand Up @@ -171,7 +211,14 @@ def fused_marlin_moe(
is_zp_float=False,
)

silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2)
if clamp_limit is not None:
swiglu_limit_func(
intermediate_cache2,
intermediate_cache1.view(-1, 2 * N),
clamp_limit,
)
else:
silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2)

if expert_map is not None:
intermediate_cache3.zero_()
Expand Down Expand Up @@ -206,13 +253,4 @@ def fused_marlin_moe(
).view(-1, topk, K)

output = hidden_states if inplace else torch.empty_like(hidden_states)

if routed_scaling_factor is None:
routed_scaling_factor = 1.0

moe_sum_reduce(
intermediate_cache3,
output,
routed_scaling_factor,
)
return output
return torch.sum(intermediate_cache3, dim=1, out=output)
32 changes: 30 additions & 2 deletions python/sglang/srt/layers/moe/moe_runner/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

import torch

from sglang.srt.debug_utils.deepseek_v4_debug_utils import (
deepseek_v4_moe_code_path_checker,
)
from sglang.srt.environ import envs
from sglang.srt.layers.moe.moe_runner.base import (
MoeQuantInfo,
MoeRunnerConfig,
Expand Down Expand Up @@ -97,8 +101,31 @@ def fused_experts_none_to_marlin(
hidden_states.device, max_blocks_per_sm=4
)

if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B" and (
runner_config.swiglu_limit is not None
):
deepseek_v4_moe_code_path_checker.observed += 1

marlin_hidden_states = hidden_states
# Avoid aliasing the MoE input buffer until Marlin output semantics are
# fully validated across shared-expert and overlap paths.
marlin_inplace = False
if (
quant_info.weight_bits == 4
and quant_info.w13_qzeros is None
and quant_info.w2_qzeros is None
and quant_info.w13_scales.dtype == torch.float8_e8m0fnu
and quant_info.w2_scales.dtype == torch.float8_e8m0fnu
and hidden_states.dtype == torch.float16
):
# MXFP4(E8M0) Marlin kernels are only numerically valid on the bf16
# activation path. The fp16 + E8M0 path is intentionally not generated
# in sgl-kernel, so upcast activations here and cast the result back.
marlin_hidden_states = hidden_states.to(torch.bfloat16)
marlin_inplace = False

output = fused_marlin_moe(
hidden_states=hidden_states,
hidden_states=marlin_hidden_states,
w1=quant_info.w13_qweight,
w2=quant_info.w2_qweight,
w1_scale=quant_info.w13_scales,
Expand All @@ -116,8 +143,9 @@ def fused_experts_none_to_marlin(
workspace=MARLIN_MOE_WORKSPACE,
num_bits=quant_info.weight_bits,
is_k_full=quant_info.is_k_full,
inplace=runner_config.inplace,
inplace=marlin_inplace,
routed_scaling_factor=runner_config.routed_scaling_factor,
clamp_limit=runner_config.swiglu_limit,
).to(hidden_states.dtype)

return StandardCombineInput(
Expand Down
78 changes: 46 additions & 32 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,10 @@ def get_quant_method(
if (
envs.SGLANG_DSV4_MODE.get() == "2604"
and envs.SGLANG_DSV4_FP4_EXPERTS.get()
and get_moe_runner_backend().is_flashinfer_mxfp4()
and (
get_moe_runner_backend().is_flashinfer_mxfp4()
or get_moe_runner_backend().is_marlin()
)
):
from sglang.srt.layers.quantization.mxfp4_deepseek import (
DeepSeekMxfp4MoEMethod,
Expand Down Expand Up @@ -929,41 +932,52 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None:
will_use_deepgemm = self.is_deepgemm_moe_runner_backend_enabled()

if self.is_fp4_expert:
layer.w13_weight.data = layer.w13_weight.data.view(torch.int8)
layer.w2_weight.data = layer.w2_weight.data.view(torch.int8)

if envs.SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE.get():
from sglang.srt.models.deepseek_v4 import (
build_mega_moe_experts_weights,
if get_moe_runner_backend().is_marlin():
layer.w13_weight.data = layer.w13_weight.data.view(torch.int8)
layer.w2_weight.data = layer.w2_weight.data.view(torch.int8)
elif not get_moe_runner_backend().is_flashinfer_mxfp4():
raise NotImplementedError(
"DeepSeekV4 FP4 experts now require a native FP4 MoE backend. "
"Use `--moe-runner-backend marlin` on Hopper or "
"`--moe-runner-backend flashinfer_mxfp4` when available."
)

build_mega_moe_experts_weights(layer)
return
else:
layer.w13_weight.data = layer.w13_weight.data.view(torch.int8)
layer.w2_weight.data = layer.w2_weight.data.view(torch.int8)

if (
envs.SGLANG_OPT_DEEPGEMM_SCALE_CONVERT_AT_INIT.get()
and envs.SGLANG_DSV4_MODE.get() == "2604"
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
and will_use_deepgemm
):
from deep_gemm import transform_sf_into_required_layout

for scale_param, weight_param in [
(layer.w13_weight_scale_inv, layer.w13_weight),
(layer.w2_weight_scale_inv, layer.w2_weight),
]:
num_experts, n, _ = scale_param.data.shape
k = weight_param.shape[2] * 2
scale_param.data = transform_sf_into_required_layout(
scale_param.data,
mn=n,
k=k,
recipe=(1, 32),
num_groups=num_experts,
disable_ue8m0_cast=False,
if envs.SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE.get():
from sglang.srt.models.deepseek_v4 import (
build_mega_moe_experts_weights,
)
layer.w13_weight_scale_inv.format_ue8m0 = True
layer.w2_weight_scale_inv.format_ue8m0 = True

build_mega_moe_experts_weights(layer)
return

if (
envs.SGLANG_OPT_DEEPGEMM_SCALE_CONVERT_AT_INIT.get()
and envs.SGLANG_DSV4_MODE.get() == "2604"
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
and will_use_deepgemm
):
from deep_gemm import transform_sf_into_required_layout

for scale_param, weight_param in [
(layer.w13_weight_scale_inv, layer.w13_weight),
(layer.w2_weight_scale_inv, layer.w2_weight),
]:
num_experts, n, _ = scale_param.data.shape
k = weight_param.shape[2] * 2
scale_param.data = transform_sf_into_required_layout(
scale_param.data,
mn=n,
k=k,
recipe=(1, 32),
num_groups=num_experts,
disable_ue8m0_cast=False,
)
layer.w13_weight_scale_inv.format_ue8m0 = True
layer.w2_weight_scale_inv.format_ue8m0 = True

if (
not self.is_fp4_expert
Expand Down
Loading
Loading