diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py index 1f1c3e709d49..586041f1b9cd 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py @@ -1,19 +1,25 @@ from typing import Optional import torch +import torch.nn.functional as F from sglang.srt.utils import is_cuda from sglang.srt.utils.custom_op import register_custom_op _is_cuda = is_cuda() - if _is_cuda: - from sgl_kernel import moe_sum_reduce, silu_and_mul - - -def get_scalar_type(num_bits: int, has_zp: bool): + from sgl_kernel import silu_and_mul from sgl_kernel.scalar_type import scalar_types + +def get_scalar_type(num_bits: int, has_zp: bool, scales: Optional[torch.Tensor] = None): + if ( + not has_zp + and num_bits == 4 + and scales is not None + and scales.dtype == torch.float8_e8m0fnu + ): + return scalar_types.float4_e2m1f if has_zp: assert num_bits == 4 return scalar_types.uint4 @@ -21,6 +27,22 @@ def get_scalar_type(num_bits: int, has_zp: bool): return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 +def swiglu_limit_func( + output: torch.Tensor, + input: torch.Tensor, # first half is gate, second half is up + swiglu_limit: float = 0.0, +) -> None: + d = input.shape[1] // 2 + gate = input[:, :d] + up = input[:, d:] + + if swiglu_limit > 0: + gate = torch.clamp(gate, max=swiglu_limit) + up = torch.clamp(up, min=-swiglu_limit, max=swiglu_limit) + + output.copy_(F.silu(gate) * up) + + @register_custom_op(out_shape="hidden_states") def fused_marlin_moe( hidden_states: torch.Tensor, @@ -44,6 +66,7 @@ def fused_marlin_moe( is_k_full: bool = True, inplace: bool = False, routed_scaling_factor: Optional[float] = None, + clamp_limit: Optional[float] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -83,12 +106,29 @@ def fused_marlin_moe( assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float16, torch.bfloat16] - assert ( - hidden_states.dtype == w1_scale.dtype - ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})" - assert ( - hidden_states.dtype == w2_scale.dtype - ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})" + is_mxfp4_marlin = ( + num_bits == 4 + and w1_zeros is None + and w2_zeros is None + and w1_scale.dtype == torch.float8_e8m0fnu + and w2_scale.dtype == torch.float8_e8m0fnu + ) + if is_mxfp4_marlin: + assert w1_scale.dtype == torch.float8_e8m0fnu, ( + "MXFP4 Marlin expects w1_scale to be torch.float8_e8m0fnu, " + f"got {w1_scale.dtype}" + ) + assert w2_scale.dtype == torch.float8_e8m0fnu, ( + "MXFP4 Marlin expects w2_scale to be torch.float8_e8m0fnu, " + f"got {w2_scale.dtype}" + ) + else: + assert ( + hidden_states.dtype == w1_scale.dtype + ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})" + assert ( + hidden_states.dtype == w2_scale.dtype + ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})" assert num_bits in [4, 8] M, K = hidden_states.shape @@ -119,8 +159,8 @@ def fused_marlin_moe( max_workspace_size, dtype=torch.int, device=device, requires_grad=False ) - scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) - scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) + scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None, w1_scale) + scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None, w2_scale) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -140,7 +180,7 @@ def fused_marlin_moe( use_atomic_add = ( hidden_states.dtype == torch.half or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 - ) + ) and (not is_mxfp4_marlin) intermediate_cache1 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default( hidden_states, @@ -171,7 +211,14 @@ def fused_marlin_moe( is_zp_float=False, ) - silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2) + if clamp_limit is not None: + swiglu_limit_func( + intermediate_cache2, + intermediate_cache1.view(-1, 2 * N), + clamp_limit, + ) + else: + silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2) if expert_map is not None: intermediate_cache3.zero_() @@ -206,13 +253,4 @@ def fused_marlin_moe( ).view(-1, topk, K) output = hidden_states if inplace else torch.empty_like(hidden_states) - - if routed_scaling_factor is None: - routed_scaling_factor = 1.0 - - moe_sum_reduce( - intermediate_cache3, - output, - routed_scaling_factor, - ) - return output + return torch.sum(intermediate_cache3, dim=1, out=output) diff --git a/python/sglang/srt/layers/moe/moe_runner/marlin.py b/python/sglang/srt/layers/moe/moe_runner/marlin.py index 45104dd27805..a25791c423a9 100644 --- a/python/sglang/srt/layers/moe/moe_runner/marlin.py +++ b/python/sglang/srt/layers/moe/moe_runner/marlin.py @@ -5,6 +5,10 @@ import torch +from sglang.srt.debug_utils.deepseek_v4_debug_utils import ( + deepseek_v4_moe_code_path_checker, +) +from sglang.srt.environ import envs from sglang.srt.layers.moe.moe_runner.base import ( MoeQuantInfo, MoeRunnerConfig, @@ -97,8 +101,31 @@ def fused_experts_none_to_marlin( hidden_states.device, max_blocks_per_sm=4 ) + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B" and ( + runner_config.swiglu_limit is not None + ): + deepseek_v4_moe_code_path_checker.observed += 1 + + marlin_hidden_states = hidden_states + # Avoid aliasing the MoE input buffer until Marlin output semantics are + # fully validated across shared-expert and overlap paths. + marlin_inplace = False + if ( + quant_info.weight_bits == 4 + and quant_info.w13_qzeros is None + and quant_info.w2_qzeros is None + and quant_info.w13_scales.dtype == torch.float8_e8m0fnu + and quant_info.w2_scales.dtype == torch.float8_e8m0fnu + and hidden_states.dtype == torch.float16 + ): + # MXFP4(E8M0) Marlin kernels are only numerically valid on the bf16 + # activation path. The fp16 + E8M0 path is intentionally not generated + # in sgl-kernel, so upcast activations here and cast the result back. + marlin_hidden_states = hidden_states.to(torch.bfloat16) + marlin_inplace = False + output = fused_marlin_moe( - hidden_states=hidden_states, + hidden_states=marlin_hidden_states, w1=quant_info.w13_qweight, w2=quant_info.w2_qweight, w1_scale=quant_info.w13_scales, @@ -116,8 +143,9 @@ def fused_experts_none_to_marlin( workspace=MARLIN_MOE_WORKSPACE, num_bits=quant_info.weight_bits, is_k_full=quant_info.is_k_full, - inplace=runner_config.inplace, + inplace=marlin_inplace, routed_scaling_factor=runner_config.routed_scaling_factor, + clamp_limit=runner_config.swiglu_limit, ).to(hidden_states.dtype) return StandardCombineInput( diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index ed74b1d47e52..029cc9eedde9 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -189,7 +189,10 @@ def get_quant_method( if ( envs.SGLANG_DSV4_MODE.get() == "2604" and envs.SGLANG_DSV4_FP4_EXPERTS.get() - and get_moe_runner_backend().is_flashinfer_mxfp4() + and ( + get_moe_runner_backend().is_flashinfer_mxfp4() + or get_moe_runner_backend().is_marlin() + ) ): from sglang.srt.layers.quantization.mxfp4_deepseek import ( DeepSeekMxfp4MoEMethod, @@ -929,41 +932,52 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: will_use_deepgemm = self.is_deepgemm_moe_runner_backend_enabled() if self.is_fp4_expert: - layer.w13_weight.data = layer.w13_weight.data.view(torch.int8) - layer.w2_weight.data = layer.w2_weight.data.view(torch.int8) - - if envs.SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE.get(): - from sglang.srt.models.deepseek_v4 import ( - build_mega_moe_experts_weights, + if get_moe_runner_backend().is_marlin(): + layer.w13_weight.data = layer.w13_weight.data.view(torch.int8) + layer.w2_weight.data = layer.w2_weight.data.view(torch.int8) + elif not get_moe_runner_backend().is_flashinfer_mxfp4(): + raise NotImplementedError( + "DeepSeekV4 FP4 experts now require a native FP4 MoE backend. " + "Use `--moe-runner-backend marlin` on Hopper or " + "`--moe-runner-backend flashinfer_mxfp4` when available." ) - build_mega_moe_experts_weights(layer) - return + else: + layer.w13_weight.data = layer.w13_weight.data.view(torch.int8) + layer.w2_weight.data = layer.w2_weight.data.view(torch.int8) - if ( - envs.SGLANG_OPT_DEEPGEMM_SCALE_CONVERT_AT_INIT.get() - and envs.SGLANG_DSV4_MODE.get() == "2604" - and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 - and will_use_deepgemm - ): - from deep_gemm import transform_sf_into_required_layout - - for scale_param, weight_param in [ - (layer.w13_weight_scale_inv, layer.w13_weight), - (layer.w2_weight_scale_inv, layer.w2_weight), - ]: - num_experts, n, _ = scale_param.data.shape - k = weight_param.shape[2] * 2 - scale_param.data = transform_sf_into_required_layout( - scale_param.data, - mn=n, - k=k, - recipe=(1, 32), - num_groups=num_experts, - disable_ue8m0_cast=False, + if envs.SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE.get(): + from sglang.srt.models.deepseek_v4 import ( + build_mega_moe_experts_weights, ) - layer.w13_weight_scale_inv.format_ue8m0 = True - layer.w2_weight_scale_inv.format_ue8m0 = True + + build_mega_moe_experts_weights(layer) + return + + if ( + envs.SGLANG_OPT_DEEPGEMM_SCALE_CONVERT_AT_INIT.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + and will_use_deepgemm + ): + from deep_gemm import transform_sf_into_required_layout + + for scale_param, weight_param in [ + (layer.w13_weight_scale_inv, layer.w13_weight), + (layer.w2_weight_scale_inv, layer.w2_weight), + ]: + num_experts, n, _ = scale_param.data.shape + k = weight_param.shape[2] * 2 + scale_param.data = transform_sf_into_required_layout( + scale_param.data, + mn=n, + k=k, + recipe=(1, 32), + num_groups=num_experts, + disable_ue8m0_cast=False, + ) + layer.w13_weight_scale_inv.format_ue8m0 = True + layer.w2_weight_scale_inv.format_ue8m0 = True if ( not self.is_fp4_expert diff --git a/python/sglang/srt/layers/quantization/marlin_utils_fp4.py b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py new file mode 100644 index 000000000000..d0f75bdcbd66 --- /dev/null +++ b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import torch + +from sglang.srt.layers.quantization.marlin_utils import ( + marlin_make_workspace, + marlin_permute_bias, + marlin_permute_scales, +) +from sglang.srt.utils import is_cuda +from sglang.srt.utils.common import get_bool_env_var + +_is_cuda = is_cuda() +_INVERT_MXFP4_MARLIN_SCALES = get_bool_env_var("SGLANG_MXFP4_MARLIN_INVERT_SCALE") +_SKIP_MXFP4_MARLIN_SCALE_TRANSPOSE = get_bool_env_var( + "SGLANG_MXFP4_MARLIN_SKIP_SCALE_TRANSPOSE" +) +_SKIP_MXFP4_MARLIN_SCALE_SWIZZLE = get_bool_env_var( + "SGLANG_MXFP4_MARLIN_SKIP_SCALE_SWIZZLE" +) +_SKIP_MXFP4_MARLIN_W13_SCALE_TRANSPOSE = get_bool_env_var( + "SGLANG_MXFP4_MARLIN_W13_SKIP_SCALE_TRANSPOSE" +) +_SKIP_MXFP4_MARLIN_W13_SCALE_PERMUTE = get_bool_env_var( + "SGLANG_MXFP4_MARLIN_W13_SKIP_SCALE_PERMUTE" +) +_SKIP_MXFP4_MARLIN_W13_SCALE_SWIZZLE = get_bool_env_var( + "SGLANG_MXFP4_MARLIN_W13_SKIP_SCALE_SWIZZLE" +) +_SKIP_MXFP4_MARLIN_WEIGHT_REPACK_TRANSPOSE = get_bool_env_var( + "SGLANG_MXFP4_MARLIN_SKIP_WEIGHT_REPACK_TRANSPOSE" +) + +if _is_cuda: + from sgl_kernel import gptq_marlin_repack + + +def mxfp4_marlin_process_scales( + marlin_scales: torch.Tensor, + input_dtype: torch.dtype | None = None, + apply_swizzle: bool = True, +) -> torch.Tensor: + if ( + apply_swizzle + and not _SKIP_MXFP4_MARLIN_SCALE_SWIZZLE + and (input_dtype is None or input_dtype.itemsize == 2) + ): + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1 + ) + marlin_scales = marlin_scales.to(torch.float8_e8m0fnu) + if input_dtype == torch.float8_e4m3fn: + marlin_scales = marlin_scales.view(torch.uint8) + assert marlin_scales.max() <= 249 + # exponent_bias (fp4->fp8) = 2 ** 3 - 2 ** 1 = 6 + marlin_scales = marlin_scales + 6 + marlin_scales = marlin_scales.view(torch.float8_e8m0fnu) + return marlin_scales + + +def _normalize_scale_tensor( + scales: torch.Tensor, target_dtype: torch.dtype +) -> torch.Tensor: + # The kernel consumes E8M0 exponents. Regardless of the placeholder dtype + # the loader used, we want the *numerical* value 2**e in ``target_dtype``. + # float32/bfloat16/float16 containers hold the numerical 2**e directly + # (they were filled via a dtype-promoting copy from uint8/e8m0). + # uint8/int8 containers hold the raw E8M0 byte and must be reinterpreted. + if scales.dtype == torch.float8_e8m0fnu: + return scales.to(target_dtype) + if scales.dtype == torch.uint8: + return scales.view(torch.float8_e8m0fnu).to(target_dtype) + if scales.dtype == torch.int8: + return scales.view(torch.uint8).view(torch.float8_e8m0fnu).to(target_dtype) + if scales.dtype in (torch.float32, torch.bfloat16, torch.float16): + return scales.to(target_dtype) + raise TypeError(f"Unsupported MXFP4 scale dtype for Marlin: {scales.dtype}") + + +def prepare_moe_mxfp4_layer_for_marlin(layer: torch.nn.Module) -> None: + group_size = 32 + w13 = layer.w13_weight.data + w2 = layer.w2_weight.data + w13_scale = layer.w13_weight_scale_inv.data + w2_scale = layer.w2_weight_scale_inv.data + w13_bias = getattr(layer, "w13_bias", None) + w2_bias = getattr(layer, "w2_bias", None) + + num_experts = w13.shape[0] + intermediate_size = w13.shape[1] // 2 + hidden_size = w13.shape[2] * 2 + param_dtype = getattr( + layer, + "orig_dtype", + w13_bias.dtype if w13_bias is not None else torch.bfloat16, + ) + + device = w13.device + layer.workspace = marlin_make_workspace(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + def _repack_weight(weight: torch.Tensor, is_w13: bool) -> torch.Tensor: + if is_w13: + size_n, size_k = intermediate_size * 2, hidden_size + else: + size_n, size_k = hidden_size, intermediate_size + assert weight.shape == (num_experts, size_n, size_k // 2) + + tensor_list = [] + for i in range(num_experts): + qweight = weight[i].view(torch.int32) + expected_packed_k = size_k // (32 // 4) + if ( + not _SKIP_MXFP4_MARLIN_WEIGHT_REPACK_TRANSPOSE + or qweight.size(0) != expected_packed_k + ): + qweight = qweight.T + qweight = qweight.contiguous() + marlin_qweight = gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + tensor_list.append(marlin_qweight) + return torch.stack(tensor_list) + + def _permute_scales(scales: torch.Tensor, is_w13: bool) -> torch.Tensor: + scales = _normalize_scale_tensor(scales, param_dtype) + if _INVERT_MXFP4_MARLIN_SCALES: + scales = torch.reciprocal(scales) + + if is_w13: + size_n, size_k = intermediate_size * 2, hidden_size + else: + size_n, size_k = hidden_size, intermediate_size + + tensor_list = [] + for i in range(num_experts): + scale = scales[i] + skip_transpose = _SKIP_MXFP4_MARLIN_SCALE_TRANSPOSE or ( + is_w13 and _SKIP_MXFP4_MARLIN_W13_SCALE_TRANSPOSE + ) + if not skip_transpose: + scale = scale.T + scale = scale.contiguous() + skip_permute = is_w13 and _SKIP_MXFP4_MARLIN_W13_SCALE_PERMUTE + if skip_permute: + marlin_scales = scale + else: + marlin_scales = marlin_permute_scales( + s=scale, + size_k=size_k, + size_n=size_n, + group_size=group_size, + ) + apply_swizzle = not (is_w13 and _SKIP_MXFP4_MARLIN_W13_SCALE_SWIZZLE) + tensor_list.append( + mxfp4_marlin_process_scales( + marlin_scales, + input_dtype=param_dtype, + apply_swizzle=apply_swizzle, + ) + ) + return torch.stack(tensor_list) + + def _permute_bias(bias: torch.Tensor | None) -> torch.Tensor | None: + if bias is None: + return None + tensor_list = [] + for i in range(num_experts): + tensor_list.append(marlin_permute_bias(bias[i].to(param_dtype))) + return torch.stack(tensor_list) + + w13_marlin = _repack_weight(w13, True) + w2_marlin = _repack_weight(w2, False) + w13_scale_marlin = _permute_scales(w13_scale, True) + w2_scale_marlin = _permute_scales(w2_scale, False) + + layer.w13_weight = torch.nn.Parameter(w13_marlin, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_marlin, requires_grad=False) + layer.w13_weight_scale_inv = torch.nn.Parameter(w13_scale_marlin, requires_grad=False) + layer.w2_weight_scale_inv = torch.nn.Parameter(w2_scale_marlin, requires_grad=False) + + if w13_bias is not None: + layer.w13_bias = torch.nn.Parameter(_permute_bias(w13_bias), requires_grad=False) + if w2_bias is not None: + layer.w2_bias = torch.nn.Parameter(_permute_bias(w2_bias), requires_grad=False) diff --git a/python/sglang/srt/layers/quantization/mxfp4_deepseek.py b/python/sglang/srt/layers/quantization/mxfp4_deepseek.py index f4b76dc5a35c..8705f395eb87 100644 --- a/python/sglang/srt/layers/quantization/mxfp4_deepseek.py +++ b/python/sglang/srt/layers/quantization/mxfp4_deepseek.py @@ -15,10 +15,13 @@ use_symmetric_memory, ) from sglang.srt.layers.dp_attention import is_allocation_symmetric +from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo +from sglang.srt.layers.moe.utils import MoeRunnerBackend, get_moe_runner_backend from sglang.srt.layers.moe.utils import RoutingMethodType from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( is_flashinfer_available, + is_sm90_supported, log_info_on_rank0, set_weight_attrs, ) @@ -129,12 +132,17 @@ class DeepSeekMxfp4MoEMethod: def __init__(self, fp8_method, prefix: str): self._fp8 = fp8_method self.prefix = prefix + self.moe_runner_backend = get_moe_runner_backend() self.flashinfer_mxfp4_moe_precision = ( get_global_server_args().flashinfer_mxfp4_moe_precision ) def create_moe_runner(self, layer, moe_runner_config): self.moe_runner_config = moe_runner_config + if self.moe_runner_backend.is_marlin(): + from sglang.srt.layers.moe.moe_runner import MoeRunner + + self.runner = MoeRunner(MoeRunnerBackend.MARLIN, moe_runner_config) swiglu_limit = moe_runner_config.swiglu_limit is_2604b = envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B" @@ -225,6 +233,38 @@ def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_mega_moe_weights_built", False): return + if self.moe_runner_backend.is_marlin(): + from sglang.srt.layers.quantization.marlin_utils import ( + check_moe_marlin_supports_layer, + ) + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_moe_mxfp4_layer_for_marlin, + ) + + if not is_sm90_supported(): + raise RuntimeError( + "DeepSeekV4 MXFP4 Marlin fallback requires Hopper/SM90 or above." + ) + if not check_moe_marlin_supports_layer(layer, 32): + raise RuntimeError( + "Current DeepSeekV4 MoE layer does not satisfy Marlin constraints." + ) + + # NOTE: the Marlin MoE runner consumes w13 in the checkpoint's + # native ``[w1; w3]`` order -- see ``silu_and_mul`` in + # fused_marlin_moe.py which expects ``gate = intermediate[:, :N]`` + # (first half) and ``up = intermediate[:, N:]`` (second half). + # Unlike the flashinfer trtllm_fp4 kernel (which wants [w3, w1]), + # we must *not* call ``reorder_w1w3_to_w3w1`` here. + + log_info_on_rank0( + logger, + f"Preparing DeepSeekV4 MXFP4 experts for Marlin backend (layer: {self.prefix})...", + ) + prepare_moe_mxfp4_layer_for_marlin(layer) + layer._dsv4_mxfp4_backend = "marlin" + return + w13_w, w13_s = reorder_w1w3_to_w3w1( layer.w13_weight.data, layer.w13_weight_scale_inv.data ) @@ -340,6 +380,27 @@ def apply( layer: Module, dispatch_output: DispatchOutput, ) -> CombineInput: + if self.moe_runner_backend.is_marlin(): + from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput + from sglang.srt.layers.moe.topk import TopKOutputChecker + + topk_output = dispatch_output.topk_output + if not TopKOutputChecker.format_is_standard(topk_output): + raise ValueError(f"Unsupported topk output format: {topk_output.format}") + + quant_info = MarlinMoeQuantInfo( + w13_qweight=layer.w13_weight, + w2_qweight=layer.w2_weight, + w13_scales=layer.w13_weight_scale_inv, + w2_scales=layer.w2_weight_scale_inv, + w13_g_idx_sort_indices=None, + w2_g_idx_sort_indices=None, + weight_bits=4, + is_k_full=True, + ) + runner_output = self.runner.run(dispatch_output, quant_info=quant_info) + return StandardCombineInput(hidden_states=runner_output.hidden_states) + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput from sglang.srt.layers.moe.topk import TopKOutputChecker @@ -475,3 +536,4 @@ def apply( output.mul_(rsf) return StandardCombineInput(hidden_states=output) + diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 28a3c92e1b55..1aebc4143d66 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -181,6 +181,7 @@ "flashinfer_mxfp4", "flashinfer_cutedsl", "cutlass", + "marlin", ] MOE_A2A_BACKEND_CHOICES = ["none", "deepep", "mooncake", "ascend_fuseep", "flashinfer"] diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h b/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h index 1ba99787ba60..01da81c95395 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -28,6 +28,8 @@ #include "gemm/marlin/marlin_dtypes.cuh" #include "scalar_type.hpp" +#include + #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert( \ std::is_same::value || std::is_same::value, \ @@ -357,6 +359,7 @@ __global__ void Marlin( constexpr bool has_zp = w_type == sglang::kU4 || w_type == sglang::kU8; constexpr bool is_int_type = w_type == sglang::kU4 || w_type == sglang::kU8 || w_type == sglang::kU4B8 || w_type == sglang::kU8B128; + constexpr bool is_8bit_scale = s_type.size_bits() == 8; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = w_type == sglang::kFE4M3fn || w_type == sglang::kFE2M1f && s_type == sglang::kFE4M3fn || @@ -371,7 +374,7 @@ __global__ void Marlin( static_assert(thread_m_blocks == 1 || !m_block_size_8); constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; - const int scales_expert_stride = prob_n * prob_k / group_size / (w_type == sglang::kFE2M1f ? 16 : 8); + const int scales_expert_stride = prob_n * prob_k / group_size / (is_8bit_scale ? 16 : 8); const int zp_expert_stride = is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); const int b_bias_expert_stride = prob_n / 8; @@ -442,52 +445,75 @@ __global__ void Marlin( locks_off = (iters * blockIdx.x) / k_tiles - 1; } + int prob_m_top_k = prob_m * top_k; // read moe block data given block_id // block_sorted_ids / block_num_valid_tokens / block_topk_weights auto read_moe_block_data = [&](int block_id) { block_num_valid_tokens = moe_block_size; + + cp_async4_pred( + sh_block_sorted_ids_int4 + threadIdx.x, + reinterpret_cast(sorted_token_ids_ptr) + + (block_id * moe_block_size / 4 + threadIdx.x), + threadIdx.x < moe_block_size / 4); + + cp_async_fence(); + cp_async_wait<0>(); + + __syncthreads(); + + if (threadIdx.x >= threads - 32) { + constexpr int size_per_thread = div_ceil(moe_block_size, 32); + int lane_id = threadIdx.x - (threads - 32); + + int local_count = 0; #pragma unroll - for (int i = 0; i < moe_block_size / 4; i++) { - int4 sorted_token_ids_int4 = - reinterpret_cast(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; - int* sorted_token_ids = reinterpret_cast(&sorted_token_ids_int4); -#pragma unroll - for (int j = 0; j < 4; j++) { - if (sorted_token_ids[j] >= prob_m * top_k) { - block_num_valid_tokens = i * 4 + j; - break; + for (int i = 0; i < size_per_thread; i++) { + int j = lane_id * size_per_thread + i; + if (j < moe_block_size) { + int idx = sh_block_sorted_ids[j]; + if (idx < prob_m_top_k) local_count++; } } - if (block_num_valid_tokens != moe_block_size) break; - } - __syncthreads(); - int tid4 = threadIdx.x / 4; - if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) { - sh_block_sorted_ids_int4[tid4] = - reinterpret_cast(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + if constexpr (moe_block_size >= 16) + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 16); + if constexpr (moe_block_size >= 8) + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 8); + if constexpr (moe_block_size >= 4) + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 4); + if constexpr (moe_block_size >= 2) + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 2); + + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 1); + block_num_valid_tokens = local_count; +#else + block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count); +#endif -#pragma unroll - for (int i = 0; i < 4; i++) - sh_rd_block_sorted_ids[tid4 * 4 + i] = sh_block_sorted_ids[tid4 * 4 + i] / top_k; + if (lane_id == 0) reinterpret_cast(sh_new)[0] = block_num_valid_tokens; + } + + if (threadIdx.x < moe_block_size) { + int idx = sh_block_sorted_ids[threadIdx.x]; + sh_rd_block_sorted_ids[threadIdx.x] = idx / top_k; if (mul_topk_weights) { -#pragma unroll - for (int i = 0; i < 4; i++) { - int idx = tid4 * 4 + i; - // idx = idx < block_num_valid_tokens ? idx : 0; - if (idx < block_num_valid_tokens) { - if constexpr (w_type == sglang::kFE2M1f && s_type == sglang::kFE4M3fn) { - sh_block_topk_weights[idx] = - __hmul2(global_scale, Dtype::num2num2(Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]))); - } else { - sh_block_topk_weights[idx] = - Dtype::num2num2(Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); - } - } + idx = idx < prob_m_top_k ? idx : 0; + scalar_t topk_weight_tmp = Dtype::float2num(topk_weights_ptr[idx]); + if constexpr (w_type == sglang::kFE2M1f && s_type == sglang::kFE4M3fn) { + sh_block_topk_weights[threadIdx.x] = + __hmul2(global_scale, Dtype::num2num2(topk_weight_tmp)); + } else { + sh_block_topk_weights[threadIdx.x] = Dtype::num2num2(topk_weight_tmp); } } } + + __syncthreads(); + + block_num_valid_tokens = reinterpret_cast(sh_new)[0]; __syncthreads(); }; @@ -629,11 +655,10 @@ __global__ void Marlin( constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == sglang::kFE2M1f ? 2 : 1) - : 1; + int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8); + constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8); + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -684,13 +709,15 @@ __global__ void Marlin( if constexpr (!has_act_order) { if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == sglang::kFE2M1f ? 2 : 1) + - s_sh_stride * slice_col + threadIdx.x; + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; } } auto s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + bool s_sh_wr_pred = threadIdx.x < s_sh_stage; // Zero-points int zp_gl_rd; @@ -708,15 +735,7 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == sglang::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; - - } else if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; @@ -910,43 +929,21 @@ __global__ void Marlin( } else { if constexpr (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); } + s_gl_rd += s_gl_rd_delta * s_tb_groups; } } if constexpr (has_zp && group_blocks != -1) { int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); } + zp_gl_rd += zp_gl_rd_delta * zp_tb_groups; } } } @@ -1024,35 +1021,33 @@ __global__ void Marlin( } } else if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * (g * (pipe / g)); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + } } } else { auto warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; int warp_row = warp_id / n_warps; - int cur_k = warp_row * 16; cur_k += k_iter_size * (k % b_sh_wr_iters); - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / (group_blocks * (w_type == sglang::kFE2M1f ? 2 : 1)); + int cur_group_id = k_blocks / group_blocks; int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (w_type_id != sglang::kFE2M1f.id()) { - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { - reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + if constexpr (!is_8bit_scale) { + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; } else { reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + k % 2]; + reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } } } @@ -1246,17 +1241,27 @@ __global__ void Marlin( } } - // Commented out FP4/FP8 scale dequantization since we don't generate - // kFE2M1f kernels to reduce compilation time - // if constexpr (w_type == sglang::kFE2M1f) { - // int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; - // int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + // FP4/FP8 scale dequantization (E4M3 for NVFP4 and E8M0 for MXFP4). + // Aligns with vLLM's single-shot marlin_template.h path: convert the raw + // FP8 scale bytes packed into frag_s back to real bf16/half values before + // they multiply the dequantized weights. // - // dequant_fp8_scales( - // s_quant_0, reinterpret_cast(&frag_s[k2])); - // dequant_fp8_scales( - // s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); - // } + // The half2 + kFE8M0fnu specialization of dequant_fp8_scales is not + // defined on purpose (fp16 cannot safely represent 2^e for large |e|, so + // fp16+MXFP4 kernels are intentionally not instantiated in + // generate_kernels.py). Guard that combination at compile time so the + // unused code path is eliminated and no unresolved extern is emitted. + if constexpr ((s_type == sglang::kFE4M3fn || s_type == sglang::kFE8M0fnu) && + !(std::is_same::value && + s_type == sglang::kFE8M0fnu)) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales( + s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. @@ -1885,8 +1890,18 @@ __global__ void Marlin( slice_k_start_shared_fetch = slice_k_start; slice_n_offset = act_s_col_tb_stride * slice_col; } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / zp_sh_stride) + + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; + } } start_pipes(); } diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu index 57334663ad48..20f7de059f37 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu @@ -274,6 +274,46 @@ int get_kernel_cache_size( return total_size; } +sglang::ScalarType infer_scale_type( + const sglang::ScalarType& q_type, + const torch::Tensor& b_scales, + at::ScalarType compute_dtype) { + if (q_type == sglang::kFE2M1f) { + if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) { + return sglang::kFE4M3fn; + } + if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) { + return sglang::kFE8M0fnu; + } + TORCH_CHECK( + false, + "When q_type=float4_e2m1f, b_scales must use float8_e4m3fn (NVFP4) " + "or float8_e8m0fnu (MXFP4). Got scalar_type=", + b_scales.scalar_type()); + return sglang::kFloat16; + } + + if (b_scales.scalar_type() == at::ScalarType::Half) { + return sglang::kFloat16; + } + if (b_scales.scalar_type() == at::ScalarType::BFloat16) { + return sglang::kBFloat16; + } + + if (compute_dtype == at::ScalarType::Half) { + return sglang::kFloat16; + } + if (compute_dtype == at::ScalarType::BFloat16) { + return sglang::kBFloat16; + } + + TORCH_CHECK( + false, + "Unsupported scale dtype for Marlin MoE: ", + b_scales.scalar_type()); + return sglang::kFloat16; +} + bool is_valid_config( thread_config_t const& th_config, bool m_block_size_8, @@ -325,27 +365,25 @@ bool is_valid_config( return cache_size + 512 <= max_shared_mem; } -#define _GET_IF( \ - W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if ( \ - q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) { \ - constexpr auto S_TYPE = W_TYPE == sglang::kFE2M1f \ - ? (GROUP_BLOCKS == 1 ? sglang::kFE4M3fn : sglang::kFE8M0fnu) \ - : (std::is_same::value ? sglang::kFloat16 : sglang::kBFloat16); \ - kernel = Marlin< \ - scalar_t, \ - W_TYPE.id(), \ - S_TYPE.id(), \ - NUM_THREADS, \ - THREAD_M_BLOCKS, \ - THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, \ - pipe_stages, \ - GROUP_BLOCKS, \ - IS_ZP_FLOAT>; \ +#define _GET_IF( \ + W_TYPE, S_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if ( \ + q_type == W_TYPE && s_type == S_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin< \ + scalar_t, \ + W_TYPE.id(), \ + S_TYPE.id(), \ + NUM_THREADS, \ + THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, \ + pipe_stages, \ + GROUP_BLOCKS, \ + IS_ZP_FLOAT>; \ } // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) @@ -354,123 +392,124 @@ bool is_valid_config( // FZP: cases for float-zero-point (is_zp_float = true) // ACT: cases for act order case (group_blocks == 0) // FP4: cases for nvfp4(e2m1) (group_blocks == 1) -#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define COMMON_GET_IF(W_TYPE) \ - COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ - COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ - COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) - -#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define BIGGROUP_GET_IF(W_TYPE) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) - -#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - -#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - -#define NVFP4_GET_IF(W_TYPE) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) - -#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - -#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - -#define MXFP4_GET_IF(W_TYPE) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) +#define COMMON_GET_IF_M1(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF_M234(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF(W_TYPE, S_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, S_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, S_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, S_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, S_TYPE, 8, 4, 128) + +#define BIGGROUP_GET_IF_M1(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF_M234(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF(W_TYPE, S_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, S_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, S_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, S_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, S_TYPE, 8, 4, 128) + +#define NVFP4_GET_IF_M1(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define NVFP4_GET_IF_M234(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define NVFP4_GET_IF(W_TYPE, S_TYPE) \ + NVFP4_GET_IF_M1(W_TYPE, S_TYPE, 8, 8, 256) \ + NVFP4_GET_IF_M1(W_TYPE, S_TYPE, 8, 4, 128) \ + NVFP4_GET_IF_M234(W_TYPE, S_TYPE, 16, 4, 256) \ + NVFP4_GET_IF_M234(W_TYPE, S_TYPE, 8, 4, 128) + +#define MXFP4_GET_IF_M1(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + +#define MXFP4_GET_IF_M234(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + +#define MXFP4_GET_IF(W_TYPE, S_TYPE) \ + MXFP4_GET_IF_M1(W_TYPE, S_TYPE, 8, 8, 256) \ + MXFP4_GET_IF_M1(W_TYPE, S_TYPE, 8, 4, 128) \ + MXFP4_GET_IF_M234(W_TYPE, S_TYPE, 16, 4, 256) \ + MXFP4_GET_IF_M234(W_TYPE, S_TYPE, 8, 4, 128) // We currently have 4-bit models only with group_blocks == 4 -#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) +#define FZP_GET_IF_M1(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) -#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) +#define FZP_GET_IF_M234(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) -#define FZP_GET_IF(W_TYPE) \ - FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FZP_GET_IF_M234(W_TYPE, 8, 4, 128) +#define FZP_GET_IF(W_TYPE, S_TYPE) \ + FZP_GET_IF_M1(W_TYPE, S_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, S_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, S_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, S_TYPE, 8, 4, 128) // We currently have 4-bit models only with group_blocks == 4 -#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) +#define ACT_GET_IF_M1(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) -#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) +#define ACT_GET_IF_M234(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) -#define ACT_GET_IF(W_TYPE) \ - ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ - ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ - ACT_GET_IF_M234(W_TYPE, 8, 4, 128) +#define ACT_GET_IF(W_TYPE, S_TYPE) \ + ACT_GET_IF_M1(W_TYPE, S_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, S_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, S_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, S_TYPE, 8, 4, 128) template MarlinFuncPtr get_marlin_kernel( const sglang::ScalarType q_type, + const sglang::ScalarType s_type, int thread_m_blocks, int thread_n_blocks, int thread_k_blocks, @@ -480,25 +519,28 @@ MarlinFuncPtr get_marlin_kernel( int group_blocks, int num_threads, bool is_zp_float) { - int num_bits = q_type.size_bits(); auto kernel = MarlinDefault; - if (false) { - } - - COMMON_GET_IF(sglang::kU4) - COMMON_GET_IF(sglang::kU4B8) - COMMON_GET_IF(sglang::kU8B128) - - NVFP4_GET_IF(sglang::kFE2M1f) - - BIGGROUP_GET_IF(sglang::kFE4M3fn) - - ACT_GET_IF(sglang::kU4B8) - ACT_GET_IF(sglang::kU8B128) - if (std::is_same::value) { + if constexpr (std::is_same::value) { if (false) { } - MXFP4_GET_IF(sglang::kFE2M1f) + COMMON_GET_IF(sglang::kU4, sglang::kFloat16) + COMMON_GET_IF(sglang::kU4B8, sglang::kFloat16) + COMMON_GET_IF(sglang::kU8B128, sglang::kFloat16) + NVFP4_GET_IF(sglang::kFE2M1f, sglang::kFE4M3fn) + BIGGROUP_GET_IF(sglang::kFE4M3fn, sglang::kFloat16) + ACT_GET_IF(sglang::kU4B8, sglang::kFloat16) + ACT_GET_IF(sglang::kU8B128, sglang::kFloat16) + } else { + if (false) { + } + COMMON_GET_IF(sglang::kU4, sglang::kBFloat16) + COMMON_GET_IF(sglang::kU4B8, sglang::kBFloat16) + COMMON_GET_IF(sglang::kU8B128, sglang::kBFloat16) + NVFP4_GET_IF(sglang::kFE2M1f, sglang::kFE4M3fn) + BIGGROUP_GET_IF(sglang::kFE4M3fn, sglang::kBFloat16) + ACT_GET_IF(sglang::kU4B8, sglang::kBFloat16) + ACT_GET_IF(sglang::kU8B128, sglang::kBFloat16) + MXFP4_GET_IF(sglang::kFE2M1f, sglang::kFE8M0fnu) } return kernel; @@ -507,6 +549,7 @@ MarlinFuncPtr get_marlin_kernel( template exec_config_t determine_exec_config( const sglang::ScalarType& q_type, + const sglang::ScalarType& s_type, int prob_m, int prob_n, int prob_k, @@ -567,6 +610,7 @@ exec_config_t determine_exec_config( auto kernel = get_marlin_kernel( q_type, + s_type, thread_m_blocks, th_config.thread_n / 16, th_config.thread_k / 16, @@ -624,6 +668,7 @@ void marlin_mm( int prob_k, void* workspace, sglang::ScalarType const& q_type, + sglang::ScalarType const& s_type, bool has_bias, bool has_act_order, bool is_k_full, @@ -677,6 +722,17 @@ void marlin_mm( } } + if (q_type == sglang::kFE2M1f) { + TORCH_CHECK( + (group_size == 16 && s_type == sglang::kFE4M3fn) || + (group_size == 32 && s_type == sglang::kFE8M0fnu), + "float4_e2m1f expects s_type=float8_e4m3fn for group_size=16 (NVFP4) " + "or s_type=float8_e8m0fnu for group_size=32 (MXFP4). Got group_size=", + group_size, + ", s_type=", + s_type.str()); + } + int num_bits = q_type.size_bits(); const int4* A_ptr = (const int4*)A; const int4* B_ptr = (const int4*)B; @@ -742,6 +798,7 @@ void marlin_mm( // Auto config exec_cfg = determine_exec_config( q_type, + s_type, prob_m, prob_n, prob_k, @@ -812,6 +869,7 @@ void marlin_mm( auto kernel = get_marlin_kernel( q_type, + s_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, @@ -845,7 +903,9 @@ void marlin_mm( ", thread_k_blocks = ", thread_k_blocks, ", num_bits = ", - num_bits); + num_bits, + ", s_type = ", + s_type.str()); } cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); @@ -889,6 +949,9 @@ torch::Tensor moe_wna16_marlin_gemm( bool use_fp32_reduce, bool is_zp_float) { sglang::ScalarType const b_q_type = sglang::ScalarType::from_id(b_q_type_id); + sglang::ScalarType const s_type = + MARLIN_NAMESPACE_NAME::infer_scale_type( + b_q_type, b_scales, a.scalar_type()); int pack_factor = 32 / b_q_type.size_bits(); if (moe_block_size != 8) { @@ -1123,13 +1186,10 @@ torch::Tensor moe_wna16_marlin_gemm( int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { void* scales_ptr; - if (b_q_type == sglang::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, "float4_e2m1f only supports group_size == 16 (NVFP4) ", "and group_size == 32 (MXFP4)"); + if (s_type == sglang::kFE4M3fn) { + scales_ptr = b_scales.data_ptr(); + } else if (s_type == sglang::kFE8M0fnu) { + scales_ptr = b_scales.data_ptr(); } else { scales_ptr = b_scales.data_ptr(); } @@ -1159,6 +1219,7 @@ torch::Tensor moe_wna16_marlin_gemm( size_k, workspace.data_ptr(), b_q_type, + s_type, has_bias, has_act_order, is_k_full, @@ -1175,13 +1236,10 @@ torch::Tensor moe_wna16_marlin_gemm( is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { void* scales_ptr; - if (b_q_type == sglang::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, "float4_e2m1f only supports group_size == 16 (NVFP4) ", "and group_size == 32 (MXFP4)"); + if (s_type == sglang::kFE4M3fn) { + scales_ptr = b_scales.data_ptr(); + } else if (s_type == sglang::kFE8M0fnu) { + scales_ptr = b_scales.data_ptr(); } else { scales_ptr = b_scales.data_ptr(); } @@ -1211,6 +1269,7 @@ torch::Tensor moe_wna16_marlin_gemm( size_k, workspace.data_ptr(), b_q_type, + s_type, has_bias, has_act_order, is_k_full,