|
| 1 | +from numpy import full |
| 2 | +from torch.testing._internal.common_utils import ( |
| 3 | + run_tests, |
| 4 | +) |
| 5 | +from torch._inductor.test_case import TestCase as InductorTestCase |
| 6 | +from torch.testing._internal import common_utils |
| 7 | +from torch._dynamo.testing import CompileCounterWithBackend |
| 8 | + |
| 9 | +from torchao.quantization.quant_api import ( |
| 10 | + quantize_, |
| 11 | + float8_weight_only, |
| 12 | + float8_dynamic_activation_float8_weight, |
| 13 | +) |
| 14 | +from torchao.float8.float8_utils import compute_error |
| 15 | +import torch |
| 16 | +import unittest |
| 17 | +import pytest |
| 18 | +import tempfile |
| 19 | +import copy |
| 20 | +import random |
| 21 | + |
| 22 | +from unittest.mock import patch |
| 23 | +from torchao.utils import ( |
| 24 | + TORCH_VERSION_AT_LEAST_2_5, |
| 25 | + unwrap_tensor_subclass, |
| 26 | +) |
| 27 | + |
| 28 | +if not TORCH_VERSION_AT_LEAST_2_5: |
| 29 | + pytest.skip("Unsupported PyTorch version", allow_module_level=True) |
| 30 | + |
| 31 | + |
| 32 | +random.seed(0) |
| 33 | +torch.manual_seed(0) |
| 34 | + |
| 35 | +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) |
| 36 | +is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) |
| 37 | + |
| 38 | + |
| 39 | +class ToyLinearModel(torch.nn.Module): |
| 40 | + def __init__(self, in_features, out_features): |
| 41 | + super().__init__() |
| 42 | + self.linear1 = torch.nn.Linear(in_features, out_features, bias=False) |
| 43 | + self.linear2 = torch.nn.Linear(out_features, in_features, bias=False) |
| 44 | + |
| 45 | + def forward(self, x): |
| 46 | + x = self.linear1(x) |
| 47 | + x = self.linear2(x) |
| 48 | + return x |
| 49 | + |
| 50 | + |
| 51 | +class TestAffineQuantizedFloat8(InductorTestCase): |
| 52 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 53 | + def test_tensor_core_layout_transpose(self): |
| 54 | + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") |
| 55 | + t = l.weight |
| 56 | + shape = t.shape |
| 57 | + apply_float8_weight_only_quant = float8_weight_only() |
| 58 | + ql = apply_float8_weight_only_quant(l) |
| 59 | + aqt = ql.weight |
| 60 | + aqt_shape = aqt.shape |
| 61 | + assert aqt_shape == shape |
| 62 | + |
| 63 | + # transpose shape test |
| 64 | + for _ in range(10): |
| 65 | + t = t.t() |
| 66 | + aqt = aqt.t() |
| 67 | + shape = t.shape |
| 68 | + aqt_shape = aqt.shape |
| 69 | + assert aqt_shape == shape |
| 70 | + |
| 71 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 72 | + def test_weights_only_save_load(self): |
| 73 | + for apply_quant in [float8_weight_only()]: |
| 74 | + # TODO Fails when l requires grad |
| 75 | + l = torch.nn.Linear( |
| 76 | + 128, 256, dtype=torch.bfloat16, device="cuda" |
| 77 | + ).requires_grad_(False) |
| 78 | + ql = apply_quant(l) |
| 79 | + with tempfile.NamedTemporaryFile() as f: |
| 80 | + torch.save(ql.state_dict(), f) |
| 81 | + f.seek(0) |
| 82 | + # `weights_only=True` is enabled for torch 2.5+ |
| 83 | + if TORCH_VERSION_AT_LEAST_2_5: |
| 84 | + _ = torch.load(f, weights_only=True) |
| 85 | + else: |
| 86 | + _ = torch.load(f, weights_only=False) |
| 87 | + |
| 88 | + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 89 | + @unittest.skipIf(not is_cuda_8_9, "Need H100") |
| 90 | + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) |
| 91 | + @common_utils.parametrize("mode", ["dynamic", "weight-only"]) |
| 92 | + @common_utils.parametrize("compile", [True, False]) |
| 93 | + # Inputs are (M,..), K, N |
| 94 | + @common_utils.parametrize( |
| 95 | + "sizes", |
| 96 | + [ |
| 97 | + ((128,), 256, 128), |
| 98 | + ((256,), 512, 256), |
| 99 | + ((64,), 128, 64), |
| 100 | + ((32, 128), 64, 256), |
| 101 | + ((64, 256), 512, 128), |
| 102 | + ], |
| 103 | + ) |
| 104 | + def test_dynamic_fp8_linear( |
| 105 | + self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple |
| 106 | + ): |
| 107 | + M, N, K = sizes |
| 108 | + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") |
| 109 | + |
| 110 | + mode_map = { |
| 111 | + "dynamic": float8_dynamic_activation_float8_weight, |
| 112 | + "weight-only": float8_weight_only, |
| 113 | + } |
| 114 | + |
| 115 | + # Create a linear layer with bfloat16 dtype |
| 116 | + model = ToyLinearModel(K, N).eval().to(dtype).to("cuda") |
| 117 | + |
| 118 | + quantized_model = copy.deepcopy(model) |
| 119 | + factory = mode_map[mode]() |
| 120 | + quantize_(model, factory) |
| 121 | + |
| 122 | + if compile: |
| 123 | + quantized_model = torch.compile(quantized_model, fullgraph=True) |
| 124 | + |
| 125 | + output_original = model(input_tensor) |
| 126 | + output_quantized = quantized_model(input_tensor) |
| 127 | + |
| 128 | + assert compute_error(output_original, output_quantized) > 20, "Error is too low" |
| 129 | + |
| 130 | + |
| 131 | +common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8) |
| 132 | + |
| 133 | +if __name__ == "__main__": |
| 134 | + pytest.main([__file__]) |
0 commit comments