diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 8a0ec4729a69..3c4b3ac12756 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -139,8 +139,9 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): from triton_kernels.tensor_details import layout if is_sm120_supported(): - # SM120 (Blackwell desktop) doesn't support persistent kernels / TMA block layout - # Use StridedLayout and disable persistent kernels to avoid assertion errors + # SM120 desktop Blackwell does not support the persistent/TMA MXFP4 path. + # This MXFP4 path uses StridedLayout and the non-persistent kernel with + # block_k=128 so the selected tile stays within the per-block shared-memory budget. from triton_kernels.tensor_details.layout import StridedLayout value_layout = StridedLayout @@ -149,6 +150,7 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): scale_layout_opts = {} constraints = { "is_persistent": False, + "block_k": 128, "num_stages": 1, } opt_flags.update_opt_flags_constraints(constraints)