Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate top level quantization APIs #344

Merged
merged 1 commit into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/dtypes/test_aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
TestCase,
run_tests,
)
from torchao.quantization.quant_api import get_apply_int4wo_quant
from torchao.quantization.quant_api import int4wo
import torch
import unittest

Expand All @@ -12,7 +12,7 @@ class TestAQ(TestCase):
def test_tensor_core_layout_transpose(self):
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
shape = t.shape
apply_int4wo_quant = get_apply_int4wo_quant(groupsize=32)
apply_int4wo_quant = int4wo(groupsize=32)
aqt = apply_int4wo_quant(t)
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)
Expand Down
92 changes: 66 additions & 26 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@
DynamicallyPerAxisQuantizedLinear,
)
from torchao.quantization.quant_api import (
apply_dynamic_quant,
apply_weight_only_int8_quant,
int4wo,
int8wo,
int8da_int8w,
quantize,
_replace_with_custom_fn_if_matches_filter,
)
# APIs to be deprecated (used for torch 2.2.2 and 2.3)
from torchao.quantization.quant_api import (
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
_replace_with_custom_fn_if_matches_filter,
)
from torchao.quantization.quant_primitives import (
safe_int_mm,
Expand Down Expand Up @@ -73,26 +78,53 @@
from parameterized import parameterized
import itertools
import logging
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4, is_fbcode
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
unwrap_tensor_subclass,
is_fbcode,
)

logger = logging.getLogger("INFO")

torch.manual_seed(0)
config.cache_size_limit = 100

# TODO: use this to reduce the number of tests
TENSOR_SUBCLASS_APIS = [
change_linear_weights_to_int8_dqtensors,
change_linear_weights_to_int8_woqtensors,
change_linear_weights_to_int4_woqtensors,
]

COMMON_DEVICES = ["cpu", "cuda"]

COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()

def _int8wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8wo())
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_woqtensors(mod)

def _int8da_int8w_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int8da_int8w())
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_dqtensors(mod)

def _int4wo_api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int4wo())
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod)

# TODO: use this to reduce the number of tests
TENSOR_SUBCLASS_APIS = [
_int8wo_api,
_int8da_int8w_api,
_int4wo_api,
]


def combine_parameters(a, b):
new_tuples = []
for (tuple1, tuple2) in itertools.product(a, b):
Expand Down Expand Up @@ -756,14 +788,14 @@ def _test_lin_weight_subclass_api_impl(
@unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen")
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
change_linear_weights_to_int8_dqtensors, device, 35, test_dtype=dtype
_int8da_int8w_api, device, 35, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_int8_weight_only_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
change_linear_weights_to_int8_woqtensors, device, 40, test_dtype=dtype
_int8wo_api, device, 40, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand All @@ -773,7 +805,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
self.skipTest(f"Fails for {dtype}")
for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 256)] if device=='cuda' else [])):
self._test_lin_weight_subclass_api_impl(
change_linear_weights_to_int4_woqtensors,
_int4wo_api,
device,
15,
test_shape=test_shape,
Expand All @@ -789,8 +821,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
for groupsize in [64, 32]:
for inner_k_tiles in [4, 2]:
kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles}

def api(mod):
if TORCH_VERSION_AFTER_2_4:
quantize(mod, int4wo(**kwargs))
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod, **kwargs)

self._test_lin_weight_subclass_api_impl(
lambda mod: change_linear_weights_to_int4_woqtensors(mod, **kwargs),
api,
device,
15,
test_shape=test_shape,
Expand All @@ -805,7 +845,7 @@ def test_dynamic_quant(self):
m = nn.Sequential(nn.Linear(K, N))

y_ref = m(x)
apply_dynamic_quant(m)
quantize(m, int8da_int8w())
y_test = m(x)

sqnr = compute_error(y_ref, y_test)
Expand All @@ -819,7 +859,7 @@ def test_weight_only_quant(self):
x = torch.randn(*x_shape)
m = nn.Sequential(nn.Linear(4, 5))
y_ref = m(x)
apply_weight_only_int8_quant(m)
_int8wo_api(m)
y_wo = m(x)
sqnr = compute_error(y_ref, y_wo)
self.assertGreater(sqnr, 44.0)
Expand All @@ -842,7 +882,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
x = torch.randn(*x_shape).to(device).to(dtype)
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
y_ref = m(x)
apply_weight_only_int8_quant(m)
_int8wo_api(m)
m(x)
m_c = torch.compile(m, mode="max-autotune")
y_wo, (code,) = run_and_get_code(m_c, x)
Expand All @@ -869,7 +909,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype):
x = torch.randn(*x_shape).to(device).to(dtype)
m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype)
y_ref = m(x)
apply_weight_only_int8_quant(m)
_int8wo_api(m)
m_c = torch.compile(m, mode="max-autotune")
y_wo, (code,) = run_and_get_code(m_c, x)
sqnr = compute_error(y_ref, y_wo)
Expand Down Expand Up @@ -910,6 +950,7 @@ def forward(self, x):

# save quantized state_dict
api(model)

torch.save(model.state_dict(), "test.pth")
# get quantized reference
model_qc = torch.compile(model, mode="max-autotune")
Expand All @@ -925,6 +966,7 @@ def forward(self, x):
# load quantized state_dict
state_dict = torch.load("test.pth", mmap=True)
os.remove("test.pth")

model.load_state_dict(state_dict, assign=True)
model = model.to(device=test_device, dtype=test_dtype).eval()

Expand All @@ -941,21 +983,21 @@ def forward(self, x):
def test_save_load_dqtensors(self, device, dtype):
if device == "cpu":
self.skipTest(f"indcutor failed for cpu right now")
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_dqtensors, device, test_dtype=dtype)
self._test_handle_save_load_meta_impl(_int8da_int8w_api, device, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@torch.no_grad()
@unittest.skipIf(is_fbcode(), "broken in fbcode")
def test_save_load_int8woqtensors(self, device, dtype):
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_woqtensors, device, test_dtype=dtype)
self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@torch.no_grad()
def test_save_load_int4woqtensors(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
self._test_handle_save_load_meta_impl(change_linear_weights_to_int4_woqtensors, device, 20, test_dtype=dtype)
self._test_handle_save_load_meta_impl(_int4wo_api, device, 20, test_dtype=dtype)


class TorchCompileUnitTest(unittest.TestCase):
Expand Down Expand Up @@ -1275,8 +1317,7 @@ def forward(self, x):
model = test_model().to(dtype=test_dtype, device=test_device).eval()
ref_f = model(x)

kwargs = {"dtype": test_dtype}
api(model, **kwargs)
api(model)

# running model
model(x)
Expand Down Expand Up @@ -1321,8 +1362,7 @@ def forward(self, x):
model = test_model().to(dtype=test_dtype, device=test_device).eval()
ref_f = model(x)

kwargs = {"dtype": test_dtype}
api(model, **kwargs)
api(model)

# running model
ref = model(x)
Expand Down
2 changes: 1 addition & 1 deletion test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_inference_compile_simple(elem_dtype):
if elem_dtype is torch.float8_e4m3fn:
assert sqnr >= 20.0
else:
assert sqnr >= 14.0
assert sqnr >= 13.5


def test_filter_fn():
Expand Down
Loading
Loading