diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu index d0264c4d154c..8583b79fd58f 100644 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -107,7 +107,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); reinterpret_cast(out)[outOffset >> 1] = packed64; } else { - out[inOffset] = out_val; + int64_t outOffset = + rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + out[outOffset] = out_val; } } } @@ -140,7 +142,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] int const numBlocksPerSM = vllm_runtime_blocks_per_sm(static_cast(block.x)); - int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); + int sf_n_unpadded = int(n / CVT_FP4_ELTS_PER_THREAD); int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); int grid_x = std::min( diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index c27fb69d44be..b521b4707a4d 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -109,7 +109,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) template __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols, - int32_t sf_n_unpadded, Type const* __restrict__ in, + int32_t sf_n_unpadded, int32_t num_packed_cols, + Type const* __restrict__ in, float const* __restrict__ SFScale, uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) { @@ -131,7 +132,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) // Iterate over all rows and cols including padded ones - // ensures we visit every single scale factor address to initialize it. for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { - if (colIdx < sf_n_unpadded) { + if (colIdx < num_packed_cols) { PackedVec in_vec; int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; @@ -222,7 +223,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, reinterpret_cast(sf_out)); }); } else { - int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); + int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD; + int grid_y = vllm::div_round_up(num_packed_cols, static_cast(block.x)); int grid_x = std::min( m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); dim3 grid(grid_x, grid_y); @@ -232,8 +234,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, auto input_ptr = static_cast(input.data_ptr()); // NOTE: We don't support e8m0 scales at this moment. vllm::cvt_fp16_to_fp4_sf_major - <<>>(m, n, sf_n_unpadded, input_ptr, - input_sf_ptr, + <<>>(m, n, sf_n_unpadded, num_packed_cols, + input_ptr, input_sf_ptr, reinterpret_cast(output_ptr), reinterpret_cast(sf_out)); });