diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index 1bfb7ce8f967..8a28436d802e 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -121,6 +121,8 @@ SGLang supports various environment variables that can be used to configure its | `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` | | `SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2` | Apply per token group quantization kernel with fused silu and mul and masked m | `false` | | `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` | +| `SGLANG_FORCE_NVFP4_MARLIN` | Force using NVFP4 Marlin fallback kernels even on Blackwell GPUs with native FP4 support | `false` | +| `SGLANG_FLASHINFER_FP4_GEMM_BACKEND` (deprecated) | Select backend for `mm_fp4` on Blackwell GPUs. **DEPRECATED**: Please use `--fp4-gemm-backend` instead. | `` | | `SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN` | Quantize q_b_proj from BF16 to FP8 when launching DeepSeek NVFP4 checkpoint | `false` | | `SGLANG_MOE_NVFP4_DISPATCH` | Use nvfp4 for moe dispatch (on flashinfer_cutlass or flashinfer_cutedsl moe runner backend) | `"false"` | | `SGLANG_NVFP4_CKPT_FP8_NEXTN_MOE` | Quantize moe of nextn layer from BF16 to FP8 when launching DeepSeek NVFP4 checkpoint | `false` | diff --git a/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h b/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h index 6c4112e633fd..651710a963f7 100644 --- a/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h +++ b/python/sglang/jit_kernel/csrc/gemm/marlin/marlin_template.h @@ -484,11 +484,11 @@ __global__ void Marlin( constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) - : 1; + // FP4 (kFE2M1f) uses FP8 scales (1 byte/element), others use FP16 (2 bytes) + int s_gl_stride = prob_n / (w_type == host::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = 16 * thread_n_blocks / (w_type == host::kFE2M1f ? 16 : 8); + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -540,8 +540,7 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + - s_sh_stride * slice_col + threadIdx.x; + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; } } auto s_sh_wr = threadIdx.x; @@ -563,15 +562,7 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + warp_row % 2; - - } else if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; @@ -876,7 +867,7 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + int cur_group_id = k_blocks / group_blocks; int4* sh_s_stage = sh_s + s_sh_stage * pipe; diff --git a/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h b/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h index bf7dcb202301..566fa5f59606 100644 --- a/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h +++ b/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h @@ -626,11 +626,10 @@ __global__ void Marlin( constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) - : 1; + int s_gl_stride = prob_n / (w_type == host::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = 16 * thread_n_blocks / (w_type == host::kFE2M1f ? 16 : 8); + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -682,8 +681,7 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + - s_sh_stride * slice_col + threadIdx.x; + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; } } auto s_sh_wr = threadIdx.x; @@ -705,15 +703,7 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; - - } else if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; @@ -1038,18 +1028,15 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + int cur_group_id = k_blocks / group_blocks; int4* sh_s_stage = sh_s + s_sh_stage * pipe; if constexpr (w_type_id != host::kFE2M1f.id()) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { - reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } else { reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + k % 2]; + reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } } } @@ -1243,17 +1230,19 @@ __global__ void Marlin( } } - // Commented out FP4/FP8 scale dequantization since we don't generate - // kFE2M1f kernels to reduce compilation time - // if constexpr (w_type == host::kFE2M1f) { - // int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; - // int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - // - // dequant_fp8_scales( - // s_quant_0, reinterpret_cast(&frag_s[k2])); - // dequant_fp8_scales( - // s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); - // } +#ifdef SGL_MOE_MARLIN_FP4 + // Convert FP8 per-group scales to BF16/FP16 before applying them. + // Required for kFE2M1f (NVFP4): frag_s holds raw float8_e4m3fn bytes; + // without this conversion scale would misinterpret them as + // BF16/FP16, producing NaN/Inf multipliers. + if constexpr (w_type == host::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales(s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } +#endif // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. diff --git a/python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh b/python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh index 81c021dc8ecc..d89954200c88 100644 --- a/python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh +++ b/python/sglang/jit_kernel/csrc/gemm/marlin_moe/moe_wna16_marlin.cuh @@ -453,7 +453,9 @@ MarlinFuncPtr get_marlin_kernel( COMMON_GET_IF(host::kU4B8) COMMON_GET_IF(host::kU8B128) +#ifdef SGL_MOE_MARLIN_FP4 NVFP4_GET_IF(host::kFE2M1f) +#endif BIGGROUP_GET_IF(host::kFE4M3fn) diff --git a/python/sglang/jit_kernel/moe_wna16_marlin.py b/python/sglang/jit_kernel/moe_wna16_marlin.py index e9a8cd25372b..0ddd6ef717d5 100644 --- a/python/sglang/jit_kernel/moe_wna16_marlin.py +++ b/python/sglang/jit_kernel/moe_wna16_marlin.py @@ -31,6 +31,24 @@ def _jit_moe_wna16_marlin_module(dtype: torch.dtype) -> Module: ) +@cache_once +def _jit_moe_wna16_marlin_fp4_module(dtype: torch.dtype) -> Module: + """Separate JIT module with NVFP4 (kFE2M1f) kernel instantiations enabled.""" + args = make_cpp_args(dtype) + return load_jit( + "moe_wna16_marlin_fp4", + *args, + cuda_files=["gemm/marlin_moe/moe_wna16_marlin.cuh"], + extra_cuda_cflags=["-DSGL_MOE_MARLIN_FP4"], + cuda_wrappers=[ + ( + "moe_wna16_marlin_gemm", + f"moe_wna16_marlin_gemm<{args}>", + ) + ], + ) + + def _or_empty( t: Optional[torch.Tensor], device: torch.device, dtype: torch.dtype ) -> torch.Tensor: @@ -134,7 +152,11 @@ def moe_wna16_marlin_gemm( b_bias_t = _or_empty(b_bias_or_none, device, a.dtype) global_scale_t = _or_empty(global_scale_or_none, device, a.dtype) - module = _jit_moe_wna16_marlin_module(a.dtype) + is_fp4 = global_scale_or_none is not None and global_scale_or_none.numel() > 0 + if is_fp4: + module = _jit_moe_wna16_marlin_fp4_module(a.dtype) + else: + module = _jit_moe_wna16_marlin_module(a.dtype) module.moe_wna16_marlin_gemm( a, c, diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index adc21301ba15..7b5ce2a7e901 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -332,6 +332,7 @@ class Envs: SGLANG_CPU_QUANTIZATION = EnvBool(False) SGLANG_USE_DYNAMIC_MXFP4_LINEAR = EnvBool(False) SGLANG_FORCE_FP8_MARLIN = EnvBool(False) + SGLANG_FORCE_NVFP4_MARLIN = EnvBool(False) SGLANG_MOE_NVFP4_DISPATCH = EnvBool(False) SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN = EnvBool(False) SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2 = EnvBool(False) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py index 4410f07f327e..b1bb618ce5ce 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py @@ -23,6 +23,13 @@ def get_scalar_type(num_bits: int, has_zp: bool): return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 +def _get_fp4_scalar_type(): + from sglang.srt.layers.quantization.utils import get_scalar_types + + _, scalar_types = get_scalar_types() + return scalar_types.float4_e2m1f + + @register_custom_op(out_shape="hidden_states") def fused_marlin_moe( hidden_states: torch.Tensor, @@ -46,6 +53,8 @@ def fused_marlin_moe( is_k_full: bool = True, inplace: bool = False, routed_scaling_factor: Optional[float] = None, + w1_global_scale: Optional[torch.Tensor] = None, + w2_global_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -76,6 +85,13 @@ def fused_marlin_moe( """ from sglang.srt.layers.moe.fused_moe_triton import moe_align_block_size + # Detect FP4 Marlin mode (when global scales are provided) + _is_fp4_marlin = w1_global_scale is not None + if _is_fp4_marlin: + assert ( + w2_global_scale is not None + ), "Both w1_global_scale and w2_global_scale must be provided for FP4 Marlin mode" + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" assert hidden_states.shape[1] == w2.shape[2] // ( @@ -85,12 +101,14 @@ def fused_marlin_moe( assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float16, torch.bfloat16] - assert ( - hidden_states.dtype == w1_scale.dtype - ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})" - assert ( - hidden_states.dtype == w2_scale.dtype - ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})" + # For FP4 Marlin, scales are in special float8_e4m3fn format (not input dtype) + if not _is_fp4_marlin: + assert ( + hidden_states.dtype == w1_scale.dtype + ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})" + assert ( + hidden_states.dtype == w2_scale.dtype + ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})" assert num_bits in [4, 8] M, K = hidden_states.shape @@ -121,8 +139,13 @@ def fused_marlin_moe( max_workspace_size, dtype=torch.int, device=device, requires_grad=False ) - scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) - scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) + # FP4 Marlin uses float4_e2m1f scalar type (not uint4b8/uint8b128) + if _is_fp4_marlin: + scalar_type1 = _get_fp4_scalar_type() + scalar_type2 = _get_fp4_scalar_type() + else: + scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) + scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -150,7 +173,7 @@ def fused_marlin_moe( w1, None, # b_bias_or_none w1_scale, - None, # global_scale_or_none + w1_global_scale, # None for INT4/INT8, tensor for FP4 Marlin w1_zeros, g_idx1, sort_indices1, @@ -184,7 +207,7 @@ def fused_marlin_moe( w2, None, # b_bias_or_none w2_scale, - None, # global_scale_or_none + w2_global_scale, # None for INT4/INT8, tensor for FP4 Marlin w2_zeros, g_idx2, sort_indices2, diff --git a/python/sglang/srt/layers/moe/moe_runner/marlin.py b/python/sglang/srt/layers/moe/moe_runner/marlin.py index 45104dd27805..429b28697d23 100644 --- a/python/sglang/srt/layers/moe/moe_runner/marlin.py +++ b/python/sglang/srt/layers/moe/moe_runner/marlin.py @@ -69,8 +69,13 @@ class MarlinMoeQuantInfo(MoeQuantInfo): w13_qzeros: Optional[torch.Tensor] = None w2_qzeros: Optional[torch.Tensor] = None - # Optional + # FP4 Marlin specific (Optional) + w13_global_scale: Optional[torch.Tensor] = None + w2_global_scale: Optional[torch.Tensor] = None + + # EP support (Optional) expert_map: Optional[torch.Tensor] = None + global_num_experts: int = -1 @register_fused_func("none", "marlin") @@ -106,6 +111,7 @@ def fused_experts_none_to_marlin( gating_output=topk_output.router_logits, topk_weights=topk_output.topk_weights, topk_ids=topk_output.topk_ids, + global_num_experts=quant_info.global_num_experts, expert_map=quant_info.expert_map, g_idx1=quant_info.w13_g_idx, g_idx2=quant_info.w2_g_idx, @@ -118,6 +124,8 @@ def fused_experts_none_to_marlin( is_k_full=quant_info.is_k_full, inplace=runner_config.inplace, routed_scaling_factor=runner_config.routed_scaling_factor, + w1_global_scale=quant_info.w13_global_scale, + w2_global_scale=quant_info.w2_global_scale, ).to(hidden_states.dtype) return StandardCombineInput( diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 1072dac7b3ca..0d5db7a5a010 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -16,6 +16,10 @@ CompressedTensorsLinearScheme, ) from sglang.srt.layers.quantization.fp4_utils import get_fp4_gemm_runner_backend +from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_fp4_layer_for_marlin, + should_use_fp4_marlin_fallback, +) from sglang.srt.layers.quantization.modelopt_quant import ( enable_flashinfer_fp4_gemm, fp4_gemm, @@ -34,7 +38,7 @@ def __init__(self): @classmethod def get_min_capability(cls) -> int: - return 100 + return 75 # SM75+ (Turing) supports Marlin FP4 fallback; SM100 for native FP4 def create_weights( self, @@ -47,6 +51,7 @@ def create_weights( ): output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes + layer.params_dtype = params_dtype layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition @@ -91,6 +96,20 @@ def create_weights( layer.register_parameter("input_global_scale", input_global_scale) def process_weights_after_loading(self, layer) -> None: + if should_use_fp4_marlin_fallback(): + # Marlin FP4 fallback: consolidate global scale then repack weights + global_scale = layer.weight_global_scale.max().to(torch.float32) + layer.weight_global_scale = Parameter(global_scale, requires_grad=False) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight_packed", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_global_scale", + ) + layer.use_marlin_fallback = True + return + + layer.use_marlin_fallback = False global_input_scale = layer.input_global_scale.max().to(torch.float32) layer.input_global_scale = Parameter(global_input_scale, requires_grad=False) @@ -136,6 +155,18 @@ def apply_weights( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if layer.use_marlin_fallback: + return torch.ops.sglang.apply_fp4_marlin_linear( + input=x, + weight=layer.weight_packed, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_global_scale, + workspace=layer.marlin_workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + output_dtype = x.dtype w_n, _ = layer.weight_packed.shape output_shape = [x.shape[0], w_n] diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py index 5898a078dbba..7824a3bcce60 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py @@ -17,6 +17,10 @@ CompressedTensorsMoEScheme, ) from sglang.srt.layers.quantization.fp8_utils import is_blackwell_supported +from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_moe_fp4_layer_for_marlin, + should_use_fp4_marlin_fallback, +) from sglang.srt.layers.quantization.utils import ( prepare_static_weights_for_trtllm_fp4_moe, reorder_w1w3_to_w3w1, @@ -38,19 +42,27 @@ class CompressedTensorsW4A4Nvfp4MoE(CompressedTensorsMoEScheme): def __init__(self): - if not is_blackwell_supported(): + self.group_size = 16 + + if should_use_fp4_marlin_fallback(): + logger.warning_once( + "GPU is not Blackwell (SM100+). Using Marlin FP4 fallback kernel " + "for MoE layers. Weights remain compressed in FP4 format." + ) + self.use_marlin_fallback = True + self.use_flashinfer_trtllm = False + elif not is_blackwell_supported(): raise ValueError( "Current platform does not support NVFP4" - " quantization. Please use Blackwell and" - " above." + " quantization. Please use SM75+ (Turing or newer)." ) - self.group_size = 16 - self.use_flashinfer_trtllm = get_moe_runner_backend().is_flashinfer_trtllm() + else: + self.use_marlin_fallback = False + self.use_flashinfer_trtllm = get_moe_runner_backend().is_flashinfer_trtllm() @classmethod def get_min_capability(cls) -> int: - # Requires sm100(blackwell) architecture - return 100 + return 75 # SM75+ (Turing) supports Marlin FP4 fallback; SM100 for native FP4 def create_weights( self, @@ -64,6 +76,7 @@ def create_weights( from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported layer.params_dtype = params_dtype + layer.intermediate_size_per_partition = intermediate_size_per_partition w13_weight = torch.nn.Parameter( torch.empty( @@ -175,6 +188,21 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) delattr(layer, "w2_weight_packed") + if self.use_marlin_fallback: + # CompressedTensors checkpoint: global_scale is stored as the inverse. + # Actual dequant scale = 1 / stored_value. We create w*_weight_scale_2 + # with the actual scale before calling prepare_moe_fp4_layer_for_marlin(). + layer.w13_weight_scale_2 = torch.nn.Parameter( + (1.0 / layer.w13_weight_global_scale).to(layer.params_dtype), + requires_grad=False, + ) # [E, 2] + layer.w2_weight_scale_2 = torch.nn.Parameter( + (1.0 / layer.w2_weight_global_scale).to(layer.params_dtype), + requires_grad=False, + ) # [E] + prepare_moe_fp4_layer_for_marlin(layer) + return + if self.use_flashinfer_trtllm: w, s = reorder_w1w3_to_w3w1( layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2 @@ -303,7 +331,10 @@ def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): self.moe_runner_config = moe_runner_config - self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + if self.use_marlin_fallback: + self.runner = MoeRunner(MoeRunnerBackend.MARLIN, moe_runner_config) + else: + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) def apply_weights( self, @@ -313,6 +344,33 @@ def apply_weights( from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + if self.use_marlin_fallback: + from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo + + expert_map = None + global_num_experts = -1 + if hasattr(layer, "dispatcher") and hasattr( + layer.dispatcher, "local_expert_mapping" + ): + expert_map = layer.dispatcher.local_expert_mapping + if expert_map is not None: + global_num_experts = self.moe_runner_config.num_experts + + quant_info = MarlinMoeQuantInfo( + w13_qweight=layer.w13_weight, + w2_qweight=layer.w2_weight, + w13_scales=layer.w13_weight_scale, + w2_scales=layer.w2_weight_scale, + w13_g_idx_sort_indices=None, + w2_g_idx_sort_indices=None, + weight_bits=4, + w13_global_scale=layer.w13_weight_scale_2, + w2_global_scale=layer.w2_weight_scale_2, + expert_map=expert_map, + global_num_experts=global_num_experts, + ) + return self.runner.run(dispatch_output, quant_info) + x = dispatch_output.hidden_states topk_output = dispatch_output.topk_output diff --git a/python/sglang/srt/layers/quantization/marlin_utils_fp4.py b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py new file mode 100644 index 000000000000..5a9fb3cef84b --- /dev/null +++ b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py @@ -0,0 +1,320 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py + +"""NVFP4 Marlin fallback: run FP4-quantized models on non-Blackwell GPUs via Marlin kernel.""" + +import logging +from typing import Optional + +import torch + +from sglang.srt.layers.quantization.marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, + marlin_make_workspace, + marlin_permute_bias, + marlin_permute_scales, + should_use_atomic_add_reduce, +) +from sglang.srt.layers.quantization.utils import get_scalar_types +from sglang.srt.utils import direct_register_custom_op, get_device_capability, is_cuda + +_is_cuda = is_cuda() +if _is_cuda: + from sglang.jit_kernel.gptq_marlin import gptq_marlin_gemm + from sglang.jit_kernel.gptq_marlin_repack import gptq_marlin_repack + +ScalarType, scalar_types = get_scalar_types() + +logger = logging.getLogger(__name__) + +# NVFP4 always uses group_size=16 +FP4_MARLIN_GROUP_SIZE = 16 + + +def is_fp4_marlin_supported() -> bool: + """Check if the current GPU supports FP4 Marlin fallback (CUDA SM >= 75).""" + if not _is_cuda: + return False + if torch.version.hip is not None: + return False + major, minor = get_device_capability() + if major is None or minor is None: + return False + return (major * 10 + minor) >= 75 + + +def should_use_fp4_marlin_fallback() -> bool: + """True if non-Blackwell (or forced) AND Marlin kernel available (SM >= 75).""" + from sglang.srt.environ import envs + from sglang.srt.layers.quantization.fp8_utils import is_blackwell_supported + + force = envs.SGLANG_FORCE_NVFP4_MARLIN.get() + return (force or not is_blackwell_supported()) and is_fp4_marlin_supported() + + +def nvfp4_marlin_process_scales(marlin_scales: torch.Tensor) -> torch.Tensor: + """Convert NVFP4 scales from FP8-S1E4M3 to FP8-S0E5M3 format for Marlin. + + The int16 <<1 may wrap for large scales (e.g. 448*128=57344), but the BIT + PATTERN is preserved correctly — the kernel reads raw bytes, not int16 values. + """ + marlin_scales = marlin_scales.to(torch.half) + + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes scales >= 0, but encountered negative scales. " + "Accuracy may be degraded. The scales are converted from FP8-S1E4M3 " + "to a special FP8-S0E5M3 format to speed up dequantization." + ) + + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1 + ) + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def nvfp4_marlin_process_global_scale(global_scale: torch.Tensor) -> torch.Tensor: + """Pre-adjust global scale with FP4/FP16/BF16 exponent bias for Marlin kernel.""" + assert global_scale.dtype in [ + torch.half, + torch.bfloat16, + ], f"global_scale dtype must be half or bfloat16, got {global_scale.dtype}" + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + exponent_bias = 2 ** (target_exponent - 1) - 2 ** (fp4_exponent - 1) + return global_scale * (2.0 ** (exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_global_scale: Optional[torch.Tensor], + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + """Apply FP4-quantized linear via Marlin kernel (non-Blackwell fallback).""" + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype, + ) + + output = gptq_marlin_gemm( + a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=weight_global_scale.reshape(-1), + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + ) + + if bias is not None: + output.add_(bias) + + return output.reshape(out_shape) + + +def fake_apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_global_scale: Optional[torch.Tensor], + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + out_shape = input.shape[:-1] + (size_n,) + return torch.empty(out_shape, dtype=input.dtype, device=input.device) + + +direct_register_custom_op( + op_name="apply_fp4_marlin_linear", + op_func=apply_fp4_marlin_linear, + mutates_args=[], + fake_impl=fake_apply_fp4_marlin_linear, +) + + +def prepare_fp4_layer_for_marlin( + layer: torch.nn.Module, + weight_attr: str = "weight", + weight_scale_attr: str = "weight_scale", + weight_global_scale_attr: str = "weight_global_scale", +) -> None: + """Repack NVFP4 linear layer weights into Marlin format in-place.""" + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads." + ) + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + weight = getattr(layer, weight_attr) + assert weight.shape == (part_size_n, part_size_k // 2), ( + f"Expected {weight_attr} shape ({part_size_n}, {part_size_k // 2}), " + f"got {weight.shape}" + ) + + device = weight.device + + # WORKSPACE + layer.marlin_workspace = marlin_make_workspace(device) + + # WEIGHT: repack from NVFP4 native layout to Marlin tile layout + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = weight.data.view(torch.int32).T.contiguous() + del weight + marlin_qweight = gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4, + ) + del qweight + setattr(layer, weight_attr, torch.nn.Parameter(marlin_qweight, requires_grad=False)) + + # WEIGHT SCALES: transpose, permute, convert to FP8-S0E5M3 + weight_scale = getattr(layer, weight_scale_attr) + weight_scale = weight_scale.data.T.contiguous().to(param_dtype) + weight_scale = marlin_permute_scales( + s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=FP4_MARLIN_GROUP_SIZE, + ) + weight_scale = nvfp4_marlin_process_scales(weight_scale) + setattr( + layer, weight_scale_attr, torch.nn.Parameter(weight_scale, requires_grad=False) + ) + + # GLOBAL SCALE: Pre-adjust exponent bias for Marlin kernel. + weight_global_scale = getattr(layer, weight_global_scale_attr) + weight_global_scale = weight_global_scale.to(param_dtype) + weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale) + setattr( + layer, + weight_global_scale_attr, + torch.nn.Parameter(weight_global_scale, requires_grad=False), + ) + + # BIAS (if present): Permute for Marlin's fast access pattern + if hasattr(layer, "bias") and layer.bias is not None: + assert layer.bias.shape == (part_size_n,) + bias = marlin_permute_bias(layer.bias) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + """Repack NVFP4 MoE weights into Marlin format in-place (per-expert).""" + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel for MoE layers. This may " + "degrade performance for compute-heavy workloads." + ) + + e = layer.num_local_experts + k = layer.w13_weight.shape[2] * 2 # hidden_size (packed: K//2 per uint8) + n = layer.intermediate_size_per_partition + param_dtype = layer.params_dtype + num_shards = 2 if layer.moe_runner_config.is_gated else 1 + + device = layer.w13_weight.device + perm = torch.empty(0, dtype=torch.int, device=device) + + # (size_n, size_k) for each projection + sizes = {"w13": (n * num_shards, k), "w2": (k, n)} + + # --- WEIGHT REPACKING --- + for name in ["w13_weight", "w2_weight"]: + prefix = name.split("_")[0] # "w13" or "w2" + size_n, size_k = sizes[prefix] + weight = getattr(layer, name) + + assert weight.shape == (e, size_n, size_k // 2), ( + f"Expected {name} shape ({e}, {size_n}, {size_k // 2}), " + f"got {weight.shape}" + ) + + repacked = [] + for i in range(e): + qweight = weight.data[i].view(torch.int32).T.contiguous() + repacked.append( + gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + ) + + del weight + setattr( + layer, name, torch.nn.Parameter(torch.stack(repacked), requires_grad=False) + ) + + # --- WEIGHT SCALE PROCESSING --- + for prefix in ["w13", "w2"]: + size_n, size_k = sizes[prefix] + scales = getattr(layer, prefix + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, prefix + "_weight_scale_2").to(param_dtype) + + processed = [] + for i in range(e): + s = marlin_permute_scales( + s=scales.data[i].T, + size_k=size_k, + size_n=size_n, + group_size=FP4_MARLIN_GROUP_SIZE, + ) + processed.append(nvfp4_marlin_process_scales(s)) + + del scales + setattr( + layer, + prefix + "_weight_scale", + torch.nn.Parameter(torch.stack(processed), requires_grad=False), + ) + + if global_scale.dim() > 1: + global_scale = global_scale.max(dim=-1).values + global_scale = nvfp4_marlin_process_global_scale(global_scale) + setattr( + layer, + prefix + "_weight_scale_2", + torch.nn.Parameter(global_scale, requires_grad=False), + ) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index dc4fa71fcd20..7dcaeb79c745 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -40,6 +40,11 @@ is_blackwell_supported, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod +from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_fp4_layer_for_marlin, + prepare_moe_fp4_layer_for_marlin, + should_use_fp4_marlin_fallback, +) from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import ( convert_to_channelwise, @@ -1128,7 +1133,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 100 + return 75 # SM75+ (Turing) supports Marlin FP4 fallback; SM100 for native FP4 @staticmethod def common_group_size(cfg: dict) -> int: @@ -1302,6 +1307,7 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes + layer.params_dtype = params_dtype layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition @@ -1356,6 +1362,20 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if should_use_fp4_marlin_fallback(): + # Marlin FP4 fallback: consolidate global scale then repack weights + weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) + layer.weight_scale_2_marlin = Parameter(weight_scale_2, requires_grad=False) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + layer.use_marlin_fallback = True + return + + layer.use_marlin_fallback = False input_scale_2 = layer.input_scale.max().to(torch.float32) weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) @@ -1456,6 +1476,18 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if layer.use_marlin_fallback: + return torch.ops.sglang.apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + output_dtype = x.dtype x_m, _ = x.shape @@ -1509,15 +1541,24 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: ModelOptFp4Config): self.quant_config = quant_config - if not is_blackwell_supported(): + + if should_use_fp4_marlin_fallback(): + logger.warning_once( + "GPU is not Blackwell (SM100+). Using Marlin FP4 fallback kernel " + "for MoE layers. Weights remain compressed in FP4 format." + ) + self.use_marlin_fallback = True + self.enable_flashinfer_trtllm_moe = False + elif not is_blackwell_supported(): raise ValueError( "Current platform does not support NVFP4" - " quantization. Please use Blackwell and" - " above." + " quantization. Please use SM75+ (Turing or newer)." + ) + else: + self.use_marlin_fallback = False + self.enable_flashinfer_trtllm_moe = ( + get_moe_runner_backend().is_flashinfer_trtllm() ) - self.enable_flashinfer_trtllm_moe = ( - get_moe_runner_backend().is_flashinfer_trtllm() - ) self._cache_permute_indices = {} @property @@ -1671,6 +1712,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: Only supports pre-quantized checkpoints with FP8 weights and scales. """ + if self.use_marlin_fallback: + prepare_moe_fp4_layer_for_marlin(layer) + return # GEMM 1 scale processing if layer.moe_runner_config.is_gated: @@ -1883,6 +1927,9 @@ def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): self.moe_runner_config = moe_runner_config + if self.use_marlin_fallback: + self.runner = MoeRunner(MoeRunnerBackend.MARLIN, moe_runner_config) + return if get_moe_runner_backend().is_flashinfer_trtllm(): self.runner = MoeRunner( MoeRunnerBackend.FLASHINFER_TRTLLM, moe_runner_config @@ -1905,6 +1952,34 @@ def apply( ), f"{activation=} missing from {ACT_STR_TO_TYPE_MAP.keys()=}" moe_runner_config = self.moe_runner_config + # Marlin FP4 fallback path for non-Blackwell GPUs (SM75-SM89) + if self.use_marlin_fallback: + from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo + + expert_map = None + global_num_experts = -1 + if hasattr(layer, "dispatcher") and hasattr( + layer.dispatcher, "local_expert_mapping" + ): + expert_map = layer.dispatcher.local_expert_mapping + if expert_map is not None: + global_num_experts = moe_runner_config.num_experts + + quant_info = MarlinMoeQuantInfo( + w13_qweight=layer.w13_weight, + w2_qweight=layer.w2_weight, + w13_scales=layer.w13_weight_scale, + w2_scales=layer.w2_weight_scale, + w13_g_idx_sort_indices=None, + w2_g_idx_sort_indices=None, + weight_bits=4, + w13_global_scale=layer.w13_weight_scale_2, + w2_global_scale=layer.w2_weight_scale_2, + expert_map=expert_map, + global_num_experts=global_num_experts, + ) + return self.runner.run(dispatch_output, quant_info) + # FlashInfer TRTLLM FP4 path - layer has shuffled weights only when # backend is flashinfer_trtllm if hasattr(layer, "gemm1_weights_fp4_shuffled"): diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 60e646fcf7c0..cc823234b821 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -186,6 +186,16 @@ def get_quant_config( if not isinstance(hf_quant_config, dict): hf_quant_config = hf_quant_config.to_dict() hf_quant_config["packed_modules_mapping"] = packed_modules_mapping + # For modelopt, route to FP4 vs FP8 config based on quant_algo + if model_config.quantization.startswith("modelopt"): + quant_algo = hf_quant_config.get("quant_algo") + if quant_algo is None: + quant_algo = hf_quant_config.get("quantization", {}).get("quant_algo") + if quant_algo is not None: + if quant_algo == "FP8" or model_config.quantization == "modelopt_fp8": + return ModelOptFp8Config.from_config(hf_quant_config) + if "FP4" in quant_algo: + return ModelOptFp4Config.from_config(hf_quant_config) return quant_cls.from_config(hf_quant_config) # In case of bitsandbytes/QLoRA, get quant config from the adapter model. diff --git a/sgl-kernel/csrc/gemm/marlin/marlin_template.h b/sgl-kernel/csrc/gemm/marlin/marlin_template.h index 01eb338782c4..19f5d5477c4e 100644 --- a/sgl-kernel/csrc/gemm/marlin/marlin_template.h +++ b/sgl-kernel/csrc/gemm/marlin/marlin_template.h @@ -487,11 +487,11 @@ __global__ void Marlin( constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == sglang::kFE2M1f ? 2 : 1) - : 1; + // FP4 (kFE2M1f) uses FP8 scales (1 byte/element), others use FP16 (2 bytes) + int s_gl_stride = prob_n / (w_type == sglang::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = 16 * thread_n_blocks / (w_type == sglang::kFE2M1f ? 16 : 8); + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -543,8 +543,7 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == sglang::kFE2M1f ? 2 : 1) + - s_sh_stride * slice_col + threadIdx.x; + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; } } auto s_sh_wr = threadIdx.x; @@ -566,15 +565,7 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == sglang::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + warp_row % 2; - - } else if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; @@ -879,7 +870,7 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / (group_blocks * (w_type == sglang::kFE2M1f ? 2 : 1)); + int cur_group_id = k_blocks / group_blocks; int4* sh_s_stage = sh_s + s_sh_stage * pipe; diff --git a/test/registered/quant/test_nvfp4_marlin_fallback.py b/test/registered/quant/test_nvfp4_marlin_fallback.py new file mode 100644 index 000000000000..348294d5e565 --- /dev/null +++ b/test/registered/quant/test_nvfp4_marlin_fallback.py @@ -0,0 +1,788 @@ +"""Tests for NVFP4 Marlin fallback on non-Blackwell GPUs (SM75+).""" + +import unittest + +import torch + +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import CustomTestCase + +register_cuda_ci(est_time=480, suite="stage-b-test-1-gpu-large") + +_FP4_MARLIN_GROUP_SIZE = 16 + +_FP4_E2M1_LUT_VALUES = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + 0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def _check_requirements(): + from sglang.srt.utils import is_cuda + + if not is_cuda(): + return False + from sglang.srt.layers.quantization.marlin_utils_fp4 import is_fp4_marlin_supported + + if not is_fp4_marlin_supported(): + return False + return True + + +def _dequant_fp4_weights( + raw_weight: torch.Tensor, device: torch.device +) -> torch.Tensor: + """Dequantize uint8-packed FP4 E2M1 weights to float32 via lookup table.""" + lut = torch.tensor(_FP4_E2M1_LUT_VALUES, dtype=torch.float32, device=device) + lo = (raw_weight.int() & 0x0F).long() + hi = ((raw_weight.int() >> 4) & 0x0F).long() + return torch.stack([lut[lo], lut[hi]], dim=-1).reshape( + raw_weight.shape[0], raw_weight.shape[1] * 2 + ) + + +class _FakeLayer(torch.nn.Module): + """Minimal stand-in for a quantized layer in unit tests.""" + + pass + + +# --------------------------------------------------------------------------- +# Linear (non-MoE) tests +# --------------------------------------------------------------------------- +class TestNvfp4MarlinLinear(CustomTestCase): + """Test the FP4 Marlin linear layer fallback (non-MoE).""" + + def setUp(self): + if not _check_requirements(): + self.skipTest("Requirements not met (CUDA unavailable or SM < 75)") + self.device = torch.device("cuda") + self.dtype = torch.bfloat16 + + # -- helpers ------------------------------------------------------------- + + def _make_fake_fp4_layer(self, N, K): + layer = _FakeLayer() + layer.params_dtype = self.dtype + layer.input_size_per_partition = K + layer.output_size_per_partition = N + + layer.weight = torch.nn.Parameter( + torch.randint(0, 256, (N, K // 2), dtype=torch.uint8, device=self.device), + requires_grad=False, + ) + layer.weight_scale = torch.nn.Parameter( + torch.ones( + N, + K // _FP4_MARLIN_GROUP_SIZE, + dtype=torch.float8_e4m3fn, + device=self.device, + ), + requires_grad=False, + ) + layer.weight_scale_2_marlin = torch.nn.Parameter( + torch.tensor(1.0, dtype=torch.float32, device=self.device), + requires_grad=False, + ) + return layer + + def _run_fp4_marlin_vs_reference(self, M, N, K): + """Prepare a layer, run the Marlin kernel, return (kernel_out, ref_out).""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, + ) + + raw_weight = torch.randint( + 0, 256, (N, K // 2), dtype=torch.uint8, device=self.device + ) + dq_weight = _dequant_fp4_weights(raw_weight, self.device) + + raw_scale = torch.full( + (N, K // _FP4_MARLIN_GROUP_SIZE), + 1.0, + dtype=torch.float8_e4m3fn, + device=self.device, + ) + global_scale_val = torch.tensor(1.0, dtype=torch.float32, device=self.device) + + x = torch.randn(M, K, dtype=self.dtype, device=self.device) + ref_output = (x.float() @ dq_weight.T).to(self.dtype) + + layer = self._make_fake_fp4_layer(N, K) + layer.weight = torch.nn.Parameter(raw_weight, requires_grad=False) + layer.weight_scale = torch.nn.Parameter(raw_scale, requires_grad=False) + layer.weight_scale_2_marlin = torch.nn.Parameter( + global_scale_val.to(self.dtype), requires_grad=False + ) + + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + marlin_output = apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=N, + size_k=K, + ) + return marlin_output, ref_output + + # -- tests --------------------------------------------------------------- + + def test_prepare_and_apply_fp4_marlin_linear(self): + """Smoke test: shape and dtype are correct after prepare + apply.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, + ) + + N, K, M = 256, 128, 16 + layer = self._make_fake_fp4_layer(N, K) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + self.assertTrue(hasattr(layer, "marlin_workspace")) + + x = torch.randn(M, K, dtype=self.dtype, device=self.device) + output = apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=N, + size_k=K, + ) + self.assertEqual(output.shape, (M, N)) + self.assertEqual(output.dtype, self.dtype) + + def test_fp4_marlin_numerical_correctness(self): + """Kernel output vs BF16 dequant reference (cosine sim, MAE, assert_close).""" + N, K, M = 256, 256, 32 + marlin_output, ref_output = self._run_fp4_marlin_vs_reference(M, N, K) + + self.assertEqual(marlin_output.shape, ref_output.shape) + self.assertEqual(marlin_output.dtype, ref_output.dtype) + + cos_sim = torch.nn.functional.cosine_similarity( + marlin_output.float().flatten(), ref_output.float().flatten(), dim=0 + ) + self.assertGreater( + cos_sim.item(), + 0.99, + f"Cosine similarity {cos_sim.item():.6f} too low", + ) + + rel_mae = torch.mean( + torch.abs(marlin_output.float() - ref_output.float()) + ) / torch.mean(torch.abs(ref_output.float())) + self.assertLess( + rel_mae.item(), + 0.04, + f"Relative MAE {rel_mae.item():.6f} >= 0.04", + ) + + torch.testing.assert_close(marlin_output, ref_output, atol=1e-1, rtol=1e-1) + + def test_fp4_marlin_multiple_shapes(self): + """Numerical correctness across various (M, N, K) dimensions.""" + shapes = [ + (1, 256, 256), + (16, 512, 128), + (64, 128, 512), + (32, 256, 256), + ] + for M, N, K in shapes: + with self.subTest(M=M, N=N, K=K): + marlin_out, ref_out = self._run_fp4_marlin_vs_reference(M, N, K) + rel_mae = torch.mean( + torch.abs(marlin_out.float() - ref_out.float()) + ) / torch.mean(torch.abs(ref_out.float())) + self.assertLess( + rel_mae.item(), + 0.04, + f"Shape ({M},{N},{K}): relative MAE {rel_mae.item():.6f} >= 0.04", + ) + + def test_fp4_marlin_linear_with_bias(self): + """Verify output_with_bias == output_no_bias + bias.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, + ) + + N, K, M = 256, 128, 16 + layer = self._make_fake_fp4_layer(N, K) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + x = torch.randn(M, K, dtype=self.dtype, device=self.device) + bias = torch.randn(N, dtype=self.dtype, device=self.device) + + common = dict( + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=N, + size_k=K, + ) + output_no_bias = apply_fp4_marlin_linear(input=x, **common) + output_with_bias = apply_fp4_marlin_linear(input=x, bias=bias, **common) + + torch.testing.assert_close( + output_with_bias, output_no_bias + bias, atol=1e-5, rtol=1e-5 + ) + + def test_fp4_marlin_registered_op_numerical(self): + """torch.ops.sglang.apply_fp4_marlin_linear matches the direct Python call.""" + import sglang.srt.layers.quantization.marlin_utils_fp4 # noqa: F401 + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, + ) + + N, K, M = 256, 128, 16 + layer = self._make_fake_fp4_layer(N, K) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + x = torch.randn(M, K, dtype=self.dtype, device=self.device) + + common = dict( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=N, + size_k=K, + ) + + direct_out = apply_fp4_marlin_linear(**common) + op_out = torch.ops.sglang.apply_fp4_marlin_linear(**common) + + self.assertEqual(op_out.shape, direct_out.shape) + self.assertEqual(op_out.dtype, direct_out.dtype) + torch.testing.assert_close(op_out, direct_out, atol=0, rtol=0) + + def test_fp4_marlin_3d_input(self): + """Verify correct reshape for 3-D input (batch, seq_len, K).""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, + ) + + N, K = 256, 128 + batch, seq_len = 2, 8 + layer = self._make_fake_fp4_layer(N, K) + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + x_3d = torch.randn(batch, seq_len, K, dtype=self.dtype, device=self.device) + x_2d = x_3d.reshape(-1, K) + + common = dict( + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_scale_2_marlin, + workspace=layer.marlin_workspace, + size_n=N, + size_k=K, + ) + + out_3d = apply_fp4_marlin_linear(input=x_3d, **common) + out_2d = apply_fp4_marlin_linear(input=x_2d, **common) + + self.assertEqual(out_3d.shape, (batch, seq_len, N)) + self.assertEqual(out_3d.dtype, self.dtype) + torch.testing.assert_close(out_3d.reshape(-1, N), out_2d, atol=0, rtol=0) + + def test_fake_apply_fp4_marlin_linear(self): + """Fake impl for PCG tracing must return the correct shape and dtype.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + fake_apply_fp4_marlin_linear, + ) + + N, K = 256, 128 + + for input_shape in [(16, K), (2, 8, K)]: + with self.subTest(input_shape=input_shape): + x = torch.randn(*input_shape, dtype=self.dtype, device=self.device) + out = fake_apply_fp4_marlin_linear( + input=x, + weight=torch.empty(0, device=self.device), + weight_scale=torch.empty(0, device=self.device), + weight_global_scale=torch.empty(0, device=self.device), + workspace=torch.empty(0, device=self.device), + size_n=N, + size_k=K, + ) + expected_shape = input_shape[:-1] + (N,) + self.assertEqual(out.shape, expected_shape) + self.assertEqual(out.dtype, self.dtype) + + def test_prepare_rejects_bad_weight_shape(self): + """prepare_fp4_layer_for_marlin must raise on mismatched weight shape.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_fp4_layer_for_marlin, + ) + + N, K = 256, 128 + layer = _FakeLayer() + layer.params_dtype = self.dtype + layer.input_size_per_partition = K + layer.output_size_per_partition = N + + layer.weight = torch.nn.Parameter( + torch.randint( + 0, 256, (N + 1, K // 2), dtype=torch.uint8, device=self.device + ), + requires_grad=False, + ) + layer.weight_scale = torch.nn.Parameter( + torch.ones( + N, + K // _FP4_MARLIN_GROUP_SIZE, + dtype=torch.float8_e4m3fn, + device=self.device, + ), + requires_grad=False, + ) + layer.weight_scale_2_marlin = torch.nn.Parameter( + torch.tensor(1.0, dtype=torch.float32, device=self.device), + requires_grad=False, + ) + + with self.assertRaises(AssertionError): + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + def test_prepare_fp4_layer_permutes_bias(self): + """prepare_fp4_layer_for_marlin must permute layer.bias when present.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_fp4_layer_for_marlin, + ) + + N, K = 256, 128 + layer = self._make_fake_fp4_layer(N, K) + original_bias = torch.randn(N, dtype=self.dtype, device=self.device) + layer.bias = torch.nn.Parameter(original_bias.clone(), requires_grad=False) + + prepare_fp4_layer_for_marlin( + layer, + weight_attr="weight", + weight_scale_attr="weight_scale", + weight_global_scale_attr="weight_scale_2_marlin", + ) + + self.assertEqual(layer.bias.shape, (N,)) + self.assertEqual(layer.bias.dtype, self.dtype) + self.assertFalse( + torch.equal(layer.bias.data, original_bias), + "Bias should be permuted by prepare_fp4_layer_for_marlin", + ) + + def test_fp4_marlin_custom_op_registration(self): + """apply_fp4_marlin_linear must be registered as torch.ops.sglang for PCG.""" + import sglang.srt.layers.quantization.marlin_utils_fp4 # noqa: F401 + + self.assertTrue( + hasattr(torch.ops.sglang, "apply_fp4_marlin_linear"), + "apply_fp4_marlin_linear not registered as a custom op", + ) + + def test_nvfp4_marlin_scale_values_correctness(self): + """Verify scale conversion produces analytically correct values.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + nvfp4_marlin_process_global_scale, + nvfp4_marlin_process_scales, + ) + + # -- global scale: BF16 -- + # fp4_exp=2, target_exp=8 => bias = 2^7 - 2^1 = 126 + # result = 1.0 * 2^(126-7) = 2^119 + gs_bf16 = torch.tensor(1.0, dtype=torch.bfloat16, device=self.device) + result_bf16 = nvfp4_marlin_process_global_scale(gs_bf16) + expected_bf16 = torch.tensor(2.0**119, dtype=torch.bfloat16, device=self.device) + self.assertEqual( + result_bf16.item(), + expected_bf16.item(), + f"BF16 global_scale(1.0): expected 2^119, got {result_bf16.item()}", + ) + self.assertEqual(result_bf16.dtype, torch.bfloat16) + + # -- global scale: FP16 -- + # fp4_exp=2, target_exp=5 => bias = 2^4 - 2^1 = 14 + # result = 1.0 * 2^(14-7) = 128 + gs_fp16 = torch.tensor(1.0, dtype=torch.float16, device=self.device) + result_fp16 = nvfp4_marlin_process_global_scale(gs_fp16) + self.assertEqual( + result_fp16.item(), + 128.0, + f"FP16 global_scale(1.0): expected 128.0, got {result_fp16.item()}", + ) + self.assertEqual(result_fp16.dtype, torch.float16) + + # -- global scale: linearity -- + gs_2 = torch.tensor(2.0, dtype=torch.bfloat16, device=self.device) + result_2 = nvfp4_marlin_process_global_scale(gs_2) + self.assertAlmostEqual( + result_2.item(), + 2.0 * result_bf16.item(), + places=0, + msg="Global scale processing should be linear", + ) + + # -- per-group scales: structural properties -- + N, K_div_group = 64, 16 + raw_scale = torch.ones( + N, K_div_group, dtype=torch.float8_e4m3fn, device=self.device + ).to(self.dtype) + processed = nvfp4_marlin_process_scales(raw_scale) + + self.assertEqual(processed.dtype, torch.float8_e4m3fn) + self.assertEqual(processed.shape, (N, K_div_group)) + self.assertFalse(torch.isnan(processed.to(self.dtype)).any()) + + # Deterministic + self.assertTrue(torch.equal(processed, nvfp4_marlin_process_scales(raw_scale))) + + # Large scales (448 = FP8 E4M3 max) must not produce NaN + large_scale = torch.full( + (N, K_div_group), 448.0, dtype=self.dtype, device=self.device + ) + proc_large = nvfp4_marlin_process_scales(large_scale) + self.assertFalse(torch.isnan(proc_large.to(self.dtype)).any()) + self.assertEqual(proc_large.shape, (N, K_div_group)) + + +# --------------------------------------------------------------------------- +# MoE tests +# --------------------------------------------------------------------------- +class TestNvfp4MarlinMoe(CustomTestCase): + """Test the FP4 Marlin MoE fallback.""" + + def setUp(self): + if not _check_requirements(): + self.skipTest("Requirements not met (CUDA unavailable or SM < 75)") + self.device = torch.device("cuda") + self.dtype = torch.bfloat16 + try: + from sglang.jit_kernel.gptq_marlin_repack import gptq_marlin_repack + + self._gptq_marlin_repack = gptq_marlin_repack + except ImportError: + self.skipTest("gptq_marlin_repack JIT compilation not available") + self._perm = torch.empty(0, dtype=torch.int, device=self.device) + + # -- helpers ------------------------------------------------------------- + + def _repack_fp4_weight(self, raw_fp4, size_k, size_n): + """Repack raw uint8 FP4 weights into Marlin tile layout.""" + qw = raw_fp4.view(torch.int32).T.contiguous() + return self._gptq_marlin_repack(qw, self._perm, size_k, size_n, num_bits=4) + + def _make_marlin_scale(self, size_k, size_n): + from sglang.srt.layers.quantization.marlin_utils import marlin_permute_scales + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + nvfp4_marlin_process_scales, + ) + + raw = torch.ones( + size_k // _FP4_MARLIN_GROUP_SIZE, + size_n, + dtype=self.dtype, + device=self.device, + ) + permuted = marlin_permute_scales(raw, size_k, size_n, _FP4_MARLIN_GROUP_SIZE) + return nvfp4_marlin_process_scales(permuted) + + def _make_processed_global_scale(self): + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + nvfp4_marlin_process_global_scale, + ) + + return nvfp4_marlin_process_global_scale( + torch.tensor(1.0, dtype=self.dtype, device=self.device) + ) + + # -- tests --------------------------------------------------------------- + + def test_fused_marlin_moe_fp4(self): + """Smoke test: shape, dtype, no NaN for multi-expert MoE.""" + from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import ( + fused_marlin_moe, + ) + + E, K, N, topk, M = 4, 128, 64, 2, 8 + + def _rand_weight(size_k, size_n): + raw = torch.randint( + 0, 256, (size_n, size_k // 2), dtype=torch.uint8, device=self.device + ) + return self._repack_fp4_weight(raw, size_k, size_n) + + w1 = torch.stack([_rand_weight(K, 2 * N) for _ in range(E)]) + w2 = torch.stack([_rand_weight(N, K) for _ in range(E)]) + w1_scale = torch.stack([self._make_marlin_scale(K, 2 * N) for _ in range(E)]) + w2_scale = torch.stack([self._make_marlin_scale(N, K) for _ in range(E)]) + + gs = self._make_processed_global_scale() + w1_gs = gs.expand(E) + w2_gs = gs.expand(E) + + hidden = torch.randn(M, K, dtype=self.dtype, device=self.device) + gating = torch.randn(M, E, dtype=self.dtype, device=self.device) + topk_weights, topk_ids = torch.topk(torch.softmax(gating, dim=-1), topk, dim=-1) + + output = fused_marlin_moe( + hidden_states=hidden, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + gating_output=gating, + topk_weights=topk_weights, + topk_ids=topk_ids, + num_bits=4, + w1_global_scale=w1_gs, + w2_global_scale=w2_gs, + ) + + self.assertEqual(output.shape, (M, K)) + self.assertEqual(output.dtype, self.dtype) + self.assertFalse(torch.isnan(output).any(), "Output contains NaN!") + + def test_fused_marlin_moe_fp4_numerical(self): + """E=1, topk=1 MoE output vs dequant reference (SiLU-gated).""" + from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import ( + fused_marlin_moe, + ) + + E, K, N, topk, M = 1, 128, 64, 1, 8 + + raw_w1 = torch.randint( + 0, 256, (2 * N, K // 2), dtype=torch.uint8, device=self.device + ) + raw_w2 = torch.randint( + 0, 256, (K, N // 2), dtype=torch.uint8, device=self.device + ) + dq_w1 = _dequant_fp4_weights(raw_w1, self.device) + dq_w2 = _dequant_fp4_weights(raw_w2, self.device) + + w1 = self._repack_fp4_weight(raw_w1, K, 2 * N).unsqueeze(0) + w2 = self._repack_fp4_weight(raw_w2, N, K).unsqueeze(0) + w1_scale = self._make_marlin_scale(K, 2 * N).unsqueeze(0) + w2_scale = self._make_marlin_scale(N, K).unsqueeze(0) + + gs = self._make_processed_global_scale() + w1_gs = gs.unsqueeze(0) + w2_gs = gs.unsqueeze(0) + + x = torch.randn(M, K, dtype=self.dtype, device=self.device) * 0.1 + gating = torch.ones(M, E, dtype=self.dtype, device=self.device) + topk_weights = torch.ones(M, topk, dtype=self.dtype, device=self.device) + topk_ids = torch.zeros(M, topk, dtype=torch.int64, device=self.device) + + output = fused_marlin_moe( + hidden_states=x, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + gating_output=gating, + topk_weights=topk_weights, + topk_ids=topk_ids, + num_bits=4, + w1_global_scale=w1_gs, + w2_global_scale=w2_gs, + ) + + gate_up = x.float() @ dq_w1.T + gate, up = gate_up[:, :N], gate_up[:, N:] + ref_output = ((torch.nn.functional.silu(gate) * up) @ dq_w2.T).to(self.dtype) + + self.assertEqual(output.shape, ref_output.shape) + self.assertFalse(torch.isinf(output).any(), "MoE output contains Inf") + self.assertFalse(torch.isnan(output).any(), "MoE output contains NaN") + + finite = torch.isfinite(ref_output) & torch.isfinite(output) + if finite.any(): + cos_sim = torch.nn.functional.cosine_similarity( + output[finite].float().flatten(), + ref_output[finite].float().flatten(), + dim=0, + ) + self.assertGreater( + cos_sim.item(), + 0.90, + f"MoE cosine similarity {cos_sim.item():.4f} too low", + ) + + def test_prepare_moe_fp4_layer_for_marlin(self): + """Weight repacking produces correct shapes for all expert tensors.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_moe_fp4_layer_for_marlin, + ) + + E, K, N = 4, 128, 64 + + class _FakeMoeRunnerConfig: + is_gated = True + + layer = _FakeLayer() + layer.num_local_experts = E + layer.intermediate_size_per_partition = N + layer.params_dtype = self.dtype + layer.moe_runner_config = _FakeMoeRunnerConfig() + + layer.w13_weight = torch.nn.Parameter( + torch.randint( + 0, 256, (E, 2 * N, K // 2), dtype=torch.uint8, device=self.device + ), + requires_grad=False, + ) + layer.w2_weight = torch.nn.Parameter( + torch.randint( + 0, 256, (E, K, N // 2), dtype=torch.uint8, device=self.device + ), + requires_grad=False, + ) + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + E, + 2 * N, + K // _FP4_MARLIN_GROUP_SIZE, + dtype=torch.float8_e4m3fn, + device=self.device, + ), + requires_grad=False, + ) + layer.w2_weight_scale = torch.nn.Parameter( + torch.ones( + E, + K, + N // _FP4_MARLIN_GROUP_SIZE, + dtype=torch.float8_e4m3fn, + device=self.device, + ), + requires_grad=False, + ) + layer.w13_weight_scale_2 = torch.nn.Parameter( + torch.ones(E, 2, dtype=torch.float32, device=self.device), + requires_grad=False, + ) + layer.w2_weight_scale_2 = torch.nn.Parameter( + torch.ones(E, dtype=torch.float32, device=self.device), + requires_grad=False, + ) + + prepare_moe_fp4_layer_for_marlin(layer) + + self.assertEqual(layer.w13_weight.shape[0], E) + self.assertEqual(layer.w2_weight.shape[0], E) + self.assertEqual(layer.w13_weight_scale_2.shape, (E,)) + self.assertEqual(layer.w2_weight_scale_2.shape, (E,)) + + +# --------------------------------------------------------------------------- +# Support / capability tests +# --------------------------------------------------------------------------- +class TestFp4MarlinSupport(CustomTestCase): + """Test the capability detection functions.""" + + def test_is_fp4_marlin_supported(self): + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + is_fp4_marlin_supported, + ) + + result = is_fp4_marlin_supported() + if torch.cuda.is_available() and torch.version.hip is None: + cap = torch.cuda.get_device_capability() + sm = cap[0] * 10 + cap[1] + expected = sm >= 75 + self.assertEqual(result, expected) + elif torch.version.hip is not None: + self.assertFalse(result, "FP4 Marlin should not be supported on ROCm/HIP") + + def test_min_capability_changed(self): + """get_min_capability() must return 75 (not 100).""" + from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config + + cap = ModelOptFp4Config.get_min_capability() + self.assertEqual(cap, 75, f"Expected 75, got {cap}") + + def test_should_use_fp4_marlin_fallback(self): + """should_use_fp4_marlin_fallback returns True on non-Blackwell SM>=75.""" + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + should_use_fp4_marlin_fallback, + ) + + result = should_use_fp4_marlin_fallback() + self.assertIsInstance(result, bool) + + if torch.cuda.is_available() and torch.version.hip is None: + cap = torch.cuda.get_device_capability() + sm = cap[0] * 10 + cap[1] + is_blackwell = sm >= 100 + if is_blackwell: + self.assertFalse( + result, + "Blackwell GPUs should NOT use Marlin fallback (native FP4)", + ) + elif sm >= 75: + self.assertTrue( + result, + f"SM{sm} should use Marlin fallback, but got False", + ) + else: + self.assertFalse( + result, + f"SM{sm} should not support FP4 Marlin at all", + ) + + +if __name__ == "__main__": + unittest.main(verbosity=3) diff --git a/test/registered/unit/model_loader/test_modelopt_loader.py b/test/registered/unit/model_loader/test_modelopt_loader.py index 7f9652c0e5db..9ad6183a0b0a 100644 --- a/test/registered/unit/model_loader/test_modelopt_loader.py +++ b/test/registered/unit/model_loader/test_modelopt_loader.py @@ -646,7 +646,11 @@ def test_mixed_precision_override_does_not_hijack_w4afp8(self): ) def test_mixed_precision_uses_nvfp4_min_capability(self): - self.assertEqual(ModelOptMixedPrecisionConfig.get_min_capability(), 100) + """NVFP4 supports SM75+ (Turing) via Marlin fallback; min_capability must be >= 75.""" + cap = ModelOptMixedPrecisionConfig.get_min_capability() + self.assertGreaterEqual( + cap, 75, f"NVFP4 requires SM75+ (Marlin fallback); got min_capability={cap}" + ) def test_mixed_precision_quant_layer_resolution_after_mapping(self): quant_config = ModelOptMixedPrecisionConfig.from_config(