Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
fc5834f
feat: NVFP4 Marlin fallback for non-Blackwell (SM75+); Linear + MoE, …
Godmook Mar 2, 2026
a30158b
lint
Godmook Mar 2, 2026
3a0b745
Modify
Godmook Mar 2, 2026
8bf0d2a
Resolve 'NaN' Bugs
Godmook Mar 2, 2026
f2d2980
feat(quantization): Enable NVFP4 inference on non-Blackwell GPUs via …
Godmook Mar 2, 2026
2f4ea2a
Lint
Godmook Mar 2, 2026
ea50117
Fixed — replaced the direct sgl_kernel import with get_scalar_types()…
Godmook Mar 2, 2026
80cfc25
ci: add test_nvfp4_marlin_fallback.py to __not_in_ci__ in run_suite
Godmook Mar 6, 2026
741c505
Move testfile and add test CI
Godmook Mar 6, 2026
443788f
Rerun CI
Godmook Mar 10, 2026
51e2675
ReRun CI
Godmook Mar 10, 2026
0c1083b
CUDI_CI Time Increase
Godmook Mar 12, 2026
1aabcea
fix(nvfp4): apply global_scale correctly in MoE Marlin fallback
Godmook Mar 13, 2026
5f6e09e
Restore Lazy Import
Godmook Mar 13, 2026
97c92c3
Fix Some Errors
Godmook Mar 13, 2026
c48421f
Add CUDA Detection and Make CI more gracefully due to JIT Cache Issue
Godmook Mar 13, 2026
15fffba
Add SGLANG_FORCE_NVFP4_MARLIN env var, use conditional JIT for NVFP4 …
Godmook Mar 13, 2026
75dd81f
Add #ifdef SGL_MOE_MARLIN_FP4 guard in marlin_template.h to reduce JI…
Godmook Mar 13, 2026
c79a549
Keep Scale as FP8
Godmook Mar 14, 2026
99b5714
Move Tensor Location
Godmook Mar 14, 2026
20aa3d1
FP8 Scale Location
Godmook Mar 14, 2026
6a3cc63
Change Formular of BF16/FP16
Godmook Mar 14, 2026
02298cb
Add Debug Log
Godmook Mar 14, 2026
4343901
Adaptor for FP8/FP4
Godmook Mar 14, 2026
83273b1
Modify Integer Issue
Godmook Mar 14, 2026
81f5f50
Modify Scale_Max
Godmook Mar 14, 2026
e7cdf84
Modify Global_Scale
Godmook Mar 14, 2026
e45a379
Remove normalization based on vLLM
Godmook Mar 14, 2026
4101751
Add More Test
Godmook Mar 14, 2026
1b0dad2
Fix Kernel Issue
Godmook Mar 14, 2026
a57d646
Refactoring and Remove Debugging and Fix MoE Kernel Issues
Godmook Mar 14, 2026
671c73c
Reduce DocString
Godmook Mar 14, 2026
eedab50
Remove E2E test on testnvfp4file
Godmook Mar 14, 2026
ca8b695
Return Global_scale
Godmook Mar 15, 2026
ef6f6c6
Fixing CI Errors
Godmook Mar 18, 2026
d1bf6a5
Merge main, fix test_mixed_precision_uses_nvfp4_min_capability for SM75+
Godmook Mar 18, 2026
d0b7ee5
Merge branch 'main' into nvfp4-marlin-fallback
Godmook Mar 24, 2026
2ea9789
Fix getAttr
Godmook Mar 27, 2026
2d48807
Merge branch 'main' into nvfp4-marlin-fallback
Godmook Mar 28, 2026
22a66a8
PCG support for NVFP4 Marlin linear (custom op + torch.ops)
Godmook Mar 30, 2026
3ea1b89
Change CI Name
Godmook Mar 30, 2026
bd73fc5
Add Test Coverage
Godmook Mar 30, 2026
133f017
Fix Test Issues
Godmook Mar 31, 2026
1699078
Merge branch 'main' into nvfp4-marlin-fallback
Godmook Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/references/environment_variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` |
| `SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2` | Apply per token group quantization kernel with fused silu and mul and masked m | `false` |
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
| `SGLANG_FORCE_NVFP4_MARLIN` | Force using NVFP4 Marlin fallback kernels even on Blackwell GPUs with native FP4 support | `false` |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have some performance advantage on Blackwell? Or just a normal feature?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No performance advantage. Blackwell's native FP4 is the default and faster. This env is purely for debugging/testing — e.g., comparing native vs Marlin accuracy, regression testing the Marlin path on Blackwell, or as a workaround if native FP4 has issues.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For RL you might want NVFP4 weights + BF16 activations (+Lora)

| `SGLANG_FLASHINFER_FP4_GEMM_BACKEND` (deprecated) | Select backend for `mm_fp4` on Blackwell GPUs. **DEPRECATED**: Please use `--fp4-gemm-backend` instead. | `` |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge conflict

Copy link
Copy Markdown
Contributor Author

@Godmook Godmook Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi. @b8zhong Thanks for the heads-up and sorry for this problems....
I traced all 16 files against current main and found 5 files with stale merge resolution from my earlier main merge. Here's the full list what I saw:

  1. docs/references/environment_variables.md

SGLANG_FLASHINFER_FP4_GEMM_BACKEND line — removed by #21536, I accidentally kept it. I'll fix it.
2. python/sglang/srt/environ.py

SGLANG_HICACHE_MAX_PINNED_RATIO — removed by #21884
SGLANG_ENABLE_MM_SPLITTING — removed by #21899

  1. python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py

silu_and_mul moved from sgl_kernel → sglang.jit_kernel.activation (#21766)
4. compressed_tensors_w4a4_nvfp4.py

Missing and not get_fp4_gemm_runner_backend().is_cutlass() guard on the flashinfer path
5. modelopt_quant.py

  • cutlass_fp4_gemm import changed to top-level try/except
  • New CUTLASS FP4 GEMM code path added
  • Same .is_cutlass() guard missing

The remaining 11 files are identical to main. I'll rebase on latest main to resolve all of these cleanly If you approve my plan. Really sorry about that...

| `SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN` | Quantize q_b_proj from BF16 to FP8 when launching DeepSeek NVFP4 checkpoint | `false` |
| `SGLANG_MOE_NVFP4_DISPATCH` | Use nvfp4 for moe dispatch (on flashinfer_cutlass or flashinfer_cutedsl moe runner backend) | `"false"` |
| `SGLANG_NVFP4_CKPT_FP8_NEXTN_MOE` | Quantize moe of nextn layer from BF16 to FP8 when launching DeepSeek NVFP4 checkpoint | `false` |
Expand Down
25 changes: 8 additions & 17 deletions python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,11 +484,11 @@ __global__ void Marlin(
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;

// Scale sizes/strides without act_order
int s_gl_stride = prob_n / 8;
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1)
: 1;
// FP4 (kFE2M1f) uses FP8 scales (1 byte/element), others use FP16 (2 bytes)
int s_gl_stride = prob_n / (w_type == host::kFE2M1f ? 16 : 8);
constexpr int s_sh_stride = 16 * thread_n_blocks / (w_type == host::kFE2M1f ? 16 : 8);
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;

Expand Down Expand Up @@ -540,8 +540,7 @@ __global__ void Marlin(
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) +
s_sh_stride * slice_col + threadIdx.x;
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x;
}
}
auto s_sh_wr = threadIdx.x;
Expand All @@ -563,15 +562,7 @@ __global__ void Marlin(
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int s_sh_rd;
if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) {
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;

s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;
s_sh_rd = s_sh_rd * 2 + warp_row % 2;

} else if constexpr (group_blocks != -1)
if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop)))
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8;
Expand Down Expand Up @@ -876,7 +867,7 @@ __global__ void Marlin(
cur_k += k_iter_size * (k % b_sh_wr_iters);

int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1));
int cur_group_id = k_blocks / group_blocks;

int4* sh_s_stage = sh_s + s_sh_stage * pipe;

Expand Down
53 changes: 21 additions & 32 deletions python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -626,11 +626,10 @@ __global__ void Marlin(
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;

// Scale sizes/strides without act_order
int s_gl_stride = prob_n / 8;
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1)
: 1;
int s_gl_stride = prob_n / (w_type == host::kFE2M1f ? 16 : 8);
constexpr int s_sh_stride = 16 * thread_n_blocks / (w_type == host::kFE2M1f ? 16 : 8);
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;

Expand Down Expand Up @@ -682,8 +681,7 @@ __global__ void Marlin(
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) +
s_sh_stride * slice_col + threadIdx.x;
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x;
}
}
auto s_sh_wr = threadIdx.x;
Expand All @@ -705,15 +703,7 @@ __global__ void Marlin(
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int s_sh_rd;
if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) {
auto warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;

s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;
s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;

} else if constexpr (group_blocks != -1)
if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop)))
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8;
Expand Down Expand Up @@ -1038,18 +1028,15 @@ __global__ void Marlin(
cur_k += k_iter_size * (k % b_sh_wr_iters);

int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1));
int cur_group_id = k_blocks / group_blocks;

int4* sh_s_stage = sh_s + s_sh_stage * pipe;

if constexpr (w_type_id != host::kFE2M1f.id()) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
} else {
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
reinterpret_cast<int2*>(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + k % 2];
reinterpret_cast<int2*>(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
}
}
Expand Down Expand Up @@ -1243,17 +1230,19 @@ __global__ void Marlin(
}
}

// Commented out FP4/FP8 scale dequantization since we don't generate
// kFE2M1f kernels to reduce compilation time
// if constexpr (w_type == host::kFE2M1f) {
// int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
// int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
//
// dequant_fp8_scales<scalar_t2, s_type_id>(
// s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
// dequant_fp8_scales<scalar_t2, s_type_id>(
// s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
// }
#ifdef SGL_MOE_MARLIN_FP4
// Convert FP8 per-group scales to BF16/FP16 before applying them.
// Required for kFE2M1f (NVFP4): frag_s holds raw float8_e4m3fn bytes;
// without this conversion scale<scalar_t> would misinterpret them as
// BF16/FP16, producing NaN/Inf multipliers.
if constexpr (w_type == host::kFE2M1f) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];

dequant_fp8_scales<scalar_t2, s_type_id>(s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
dequant_fp8_scales<scalar_t2, s_type_id>(s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
}
#endif

// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,9 @@ MarlinFuncPtr get_marlin_kernel(
COMMON_GET_IF(host::kU4B8)
COMMON_GET_IF(host::kU8B128)

#ifdef SGL_MOE_MARLIN_FP4
NVFP4_GET_IF(host::kFE2M1f)
#endif

BIGGROUP_GET_IF(host::kFE4M3fn)

Expand Down
24 changes: 23 additions & 1 deletion python/sglang/jit_kernel/moe_wna16_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@ def _jit_moe_wna16_marlin_module(dtype: torch.dtype) -> Module:
)


@cache_once
def _jit_moe_wna16_marlin_fp4_module(dtype: torch.dtype) -> Module:
"""Separate JIT module with NVFP4 (kFE2M1f) kernel instantiations enabled."""
args = make_cpp_args(dtype)
return load_jit(
"moe_wna16_marlin_fp4",
*args,
cuda_files=["gemm/marlin_moe/moe_wna16_marlin.cuh"],
extra_cuda_cflags=["-DSGL_MOE_MARLIN_FP4"],
cuda_wrappers=[
(
"moe_wna16_marlin_gemm",
f"moe_wna16_marlin_gemm<{args}>",
)
],
)


def _or_empty(
t: Optional[torch.Tensor], device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
Expand Down Expand Up @@ -134,7 +152,11 @@ def moe_wna16_marlin_gemm(
b_bias_t = _or_empty(b_bias_or_none, device, a.dtype)
global_scale_t = _or_empty(global_scale_or_none, device, a.dtype)

module = _jit_moe_wna16_marlin_module(a.dtype)
is_fp4 = global_scale_or_none is not None and global_scale_or_none.numel() > 0
if is_fp4:
module = _jit_moe_wna16_marlin_fp4_module(a.dtype)
else:
module = _jit_moe_wna16_marlin_module(a.dtype)
module.moe_wna16_marlin_gemm(
a,
c,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ class Envs:
SGLANG_CPU_QUANTIZATION = EnvBool(False)
SGLANG_USE_DYNAMIC_MXFP4_LINEAR = EnvBool(False)
SGLANG_FORCE_FP8_MARLIN = EnvBool(False)
SGLANG_FORCE_NVFP4_MARLIN = EnvBool(False)
SGLANG_MOE_NVFP4_DISPATCH = EnvBool(False)
SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN = EnvBool(False)
SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2 = EnvBool(False)
Expand Down
43 changes: 33 additions & 10 deletions python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ def get_scalar_type(num_bits: int, has_zp: bool):
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128


def _get_fp4_scalar_type():
from sglang.srt.layers.quantization.utils import get_scalar_types

_, scalar_types = get_scalar_types()
return scalar_types.float4_e2m1f


@register_custom_op(out_shape="hidden_states")
def fused_marlin_moe(
hidden_states: torch.Tensor,
Expand All @@ -46,6 +53,8 @@ def fused_marlin_moe(
is_k_full: bool = True,
inplace: bool = False,
routed_scaling_factor: Optional[float] = None,
w1_global_scale: Optional[torch.Tensor] = None,
w2_global_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
Expand Down Expand Up @@ -76,6 +85,13 @@ def fused_marlin_moe(
"""
from sglang.srt.layers.moe.fused_moe_triton import moe_align_block_size

# Detect FP4 Marlin mode (when global scales are provided)
_is_fp4_marlin = w1_global_scale is not None
if _is_fp4_marlin:
assert (
w2_global_scale is not None
), "Both w1_global_scale and w2_global_scale must be provided for FP4 Marlin mode"

assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // (
Expand All @@ -85,12 +101,14 @@ def fused_marlin_moe(
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert (
hidden_states.dtype == w1_scale.dtype
), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})"
assert (
hidden_states.dtype == w2_scale.dtype
), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})"
# For FP4 Marlin, scales are in special float8_e4m3fn format (not input dtype)
if not _is_fp4_marlin:
assert (
hidden_states.dtype == w1_scale.dtype
), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})"
assert (
hidden_states.dtype == w2_scale.dtype
), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})"
assert num_bits in [4, 8]

M, K = hidden_states.shape
Expand Down Expand Up @@ -121,8 +139,13 @@ def fused_marlin_moe(
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
)

scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
# FP4 Marlin uses float4_e2m1f scalar type (not uint4b8/uint8b128)
if _is_fp4_marlin:
scalar_type1 = _get_fp4_scalar_type()
scalar_type2 = _get_fp4_scalar_type()
else:
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)

intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
Expand Down Expand Up @@ -150,7 +173,7 @@ def fused_marlin_moe(
w1,
None, # b_bias_or_none
w1_scale,
None, # global_scale_or_none
w1_global_scale, # None for INT4/INT8, tensor for FP4 Marlin
w1_zeros,
g_idx1,
sort_indices1,
Expand Down Expand Up @@ -184,7 +207,7 @@ def fused_marlin_moe(
w2,
None, # b_bias_or_none
w2_scale,
None, # global_scale_or_none
w2_global_scale, # None for INT4/INT8, tensor for FP4 Marlin
w2_zeros,
g_idx2,
sort_indices2,
Expand Down
10 changes: 9 additions & 1 deletion python/sglang/srt/layers/moe/moe_runner/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,13 @@ class MarlinMoeQuantInfo(MoeQuantInfo):
w13_qzeros: Optional[torch.Tensor] = None
w2_qzeros: Optional[torch.Tensor] = None

# Optional
# FP4 Marlin specific (Optional)
w13_global_scale: Optional[torch.Tensor] = None
w2_global_scale: Optional[torch.Tensor] = None

# EP support (Optional)
expert_map: Optional[torch.Tensor] = None
global_num_experts: int = -1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this extra args global_num_experts? When will it be used?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed for Expert Parallelism (EP). Under EP, each rank holds only a subset of experts, so the local weight tensor's expert count E < total model experts. But topk_ids contains global expert IDs, and moe_align_block_size creates buckets indexed by expert ID — it needs the global count to size the output correctly. Without it, global IDs would exceed the local range and cause incorrect routing or out-of-bounds access. When EP is not used, -1 falls back to E (line 125-126 in fused_marlin_moe.py), so it's backward-compatible.



@register_fused_func("none", "marlin")
Expand Down Expand Up @@ -106,6 +111,7 @@ def fused_experts_none_to_marlin(
gating_output=topk_output.router_logits,
topk_weights=topk_output.topk_weights,
topk_ids=topk_output.topk_ids,
global_num_experts=quant_info.global_num_experts,
expert_map=quant_info.expert_map,
g_idx1=quant_info.w13_g_idx,
g_idx2=quant_info.w2_g_idx,
Expand All @@ -118,6 +124,8 @@ def fused_experts_none_to_marlin(
is_k_full=quant_info.is_k_full,
inplace=runner_config.inplace,
routed_scaling_factor=runner_config.routed_scaling_factor,
w1_global_scale=quant_info.w13_global_scale,
w2_global_scale=quant_info.w2_global_scale,
).to(hidden_states.dtype)

return StandardCombineInput(
Expand Down
Loading
Loading