[Bugfix][Hardware][AMD] Fix FP8 dtype in silu_mul quantization#31179
[Bugfix][Hardware][AMD] Fix FP8 dtype in silu_mul quantization#31179yewentao256 merged 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses a bug in silu_mul_per_token_group_quant_fp8_colmajor where a hardcoded torch.float8_e4m3fn dtype and its associated min/max values were used, causing accuracy issues on ROCm platforms that expect torch.float8_e4m3fnuz. The changes correctly use current_platform.fp8_dtype() to determine the appropriate float8 data type and apply platform-specific min/max values for quantization, aligning the function's behavior with per_token_group_quant_fp8 and ensuring correctness on ROCm. The changes are correct and well-implemented.
|
@hongxiayang @jithunnair-amd This is ready for review and addresses critical FP8 dtype handling for ROCm on the new Strix Halo architecture. |
yewentao256
left a comment
There was a problem hiding this comment.
LGTM, thanks for the work!
|
Hi maintainers, This PR has been approved by @yewentao256. The CI failures appear to be known flaky tests (e.g., Would it be possible to trigger a merge or a final re-run of the failing jobs? The fix itself is straightforward - using Thank you! |
|
@vllm-bot rerun ci The |
…nt_fp8_colmajor The function was hardcoding torch.float8_e4m3fn dtype and using its default min/max values. On ROCm platforms that use torch.float8_e4m3fnuz, this causes incorrect dtype and accuracy issues. This fix: - Uses current_platform.fp8_dtype() instead of hardcoded dtype - Applies the same ROCm-aware fp8 min/max logic (224.0 for fnuz) that is already used in per_token_group_quant_fp8() in the same file Signed-off-by: c0de128 <kevin.mckay@outlook.com>
2f4658d to
f3b8abe
Compare
…project#31179) Signed-off-by: c0de128 <kevin.mckay@outlook.com>
…project#31179) Signed-off-by: c0de128 <kevin.mckay@outlook.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
…project#31179) Signed-off-by: c0de128 <kevin.mckay@outlook.com>
Summary
Fix hardcoded
torch.float8_e4m3fndtype insilu_mul_per_token_group_quant_fp8_colmajor()that causes incorrect dtype and accuracy issues on ROCm platforms usingtorch.float8_e4m3fnuz.Problem
The function in
vllm/model_executor/layers/quantization/utils/fp8_utils.pywas:torch.float8_e4m3fndtype for the output tensor (line 629)finfo.min/maxvalues fromtorch.float8_e4m3fn(lines 640-642)On ROCm platforms that use
torch.float8_e4m3fnuz:Solution
Apply the same pattern already used in
per_token_group_quant_fp8()in the same file (lines 766-770):Test Plan
This is a consistency fix that aligns the function with the existing ROCm-aware pattern used elsewhere in the same file. The fix ensures:
Related
This is similar to the pattern established for other FP8 quantization functions in this file that already handle ROCm fnuz correctly.