Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions include/flashinfer/comm/trtllm_allreduce_fusion.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1430,7 +1430,14 @@ cudaError_t allreduce_fusion_kernel_launcher(AllReduceFusionParams<T> const& par

// FP4 optimization: apply BEFORE SM count check to avoid being overridden
// This allows FP4 to use smaller block_size even when cluster_num is large
if constexpr (GetQuantType<Pattern> == QuantType::kFP4) {
// NOTE: Only apply for oneshot mode or small token counts (<=16).
// For larger token counts with twoshot (oneshot=false), the original heuristic
// performs better. See benchmark data showing regression at token_num >= 32.
constexpr int kFP4OptimizationTokenThreshold = 16;
bool apply_fp4_optimization = (GetQuantType<Pattern> == QuantType::kFP4) &&
(oneshot || token_num <= kFP4OptimizationTokenThreshold);

if (apply_fp4_optimization) {
// Try to use 160 as block_size if possible (better occupancy for FP4)
if (threads_per_token % 160 == 0 && 160 <= max_threads_per_block && 160 >= 128) {
block_size = 160;
Expand Down Expand Up @@ -1458,13 +1465,13 @@ cudaError_t allreduce_fusion_kernel_launcher(AllReduceFusionParams<T> const& par
threads_per_block *= 2;
cluster_size /= 2;
// If FP4 optimization was applied, update block_size to match
if constexpr (GetQuantType<Pattern> == QuantType::kFP4) {
if (apply_fp4_optimization) {
block_size = threads_per_block;
}
Comment on lines +1468 to 1470
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For simplification, this conditional update inside the loop can be removed. block_size is not read within the loop, so it can be updated once after the loop finishes. This simplifies the logic when combined with the change I'm suggesting below.

 

}

// Update block_size if not FP4 (FP4 already set it above)
if constexpr (GetQuantType<Pattern> != QuantType::kFP4) {
// Update block_size if FP4 optimization was not applied
if (!apply_fp4_optimization) {
block_size = threads_per_block;
}
Comment on lines +1474 to 1476
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic can be simplified by making the assignment unconditional. The block_size should be updated to the final threads_per_block value after the loop, regardless of whether the FP4 optimization was applied. This makes the code cleaner.

  block_size = threads_per_block;


Expand Down
Loading