Skip to content

[FlashInfer v0.6.6][RL] Support fp8-last-n-bf16 RL for flashinfer_trtllm_routed moe backend#20214

Merged
Fridge003 merged 11 commits intosgl-project:mainfrom
zianglih:agent-flashinfer-bf16-routed
Mar 22, 2026
Merged

[FlashInfer v0.6.6][RL] Support fp8-last-n-bf16 RL for flashinfer_trtllm_routed moe backend#20214
Fridge003 merged 11 commits intosgl-project:mainfrom
zianglih:agent-flashinfer-bf16-routed

Conversation

@zianglih
Copy link
Contributor

@zianglih zianglih commented Mar 9, 2026

Motivation

@HumansAnd

This PR is the last missing piece of Miles Blackwell mxfp8 RL training: radixark/miles#615 , radixark/miles#614 .

Dependencies:

FlashInfer v0.6.5 has a bug flashinfer-ai/flashinfer#2697 (comment), so this PR cannot be merged until v0.6.6 release.

Modifications

  • Integrate bf16 trtllm_bf16_routed_moe from FlashInfer v0.6.6, mirroring existing fp8/mxfp8 code path in [FlashInfer v0.6.4] [RL] Integrate FlashInfer mxfp8 gemm, MoE, and routed MoE #19537
  • Expand test coverage in test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py
  • Add a weight update test for mxfp8-last-n-bf16 checkpoint + flashinfer_trtllm_routed moe backend in test/registered/rl/test_update_weights_blackwell.py

Accuracy Tests

root@B200-78:/sgl-workspace/sglang# python3 -m pytest -q test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py
......                                                                                                         [100%]
================================================== warnings summary ==================================================
../../usr/local/lib/python3.12/dist-packages/_pytest/config/__init__.py:1428
  /usr/local/lib/python3.12/dist-packages/_pytest/config/__init__.py:1428: PytestConfigWarning: Unknown config option: asyncio_mode
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

registered/backends/test_flashinfer_trtllm_gen_moe_backend.py::TestFlashinferTrtllmGenMoeBackendFP8::test_gsm8k
registered/backends/test_flashinfer_trtllm_gen_moe_backend.py::TestFlashinferTrtllmGenMoeBackendMXFP8::test_gsm8k
registered/backends/test_flashinfer_trtllm_gen_moe_backend.py::TestFlashinferTrtllmGenMoeBackendBF16::test_gsm8k
registered/backends/test_flashinfer_trtllm_gen_moe_backend.py::TestFlashinferTrtllmGenMoeBackendFP8Routed::test_gsm8k
registered/backends/test_flashinfer_trtllm_gen_moe_backend.py::TestFlashinferTrtllmGenMoeBackendMXFP8Routed::test_gsm8k
registered/backends/test_flashinfer_trtllm_gen_moe_backend.py::TestFlashinferTrtllmGenMoeBackendBF16Routed::test_gsm8k
  /sgl-workspace/sglang/python/sglang/test/few_shot_gsm8k.py:54: DeprecationWarning: Including the scheme in --host ('http://127.0.0.1') is deprecated. Pass just the hostname (e.g. '127.0.0.1') instead.
    set_default_backend(RuntimeEndpoint(normalize_base_url(args.host, args.port)))

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
6 passed, 7 warnings in 615.78s (0:10:15)
root@B200-78:/sgl-workspace/sglang# python3 -m pytest -q test/registered/rl/test_update_weights_blackwell.py
...                                                                                                   [100%]
============================= warnings summary =============================
../../usr/local/lib/python3.12/dist-packages/_pytest/config/__init__.py:1428
  /usr/local/lib/python3.12/dist-packages/_pytest/config/__init__.py:1428: PytestConfigWarning: Unknown config option: asyncio_mode
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
3 passed, 1 warning, 9 subtests passed in 442.51s (0:07:22)

Benchmarking and Profiling

For softmax MoE models like Qwen3-30B-A3B, enabling piecewise CUDA graph causes instable numerics, which might be relevant to torch compiling:

def fused_topk_softmax_torch_raw_logits(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
assert (
hidden_states.shape[0] == gating_output.shape[0]
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
_, topk_ids = torch.topk(gating_output, k=topk, dim=-1, sorted=False)
logits = gating_output.float()
topk_weights = logits.gather(1, topk_ids)
if renormalize:
topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)

bf16 checkpoint with flashinfer_trtllm (baseline)

python -m sglang.launch_server --kv-cache-dtype bf16 --model Qwen/Qwen3-30B-A3B-Instruct-2507 --fp8-gemm-backend flashinfer_trtllm --moe-runner-backend flashinfer_trtllm &
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.967
Invalid: 0.000
Latency: 8.572 s
Output throughput: 19904.440 token/s
Accuracy: 0.967
Invalid: 0.000
Latency: 8.504 s
Output throughput: 20070.594 token/s
Accuracy: 0.967
Invalid: 0.000
Latency: 8.234 s
Output throughput: 20726.172 token/s

bf16 checkpoint with flashinfer_trtllm_routed

python -m sglang.launch_server --kv-cache-dtype bf16 --model Qwen/Qwen3-30B-A3B-Instruct-2507 --fp8-gemm-backend flashinfer_trtllm --moe-runner-backend flashinfer_trtllm_routed &
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.965
Invalid: 0.000
Latency: 9.317 s
Output throughput: 18291.868 token/s
Accuracy: 0.966
Invalid: 0.000
Latency: 9.254 s
Output throughput: 18426.035 token/s
Accuracy: 0.965
Invalid: 0.000
Latency: 8.928 s
Output throughput: 19086.286 token/s

bf16 checkpoint with flashinfer_trtllm_routed with --disable-piecewise-cuda-graph

python -m sglang.launch_server --kv-cache-dtype bf16 --model Qwen/Qwen3-30B-A3B-Instruct-2507 --fp8-gemm-backend flashinfer_trtllm --moe-runner-backend flashinfer_trtllm_routed --disable-piecewise-cuda-graph &
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.965
Invalid: 0.000
Latency: 9.526 s
Output throughput: 17901.575 token/s
Accuracy: 0.965
Invalid: 0.000
Latency: 9.278 s
Output throughput: 18392.800 token/s
Accuracy: 0.965
Invalid: 0.000
Latency: 8.978 s
Output throughput: 18991.882 token/s

mxfp8-last-n-bf16 checkpoint

python -m sglang.launch_server --kv-cache-dtype bf16 --model zianglih/Qwen3-30B-A3B-Instruct-2507-MXFP8-last-8-BF16 --fp8-gemm-backend flashinfer_trtllm --moe-runner-backend flashinfer_trtllm_routed &
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.970
Invalid: 0.000
Latency: 10.267 s
Output throughput: 16594.531 token/s
Accuracy: 0.971
Invalid: 0.000
Latency: 10.003 s
Output throughput: 17022.217 token/s
Accuracy: 0.969
Invalid: 0.000
Latency: 9.421 s
Output throughput: 18088.243 token/s

mxfp8-last-n-bf16 checkpoint with --disable-piecewise-cuda-graph

python -m sglang.launch_server --kv-cache-dtype bf16 --model zianglih/Qwen3-30B-A3B-Instruct-2507-MXFP8-last-8-BF16 --fp8-gemm-backend flashinfer_trtllm --moe-runner-backend flashinfer_trtllm_routed --disable-piecewise-cuda-graph &
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.970
Invalid: 0.000
Latency: 9.863 s
Output throughput: 17271.128 token/s
Accuracy: 0.970
Invalid: 0.000
Latency: 9.806 s
Output throughput: 17368.509 token/s
Accuracy: 0.970
Invalid: 0.000
Latency: 9.632 s
Output throughput: 17686.235 token/s

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions bot added documentation Improvements or additions to documentation quant LLM Quantization labels Mar 9, 2026
@ziang-and ziang-and force-pushed the agent-flashinfer-bf16-routed branch 9 times, most recently from 46e7e79 to 4dabe69 Compare March 10, 2026 22:57
@zianglih zianglih changed the title [FlashInfer v0.6.5][RL] Integrate trtllm_bf16_routed_moe [FlashInfer v0.6.6][RL] Integrate trtllm_bf16_routed_moe Mar 10, 2026
@ziang-and ziang-and force-pushed the agent-flashinfer-bf16-routed branch from 4dabe69 to 78b51be Compare March 12, 2026 23:08
@zianglih zianglih marked this pull request as ready for review March 13, 2026 07:02
@ziang-and ziang-and force-pushed the agent-flashinfer-bf16-routed branch from 78b51be to c44fda8 Compare March 15, 2026 04:44
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want a dedicated blackwell test file because mxfp/nvfp data types usually requires an extra weight swizzling step after weight update. We want a separate file for testing these behaviors.

@zianglih zianglih changed the title [FlashInfer v0.6.6][RL] Integrate trtllm_bf16_routed_moe [FlashInfer v0.6.6][RL] Integrate trtllm_bf16_routed_moe and support fp8-last-n-bf16 weight update Mar 15, 2026
@zianglih zianglih changed the title [FlashInfer v0.6.6][RL] Integrate trtllm_bf16_routed_moe and support fp8-last-n-bf16 weight update [FlashInfer v0.6.6][RL] Support fp8-last-n-bf16 RL for flashinfer_trtllm_routed moe Mar 15, 2026
@zianglih zianglih changed the title [FlashInfer v0.6.6][RL] Support fp8-last-n-bf16 RL for flashinfer_trtllm_routed moe [FlashInfer v0.6.6][RL] Support fp8-last-n-bf16 RL for flashinfer_trtllm_routed moe backend Mar 15, 2026
@ziang-and ziang-and force-pushed the agent-flashinfer-bf16-routed branch from 6d16283 to c47bd31 Compare March 18, 2026 06:36
@zianglih zianglih requested a review from Fridge003 March 18, 2026 07:28
@zianglih
Copy link
Contributor Author

I have renamed test_update_weights_blackwell.py to test_update_weights_from_disk_mxfp8.py and only test /update_weights_from_disk, as it does a full model weight update and is sufficient to expose any swizzling/re-layout bugs for mxfp/nvfp layers.

In the future when we expand the test coverage for like mxfp4/nvfp4 we can rename the test.

@Fridge003
Copy link
Collaborator

/tag-and-rerun-ci

@Fridge003
Copy link
Collaborator

/rerun-stage stage-c-test-4-gpu-b200

@github-actions
Copy link
Contributor

✅ Triggered stage-c-test-4-gpu-b200 to run independently (skipping dependencies).

@github-actions
Copy link
Contributor

🔗 View workflow run

@Fridge003
Copy link
Collaborator

/rerun-ut test_update_weights_from_disk_mxfp8.py

@github-actions
Copy link
Contributor

/rerun-ut is not available for fork PRs (security restriction).

Please ask a maintainer to add the run-ci label and use the normal CI flow, or use /rerun-failed-ci to rerun workflows that have already passed the gate.

@Fridge003
Copy link
Collaborator

/rerun-stage stage-c-test-4-gpu-b200

1 similar comment
@Fridge003
Copy link
Collaborator

/rerun-stage stage-c-test-4-gpu-b200

@Fridge003 Fridge003 merged commit ce05414 into sgl-project:main Mar 22, 2026
270 of 295 checks passed
OrangeRedeng pushed a commit to OrangeRedeng/sglang that referenced this pull request Mar 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation quant LLM Quantization run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants