Align triton_kernels with Triton 3.6.0 and fix SM120 MXFP4 MoE performance#24281
Align triton_kernels with Triton 3.6.0 and fix SM120 MXFP4 MoE performance#24281mmangkad wants to merge 3 commits into
triton_kernels with Triton 3.6.0 and fix SM120 MXFP4 MoE performance#24281Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the Triton dependency to version 3.6.0 and adapts the MoE and quantization layers accordingly. Key changes include a local implementation of the routing function in topk.py, which introduces a regression by disabling simulated expert parallelism, and the addition of a monkey-patching utility in mxfp4.py to enforce a minimum warp count for specific matmuls on SM120 hardware. Additionally, the is_triton_kernels_available check was expanded to include ragged tensor metadata, and FusedActivation calls were updated to include explicit reduction parameters. One piece of feedback notes the loss of functionality for simulated expert parallelism.
|
/rerun-failed-ci |
|
could you align with new 3.7.0? It fix behavior agx thor and dgx spark |
|
Disclosure: Atlas maintainer. We carry a patch on roughly this same gap, dropping it here in case it shortens the review. The Triton MXFP4 path on What worked for us was bypassing the PTX path entirely with a software E2M1 conversion. We patch FlashInfer's CUTLASS headers at container build time:
The patch is roughly 30 lines, disables
For the SGLang-internal route specifically, the equivalent NVFP4-via-Marlin-W4A16 fallback is probably the path of least resistance, since you avoid the FlashInfer dependency in the kernel build. Either way, happy to land the FlashInfer patch as a PR if it'd help reviewers compare approaches. |
|
@mmangkad Can you please fix the conflicts |
|
To be included in #25312 |
Summary
triton_kernelssource with the Triton 3.6.0 version shipped by Torch.triton_kernelsMoE integration for Triton 3.6.0 API changes:GatherIndx,RoutingData, andScatterIndxmoved fromtriton_kernels.routingtotriton_kernels.matmul_ogs.triton_kernels.routingis no longer exposed, so SGLang now rebuilds routing fromtriton_kernels.topkand ragged tensor metadata.swiglufused activation now passesreduction_n=2throughFnSpecs.is_triton_kernels_available()so stale pre-3.6 installs are not treated as compatible.The SM120 performance issue comes from this
triton_kernelsheuristic change:For small decode/ragged MoE batches on SM120 this can select
num_warps=1, which makes throughput very slow. Without this patch, I was seeing SM120 decode around 35-36 token/s. This patch restores the old 4-warp floor only for SM120 non-persistentStridedLayoutMXFP4 matmuls.This also removes the explicit SM120
block_k=128override added in #20040. That override was added because the oldertriton_kernelspath could hitassert num_stages >= 1during GPT-OSS startup on SM120. I think this was most likely being tripped during PCG warmup/capture. With the Triton 3.6.0triton_kernelspath, I can no longer reproduce that failure without the override, so this PR lets Triton choose its defaultblock_kagain, which is currently 256 for this path.The SM120
num_warpspatch is ugly, but I think this is the only practical way to restore SM120 performance from SGLang right now.triton_kernelsdoes not exposenum_warpsas an opt-flag constraint, so this patch narrowly adjusts the heuristic only for the SM120 non-persistentStridedLayoutMXFP4 path.Accuracy Tests
H200 (MXFP4):
python -m gpt_oss.evals --model openai/gpt-oss-20b --eval gpqa --n-threads 256 --reasoning-effort low --base-url http://127.0.0.1:30000/v1 Writing report to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20260502_212928.html {'chars': np.float64(52.16729797979798), 'chars:std': np.float64(218.80938828184415), 'score': np.float64(0.5744949494949495), 'score:std': np.float64(0.494419358945162)} Writing results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20260502_212928.json Writing all results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20260502_212928_allresults.json [{'eval_name': 'gpqa', 'model_name': 'openai__gpt-oss-20b-low_temp1.0_20260502_212928', 'metric': 0.5744949494949495}]GB300 (MXFP4):
python -m gpt_oss.evals --model openai/gpt-oss-20b --eval gpqa --n-threads 512 --reasoning-effort low --base-url http://127.0.0.1:30000/v1 Writing report to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20260502_214250.html {'chars': np.float64(52.530934343434346), 'chars:std': np.float64(200.0052121862134), 'score': np.float64(0.5549242424242424), 'score:std': np.float64(0.4969741719587881)} Writing results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20260502_214250.json Writing all results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20260502_214250_allresults.json [{'eval_name': 'gpqa', 'model_name': 'openai__gpt-oss-20b-low_temp1.0_20260502_214250', 'metric': 0.5549242424242424}]GB300 (BF16):
python -m gpt_oss.evals --model lmsys/gpt-oss-20b-bf16 --eval gpqa --n-threads 512 --reasoning-effort low --base-url http://127.0.0.1:30000/v1 Writing report to /tmp/gpqa_lmsys__gpt-oss-20b-bf16-low_temp1.0_20260502_214425.html {'chars': np.float64(50.92550505050505), 'chars:std': np.float64(214.67684371696978), 'score': np.float64(0.5568181818181818), 'score:std': np.float64(0.4967612044180544)} Writing results to /tmp/gpqa_lmsys__gpt-oss-20b-bf16-low_temp1.0_20260502_214425.json Writing all results to /tmp/gpqa_lmsys__gpt-oss-20b-bf16-low_temp1.0_20260502_214425_allresults.json [{'eval_name': 'gpqa', 'model_name': 'lmsys__gpt-oss-20b-bf16-low_temp1.0_20260502_214425', 'metric': 0.5568181818181818}]RTX PRO 6000 (MXFP4):
python -m gpt_oss.evals --model openai/gpt-oss-20b --eval gpqa --n-threads 256 --reasoning-effort low --base-url http://127.0.0.1:30000/v1 Writing report to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20260502_223913.html {'chars': np.float64(52.41856060606061), 'chars:std': np.float64(206.45098944881033), 'score': np.float64(0.553030303030303), 'score:std': np.float64(0.4971798336221153)} Writing results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20260502_223913.json Writing all results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20260502_223913_allresults.json [{'eval_name': 'gpqa', 'model_name': 'openai__gpt-oss-20b-low_temp1.0_20260502_223913', 'metric': 0.553030303030303}]