[Bugfix] Handle NaN in QuantFP8 Native Forward#41427
Conversation
Signed-off-by: Alex Brooks <albrooks@redhat.com>
Signed-off-by: Alex Brooks <albrooks@redhat.com>
There was a problem hiding this comment.
Code Review
This pull request introduces NaN handling in the forward_native method of QuantFP8 by converting NaNs to zeros, aligning its behavior with the CUDA kernel. It also adds a test case to verify this behavior. Feedback suggests moving the NaN conversion to the start of the method to ensure group quantization is also covered, expanding the test suite to include group quantization shapes, and strengthening the test assertions to verify exact alignment across the entire tensor.
|
|
||
| # Replace NaN with 0 to match the CUDA kernel's behavior, since the underlying | ||
| # CUDA kernels use fmaxf, which won't propagate NaNs if we have numeric values. | ||
| x_f = torch.nan_to_num(x.to(torch.float32), nan=0.0) |
There was a problem hiding this comment.
The NaN handling logic is currently placed after the check for is_group_quant (lines 186-188). This means that dynamic group quantization (e.g., block quantization) will still use the original tensor x containing NaNs when calling _quantize_group_native, leading to NaN scales and outputs in that path. To ensure consistent behavior across all native quantization modes, this nan_to_num conversion should be moved to the beginning of the forward_native method, before the is_group_quant check.
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", DTYPES) | ||
| @pytest.mark.parametrize("group_shape_name", ["PER_TOKEN", "PER_TENSOR"]) |
There was a problem hiding this comment.
The test currently only covers PER_TOKEN and PER_TENSOR quantization. It should also include a group quantization shape (e.g., a block size like (1, 128)) to ensure that the is_group_quant path in forward_native (which calls _quantize_group_native) also handles NaNs correctly. This is particularly important as the current implementation in input_quant_fp8.py appears to skip NaN handling for the group quantization path.
|
|
||
| # Quantized outputs should match at non-NaN input positions. | ||
| valid = ~torch.isnan(x) | ||
| torch.testing.assert_close(out_native.float()[valid], out_cuda.float()[valid]) |
There was a problem hiding this comment.
Since the goal of this PR is to align the native forward behavior with the CUDA kernel (which treats NaNs as 0), we should verify that the outputs match exactly across the entire tensor, including at the positions where NaNs were injected. Using [valid] masking allows the outputs to differ at NaN positions as long as they are not NaN, which doesn't fully guarantee alignment.
| torch.testing.assert_close(out_native.float()[valid], out_cuda.float()[valid]) | |
| torch.testing.assert_close(out_native.float(), out_cuda.float()) |
Purpose
Encountered while investigating this issue in Granite Speech: #41284
Also related fix for the source of the NaNs: #41424
While digging into the the cause for FP8 outputs turning into garbage for
0.17and0.20, I eventually found that the cause was due to bias values being corrupted, which was causing a bunch of NaNs. It looks like the reason this didn't become a problem until recently is that the native forward doesn't handle NaNs, but the underlying CUDA kernels do, because they usefmaxfetc, which implicitly handle NaNs.This PR adds handling for NaNs to the native forward implementation and a check to ensure that it lines up with the CUDA forward.
Test Plan
Also gives normal outputs with the broken example here #41424, although the other PR is the correct fix for the root cause of the NaNs, while this one is to align the behaviors between the forward implementations.
Test Result
Test fails on main with
Native scales contain NaNand passes on this branch.@DarkLight1337 @robertgshaw2-redhat Could you PTAL?