-
-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[BugFix] Fix fp4 quant kernel on CUDA 12.8 #35210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -109,7 +109,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) | |
| template <class Type, bool UE8M0_SF = false> | ||
| __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) | ||
| cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols, | ||
| int32_t sf_n_unpadded, Type const* __restrict__ in, | ||
| int32_t sf_n_unpadded, int32_t num_packed_cols, | ||
|
Comment on lines
111
to
+112
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The addition of |
||
| Type const* __restrict__ in, | ||
| float const* __restrict__ SFScale, | ||
| uint32_t* __restrict__ out, | ||
| uint32_t* __restrict__ SFout) { | ||
|
|
@@ -131,7 +132,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) | |
| // Iterate over all rows and cols including padded ones - | ||
| // ensures we visit every single scale factor address to initialize it. | ||
| for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is crucial for correcting the bug. By changing the column index comparison from |
||
| if (colIdx < sf_n_unpadded) { | ||
| if (colIdx < num_packed_cols) { | ||
| PackedVec in_vec; | ||
| int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; | ||
|
|
||
|
|
@@ -222,7 +223,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, | |
| reinterpret_cast<uint32_t*>(sf_out)); | ||
| }); | ||
| } else { | ||
| int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast<int>(block.x)); | ||
| int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD; | ||
| int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x)); | ||
| int grid_x = std::min( | ||
| m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); | ||
| dim3 grid(grid_x, grid_y); | ||
|
|
@@ -232,8 +234,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, | |
| auto input_ptr = static_cast<cuda_type const*>(input.data_ptr()); | ||
| // NOTE: We don't support e8m0 scales at this moment. | ||
| vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false> | ||
| <<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, input_ptr, | ||
| input_sf_ptr, | ||
| <<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, num_packed_cols, | ||
|
Comment on lines
236
to
+237
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Passing |
||
| input_ptr, input_sf_ptr, | ||
| reinterpret_cast<uint32_t*>(output_ptr), | ||
| reinterpret_cast<uint32_t*>(sf_out)); | ||
| }); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this name is misleading now, it should be
num_packed_colslike innvfp4_quant_kernels.cu, right?