Skip to content

Commit d0e6246

Browse files
authored
Add Float8 Weight Only and FP8 weight + dynamic activation (#740)
* mixin * fix memory being held by autograd
1 parent 05224a9 commit d0e6246

13 files changed

+566
-95
lines changed

scripts/hf_eval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def all_linear(mod, name):
114114
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
115115
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
116116
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
117-
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "fp6", "None"], help='Which quantization technique to apply')
117+
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo", "autoquant", "None"], help='Which quantization technique to apply')
118118
parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply')
119119
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
120120
parser.add_argument('--save', action='store_true', help='Whether to save the model.')

test/dtypes/test_affine_quantized.py

+46-25
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,33 @@
1010
int8_dynamic_activation_int8_semi_sparse_weight,
1111
float8_weight_only,
1212
)
13+
from torch.testing._internal import common_utils
1314
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1415

1516
import torch
1617
import unittest
1718
import tempfile
1819

20+
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
21+
22+
23+
def get_quantization_functions(do_sparse: bool, do_int4: bool):
24+
base_functions = [
25+
int8_weight_only(),
26+
int8_dynamic_activation_int4_weight(),
27+
int8_dynamic_activation_int8_weight(),
28+
]
29+
if do_int4:
30+
base_functions.append(int4_weight_only(group_size=32))
31+
32+
if do_sparse:
33+
base_functions.append(int8_dynamic_activation_int8_semi_sparse_weight())
34+
35+
if is_cuda_8_9:
36+
base_functions.append(float8_weight_only())
37+
38+
return base_functions
39+
1940

2041
class TestAffineQuantized(TestCase):
2142
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -38,36 +59,36 @@ def test_tensor_core_layout_transpose(self):
3859
self.assertEqual(aqt_shape, shape)
3960

4061
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
41-
def test_weights_only(self):
42-
for apply_quant in [int4_weight_only(group_size=32), int8_weight_only(), int8_dynamic_activation_int4_weight(),
43-
int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int8_semi_sparse_weight(), float8_weight_only()]:
44-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
45-
ql = apply_quant(l)
46-
with tempfile.NamedTemporaryFile() as f:
47-
torch.save(ql.state_dict(), f)
48-
f.seek(0)
49-
# `weights_only=True` is enabled for torch 2.5+
50-
if TORCH_VERSION_AT_LEAST_2_5:
51-
_ = torch.load(f, weights_only=True)
52-
else:
53-
_ = torch.load(f, weights_only=False)
62+
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
63+
def test_weights_only(self, apply_quant):
64+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
65+
ql = apply_quant(l)
66+
with tempfile.NamedTemporaryFile() as f:
67+
torch.save(ql.state_dict(), f)
68+
f.seek(0)
69+
# `weights_only=True` is enabled for torch 2.5+
70+
if TORCH_VERSION_AT_LEAST_2_5:
71+
_ = torch.load(f, weights_only=True)
72+
else:
73+
_ = torch.load(f, weights_only=False)
5474

5575
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
56-
def test_to_device(self):
57-
from torchao.quantization import quantize_
58-
for apply_quant in [int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight()]:
59-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
60-
ql = apply_quant(l)
61-
ql.to("cuda")
76+
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
77+
def test_to_device(self, apply_quant):
78+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
79+
ql = apply_quant(l)
80+
ql.to("cuda")
81+
82+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
83+
ql = apply_quant(l)
84+
ql.to(device="cuda")
6285

63-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
64-
ql = apply_quant(l)
65-
ql.to(device="cuda")
86+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
87+
ql = apply_quant(l)
88+
ql.cuda()
6689

67-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
68-
ql = apply_quant(l)
69-
ql.cuda()
7090

91+
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
7192

7293
if __name__ == "__main__":
7394
run_tests()
+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from torchao.utils import (
2+
TORCH_VERSION_AT_LEAST_2_5,
3+
unwrap_tensor_subclass,
4+
)
5+
import pytest
6+
7+
if not TORCH_VERSION_AT_LEAST_2_5:
8+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
9+
10+
from numpy import full
11+
from torch.testing._internal.common_utils import (
12+
run_tests,
13+
)
14+
from torch._inductor.test_case import TestCase as InductorTestCase
15+
from torch.testing._internal import common_utils
16+
from torch._dynamo.testing import CompileCounterWithBackend
17+
18+
from torchao.quantization import (
19+
quantize_,
20+
float8_weight_only,
21+
float8_dynamic_activation_float8_weight,
22+
)
23+
from torchao.float8.float8_utils import compute_error
24+
import torch
25+
import unittest
26+
import pytest
27+
import tempfile
28+
import copy
29+
import random
30+
31+
from unittest.mock import patch
32+
33+
34+
random.seed(0)
35+
torch.manual_seed(0)
36+
37+
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
38+
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
39+
40+
41+
class ToyLinearModel(torch.nn.Module):
42+
def __init__(self, in_features, out_features):
43+
super().__init__()
44+
self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)
45+
self.linear2 = torch.nn.Linear(out_features, in_features, bias=False)
46+
47+
def forward(self, x):
48+
x = self.linear1(x)
49+
x = self.linear2(x)
50+
return x
51+
52+
53+
class TestAffineQuantizedFloat8Compile(InductorTestCase):
54+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
55+
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
56+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
57+
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
58+
@common_utils.parametrize("compile", [True, False])
59+
# Inputs are (M,..), K, N
60+
@common_utils.parametrize(
61+
"sizes",
62+
[
63+
((128,), 256, 128),
64+
((256,), 512, 256),
65+
((64,), 128, 64),
66+
((32, 128), 64, 256),
67+
((64, 256), 512, 128),
68+
],
69+
)
70+
def test_fp8_linear_variants(
71+
self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple
72+
):
73+
M, N, K = sizes
74+
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
75+
76+
mode_map = {
77+
"dynamic": float8_dynamic_activation_float8_weight,
78+
"weight-only": float8_weight_only,
79+
}
80+
81+
# Create a linear layer with bfloat16 dtype
82+
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
83+
84+
quantized_model = copy.deepcopy(model)
85+
factory = mode_map[mode]()
86+
quantize_(model, factory)
87+
88+
if compile:
89+
quantized_model = torch.compile(quantized_model, fullgraph=True)
90+
91+
output_original = model(input_tensor)
92+
output_quantized = quantized_model(input_tensor)
93+
94+
error = compute_error(output_original, output_quantized)
95+
assert (
96+
compute_error(output_original, output_quantized) > 20
97+
), f"Quantization error is too high got a SQNR of {error}"
98+
99+
100+
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
101+
102+
if __name__ == "__main__":
103+
pytest.main([__file__])

torchao/dtypes/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
PlainLayoutType,
1313
SemiSparseLayoutType,
1414
TensorCoreTiledLayoutType,
15+
Float8LayoutType,
16+
Float8AQTLayout,
1517
)
1618

1719
__all__ = [
@@ -27,4 +29,6 @@
2729
"PlainLayoutType",
2830
"SemiSparseLayoutType",
2931
"TensorCoreTiledLayoutType",
32+
"Float8LayoutType",
33+
"Float8AQTLayout",
3034
]

0 commit comments

Comments
 (0)