From 9efd43922e19493473085901d633225746717f78 Mon Sep 17 00:00:00 2001 From: "bruce.xu" Date: Wed, 24 Dec 2025 12:12:35 +0000 Subject: [PATCH] optimize for nvfp4 Signed-off-by: bruce.xu --- .../comm/trtllm_allreduce_fusion.cuh | 133 +++++++++++++----- 1 file changed, 97 insertions(+), 36 deletions(-) diff --git a/include/flashinfer/comm/trtllm_allreduce_fusion.cuh b/include/flashinfer/comm/trtllm_allreduce_fusion.cuh index 8ec84c4a03..7675ec1481 100644 --- a/include/flashinfer/comm/trtllm_allreduce_fusion.cuh +++ b/include/flashinfer/comm/trtllm_allreduce_fusion.cuh @@ -531,6 +531,26 @@ __forceinline__ __device__ uint32_t pack_bytes(uint8_t c0, uint8_t c1, uint8_t c return (val3 << 24) | (val2 << 16) | (val1 << 8) | val0; } +// Convert single float2 pair to e2m1 (2 float32 -> 2 e2m1, returns uint8_t) +// Optimization: allows pipelined processing to reduce register usage +// Note: "=r" constraint always allocates 32-bit register regardless of variable type +inline __device__ uint8_t fp32_pair_to_e2m1(float2 pair) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val32; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "mov.b32 %0, {byte0, 0, 0, 0};\n" + "}" + : "=r"(val32) + : "f"(pair.x), "f"(pair.y)); + return static_cast(val32 & 0xFF); // Extract low 8 bits +#else + return 0; +#endif +} + #if CUDA_VERSION >= 12080 // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). // NOTE: bypass sm_100 requirement by __nv_cvt_float2_to_fp4x2 @@ -602,6 +622,9 @@ template __device__ uint32_t cvt_warp_fp16_to_fp4(vec_t& vec, float SFScaleVal, uint8_t* SFout) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Pre-compute constant: reciprocal of 6.0 (maximum value of e2m1) + static constexpr float RECIPROCAL_6 = 1.0f / 6.0f; + // Get absolute maximum values among the local 8 values. auto localMax = maths::cuda_abs(get_vec2_element(vec, 0)); @@ -613,57 +636,62 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(vec_t& vec, float SFScaleV // Get the absolute maximum among all 16 values (two threads). localMax = maths::cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); // Get the final absolute maximum values. + // Optimization: compute vecMax and reuse localMax space (localMax no longer needed) float vecMax = float(maths::cuda_max(localMax.x, localMax.y)); // Get the SF (max value of the vector / max value of e2m1). // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - float SFValue = SFScaleVal * (vecMax * maths::reciprocal_approximate_ftz(6.0f)); - // 8 bits representation of the SF. + // Optimization: compute quantized SF directly, avoid storing intermediate SFValue uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). + float quantized_sf; if constexpr (UE8M0_SF) { #if (__CUDACC_VER_MAJOR__ * 1000 + __CUDACC_VER_MINOR__ * 10 >= 12080) __nv_fp8_e8m0 tmp; - tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); - SFValue = static_cast(tmp); + float sf_value = SFScaleVal * (vecMax * RECIPROCAL_6); + tmp.__x = __nv_cvt_float_to_e8m0(sf_value, __NV_SATFINITE, cudaRoundPosInf); + quantized_sf = static_cast(tmp); fp8SFVal = tmp.__x; #else #error "FP8 E8M0 support requires CUDA 12.8 or newer." #endif } else { // Here SFValue is always positive, so E4M3 is the same as UE4M3. - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFScaleVal * (vecMax * RECIPROCAL_6)); fp8SFVal = tmp.__x; - SFValue = static_cast(tmp); + quantized_sf = static_cast(tmp); } - // Get the output scale. + // Get the output scale directly (optimization: avoid storing intermediate SFValue) // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal)) - float outputScale = SFValue != 0 ? maths::reciprocal_approximate_ftz( - SFValue * maths::reciprocal_approximate_ftz(SFScaleVal)) - : 0.0f; + // Optimization: mathematically equivalent to SFScaleVal / quantized_sf, but more efficient + // (reduces 1 reciprocal call and 1 multiply operation) + float outputScale = quantized_sf != 0 ? SFScaleVal / quantized_sf : 0.0f; if (SFout) { // Write the SF to global memory (STG.8). *SFout = fp8SFVal; } - // Convert the input to float. - float2 fp2Vals[details::CVT_FP4_ELTS_PER_THREAD / 2]; + // Convert the input to float and quantize (pipelined to reduce register usage). + // Optimization: use single float2 instead of array to reduce register pressure from 32 bytes to 8 + // bytes + uint32_t e2m1Vec = 0; #pragma unroll for (int i = 0; i < details::CVT_FP4_ELTS_PER_THREAD / 2; i++) { + // Reuse single float2 register instead of array + float2 fp2Val; if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(get_vec2_element(vec, i)); + fp2Val = __half22float2(get_vec2_element(vec, i)); } else { - fp2Vals[i] = __bfloat1622float2(get_vec2_element(vec, i)); + fp2Val = __bfloat1622float2(get_vec2_element(vec, i)); } - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } + fp2Val.x *= outputScale; + fp2Val.y *= outputScale; - // Convert to e2m1 values. - uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + // Convert pair immediately and pack into result + uint8_t e2m1Pair = fp32_pair_to_e2m1(fp2Val); + e2m1Vec |= (static_cast(e2m1Pair) << (i * 8)); + } // Write the e2m1 values to global memory. return e2m1Vec; @@ -1105,22 +1133,17 @@ template __device__ __forceinline__ vec_t allreduce_sum(vec_t* vals) { if constexpr (Fp32Acc) { static_assert(!std::is_same_v); - float acc_f32[VEC_SIZE]; + // Optimization: process one element at a time to reduce register usage + // Instead of storing acc_f32[VEC_SIZE] (32 bytes), process and convert immediately + vec_t acc; #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { - acc_f32[i] = static_cast(reinterpret_cast(&vals[0])[i]); - } -#pragma unroll - for (int r = 1; r < NRanks; ++r) { + float acc_f32 = static_cast(reinterpret_cast(&vals[0])[i]); #pragma unroll - for (int i = 0; i < VEC_SIZE; ++i) { - acc_f32[i] += static_cast(reinterpret_cast(&vals[r])[i]); + for (int r = 1; r < NRanks; ++r) { + acc_f32 += static_cast(reinterpret_cast(&vals[r])[i]); } - } - vec_t acc; -#pragma unroll - for (int i = 0; i < VEC_SIZE; ++i) { - acc[i] = static_cast(acc_f32[i]); + acc[i] = static_cast(acc_f32); } return acc; } else { @@ -1402,16 +1425,54 @@ cudaError_t allreduce_fusion_kernel_launcher(AllReduceFusionParams const& par max_registers = utils::getSMRegisters(); } int max_threads_per_block = min(max_registers / registers_per_thread, 1024); + + int block_size = threads_per_block; + + // FP4 optimization: apply BEFORE SM count check to avoid being overridden + // This allows FP4 to use smaller block_size even when cluster_num is large + if constexpr (GetQuantType == QuantType::kFP4) { + // Try to use 160 as block_size if possible (better occupancy for FP4) + if (threads_per_token % 160 == 0 && 160 <= max_threads_per_block && 160 >= 128) { + block_size = 160; + cluster_size = threads_per_token / 160; + if (cluster_size > 8) cluster_size = 8; + } + // Fallback: try 192, 128 if 160 doesn't work + else if (threads_per_token % 192 == 0 && 192 <= max_threads_per_block && 192 >= 128) { + block_size = 192; + cluster_size = threads_per_token / 192; + if (cluster_size > 8) cluster_size = 8; + } else if (threads_per_token % 128 == 0 && 128 <= max_threads_per_block) { + block_size = 128; + cluster_size = threads_per_token / 128; + if (cluster_size > 8) cluster_size = 8; + } + // Update threads_per_block to match block_size for SM count check + threads_per_block = block_size; + } + + // SM count check: adjust if cluster_num * cluster_size > sm_count + // But respect FP4 optimization if already applied while (cluster_num * cluster_size > sm_count && cluster_size > 1 && threads_per_block <= max_threads_per_block / 2) { threads_per_block *= 2; cluster_size /= 2; + // If FP4 optimization was applied, update block_size to match + if constexpr (GetQuantType == QuantType::kFP4) { + block_size = threads_per_block; + } } - FLASHINFER_CHECK(oneshot || threads_per_block >= params.nranks, - "not oneshot, or threads_per_block < nranks"); - int block_size = threads_per_block; + + // Update block_size if not FP4 (FP4 already set it above) + if constexpr (GetQuantType != QuantType::kFP4) { + block_size = threads_per_block; + } + + // Check conditions using the final block_size (not threads_per_block) + FLASHINFER_CHECK(oneshot || block_size >= params.nranks, "not oneshot, or block_size < nranks"); FLASHINFER_CHECK(block_size <= 1024 && cluster_size > 0, "block_size > 1024 or cluster_size <= 0"); + int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size; cudaLaunchConfig_t cfg; cudaLaunchAttribute attribute[2];