Skip to content

Commit 221514e

Browse files
authored
Deprecate top level quantization APIs (pytorch#344)
Summary: This PR deprecates a few quantization APIs and here are the bc-breaking notes: 1. int8 weight only quantization int8 weight only quant module swap API ``` apply_weight_only_int8_quant(model) ``` and int8 weight only tensor subclass API ``` change_linear_weights_to_int8_woqtensors(model) ``` --> unified tensor subclass API ``` quantize(model, get_apply_int8wo_quant())) ``` 2. int8 dynamic quantization ``` apply_dynamic_quant(model) ``` or ``` change_linear_weights_to_int8_dqtensors(model) ``` --> unified tensor subclass API ``` quantize(model, get_apply_int8dyn_quant())) ``` 3. int4 weight only quantization ``` change_linear_weights_to_int4_wotensors(model) ``` --> unified tensor subclass API ``` quantize(model, get_apply_int4wo_quant())) ``` Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
1 parent 9d4d3f9 commit 221514e

File tree

11 files changed

+399
-326
lines changed

11 files changed

+399
-326
lines changed

test/dtypes/test_aq.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
TestCase,
33
run_tests,
44
)
5-
from torchao.quantization.quant_api import get_apply_int4wo_quant
5+
from torchao.quantization.quant_api import int4wo
66
import torch
77
import unittest
88

@@ -12,7 +12,7 @@ class TestAQ(TestCase):
1212
def test_tensor_core_layout_transpose(self):
1313
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
1414
shape = t.shape
15-
apply_int4wo_quant = get_apply_int4wo_quant(groupsize=32)
15+
apply_int4wo_quant = int4wo(groupsize=32)
1616
aqt = apply_int4wo_quant(t)
1717
aqt_shape = aqt.shape
1818
self.assertEqual(aqt_shape, shape)

test/integration/test_integration.py

+66-26
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,17 @@
2020
DynamicallyPerAxisQuantizedLinear,
2121
)
2222
from torchao.quantization.quant_api import (
23-
apply_dynamic_quant,
24-
apply_weight_only_int8_quant,
23+
int4wo,
24+
int8wo,
25+
int8da_int8w,
26+
quantize,
27+
_replace_with_custom_fn_if_matches_filter,
28+
)
29+
# APIs to be deprecated (used for torch 2.2.2 and 2.3)
30+
from torchao.quantization.quant_api import (
2531
change_linear_weights_to_int8_dqtensors,
2632
change_linear_weights_to_int8_woqtensors,
2733
change_linear_weights_to_int4_woqtensors,
28-
_replace_with_custom_fn_if_matches_filter,
2934
)
3035
from torchao.quantization.quant_primitives import (
3136
safe_int_mm,
@@ -73,26 +78,53 @@
7378
from parameterized import parameterized
7479
import itertools
7580
import logging
76-
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4, is_fbcode
81+
from torchao.utils import (
82+
TORCH_VERSION_AFTER_2_3,
83+
TORCH_VERSION_AFTER_2_4,
84+
unwrap_tensor_subclass,
85+
is_fbcode,
86+
)
7787

7888
logger = logging.getLogger("INFO")
7989

8090
torch.manual_seed(0)
8191
config.cache_size_limit = 100
8292

83-
# TODO: use this to reduce the number of tests
84-
TENSOR_SUBCLASS_APIS = [
85-
change_linear_weights_to_int8_dqtensors,
86-
change_linear_weights_to_int8_woqtensors,
87-
change_linear_weights_to_int4_woqtensors,
88-
]
89-
9093
COMMON_DEVICES = ["cpu", "cuda"]
9194

9295
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
9396

9497
COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()
9598

99+
def _int8wo_api(mod):
100+
if TORCH_VERSION_AFTER_2_4:
101+
quantize(mod, int8wo())
102+
unwrap_tensor_subclass(mod)
103+
else:
104+
change_linear_weights_to_int8_woqtensors(mod)
105+
106+
def _int8da_int8w_api(mod):
107+
if TORCH_VERSION_AFTER_2_4:
108+
quantize(mod, int8da_int8w())
109+
unwrap_tensor_subclass(mod)
110+
else:
111+
change_linear_weights_to_int8_dqtensors(mod)
112+
113+
def _int4wo_api(mod):
114+
if TORCH_VERSION_AFTER_2_4:
115+
quantize(mod, int4wo())
116+
unwrap_tensor_subclass(mod)
117+
else:
118+
change_linear_weights_to_int4_woqtensors(mod)
119+
120+
# TODO: use this to reduce the number of tests
121+
TENSOR_SUBCLASS_APIS = [
122+
_int8wo_api,
123+
_int8da_int8w_api,
124+
_int4wo_api,
125+
]
126+
127+
96128
def combine_parameters(a, b):
97129
new_tuples = []
98130
for (tuple1, tuple2) in itertools.product(a, b):
@@ -756,14 +788,14 @@ def _test_lin_weight_subclass_api_impl(
756788
@unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen")
757789
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
758790
self._test_lin_weight_subclass_api_impl(
759-
change_linear_weights_to_int8_dqtensors, device, 35, test_dtype=dtype
791+
_int8da_int8w_api, device, 35, test_dtype=dtype
760792
)
761793

762794
@parameterized.expand(COMMON_DEVICE_DTYPE)
763795
@unittest.skipIf(is_fbcode(), "broken in fbcode")
764796
def test_int8_weight_only_quant_subclass_api(self, device, dtype):
765797
self._test_lin_weight_subclass_api_impl(
766-
change_linear_weights_to_int8_woqtensors, device, 40, test_dtype=dtype
798+
_int8wo_api, device, 40, test_dtype=dtype
767799
)
768800

769801
@parameterized.expand(COMMON_DEVICE_DTYPE)
@@ -773,7 +805,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
773805
self.skipTest(f"Fails for {dtype}")
774806
for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 256)] if device=='cuda' else [])):
775807
self._test_lin_weight_subclass_api_impl(
776-
change_linear_weights_to_int4_woqtensors,
808+
_int4wo_api,
777809
device,
778810
15,
779811
test_shape=test_shape,
@@ -789,8 +821,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
789821
for groupsize in [64, 32]:
790822
for inner_k_tiles in [4, 2]:
791823
kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles}
824+
825+
def api(mod):
826+
if TORCH_VERSION_AFTER_2_4:
827+
quantize(mod, int4wo(**kwargs))
828+
unwrap_tensor_subclass(mod)
829+
else:
830+
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
831+
792832
self._test_lin_weight_subclass_api_impl(
793-
lambda mod: change_linear_weights_to_int4_woqtensors(mod, **kwargs),
833+
api,
794834
device,
795835
15,
796836
test_shape=test_shape,
@@ -805,7 +845,7 @@ def test_dynamic_quant(self):
805845
m = nn.Sequential(nn.Linear(K, N))
806846

807847
y_ref = m(x)
808-
apply_dynamic_quant(m)
848+
quantize(m, int8da_int8w())
809849
y_test = m(x)
810850

811851
sqnr = compute_error(y_ref, y_test)
@@ -819,7 +859,7 @@ def test_weight_only_quant(self):
819859
x = torch.randn(*x_shape)
820860
m = nn.Sequential(nn.Linear(4, 5))
821861
y_ref = m(x)
822-
apply_weight_only_int8_quant(m)
862+
_int8wo_api(m)
823863
y_wo = m(x)
824864
sqnr = compute_error(y_ref, y_wo)
825865
self.assertGreater(sqnr, 44.0)
@@ -842,7 +882,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
842882
x = torch.randn(*x_shape).to(device).to(dtype)
843883
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
844884
y_ref = m(x)
845-
apply_weight_only_int8_quant(m)
885+
_int8wo_api(m)
846886
m(x)
847887
m_c = torch.compile(m, mode="max-autotune")
848888
y_wo, (code,) = run_and_get_code(m_c, x)
@@ -869,7 +909,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype):
869909
x = torch.randn(*x_shape).to(device).to(dtype)
870910
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
871911
y_ref = m(x)
872-
apply_weight_only_int8_quant(m)
912+
_int8wo_api(m)
873913
m_c = torch.compile(m, mode="max-autotune")
874914
y_wo, (code,) = run_and_get_code(m_c, x)
875915
sqnr = compute_error(y_ref, y_wo)
@@ -910,6 +950,7 @@ def forward(self, x):
910950

911951
# save quantized state_dict
912952
api(model)
953+
913954
torch.save(model.state_dict(), "test.pth")
914955
# get quantized reference
915956
model_qc = torch.compile(model, mode="max-autotune")
@@ -925,6 +966,7 @@ def forward(self, x):
925966
# load quantized state_dict
926967
state_dict = torch.load("test.pth", mmap=True)
927968
os.remove("test.pth")
969+
928970
model.load_state_dict(state_dict, assign=True)
929971
model = model.to(device=test_device, dtype=test_dtype).eval()
930972

@@ -941,21 +983,21 @@ def forward(self, x):
941983
def test_save_load_dqtensors(self, device, dtype):
942984
if device == "cpu":
943985
self.skipTest(f"indcutor failed for cpu right now")
944-
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_dqtensors, device, test_dtype=dtype)
986+
self._test_handle_save_load_meta_impl(_int8da_int8w_api, device, test_dtype=dtype)
945987

946988
@parameterized.expand(COMMON_DEVICE_DTYPE)
947989
@torch.no_grad()
948990
@unittest.skipIf(is_fbcode(), "broken in fbcode")
949991
def test_save_load_int8woqtensors(self, device, dtype):
950-
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_woqtensors, device, test_dtype=dtype)
992+
self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype)
951993

952994
@parameterized.expand(COMMON_DEVICE_DTYPE)
953995
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
954996
@torch.no_grad()
955997
def test_save_load_int4woqtensors(self, device, dtype):
956998
if dtype != torch.bfloat16:
957999
self.skipTest(f"Fails for {dtype}")
958-
self._test_handle_save_load_meta_impl(change_linear_weights_to_int4_woqtensors, device, 20, test_dtype=dtype)
1000+
self._test_handle_save_load_meta_impl(_int4wo_api, device, 20, test_dtype=dtype)
9591001

9601002

9611003
class TorchCompileUnitTest(unittest.TestCase):
@@ -1275,8 +1317,7 @@ def forward(self, x):
12751317
model = test_model().to(dtype=test_dtype, device=test_device).eval()
12761318
ref_f = model(x)
12771319

1278-
kwargs = {"dtype": test_dtype}
1279-
api(model, **kwargs)
1320+
api(model)
12801321

12811322
# running model
12821323
model(x)
@@ -1321,8 +1362,7 @@ def forward(self, x):
13211362
model = test_model().to(dtype=test_dtype, device=test_device).eval()
13221363
ref_f = model(x)
13231364

1324-
kwargs = {"dtype": test_dtype}
1325-
api(model, **kwargs)
1365+
api(model)
13261366

13271367
# running model
13281368
ref = model(x)

test/prototype/mx_formats/test_mx_linear.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def test_inference_compile_simple(elem_dtype):
189189
if elem_dtype is torch.float8_e4m3fn:
190190
assert sqnr >= 20.0
191191
else:
192-
assert sqnr >= 14.0
192+
assert sqnr >= 13.5
193193

194194

195195
def test_filter_fn():

0 commit comments

Comments
 (0)