Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 97 additions & 36 deletions include/flashinfer/comm/trtllm_allreduce_fusion.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,26 @@ __forceinline__ __device__ uint32_t pack_bytes(uint8_t c0, uint8_t c1, uint8_t c
return (val3 << 24) | (val2 << 16) | (val1 << 8) | val0;
}

// Convert single float2 pair to e2m1 (2 float32 -> 2 e2m1, returns uint8_t)
// Optimization: allows pipelined processing to reduce register usage
// Note: "=r" constraint always allocates 32-bit register regardless of variable type
inline __device__ uint8_t fp32_pair_to_e2m1(float2 pair) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val32;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"mov.b32 %0, {byte0, 0, 0, 0};\n"
"}"
: "=r"(val32)
: "f"(pair.x), "f"(pair.y));
return static_cast<uint8_t>(val32 & 0xFF); // Extract low 8 bits
#else
return 0;
#endif
}

#if CUDA_VERSION >= 12080
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
// NOTE: bypass sm_100 requirement by __nv_cvt_float2_to_fp4x2
Expand Down Expand Up @@ -602,6 +622,9 @@ template <typename T, uint32_t VEC_SIZE, bool UE8M0_SF = false>
__device__ uint32_t cvt_warp_fp16_to_fp4(vec_t<T, VEC_SIZE>& vec, float SFScaleVal,
uint8_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Pre-compute constant: reciprocal of 6.0 (maximum value of e2m1)
static constexpr float RECIPROCAL_6 = 1.0f / 6.0f;

// Get absolute maximum values among the local 8 values.
auto localMax = maths::cuda_abs(get_vec2_element(vec, 0));

Expand All @@ -613,57 +636,62 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(vec_t<T, VEC_SIZE>& vec, float SFScaleV
// Get the absolute maximum among all 16 values (two threads).
localMax = maths::cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
// Get the final absolute maximum values.
// Optimization: compute vecMax and reuse localMax space (localMax no longer needed)
float vecMax = float(maths::cuda_max(localMax.x, localMax.y));

// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = SFScaleVal * (vecMax * maths::reciprocal_approximate_ftz(6.0f));
// 8 bits representation of the SF.
// Optimization: compute quantized SF directly, avoid storing intermediate SFValue
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
float quantized_sf;
if constexpr (UE8M0_SF) {
#if (__CUDACC_VER_MAJOR__ * 1000 + __CUDACC_VER_MINOR__ * 10 >= 12080)
__nv_fp8_e8m0 tmp;
tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf);
SFValue = static_cast<float>(tmp);
float sf_value = SFScaleVal * (vecMax * RECIPROCAL_6);
tmp.__x = __nv_cvt_float_to_e8m0(sf_value, __NV_SATFINITE, cudaRoundPosInf);
quantized_sf = static_cast<float>(tmp);
fp8SFVal = tmp.__x;
#else
#error "FP8 E8M0 support requires CUDA 12.8 or newer."
#endif
} else {
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFScaleVal * (vecMax * RECIPROCAL_6));
fp8SFVal = tmp.__x;
SFValue = static_cast<float>(tmp);
quantized_sf = static_cast<float>(tmp);
}
// Get the output scale.
// Get the output scale directly (optimization: avoid storing intermediate SFValue)
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal))
float outputScale = SFValue != 0 ? maths::reciprocal_approximate_ftz(
SFValue * maths::reciprocal_approximate_ftz(SFScaleVal))
: 0.0f;
// Optimization: mathematically equivalent to SFScaleVal / quantized_sf, but more efficient
// (reduces 1 reciprocal call and 1 multiply operation)
float outputScale = quantized_sf != 0 ? SFScaleVal / quantized_sf : 0.0f;
Copy link
Contributor

@Edenzzzz Edenzzzz Dec 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I thought division would take a few times longer than reciprocal approximation--curious if you ablated this

EDIT: I tried and it doesn't make noticeable diff


if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}

// Convert the input to float.
float2 fp2Vals[details::CVT_FP4_ELTS_PER_THREAD / 2];
// Convert the input to float and quantize (pipelined to reduce register usage).
// Optimization: use single float2 instead of array to reduce register pressure from 32 bytes to 8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any profiling results showing the register usage (e.g. from cuobjdump or ncu)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I will give it these days

// bytes
uint32_t e2m1Vec = 0;

#pragma unroll
for (int i = 0; i < details::CVT_FP4_ELTS_PER_THREAD / 2; i++) {
// Reuse single float2 register instead of array
float2 fp2Val;
if constexpr (std::is_same_v<T, half>) {
fp2Vals[i] = __half22float2(get_vec2_element(vec, i));
fp2Val = __half22float2(get_vec2_element(vec, i));
} else {
fp2Vals[i] = __bfloat1622float2(get_vec2_element(vec, i));
fp2Val = __bfloat1622float2(get_vec2_element(vec, i));
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
fp2Val.x *= outputScale;
fp2Val.y *= outputScale;

// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
// Convert pair immediately and pack into result
uint8_t e2m1Pair = fp32_pair_to_e2m1(fp2Val);
e2m1Vec |= (static_cast<uint32_t>(e2m1Pair) << (i * 8));
}

// Write the e2m1 values to global memory.
return e2m1Vec;
Expand Down Expand Up @@ -1105,22 +1133,17 @@ template <typename T, uint32_t VEC_SIZE, int NRanks, bool Fp32Acc>
__device__ __forceinline__ vec_t<T, VEC_SIZE> allreduce_sum(vec_t<T, VEC_SIZE>* vals) {
if constexpr (Fp32Acc) {
static_assert(!std::is_same_v<T, float>);
float acc_f32[VEC_SIZE];
// Optimization: process one element at a time to reduce register usage
// Instead of storing acc_f32[VEC_SIZE] (32 bytes), process and convert immediately
vec_t<T, VEC_SIZE> acc;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
acc_f32[i] = static_cast<float>(reinterpret_cast<T*>(&vals[0])[i]);
}
#pragma unroll
for (int r = 1; r < NRanks; ++r) {
float acc_f32 = static_cast<float>(reinterpret_cast<T*>(&vals[0])[i]);
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
acc_f32[i] += static_cast<float>(reinterpret_cast<T*>(&vals[r])[i]);
for (int r = 1; r < NRanks; ++r) {
acc_f32 += static_cast<float>(reinterpret_cast<T*>(&vals[r])[i]);
}
}
vec_t<T, VEC_SIZE> acc;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
acc[i] = static_cast<T>(acc_f32[i]);
acc[i] = static_cast<T>(acc_f32);
}
return acc;
} else {
Expand Down Expand Up @@ -1402,16 +1425,54 @@ cudaError_t allreduce_fusion_kernel_launcher(AllReduceFusionParams<T> const& par
max_registers = utils::getSMRegisters();
}
int max_threads_per_block = min(max_registers / registers_per_thread, 1024);

int block_size = threads_per_block;

// FP4 optimization: apply BEFORE SM count check to avoid being overridden
// This allows FP4 to use smaller block_size even when cluster_num is large
if constexpr (GetQuantType<Pattern> == QuantType::kFP4) {
Copy link
Contributor

@timlee0212 timlee0212 Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Bruce-x-1997 , will this occupancy optimization triggered in DeepSeek case mentioned in this case? 7168/8 = 896 is not divisible by either 160 or 192, so 128 will be used, which is typically a preferred block_size. I have tried other optimizations mentioned in this PR but did not found noticeable difference (~ 0.05uS faster). Haven't check the register usage though. Is the 10~15% improvement mentioned in PR results from using this block size (and possibly different cluster size) instead of the original heuristic?

// Try to use 160 as block_size if possible (better occupancy for FP4)
if (threads_per_token % 160 == 0 && 160 <= max_threads_per_block && 160 >= 128) {
block_size = 160;
cluster_size = threads_per_token / 160;
if (cluster_size > 8) cluster_size = 8;
}
// Fallback: try 192, 128 if 160 doesn't work
else if (threads_per_token % 192 == 0 && 192 <= max_threads_per_block && 192 >= 128) {
block_size = 192;
cluster_size = threads_per_token / 192;
if (cluster_size > 8) cluster_size = 8;
} else if (threads_per_token % 128 == 0 && 128 <= max_threads_per_block) {
block_size = 128;
cluster_size = threads_per_token / 128;
if (cluster_size > 8) cluster_size = 8;
}
// Update threads_per_block to match block_size for SM count check
threads_per_block = block_size;
}
Comment on lines +1433 to +1452
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for selecting a special block size for FP4 kernels is a bit repetitive and contains some redundant conditions. This can be refactored to be more concise and easier to read.

Specifically:

  • The condition 160 >= 128 (and similar for 192) is always true and can be removed.
  • The block of code to update cluster_size and cap it at 8 is repeated in each branch.

I've suggested a refactoring that consolidates this logic, making it cleaner and more maintainable without changing the functionality.

  if constexpr (GetQuantType<Pattern> == QuantType::kFP4) {
    int new_block_size = 0;
    // Try to use 160 as block_size if possible (better occupancy for FP4)
    if (threads_per_token % 160 == 0 && 160 <= max_threads_per_block) {
      new_block_size = 160;
    } else if (threads_per_token % 192 == 0 && 192 <= max_threads_per_block) {
      // Fallback: try 192, 128 if 160 doesn't work
      new_block_size = 192;
    } else if (threads_per_token % 128 == 0 && 128 <= max_threads_per_block) {
      new_block_size = 128;
    }

    if (new_block_size > 0) {
      block_size = new_block_size;
      cluster_size = threads_per_token / block_size;
      if (cluster_size > 8) {
        cluster_size = 8;
      }
      // Update threads_per_block to match block_size for SM count check
      threads_per_block = block_size;
    }
  }

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gemini's suggestion looks reasonable.

Comment on lines +1433 to +1452
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential correctness issue: threads_per_token may not be fully covered when cluster_size is capped.

When threads_per_token / block_size > 8, the cluster_size is capped at 8 but block_size is not adjusted to compensate. This breaks the invariant threads_per_token == block_size * cluster_size.

Example with hidden_dim=12800, VEC_SIZE=8:

  • threads_per_token = 1600
  • block_size = 160, cluster_size = 10 β†’ capped to 8
  • Effective threads = 160 * 8 = 1280 < 1600

This would result in some elements not being processed.

πŸ”Ž Suggested fix: Recalculate block_size after capping cluster_size
     if (threads_per_token % 160 == 0 && 160 <= max_threads_per_block && 160 >= 128) {
       block_size = 160;
       cluster_size = threads_per_token / 160;
-      if (cluster_size > 8) cluster_size = 8;
+      if (cluster_size > 8) {
+        cluster_size = 8;
+        // Recalculate block_size to ensure full coverage
+        block_size = threads_per_token / cluster_size;
+      }
     }

Apply similar logic to the 192 and 128 fallback cases.

Committable suggestion skipped: line range outside the PR's diff.

πŸ€– Prompt for AI Agents
In include/flashinfer/comm/trtllm_allreduce_fusion.cuh around lines 1433-1452,
the current FP4 branch caps cluster_size at 8 without adjusting block_size,
breaking the invariant threads_per_token == block_size * cluster_size and
leaving some threads unprocessed; after capping cluster_size recompute
block_size = threads_per_token / cluster_size (and if needed enforce block_size
<= max_threads_per_block and >=128) for the 160, 192 and 128 branches so the
product exactly covers threads_per_token, then update threads_per_block =
block_size for the SM check.


// SM count check: adjust if cluster_num * cluster_size > sm_count
// But respect FP4 optimization if already applied
while (cluster_num * cluster_size > sm_count && cluster_size > 1 &&
threads_per_block <= max_threads_per_block / 2) {
threads_per_block *= 2;
cluster_size /= 2;
// If FP4 optimization was applied, update block_size to match
if constexpr (GetQuantType<Pattern> == QuantType::kFP4) {
block_size = threads_per_block;
}
}
FLASHINFER_CHECK(oneshot || threads_per_block >= params.nranks,
"not oneshot, or threads_per_block < nranks");
int block_size = threads_per_block;

// Update block_size if not FP4 (FP4 already set it above)
if constexpr (GetQuantType<Pattern> != QuantType::kFP4) {
block_size = threads_per_block;
}

// Check conditions using the final block_size (not threads_per_block)
FLASHINFER_CHECK(oneshot || block_size >= params.nranks, "not oneshot, or block_size < nranks");
FLASHINFER_CHECK(block_size <= 1024 && cluster_size > 0,
"block_size > 1024 or cluster_size <= 0");

int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size;
cudaLaunchConfig_t cfg;
cudaLaunchAttribute attribute[2];
Expand Down