Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,10 @@ __device__ uint64_t cvt_warp_fp16_to_mxfp8(PackedVec<Type, CVT_ELTS_PER_THREAD>&
SFValue = static_cast<float>(tmpSFVal);
fp8SFVal = tmpSFVal.__x;
// Get the output scale (reciprocal of the SFValue).
float outputScale = vecMax != 0.f ? reciprocal_approximate_ftz(SFValue) : 0.0f;
// Note: Check SFValue != 0 (not vecMax != 0) because E8M0 conversion can underflow
// very small vecMax values to zero. Using vecMax != 0 would cause division by zero
// (reciprocal of 0 = infinity), leading to NaN when multiplied with denormal inputs.
float outputScale = SFValue != 0.f ? reciprocal_approximate_ftz(SFValue) : 0.0f;

if (SFout) {
// Write the SF to global memory (STG.8).
Expand Down
124 changes: 124 additions & 0 deletions tests/utils/test_fp8_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,129 @@ def test_mxfp8_quantize_alignment_torch_device(
)


@pytest.mark.parametrize("m", [1, 128, 2048])
@pytest.mark.parametrize("k", [1024])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
def test_mxfp8_quantize_denormal_inputs(m, k, dtype, is_sf_swizzled_layout):
"""Test that very small denormalized inputs do not produce NaN.

This test covers a bug where inputs small enough to cause E8M0 scale factor
underflow would result in NaN outputs due to 0 * infinity computations.
"""
major, _ = get_compute_capability(torch.device("cuda:0"))
if major < 10:
pytest.skip("mxfp8 quantization is not supported on compute capability < 10")
Comment on lines +169 to +171
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This compute capability check is duplicated across all four new tests and some existing ones. To improve maintainability and reduce code duplication, consider extracting this logic into a shared helper function or a pytest fixture.

For example, you could define a helper function:

def _require_sm100():
    major, _ = get_compute_capability(torch.device("cuda:0"))
    if major < 10:
        pytest.skip("mxfp8 quantization is not supported on compute capability < 10")

And then call _require_sm100() at the start of each relevant test. This would make the tests cleaner and avoid repeating the same logic.


torch.random.manual_seed(42)

# Create very small denormalized values (below float32 normal range ~1.17e-38)
# These values caused NaN in the original buggy implementation
a = (torch.randn([m, k], dtype=torch.float32) * 1e-38).to(dtype).cuda().contiguous()

a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Silence Ruff RUF059 by marking a_sf as intentionally unused.

Ruff flags these unpacked a_sf bindings as unused. Rename them to _a_sf (or _) to avoid lint failures.

🧹 Minimal fix (apply to all four occurrences)
-    a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)
+    a_fp8, _a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)

Also applies to: 201-201, 237-237, 276-276

🧰 Tools
πŸͺ› Ruff (0.14.14)

[warning] 179-179: Unpacked variable a_sf is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

πŸ€– Prompt for AI Agents
In `@tests/utils/test_fp8_quantize.py` at line 179, Rename the unused unpacked
scalar-factor bindings from a_sf to a deliberately ignored name to silence Ruff
RUF059: change occurrences where you unpack mxfp8_quantize into "a_fp8, a_sf"
(calls to mxfp8_quantize) to use "_a_sf" or "_" (e.g., "a_fp8, _a_sf =
mxfp8_quantize(...)") in all four places so the variable is intentionally marked
unused while preserving the a_fp8 binding and test behavior.


# The primary check: no NaN values should be produced
nan_count = torch.isnan(a_fp8.float()).sum().item()
assert nan_count == 0, f"Found {nan_count} NaN values in output (expected 0)"

# Secondary check: no Inf values should be produced
inf_count = torch.isinf(a_fp8.float()).sum().item()
assert inf_count == 0, f"Found {inf_count} Inf values in output (expected 0)"


@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
def test_mxfp8_quantize_all_zeros(dtype, is_sf_swizzled_layout):
"""Test that all-zero inputs produce all-zero outputs without NaN."""
major, _ = get_compute_capability(torch.device("cuda:0"))
if major < 10:
pytest.skip("mxfp8 quantization is not supported on compute capability < 10")

m, k = 128, 1024
a = torch.zeros([m, k], dtype=dtype, device="cuda").contiguous()

a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)

# No NaN values
assert not torch.isnan(a_fp8.float()).any(), "NaN found in output for zero input"

# All outputs should be zero
assert (a_fp8.float() == 0).all(), "Non-zero output for zero input"


@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
def test_mxfp8_quantize_mixed_magnitude(dtype, is_sf_swizzled_layout):
"""Test mixed inputs: some blocks with normal values, some with denormals.

This mimics real-world scenarios where different regions of a tensor
may have vastly different magnitudes.
"""
major, _ = get_compute_capability(torch.device("cuda:0"))
if major < 10:
pytest.skip("mxfp8 quantization is not supported on compute capability < 10")

torch.random.manual_seed(123)

m, k = 256, 1024
a = torch.randn([m, k], dtype=torch.float32)

# Make some rows have very small values (denormals)
# Rows 0-63: normal magnitude
# Rows 64-127: very small (denormal range)
# Rows 128-191: normal magnitude
# Rows 192-255: extremely small
a[64:128, :] *= 1e-38
a[192:256, :] *= 1e-40

a = a.to(dtype).cuda().contiguous()

a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)

# No NaN values should be produced anywhere
nan_mask = torch.isnan(a_fp8.float())
nan_count = nan_mask.sum().item()
if nan_count > 0:
nan_positions = torch.where(nan_mask)
first_nan_row = nan_positions[0][0].item()
first_nan_col = nan_positions[1][0].item()
pytest.fail(
f"Found {nan_count} NaN values. First NaN at row={first_nan_row}, col={first_nan_col}"
)


@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
def test_mxfp8_quantize_single_denormal_in_block(dtype, is_sf_swizzled_layout):
"""Test a block where most values are normal but one is a tiny denormal.

This specifically tests the scenario from the original bug report where
a single float32 denormal value in a block would become NaN due to
0 * infinity when FTZ mode flushes it to zero.
"""
major, _ = get_compute_capability(torch.device("cuda:0"))
if major < 10:
pytest.skip("mxfp8 quantization is not supported on compute capability < 10")

m, k = 64, 1024
# Start with small but normal-range values
a = torch.full([m, k], 1e-36, dtype=torch.float32)

# Insert a few extremely small values (float32 denormals) at specific positions
# These are the values that triggered NaN in the original bug
denormal_positions = [(0, 498), (0, 911), (32, 100), (63, 512)]
for row, col in denormal_positions:
a[row, col] = 9.18e-40 # A float32 denormal value

a = a.to(dtype).cuda().contiguous()

a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout)

# Check that no NaN is produced
nan_mask = torch.isnan(a_fp8.float())
assert not nan_mask.any(), f"Found NaN at positions: {torch.where(nan_mask)}"


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading