-
Notifications
You must be signed in to change notification settings - Fork 826
[performance]optimize for nvfp4 #2268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -531,6 +531,26 @@ __forceinline__ __device__ uint32_t pack_bytes(uint8_t c0, uint8_t c1, uint8_t c | |
| return (val3 << 24) | (val2 << 16) | (val1 << 8) | val0; | ||
| } | ||
|
|
||
| // Convert single float2 pair to e2m1 (2 float32 -> 2 e2m1, returns uint8_t) | ||
| // Optimization: allows pipelined processing to reduce register usage | ||
| // Note: "=r" constraint always allocates 32-bit register regardless of variable type | ||
| inline __device__ uint8_t fp32_pair_to_e2m1(float2 pair) { | ||
| #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) | ||
| uint32_t val32; | ||
| asm volatile( | ||
| "{\n" | ||
| ".reg .b8 byte0;\n" | ||
| "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" | ||
| "mov.b32 %0, {byte0, 0, 0, 0};\n" | ||
| "}" | ||
| : "=r"(val32) | ||
| : "f"(pair.x), "f"(pair.y)); | ||
| return static_cast<uint8_t>(val32 & 0xFF); // Extract low 8 bits | ||
| #else | ||
| return 0; | ||
| #endif | ||
| } | ||
|
|
||
| #if CUDA_VERSION >= 12080 | ||
| // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). | ||
| // NOTE: bypass sm_100 requirement by __nv_cvt_float2_to_fp4x2 | ||
|
|
@@ -602,6 +622,9 @@ template <typename T, uint32_t VEC_SIZE, bool UE8M0_SF = false> | |
| __device__ uint32_t cvt_warp_fp16_to_fp4(vec_t<T, VEC_SIZE>& vec, float SFScaleVal, | ||
| uint8_t* SFout) { | ||
| #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) | ||
| // Pre-compute constant: reciprocal of 6.0 (maximum value of e2m1) | ||
| static constexpr float RECIPROCAL_6 = 1.0f / 6.0f; | ||
|
|
||
| // Get absolute maximum values among the local 8 values. | ||
| auto localMax = maths::cuda_abs(get_vec2_element(vec, 0)); | ||
|
|
||
|
|
@@ -613,57 +636,62 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(vec_t<T, VEC_SIZE>& vec, float SFScaleV | |
| // Get the absolute maximum among all 16 values (two threads). | ||
| localMax = maths::cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); | ||
| // Get the final absolute maximum values. | ||
| // Optimization: compute vecMax and reuse localMax space (localMax no longer needed) | ||
| float vecMax = float(maths::cuda_max(localMax.x, localMax.y)); | ||
|
|
||
| // Get the SF (max value of the vector / max value of e2m1). | ||
| // maximum value of e2m1 = 6.0. | ||
| // TODO: use half as compute data type. | ||
| float SFValue = SFScaleVal * (vecMax * maths::reciprocal_approximate_ftz(6.0f)); | ||
| // 8 bits representation of the SF. | ||
| // Optimization: compute quantized SF directly, avoid storing intermediate SFValue | ||
| uint8_t fp8SFVal; | ||
| // Write the SF to global memory (STG.8). | ||
| float quantized_sf; | ||
| if constexpr (UE8M0_SF) { | ||
| #if (__CUDACC_VER_MAJOR__ * 1000 + __CUDACC_VER_MINOR__ * 10 >= 12080) | ||
| __nv_fp8_e8m0 tmp; | ||
| tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); | ||
| SFValue = static_cast<float>(tmp); | ||
| float sf_value = SFScaleVal * (vecMax * RECIPROCAL_6); | ||
| tmp.__x = __nv_cvt_float_to_e8m0(sf_value, __NV_SATFINITE, cudaRoundPosInf); | ||
| quantized_sf = static_cast<float>(tmp); | ||
| fp8SFVal = tmp.__x; | ||
| #else | ||
| #error "FP8 E8M0 support requires CUDA 12.8 or newer." | ||
| #endif | ||
| } else { | ||
| // Here SFValue is always positive, so E4M3 is the same as UE4M3. | ||
| __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); | ||
| __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFScaleVal * (vecMax * RECIPROCAL_6)); | ||
| fp8SFVal = tmp.__x; | ||
| SFValue = static_cast<float>(tmp); | ||
| quantized_sf = static_cast<float>(tmp); | ||
| } | ||
| // Get the output scale. | ||
| // Get the output scale directly (optimization: avoid storing intermediate SFValue) | ||
| // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal)) | ||
| float outputScale = SFValue != 0 ? maths::reciprocal_approximate_ftz( | ||
| SFValue * maths::reciprocal_approximate_ftz(SFScaleVal)) | ||
| : 0.0f; | ||
| // Optimization: mathematically equivalent to SFScaleVal / quantized_sf, but more efficient | ||
| // (reduces 1 reciprocal call and 1 multiply operation) | ||
| float outputScale = quantized_sf != 0 ? SFScaleVal / quantized_sf : 0.0f; | ||
|
|
||
| if (SFout) { | ||
| // Write the SF to global memory (STG.8). | ||
| *SFout = fp8SFVal; | ||
| } | ||
|
|
||
| // Convert the input to float. | ||
| float2 fp2Vals[details::CVT_FP4_ELTS_PER_THREAD / 2]; | ||
| // Convert the input to float and quantize (pipelined to reduce register usage). | ||
| // Optimization: use single float2 instead of array to reduce register pressure from 32 bytes to 8 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have any profiling results showing the register usage (e.g. from cuobjdump or ncu)?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, I will give it these days |
||
| // bytes | ||
| uint32_t e2m1Vec = 0; | ||
|
|
||
| #pragma unroll | ||
| for (int i = 0; i < details::CVT_FP4_ELTS_PER_THREAD / 2; i++) { | ||
| // Reuse single float2 register instead of array | ||
| float2 fp2Val; | ||
| if constexpr (std::is_same_v<T, half>) { | ||
| fp2Vals[i] = __half22float2(get_vec2_element(vec, i)); | ||
| fp2Val = __half22float2(get_vec2_element(vec, i)); | ||
| } else { | ||
| fp2Vals[i] = __bfloat1622float2(get_vec2_element(vec, i)); | ||
| fp2Val = __bfloat1622float2(get_vec2_element(vec, i)); | ||
| } | ||
| fp2Vals[i].x *= outputScale; | ||
| fp2Vals[i].y *= outputScale; | ||
| } | ||
| fp2Val.x *= outputScale; | ||
| fp2Val.y *= outputScale; | ||
|
|
||
| // Convert to e2m1 values. | ||
| uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); | ||
| // Convert pair immediately and pack into result | ||
| uint8_t e2m1Pair = fp32_pair_to_e2m1(fp2Val); | ||
| e2m1Vec |= (static_cast<uint32_t>(e2m1Pair) << (i * 8)); | ||
| } | ||
|
|
||
| // Write the e2m1 values to global memory. | ||
| return e2m1Vec; | ||
|
|
@@ -1105,22 +1133,17 @@ template <typename T, uint32_t VEC_SIZE, int NRanks, bool Fp32Acc> | |
| __device__ __forceinline__ vec_t<T, VEC_SIZE> allreduce_sum(vec_t<T, VEC_SIZE>* vals) { | ||
| if constexpr (Fp32Acc) { | ||
| static_assert(!std::is_same_v<T, float>); | ||
| float acc_f32[VEC_SIZE]; | ||
| // Optimization: process one element at a time to reduce register usage | ||
| // Instead of storing acc_f32[VEC_SIZE] (32 bytes), process and convert immediately | ||
| vec_t<T, VEC_SIZE> acc; | ||
| #pragma unroll | ||
| for (int i = 0; i < VEC_SIZE; ++i) { | ||
| acc_f32[i] = static_cast<float>(reinterpret_cast<T*>(&vals[0])[i]); | ||
| } | ||
| #pragma unroll | ||
| for (int r = 1; r < NRanks; ++r) { | ||
| float acc_f32 = static_cast<float>(reinterpret_cast<T*>(&vals[0])[i]); | ||
| #pragma unroll | ||
| for (int i = 0; i < VEC_SIZE; ++i) { | ||
| acc_f32[i] += static_cast<float>(reinterpret_cast<T*>(&vals[r])[i]); | ||
| for (int r = 1; r < NRanks; ++r) { | ||
| acc_f32 += static_cast<float>(reinterpret_cast<T*>(&vals[r])[i]); | ||
| } | ||
| } | ||
| vec_t<T, VEC_SIZE> acc; | ||
| #pragma unroll | ||
| for (int i = 0; i < VEC_SIZE; ++i) { | ||
| acc[i] = static_cast<T>(acc_f32[i]); | ||
| acc[i] = static_cast<T>(acc_f32); | ||
| } | ||
| return acc; | ||
| } else { | ||
|
|
@@ -1402,16 +1425,54 @@ cudaError_t allreduce_fusion_kernel_launcher(AllReduceFusionParams<T> const& par | |
| max_registers = utils::getSMRegisters(); | ||
| } | ||
| int max_threads_per_block = min(max_registers / registers_per_thread, 1024); | ||
|
|
||
| int block_size = threads_per_block; | ||
|
|
||
| // FP4 optimization: apply BEFORE SM count check to avoid being overridden | ||
| // This allows FP4 to use smaller block_size even when cluster_num is large | ||
| if constexpr (GetQuantType<Pattern> == QuantType::kFP4) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @Bruce-x-1997 , will this occupancy optimization triggered in DeepSeek case mentioned in this case? 7168/8 = 896 is not divisible by either 160 or 192, so 128 will be used, which is typically a preferred block_size. I have tried other optimizations mentioned in this PR but did not found noticeable difference (~ 0.05uS faster). Haven't check the register usage though. Is the 10~15% improvement mentioned in PR results from using this block size (and possibly different cluster size) instead of the original heuristic? |
||
| // Try to use 160 as block_size if possible (better occupancy for FP4) | ||
| if (threads_per_token % 160 == 0 && 160 <= max_threads_per_block && 160 >= 128) { | ||
| block_size = 160; | ||
| cluster_size = threads_per_token / 160; | ||
| if (cluster_size > 8) cluster_size = 8; | ||
| } | ||
| // Fallback: try 192, 128 if 160 doesn't work | ||
| else if (threads_per_token % 192 == 0 && 192 <= max_threads_per_block && 192 >= 128) { | ||
| block_size = 192; | ||
| cluster_size = threads_per_token / 192; | ||
| if (cluster_size > 8) cluster_size = 8; | ||
| } else if (threads_per_token % 128 == 0 && 128 <= max_threads_per_block) { | ||
| block_size = 128; | ||
| cluster_size = threads_per_token / 128; | ||
| if (cluster_size > 8) cluster_size = 8; | ||
| } | ||
| // Update threads_per_block to match block_size for SM count check | ||
| threads_per_block = block_size; | ||
| } | ||
|
Comment on lines
+1433
to
+1452
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for selecting a special block size for FP4 kernels is a bit repetitive and contains some redundant conditions. This can be refactored to be more concise and easier to read. Specifically:
I've suggested a refactoring that consolidates this logic, making it cleaner and more maintainable without changing the functionality.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. gemini's suggestion looks reasonable.
Comment on lines
+1433
to
+1452
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential correctness issue: When Example with
This would result in some elements not being processed. π Suggested fix: Recalculate block_size after capping cluster_size if (threads_per_token % 160 == 0 && 160 <= max_threads_per_block && 160 >= 128) {
block_size = 160;
cluster_size = threads_per_token / 160;
- if (cluster_size > 8) cluster_size = 8;
+ if (cluster_size > 8) {
+ cluster_size = 8;
+ // Recalculate block_size to ensure full coverage
+ block_size = threads_per_token / cluster_size;
+ }
}Apply similar logic to the 192 and 128 fallback cases.
π€ Prompt for AI Agents |
||
|
|
||
| // SM count check: adjust if cluster_num * cluster_size > sm_count | ||
| // But respect FP4 optimization if already applied | ||
| while (cluster_num * cluster_size > sm_count && cluster_size > 1 && | ||
| threads_per_block <= max_threads_per_block / 2) { | ||
| threads_per_block *= 2; | ||
| cluster_size /= 2; | ||
| // If FP4 optimization was applied, update block_size to match | ||
| if constexpr (GetQuantType<Pattern> == QuantType::kFP4) { | ||
| block_size = threads_per_block; | ||
| } | ||
| } | ||
| FLASHINFER_CHECK(oneshot || threads_per_block >= params.nranks, | ||
| "not oneshot, or threads_per_block < nranks"); | ||
| int block_size = threads_per_block; | ||
|
|
||
| // Update block_size if not FP4 (FP4 already set it above) | ||
| if constexpr (GetQuantType<Pattern> != QuantType::kFP4) { | ||
| block_size = threads_per_block; | ||
| } | ||
|
|
||
| // Check conditions using the final block_size (not threads_per_block) | ||
| FLASHINFER_CHECK(oneshot || block_size >= params.nranks, "not oneshot, or block_size < nranks"); | ||
| FLASHINFER_CHECK(block_size <= 1024 && cluster_size > 0, | ||
| "block_size > 1024 or cluster_size <= 0"); | ||
|
|
||
| int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size; | ||
| cudaLaunchConfig_t cfg; | ||
| cudaLaunchAttribute attribute[2]; | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, I thought division would take a few times longer than reciprocal approximation--curious if you ablated this
EDIT: I tried and it doesn't make noticeable diff