Conversation
…er-ai#2086)" This reverts commit 9a79b78.
Signed-off-by: bruce.xu <bruce.xu@gmicloud.ai>
📝 WalkthroughWalkthroughThis PR introduces a new Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @Bruce-x-1997, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request aims to significantly enhance the performance of NVIDIA FP4 (nvfp4) and FP8 Mixture-of-Experts (MoE) implementations, addressing an observed gap where FP4's speedup over FP8 was less than its theoretical potential. The changes introduce a new Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces performance optimizations for nvfp4, primarily within trtllm_allreduce_fusion.cuh. The changes focus on reducing register pressure and tuning launch configurations, which are well-implemented and should yield performance gains. A new tile_tokens_dim parameter is also added for MoE benchmarks and kernels, with deprecation warnings for its use in fp8 kernels. The overall changes are consistent and improve the codebase. I have one suggestion to refactor a piece of logic for better readability.
| if constexpr (GetQuantType<Pattern> == QuantType::kFP4) { | ||
| // Try to use 160 as block_size if possible (better occupancy for FP4) | ||
| if (threads_per_token % 160 == 0 && 160 <= max_threads_per_block && 160 >= 128) { | ||
| block_size = 160; | ||
| cluster_size = threads_per_token / 160; | ||
| if (cluster_size > 8) cluster_size = 8; | ||
| } | ||
| // Fallback: try 192, 128 if 160 doesn't work | ||
| else if (threads_per_token % 192 == 0 && 192 <= max_threads_per_block && 192 >= 128) { | ||
| block_size = 192; | ||
| cluster_size = threads_per_token / 192; | ||
| if (cluster_size > 8) cluster_size = 8; | ||
| } else if (threads_per_token % 128 == 0 && 128 <= max_threads_per_block) { | ||
| block_size = 128; | ||
| cluster_size = threads_per_token / 128; | ||
| if (cluster_size > 8) cluster_size = 8; | ||
| } | ||
| // Update threads_per_block to match block_size for SM count check | ||
| threads_per_block = block_size; | ||
| } |
There was a problem hiding this comment.
The logic for selecting block_size for FP4 kernels can be simplified for better readability and maintainability. The conditions 160 >= 128 and 192 >= 128 are always true and can be removed. Also, the logic for capping cluster_size is repeated. Consider refactoring this block to reduce redundancy.
if constexpr (GetQuantType<Pattern> == QuantType::kFP4) {
int new_block_size = 0;
// Try to use 160 as block_size if possible (better occupancy for FP4)
if (threads_per_token % 160 == 0 && 160 <= max_threads_per_block) {
new_block_size = 160;
}
// Fallback: try 192, 128 if 160 doesn't work
else if (threads_per_token % 192 == 0 && 192 <= max_threads_per_block) {
new_block_size = 192;
} else if (threads_per_token % 128 == 0 && 128 <= max_threads_per_block) {
new_block_size = 128;
}
if (new_block_size > 0) {
block_size = new_block_size;
cluster_size = threads_per_token / new_block_size;
if (cluster_size > 8) {
cluster_size = 8;
}
}
// Update threads_per_block to match block_size for SM count check
threads_per_block = block_size;
}
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/fused_moe/core.py (1)
2106-2138: FP4 wrappers missing default fortile_tokens_dimbreaks backward compatibilityBoth
trtllm_fp4_block_scale_moeandtrtllm_fp4_block_scale_routed_moedeclaretile_tokens_dim: Optional[int]without a default value, positioned betweenrouted_scaling_factorandrouting_method_type. The FP8 counterpart (trtllm_fp8_block_scale_moe) defaultstile_tokens_dimtoNone, remaining backward-compatible.This breaks existing code using positional arguments through
routing_method_type, causing aTypeErrorfor missing required argumenttile_tokens_dim. Add= Noneto both FP4 function signatures to match the FP8 pattern and preserve backward compatibility.
🧹 Nitpick comments (4)
tests/moe/test_trtllm_gen_fused_moe.py (1)
185-215: Tile_tokens_dim wiring in test MoE calls looks correct; consider using keyword for FP8 pathsThe added
tile_tokens_dim=Nonefor the FP4 graph path and the insertedNonepositional arguments for the FP8 block-scale and FP8 per‑tensor paths match the updated Python wrappers inflashinfer.fused_moe.core(the new parameter is betweenrouted_scaling_factorandrouting_method_type). Behavior remains unchanged becausetile_tokens_dimisNoneand is ignored by the core.For long‑term maintainability, you may want to pass
tile_tokens_dimby keyword in the FP8 calls as you already do for the FP4 path; that would make these tests more robust to any future reordering of the wrapper’s trailing parameters.Also applies to: 771-797, 947-973
benchmarks/routines/moe.py (2)
119-125: CLItile_tokens_dimthreading is coherent, but currently acts as a metadata knob onlyThe new
--tile_tokens_dimargument is parsed, threaded into all three TRT‑LLM MoE benchmarks, and written into the result dicts. The call‑sites correctly pass it into the updated Python wrappers (trtllm_fp4_block_scale_moe,trtllm_fp8_block_scale_moe,trtllm_fp8_per_tensor_scale_moe) in the right positional/keyword slots, so everything is consistent within this repo.However, the current core wrappers only use
tile_tokens_dimto gate a one‑time deprecation warning and do not forward its value into the underlying C++ runner, so changing--tile_tokens_dimdoes not actually influence the kernel configuration yet. From a benchmark‑user perspective this behaves more like an informational field than a tuning knob.If the intent is:
- Just compatibility / logging: consider documenting here (or in help text) that the flag is deprecated and ignored by the kernel, and is present only for backward‑compat / reporting.
- Real tuning in future: once you wire
tile_tokens_dimthroughflashinfer.fused_moe.coreinto the C++ launcher, this plumbing should already be in the right place; at that point you may also want to validate that the provided value is within the supported tile set.Also applies to: 563-564, 682-713, 765-784, 1188-1189, 1323-1324, 1384-1385, 1451-1452, 1530-1531, 1588-1588
1280-1300: BlockMajorK heuristic fortile_tokens_dimis reasonable, but please clarify intentThe BlockMajorK override:
- Computes
tokens_per_expert ≈ (num_tokens * top_k) / local_num_expertswith sensible guards.- Rounds to next power of two, then clamps to
[8, 64].- Logs when overriding a user‑supplied value.
That’s a sane heuristic and matches the idea of choosing a tile size proportional to tokens per expert. Given that
tile_tokens_dimis currently ignored by the core (other than logging a warning), this override only affects the metadata recorded in results.If you plan to make
tile_tokens_dimdrive the actual kernel selection later, this heuristic is a good starting point, but you may want to:
- Revisit the
[8, 64]clamp against the set of tile sizes supported by the C++ runner.- Document in the CLI help (and possibly here) that BlockMajorK may override the requested tile for better alignment with kernel constraints.
flashinfer/fused_moe/core.py (1)
1940-1997: Tile_tokens_dim deprecation handling is consistent but currently discards the knob entirelyThe four high‑level Python wrappers:
trtllm_fp8_per_tensor_scale_moetrtllm_fp8_block_scale_moetrtllm_fp4_block_scale_moetrtllm_fp4_block_scale_routed_moenow all accept a
tile_tokens_dimparameter and emit a one‑time deprecation warning when it is notNone. However, none of them forward this argument into the underlying TRT‑LLM custom ops; it is only used to decide whether to log a warning, and then discarded.Given the rest of this PR adds CLI and benchmark wiring plus heuristics around
tile_tokens_dim, it’s worth making the intent explicit:
- If the goal is pure deprecation / backward compatibility, this is fine functionally, but:
- Consider adjusting the warning text (“will no longer be supported after v0.5.0”) to match the current versioning story (we’re already past 0.5.0) or to use a vaguer “in a future release”.
- It might also help to mention explicitly in the docstrings that the parameter is ignored and exists only for compatibility, to avoid users trying to tune it.
- If, instead, you eventually want
tile_tokens_dimto controltile_Nin the C++ runner, you’ll need a follow‑up change that threads this value into the appropriate C++ init APIs and/or configuration structures so it actually affects kernel selection.Right now there is no functional effect from any non‑
Nonevalue, beyond triggering the warning.Also applies to: 2021-2077, 2106-2183, 2243-2322
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
benchmarks/samples/sample_testlist_output.csvis excluded by!**/*.csv
📒 Files selected for processing (11)
benchmarks/README.mdbenchmarks/bench_trtllm_gen_fused_moe_autotuner.pybenchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/moe.pybenchmarks/samples/sample_testlist_output.txtcsrc/trtllm_fused_moe_kernel_launcher.cuflashinfer/fused_moe/core.pyinclude/flashinfer/comm/trtllm_allreduce_fusion.cuhtests/moe/test_trtllm_gen_fused_moe.pytests/moe/test_trtllm_gen_routed_fused_moe.pyversion.txt
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/fused_moe/core.py (1)
flashinfer/jit/core.py (1)
warning_once(78-83)
benchmarks/routines/moe.py (3)
csrc/trtllm_fused_moe_kernel_launcher.cu (17)
args(142-144)args(419-428)args(419-421)args(536-558)args(536-538)args(728-752)args(728-730)args(979-1006)args(979-982)top_k(490-517)top_k(490-493)top_k(683-710)top_k(683-687)top_k(913-938)top_k(913-916)top_k(1223-1249)top_k(1223-1226)flashinfer/fused_moe/core.py (1)
WeightLayout(163-170)include/flashinfer/trtllm/fused_moe/runner.h (2)
top_k(270-270)local_num_experts(277-277)
🔇 Additional comments (13)
version.txt (1)
1-1: LGTM: Version bump is appropriate.The version increment from 0.5.2 to 0.5.3 is suitable for a performance optimization release.
benchmarks/README.md (1)
169-170: LGTM: Documentation accurately describes the new parameter.The
tile_tokens_dimparameter is properly documented with its default value.csrc/trtllm_fused_moe_kernel_launcher.cu (2)
1389-1391: Clarify whether this is actually incorrect behavior.The comments state "This seems incorrect but we match the original behavior." If there's a genuine bug where
tile_Nis incorrectly passed astile_tokens_dim, it should be fixed rather than documented and perpetuated.However, if
tile_N(tile size in the N/token dimension) is semantically equivalent totile_tokens_dim, the comment is misleading and should be revised or removed.Please verify:
- Are
tile_Nandtile_tokens_dimsemantically equivalent concepts?- If they differ, what is the correct value to pass?
- If the behavior is correct, update the comment to clarify rather than suggest incorrectness.
1473-1475: Same concern: clarify whether this behavior is correct.This is a duplicate of the concern at lines 1389-1391. The comment suggests incorrect behavior but preserves it. Please verify if this is genuinely a bug or just confusing naming.
benchmarks/samples/sample_testlist_output.txt (2)
295-295: LGTM: Test output correctly includes the new parameter.The
tile_tokens_dim=8additions to test configurations are consistent with the documented default value.Also applies to: 306-306, 317-317, 328-328, 350-350
339-339: Note: Routing method change in test configuration.Line 339 shows
routing_method='renormalize'in this test output. While this change appears unrelated to the tile_tokens_dim additions, please verify this routing method change is intentional.include/flashinfer/comm/trtllm_allreduce_fusion.cuh (4)
534-552: LGTM: Well-designed helper function for pipelined FP4 conversion.The
fp32_pair_to_e2m1function enables pipelined processing to reduce register usage. The implementation correctly:
- Uses inline PTX for efficient conversion on SM 10.0+
- Documents the register allocation behavior
- Provides safe fallback for older architectures
- Extracts the packed result correctly
625-701: LGTM: Effective optimizations to reduce register pressure.The changes to
cvt_warp_fp16_to_fp4are well-designed:
Line 626: Pre-computing
RECIPROCAL_6eliminates repeated division operations.Lines 646-667: The SF computation is mathematically equivalent but more efficient:
- Computes
quantized_sfdirectly from the quantized value- Derives
outputScale = SFScaleVal / quantized_sfin one step- Avoids storing intermediate
SFValueLines 675-694: The conversion loop optimization significantly reduces register pressure:
- Uses single
float2register (8 bytes) instead of array (32 bytes)- Converts and packs immediately using the new
fp32_pair_to_e2m1helper- Maintains correctness while improving pipelining
1132-1157: LGTM: Efficient register usage optimization for FP32 accumulation.The change to process elements one at a time instead of storing an
acc_f32[VEC_SIZE]array reduces register usage from 32 bytes to 4 bytes (single scalar), while maintaining mathematical correctness. This is a straightforward and effective optimization for FP32 accumulation paths.
1428-1475: LGTM: FP4-specific block size optimization improves occupancy.The FP4 block size selection logic is well-structured:
Lines 1433-1452: FP4 path tries specific block sizes (160, 192, 128) for better occupancy before the SM count check, preventing them from being overridden.
Lines 1456-1464: SM count check respects the FP4-optimized
block_sizeif already set.Lines 1467-1469: Non-FP4 paths update
block_sizefromthreads_per_block.Lines 1472-1473: Final check correctly uses
block_sizeinstead ofthreads_per_block.The logic correctly handles both FP4 and non-FP4 paths without conflict.
benchmarks/routines/flashinfer_benchmark_utils.py (1)
56-56: LGTM: Output schema correctly includes the new parameter.Adding
tile_tokens_dimto the MOE output columns aligns with the parameter's introduction throughout the codebase and enables benchmarks to report this metric.tests/moe/test_trtllm_gen_routed_fused_moe.py (1)
183-183: LGTM: Tests correctly pass None for the new optional parameter.The additions of
None, # tile_tokens_dimto both function calls maintain existing test behavior while accommodating the new parameter signature.Also applies to: 237-237
benchmarks/bench_trtllm_gen_fused_moe_autotuner.py (1)
99-125: Autotuner call‑sites correctly adapted to new core signaturesPassing
Noneas the newtile_tokens_dimargument for the FP8 block‑scale, FP8 per‑tensor, and FP4 autotuner paths keeps these benchmarks compatible with the updatedflashinfer.fused_moe.coreAPIs while preserving the original behavior (sincetile_tokens_dimis ignored whenNone).Looks good as a minimal, non‑functional adjustment.
Also applies to: 127-149, 265-297
📌 Description
I find the nvfp4 implemantation could only 1.3-1.4x speedup compared to fp8 in deepseek-v3-0324 model .
and as the fp4 pflops is twice that of fp8, I think there should be some points that could be optimization
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Deprecations
Improvements
Version
✏️ Tip: You can customize this high-level summary in your review settings.