Skip to content

Commit e0b3d31

Browse files
jainapurvadrisspg
authored andcommitted
Add float8 quant primitives
1 parent cfabc13 commit e0b3d31

11 files changed

+522
-62
lines changed

scripts/hf_eval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def all_linear(mod, name):
111111
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
112112
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
113113
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
114-
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply')
114+
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo", "autoquant", "None"], help='Which quantization technique to apply')
115115
parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply')
116116
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
117117
parser.add_argument('--save', action='store_true', help='Whether to save the model.')
+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from numpy import full
2+
from torch.testing._internal.common_utils import (
3+
run_tests,
4+
)
5+
from torch._inductor.test_case import TestCase as InductorTestCase
6+
from torch.testing._internal import common_utils
7+
from torch._dynamo.testing import CompileCounterWithBackend
8+
9+
from torchao.quantization.quant_api import (
10+
quantize_,
11+
float8_weight_only,
12+
float8_dynamic_activation_float8_weight,
13+
)
14+
from torchao.float8.float8_utils import compute_error
15+
import torch
16+
import unittest
17+
import pytest
18+
import tempfile
19+
import copy
20+
import random
21+
22+
from unittest.mock import patch
23+
from torchao.utils import (
24+
TORCH_VERSION_AT_LEAST_2_5,
25+
unwrap_tensor_subclass,
26+
)
27+
28+
if not TORCH_VERSION_AT_LEAST_2_5:
29+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
30+
31+
32+
random.seed(0)
33+
torch.manual_seed(0)
34+
35+
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
36+
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
37+
38+
39+
class ToyLinearModel(torch.nn.Module):
40+
def __init__(self, in_features, out_features):
41+
super().__init__()
42+
self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)
43+
self.linear2 = torch.nn.Linear(out_features, in_features, bias=False)
44+
45+
def forward(self, x):
46+
x = self.linear1(x)
47+
x = self.linear2(x)
48+
return x
49+
50+
51+
class TestAffineQuantizedFloat8(InductorTestCase):
52+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
53+
def test_tensor_core_layout_transpose(self):
54+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
55+
t = l.weight
56+
shape = t.shape
57+
apply_float8_weight_only_quant = float8_weight_only()
58+
ql = apply_float8_weight_only_quant(l)
59+
aqt = ql.weight
60+
aqt_shape = aqt.shape
61+
assert aqt_shape == shape
62+
63+
# transpose shape test
64+
for _ in range(10):
65+
t = t.t()
66+
aqt = aqt.t()
67+
shape = t.shape
68+
aqt_shape = aqt.shape
69+
assert aqt_shape == shape
70+
71+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
72+
def test_weights_only_save_load(self):
73+
for apply_quant in [float8_weight_only()]:
74+
# TODO Fails when l requires grad
75+
l = torch.nn.Linear(
76+
128, 256, dtype=torch.bfloat16, device="cuda"
77+
).requires_grad_(False)
78+
ql = apply_quant(l)
79+
with tempfile.NamedTemporaryFile() as f:
80+
torch.save(ql.state_dict(), f)
81+
f.seek(0)
82+
# `weights_only=True` is enabled for torch 2.5+
83+
if TORCH_VERSION_AT_LEAST_2_5:
84+
_ = torch.load(f, weights_only=True)
85+
else:
86+
_ = torch.load(f, weights_only=False)
87+
88+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
89+
@unittest.skipIf(not is_cuda_8_9, "Need H100")
90+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
91+
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
92+
@common_utils.parametrize("compile", [True, False])
93+
# Inputs are (M,..), K, N
94+
@common_utils.parametrize(
95+
"sizes",
96+
[
97+
((128,), 256, 128),
98+
((256,), 512, 256),
99+
((64,), 128, 64),
100+
((32, 128), 64, 256),
101+
((64, 256), 512, 128),
102+
],
103+
)
104+
def test_dynamic_fp8_linear(
105+
self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple
106+
):
107+
M, N, K = sizes
108+
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
109+
110+
mode_map = {
111+
"dynamic": float8_dynamic_activation_float8_weight,
112+
"weight-only": float8_weight_only,
113+
}
114+
115+
# Create a linear layer with bfloat16 dtype
116+
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
117+
118+
quantized_model = copy.deepcopy(model)
119+
factory = mode_map[mode]()
120+
quantize_(model, factory)
121+
122+
if compile:
123+
quantized_model = torch.compile(quantized_model, fullgraph=True)
124+
125+
output_original = model(input_tensor)
126+
output_quantized = quantized_model(input_tensor)
127+
128+
assert compute_error(output_original, output_quantized) > 20, "Error is too low"
129+
130+
131+
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8)
132+
133+
if __name__ == "__main__":
134+
pytest.main([__file__])

torchao/dtypes/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
PlainLayoutType,
1111
SemiSparseLayoutType,
1212
TensorCoreTiledLayoutType,
13+
Float8LayoutType,
14+
Float8AQTLayout,
1315
)
1416

1517
__all__ = [
@@ -24,4 +26,6 @@
2426
"PlainLayoutType",
2527
"SemiSparseLayoutType",
2628
"TensorCoreTiledLayoutType",
29+
"Float8LayoutType",
30+
"Float8AQTLayout",
2731
]

0 commit comments

Comments
 (0)