|
25 | 25 | MXInferenceLinear,
|
26 | 26 | MXLinear,
|
27 | 27 | )
|
| 28 | +from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig |
28 | 29 | from torchao.quantization import quantize_
|
29 | 30 | from torchao.quantization.utils import compute_error
|
30 | 31 | from torchao.utils import (
|
@@ -372,3 +373,34 @@ def test_inference_print_str():
|
372 | 373 | s = str(m)
|
373 | 374 | assert "bl_sz=32" in s
|
374 | 375 | assert "kernel=emulated" in s
|
| 376 | + |
| 377 | + |
| 378 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 379 | +@pytest.mark.skipif( |
| 380 | + not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+" |
| 381 | +) |
| 382 | +@pytest.mark.skipif(not is_sm_at_least_100, reason="Reqs sm100") |
| 383 | +@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn]) |
| 384 | +@pytest.mark.parametrize("bias", [True, False]) |
| 385 | +@pytest.mark.parametrize("compile", [True, False]) |
| 386 | +@torch.no_grad() |
| 387 | +def test_inference_subclass(elem_dtype, bias: bool, compile: bool): |
| 388 | + """ |
| 389 | + Smoke test for inference compile |
| 390 | + """ |
| 391 | + if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): |
| 392 | + if not is_sm_at_least_89(): |
| 393 | + pytest.skip("CUDA capability >= 8.9 required for float8 in triton") |
| 394 | + |
| 395 | + m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda") |
| 396 | + m_mx = copy.deepcopy(m) |
| 397 | + config = MXFPInferenceConfig() |
| 398 | + quantize_(m_mx, config=config) |
| 399 | + if compile: |
| 400 | + m_mx = torch.compile(m_mx, fullgraph=True) |
| 401 | + |
| 402 | + x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16) |
| 403 | + y_ref = m(x) |
| 404 | + y_mx = m_mx(x) |
| 405 | + sqnr = compute_error(y_ref, y_mx) |
| 406 | + assert sqnr >= 25.0, f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}" |
0 commit comments