[ROCm][GPTQ][Bugfix] Fix GPTQ GEMM kernel output zeroing race condition#30719
[ROCm][GPTQ][Bugfix] Fix GPTQ GEMM kernel output zeroing race condition#30719vllm-bot merged 10 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
There was a problem hiding this comment.
Code Review
This PR correctly identifies and fixes a race condition in the GPTQ GEMM kernels by moving the output tensor zeroing from inside the CUDA kernels to the host side before the kernel launch. The changes are in the right direction, but the fix is incomplete. I've left a critical comment pointing out that the in-kernel zeroing logic also needs to be removed from several other kernel variants to fully resolve the bug.
I am having trouble creating individual review comments. Click here to see my feedback.
csrc/quantization/gptq/q_gemm.cu (236-239)
This change correctly addresses the race condition for the 4-bit kernel. However, the fix is incomplete as the same problematic in-kernel zeroing logic exists in several other kernel variants. This is a critical issue because the bug will persist for other quantization bit-widths.
Please apply the same fix (i.e., remove the in-kernel zeroing) to the following kernels as well:
gemm_half_q_half_gptq_2bit_kernel(lines 375-378)gemm_half_q_half_gptq_3bit_kernel(lines 497-500)gemm_half_q_half_gptq_8bit_kernel(lines 626-629)gemm_half_q_half_alt_4bit_kernel(lines 1227-1229)gemm_half_q_half_alt_8bit_kernel(lines 1322-1324)
Additionally, it's better to remove this dead code entirely rather than commenting it out.
Already done :) |
|
maybe we should just deprecate this kernel... |
For now, I think we can merge this PR though, since it resolves the GPTQ test bug on ROCm. |
mgoin
left a comment
There was a problem hiding this comment.
This makes sense and simplifies the kernel, LGTM! I do think we were planning to deprecate this kernel now that @jinzhen-lin added SM75 support for Marlin #29901, but if ROCm needs this kernel we can keep it for now. I would highly recommend the ROCm team investigating if they can reuse the Marlin kernels
|
@mgoin There are some failures due to |
|
cc @tjtanaa Can we merge this one too? The failing tests are known to be problematic. |
…on (vllm-project#30719) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…on (vllm-project#30719) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
…on (vllm-project#30719) Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
…on (vllm-project#30719) Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Summary
Fixes a race condition in the GPTQ GEMM kernels that caused incorrect results when
input_size > BLOCK_KN_SIZE (128).Problem
The GPTQ GEMM kernels use multiple thread blocks along the k-dimension (
gridDim.z > 1) that accumulate partial results viaatomicAdd. The output tensor was being zeroed inside the kernel byblockIdx.z == 0:Since
__syncthreads()only synchronizes threads within the same block, not across different blocks, this creates a race condition where:z > 0mayatomicAddtheir results before blockz=0finishes zeroingz=0may overwrite results that other blocks have already addedThis caused numerical errors up to 45x the expected values, particularly when:
input_size > 128(triggers multiple k-blocks)Solution
torch::zeros()instead oftorch::empty()