Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
(uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo);
reinterpret_cast<uint64_t*>(out)[outOffset >> 1] = packed64;
} else {
out[inOffset] = out_val;
int64_t outOffset =
rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
out[outOffset] = out_val;
}
}
}
Expand Down Expand Up @@ -140,7 +142,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));

int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
int sf_n_unpadded = int(n / CVT_FP4_ELTS_PER_THREAD);
Comment on lines -143 to +145
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this name is misleading now, it should be num_packed_cols like in nvfp4_quant_kernels.cu, right?


int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast<int>(block.x));
int grid_x = std::min(
Expand Down
12 changes: 7 additions & 5 deletions csrc/quantization/fp4/nvfp4_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols,
int32_t sf_n_unpadded, Type const* __restrict__ in,
int32_t sf_n_unpadded, int32_t num_packed_cols,
Comment on lines 111 to +112
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The addition of num_packed_cols to the kernel signature is necessary to correct the thread-column index bound. This change ensures that the kernel operates correctly when CVT_FP4_ELTS_PER_THREAD == 8 on CUDA 12.8. It's critical to ensure that all call sites of this kernel are updated to pass this new parameter.

cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols,
                             int32_t sf_n_unpadded, int32_t num_packed_cols,

Type const* __restrict__ in,
float const* __restrict__ SFScale,
uint32_t* __restrict__ out,
uint32_t* __restrict__ SFout) {
Expand All @@ -131,7 +132,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
// Iterate over all rows and cols including padded ones -
// ensures we visit every single scale factor address to initialize it.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change is crucial for correcting the bug. By changing the column index comparison from colIdx < sf_n_unpadded to colIdx < num_packed_cols, the kernel now correctly iterates over the appropriate number of thread columns when ELTS_PER_THREAD == 8. This ensures that the correct elements are processed, resolving the original issue.

if (colIdx < num_packed_cols) {

if (colIdx < sf_n_unpadded) {
if (colIdx < num_packed_cols) {
PackedVec in_vec;
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;

Expand Down Expand Up @@ -222,7 +223,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
reinterpret_cast<uint32_t*>(sf_out));
});
} else {
int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast<int>(block.x));
int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD;
int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
int grid_x = std::min(
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
Expand All @@ -232,8 +234,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, input_ptr,
input_sf_ptr,
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, num_packed_cols,
Comment on lines 236 to +237
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Passing num_packed_cols to the kernel launch is essential for the fix. This ensures that the kernel receives the correct number of packed columns, which is used to bound the thread-column index. This change, combined with the updated column index comparison within the kernel, resolves the bug on CUDA 12.8 when ELTS_PER_THREAD == 8.

vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
          <<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, num_packed_cols,

input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
Expand Down