-
Notifications
You must be signed in to change notification settings - Fork 841
fix: Fix NaN output in mxfp8_quantize for very small input values #2441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -156,5 +156,129 @@ def test_mxfp8_quantize_alignment_torch_device( | |
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("m", [1, 128, 2048]) | ||
| @pytest.mark.parametrize("k", [1024]) | ||
| @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) | ||
| @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) | ||
| def test_mxfp8_quantize_denormal_inputs(m, k, dtype, is_sf_swizzled_layout): | ||
| """Test that very small denormalized inputs do not produce NaN. | ||
|
|
||
| This test covers a bug where inputs small enough to cause E8M0 scale factor | ||
| underflow would result in NaN outputs due to 0 * infinity computations. | ||
| """ | ||
| major, _ = get_compute_capability(torch.device("cuda:0")) | ||
| if major < 10: | ||
| pytest.skip("mxfp8 quantization is not supported on compute capability < 10") | ||
|
|
||
| torch.random.manual_seed(42) | ||
|
|
||
| # Create very small denormalized values (below float32 normal range ~1.17e-38) | ||
| # These values caused NaN in the original buggy implementation | ||
| a = (torch.randn([m, k], dtype=torch.float32) * 1e-38).to(dtype).cuda().contiguous() | ||
|
|
||
| a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Silence Ruff RUF059 by marking Ruff flags these unpacked π§Ή Minimal fix (apply to all four occurrences)- a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)
+ a_fp8, _a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)Also applies to: 201-201, 237-237, 276-276 π§° Toolsπͺ Ruff (0.14.14)[warning] 179-179: Unpacked variable Prefix it with an underscore or any other dummy variable pattern (RUF059) π€ Prompt for AI Agents |
||
|
|
||
| # The primary check: no NaN values should be produced | ||
| nan_count = torch.isnan(a_fp8.float()).sum().item() | ||
| assert nan_count == 0, f"Found {nan_count} NaN values in output (expected 0)" | ||
|
|
||
| # Secondary check: no Inf values should be produced | ||
| inf_count = torch.isinf(a_fp8.float()).sum().item() | ||
| assert inf_count == 0, f"Found {inf_count} Inf values in output (expected 0)" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) | ||
| @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) | ||
| def test_mxfp8_quantize_all_zeros(dtype, is_sf_swizzled_layout): | ||
| """Test that all-zero inputs produce all-zero outputs without NaN.""" | ||
| major, _ = get_compute_capability(torch.device("cuda:0")) | ||
| if major < 10: | ||
| pytest.skip("mxfp8 quantization is not supported on compute capability < 10") | ||
|
|
||
| m, k = 128, 1024 | ||
| a = torch.zeros([m, k], dtype=dtype, device="cuda").contiguous() | ||
|
|
||
| a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout) | ||
|
|
||
| # No NaN values | ||
| assert not torch.isnan(a_fp8.float()).any(), "NaN found in output for zero input" | ||
|
|
||
| # All outputs should be zero | ||
| assert (a_fp8.float() == 0).all(), "Non-zero output for zero input" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) | ||
| @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) | ||
| def test_mxfp8_quantize_mixed_magnitude(dtype, is_sf_swizzled_layout): | ||
| """Test mixed inputs: some blocks with normal values, some with denormals. | ||
|
|
||
| This mimics real-world scenarios where different regions of a tensor | ||
| may have vastly different magnitudes. | ||
| """ | ||
| major, _ = get_compute_capability(torch.device("cuda:0")) | ||
| if major < 10: | ||
| pytest.skip("mxfp8 quantization is not supported on compute capability < 10") | ||
|
|
||
| torch.random.manual_seed(123) | ||
|
|
||
| m, k = 256, 1024 | ||
| a = torch.randn([m, k], dtype=torch.float32) | ||
|
|
||
| # Make some rows have very small values (denormals) | ||
| # Rows 0-63: normal magnitude | ||
| # Rows 64-127: very small (denormal range) | ||
| # Rows 128-191: normal magnitude | ||
| # Rows 192-255: extremely small | ||
| a[64:128, :] *= 1e-38 | ||
| a[192:256, :] *= 1e-40 | ||
|
|
||
| a = a.to(dtype).cuda().contiguous() | ||
|
|
||
| a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout) | ||
|
|
||
| # No NaN values should be produced anywhere | ||
| nan_mask = torch.isnan(a_fp8.float()) | ||
| nan_count = nan_mask.sum().item() | ||
| if nan_count > 0: | ||
| nan_positions = torch.where(nan_mask) | ||
| first_nan_row = nan_positions[0][0].item() | ||
| first_nan_col = nan_positions[1][0].item() | ||
| pytest.fail( | ||
| f"Found {nan_count} NaN values. First NaN at row={first_nan_row}, col={first_nan_col}" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) | ||
| @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) | ||
| def test_mxfp8_quantize_single_denormal_in_block(dtype, is_sf_swizzled_layout): | ||
| """Test a block where most values are normal but one is a tiny denormal. | ||
|
|
||
| This specifically tests the scenario from the original bug report where | ||
| a single float32 denormal value in a block would become NaN due to | ||
| 0 * infinity when FTZ mode flushes it to zero. | ||
| """ | ||
| major, _ = get_compute_capability(torch.device("cuda:0")) | ||
| if major < 10: | ||
| pytest.skip("mxfp8 quantization is not supported on compute capability < 10") | ||
|
|
||
| m, k = 64, 1024 | ||
| # Start with small but normal-range values | ||
| a = torch.full([m, k], 1e-36, dtype=torch.float32) | ||
|
|
||
| # Insert a few extremely small values (float32 denormals) at specific positions | ||
| # These are the values that triggered NaN in the original bug | ||
| denormal_positions = [(0, 498), (0, 911), (32, 100), (63, 512)] | ||
| for row, col in denormal_positions: | ||
| a[row, col] = 9.18e-40 # A float32 denormal value | ||
|
|
||
| a = a.to(dtype).cuda().contiguous() | ||
|
|
||
| a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout) | ||
|
|
||
| # Check that no NaN is produced | ||
| nan_mask = torch.isnan(a_fp8.float()) | ||
| assert not nan_mask.any(), f"Found NaN at positions: {torch.where(nan_mask)}" | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-v"]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This compute capability check is duplicated across all four new tests and some existing ones. To improve maintainability and reduce code duplication, consider extracting this logic into a shared helper function or a pytest fixture.
For example, you could define a helper function:
And then call
_require_sm100()at the start of each relevant test. This would make the tests cleaner and avoid repeating the same logic.