[FlashInfer v0.6.6][RL] Support fp8-last-n-bf16 RL for flashinfer_trtllm_routed moe backend#20214
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
46e7e79 to
4dabe69
Compare
trtllm_bf16_routed_moetrtllm_bf16_routed_moe
4dabe69 to
78b51be
Compare
78b51be to
c44fda8
Compare
There was a problem hiding this comment.
We want a dedicated blackwell test file because mxfp/nvfp data types usually requires an extra weight swizzling step after weight update. We want a separate file for testing these behaviors.
trtllm_bf16_routed_moetrtllm_bf16_routed_moe and support fp8-last-n-bf16 weight update
trtllm_bf16_routed_moe and support fp8-last-n-bf16 weight updateflashinfer_trtllm_routed moe
flashinfer_trtllm_routed moeflashinfer_trtllm_routed moe backend
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
6d16283 to
c47bd31
Compare
|
I have renamed In the future when we expand the test coverage for like mxfp4/nvfp4 we can rename the test. |
|
/tag-and-rerun-ci |
|
/rerun-stage stage-c-test-4-gpu-b200 |
|
✅ Triggered |
|
/rerun-ut test_update_weights_from_disk_mxfp8.py |
|
❌ Please ask a maintainer to add the |
|
/rerun-stage stage-c-test-4-gpu-b200 |
1 similar comment
|
/rerun-stage stage-c-test-4-gpu-b200 |
…tllm_routed` moe backend (sgl-project#20214)
Motivation
@HumansAnd
This PR is the last missing piece of Miles Blackwell mxfp8 RL training: radixark/miles#615 , radixark/miles#614 .
Dependencies:
trtllm_bf16_routed_moe: Bf16 routed moe flashinfer-ai/flashinfer#2594FlashInfer v0.6.5 has a bug flashinfer-ai/flashinfer#2697 (comment), so this PR cannot be merged until v0.6.6 release.
Modifications
trtllm_bf16_routed_moefrom FlashInfer v0.6.6, mirroring existing fp8/mxfp8 code path in [FlashInfer v0.6.4] [RL] Integrate FlashInfer mxfp8 gemm, MoE, and routed MoE #19537test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.pytest/registered/rl/test_update_weights_blackwell.pyAccuracy Tests
Benchmarking and Profiling
For softmax MoE models like
Qwen3-30B-A3B, enabling piecewise CUDA graph causes instable numerics, which might be relevant to torch compiling:sglang/python/sglang/srt/layers/moe/topk.py
Lines 447 to 463 in 538acb4
bf16 checkpoint with
flashinfer_trtllm(baseline)bf16 checkpoint with
flashinfer_trtllm_routedbf16 checkpoint with
flashinfer_trtllm_routedwith--disable-piecewise-cuda-graphmxfp8-last-n-bf16 checkpoint
mxfp8-last-n-bf16 checkpoint with
--disable-piecewise-cuda-graphpython -m sglang.launch_server --kv-cache-dtype bf16 --model zianglih/Qwen3-30B-A3B-Instruct-2507-MXFP8-last-8-BF16 --fp8-gemm-backend flashinfer_trtllm --moe-runner-backend flashinfer_trtllm_routed --disable-piecewise-cuda-graph & python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum Accuracy: 0.970 Invalid: 0.000 Latency: 9.863 s Output throughput: 17271.128 token/s Accuracy: 0.970 Invalid: 0.000 Latency: 9.806 s Output throughput: 17368.509 token/s Accuracy: 0.970 Invalid: 0.000 Latency: 9.632 s Output throughput: 17686.235 token/sChecklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci