Skip to content

Fix SM120 triton_kernels MXFP4 block_k for GPT-OSS#20040

Merged
Kangyan-Zhou merged 1 commit intosgl-project:mainfrom
mmangkad-dev:fix-sm120-triton-kernels
Mar 6, 2026
Merged

Fix SM120 triton_kernels MXFP4 block_k for GPT-OSS#20040
Kangyan-Zhou merged 1 commit intosgl-project:mainfrom
mmangkad-dev:fix-sm120-triton-kernels

Conversation

@mmangkad
Copy link
Copy Markdown
Contributor

@mmangkad mmangkad commented Mar 6, 2026

Motivation

On SM120, the triton_kernels MXFP4 path can pick a tile that exceeds the per-block shared-memory budget and hits assert num_stages >= 1 during GPT-OSS startup. This sets block_k=128 for the SM120 MXFP4 path, which is the largest power-of-two tile that fits this kernel’s requirements and the SM120 shared-memory limit.

Full crash log
sglang serve --model-path openai/gpt-oss-120b --reasoning-parser gpt-oss --tool-call-parser gpt-oss
...
[2026-03-06 08:43:46] Using KV cache dtype: torch.bfloat16
[2026-03-06 08:43:46] Use sliding window memory pool. full_layer_tokens=318715, swa_layer_tokens=254972
[2026-03-06 08:43:46] KV Cache is allocated. #tokens: 254972, K size: 4.38 GB, V size: 4.38 GB
[2026-03-06 08:43:46] KV Cache is allocated. #tokens: 318715, K size: 5.47 GB, V size: 5.47 GB
[2026-03-06 08:43:46] SWAKVPool mem usage: 19.70 GB, swa size: 254972, full size: 318715
[2026-03-06 08:43:46] Memory pool end. avail mem=12.33 GB
[2026-03-06 08:43:46] Capture cuda graph begin. This can take up to several minutes. avail mem=12.25 GB
[2026-03-06 08:43:46] Capture cuda graph bs [1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256]
Capturing batches (bs=1 avail_mem=11.50 GB): 100%|██████████| 36/36 [01:09<00:00,  1.93s/it]
[2026-03-06 08:44:56] Capture cuda graph end. Time elapsed: 69.83 s. mem usage=0.75 GB. avail mem=11.49 GB.
[2026-03-06 08:44:56] Capture piecewise CUDA graph begin. avail mem=11.49 GB
[2026-03-06 08:44:56] Capture cuda graph num tokens [4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1280, 1536, 1792, 2048, 2304, 2560, 2816, 3072, 3328, 3584, 3840, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]
[2026-03-06 08:45:00] install_torch_compiled
Compiling num tokens (num_tokens=8192):   0%|          | 0/58 [00:00<?, ?it/s]/root/.local/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:1692: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
  torch._dynamo.utils.warn_once(msg)
[2026-03-06 08:45:08] Initializing SGLangBackend
[2026-03-06 08:45:08] SGLangBackend __call__
[2026-03-06 08:45:10] Compiling a graph for dynamic shape takes 0.64 s
[2026-03-06 08:45:10] Computation graph saved to /root/.cache/sglang/torch_compile_cache/rank_0_0/backbone/computation_graph_1772786710.2688763.py
Compiling num tokens (num_tokens=8192):   0%|          | 0/58 [00:12<?, ?it/s]
[2026-03-06 08:45:13] Piecewise CUDA Graph failed with error: 
Piecewise CUDA Graph is enabled by default as an experimental feature.
To work around this error, add --disable-piecewise-cuda-graph to your launch command.
Please report this issue at https://github.com/sgl-project/sglang/issues/new/choose
[2026-03-06 08:45:13] Scheduler hit an exception: Traceback (most recent call last):
  File "/sglang/python/sglang/srt/managers/scheduler.py", line 3237, in run_scheduler_process
    scheduler = Scheduler(
                ^^^^^^^^^^
  File "/sglang/python/sglang/srt/managers/scheduler.py", line 365, in __init__
    self.init_model_worker()
  File "/sglang/python/sglang/srt/managers/scheduler.py", line 561, in init_model_worker
    self.init_tp_model_worker()
  File "/sglang/python/sglang/srt/managers/scheduler.py", line 519, in init_tp_model_worker
    self.tp_worker = TpModelWorker(
                     ^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/managers/tp_worker.py", line 258, in __init__
    self._init_model_runner()
  File "/sglang/python/sglang/srt/managers/tp_worker.py", line 341, in _init_model_runner
    self._model_runner = ModelRunner(
                         ^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/model_executor/model_runner.py", line 416, in __init__
    self.initialize(min_per_gpu_memory)
  File "/sglang/python/sglang/srt/model_executor/model_runner.py", line 640, in initialize
    self.init_piecewise_cuda_graphs()
  File "/sglang/python/sglang/srt/model_executor/model_runner.py", line 2247, in init_piecewise_cuda_graphs
    self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py", line 306, in __init__
    self.warmup_compile(num_tokens=num_tokens)
  File "/sglang/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py", line 403, in warmup_compile
    _ = self.model_runner.model.forward(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/models/gpt_oss.py", line 635, in forward
    hidden_states = self.model(
                    ^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/compilation/compile.py", line 194, in trampoline
    _ensure_compiled(self, *args, **kwargs)
  File "/sglang/python/sglang/srt/compilation/compile.py", line 185, in _ensure_compiled
    compiled_callable(*args, **kwargs)
  File "/root/.local/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 832, in compile_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/models/gpt_oss.py", line 541, in forward
    def forward(
  File "/root/.local/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 414, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/fx/graph_module.py", line 837, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/fx/graph_module.py", line 413, in __call__
    raise e
  File "/root/.local/lib/python3.12/site-packages/torch/fx/graph_module.py", line 400, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.74", line 269, in forward
    submod_2 = self.submod_2(getitem_3, s72, l_self_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_, l_self_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_bias_, l_self_modules_layers_modules_0_layer_communicator_post_attention_layernorm_parameters_weight_, getitem_4, l_self_modules_layers_modules_1_layer_communicator_input_layernorm_parameters_weight_, l_self_modules_layers_modules_1_modules_self_attn_modules_qkv_proj_parameters_weight_, l_self_modules_layers_modules_1_modules_self_attn_modules_qkv_proj_parameters_bias_, l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_, l_positions_);  getitem_3 = l_self_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_ = l_self_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_bias_ = l_self_modules_layers_modules_0_layer_communicator_post_attention_layernorm_parameters_weight_ = l_self_modules_layers_modules_1_layer_communicator_input_layernorm_parameters_weight_ = l_self_modules_layers_modules_1_modules_self_attn_modules_qkv_proj_parameters_weight_ = l_self_modules_layers_modules_1_modules_self_attn_modules_qkv_proj_parameters_bias_ = None
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/compilation/cuda_piecewise_backend.py", line 111, in __call__
    return self.compiled_graph_for_general_shape(*args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/fx/graph_module.py", line 837, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/fx/graph_module.py", line 413, in __call__
    raise e
  File "/root/.local/lib/python3.12/site-packages/torch/fx/graph_module.py", line 400, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.3", line 9, in forward
    moe_impl = torch.ops.sglang.moe_impl(0, linear);  linear = None
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/_ops.py", line 1255, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/models/gpt_oss.py", line 206, in moe_impl
    final_hidden_states = moe_fusion.experts(hidden_states, topk_output)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py", line 961, in forward
    return self.forward_impl(hidden_states, topk_output)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py", line 990, in forward_impl
    combine_input = self.run_moe_core(
                    ^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py", line 1011, in run_moe_core
    return self.quant_method.apply(
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/layers/quantization/mxfp4.py", line 902, in apply
    return self.runner.run(dispatch_output, quant_info)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/layers/moe/moe_runner/runner.py", line 96, in run
    runner_output = self.runner_core.run(runner_input, quant_info, running_state)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/layers/moe/moe_runner/triton_kernels.py", line 115, in run
    output = triton_kernel_fused_experts_with_bias(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py", line 306, in triton_kernel_fused_experts_with_bias
    matmul_ogs(
  File "/root/.local/lib/python3.12/site-packages/triton_kernels/matmul_ogs.py", line 370, in matmul_ogs
    opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/triton_kernels/matmul_ogs_details/opt_flags.py", line 302, in make_opt_flags
    return make_default_opt_flags_nvidia(*args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/lib/python3.12/site-packages/triton_kernels/matmul_ogs_details/opt_flags.py", line 218, in make_default_opt_flags_nvidia
    assert num_stages >= 1
           ^^^^^^^^^^^^^^^
AssertionError

@b8zhong

Accuracy Tests

python -m gpt_oss.evals --model openai/gpt-oss-120b --eval gpqa --n-threads 2048 --reasoning-effort low --base-url http://127.0.0.1:30000/v1

Writing report to /tmp/gpqa_openai__gpt-oss-120b-low_temp1.0_20260306_124153.html
{'chars': np.float64(97.75126262626263), 'chars:std': np.float64(264.199135273403), 'score': np.float64(0.6527777777777778), 'score:std': np.float64(0.4760871250278503)}
Writing results to /tmp/gpqa_openai__gpt-oss-120b-low_temp1.0_20260306_124153.json
Writing all results to /tmp/gpqa_openai__gpt-oss-120b-low_temp1.0_20260306_124153_allresults.json
[{'eval_name': 'gpqa', 'model_name': 'openai__gpt-oss-120b-low_temp1.0_20260306_124153', 'metric': 0.6527777777777778}]

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical issue affecting the GPT-OSS model on SM120 GPUs when utilizing MXFP4 quantization. Previously, the system would crash due to an out-of-budget shared memory allocation during kernel compilation. The implemented solution configures the kernel to use a specific block size (block_k=128) that is compatible with the SM120 architecture's memory constraints, ensuring stable operation without compromising accuracy, as validated by provided accuracy tests.

Highlights

  • Issue Addressed: Fixed a crash on SM120 (Blackwell desktop) when using the triton_kernels MXFP4 path for GPT-OSS, which occurred due to selecting a tile that exceeded the per-block shared-memory budget, leading to an AssertionError.
  • Solution Implemented: Explicitly set block_k=128 for the SM120 MXFP4 path. This ensures that the selected tile fits within the SM120 shared-memory limit and avoids the AssertionError.
  • Technical Detail: The fix leverages StridedLayout and a non-persistent kernel configuration for SM120, as it does not support persistent kernels or TMA block layout for MXFP4, ensuring compatibility and stability.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/quantization/mxfp4.py
    • Updated comments to clarify the SM120 MXFP4 path's behavior and the reason for using StridedLayout and block_k=128.
    • Added block_k: 128 to the constraints for SM120 to explicitly set the block size, preventing shared memory overflow.
Activity
  • No specific activity (comments, reviews, progress) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a runtime assertion failure on SM120 hardware by constraining block_k to 128 for the MXFP4 Triton kernel path. No security vulnerabilities were found. A suggestion has been made to use a named constant to improve code maintainability.

Copy link
Copy Markdown
Collaborator

@b8zhong b8zhong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This makes sense... it was a bit puzzling. I also encountered this bug at BS > 1

@b8zhong
Copy link
Copy Markdown
Collaborator

b8zhong commented Mar 6, 2026

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Mar 6, 2026
@Kangyan-Zhou Kangyan-Zhou merged commit 759700c into sgl-project:main Mar 6, 2026
146 of 163 checks passed
@mmangkad mmangkad deleted the fix-sm120-triton-kernels branch March 6, 2026 19:13
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Mar 6, 2026
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants