Skip to content

Commit

Permalink
Move quant ops to utils.py
Browse files Browse the repository at this point in the history
Summary:
We had a lot of "quant primitive" ops that can be expressed with more primitive ops,
so these ops are more of a helper functions now, so we moved them to torchao.quantization.utils

we should be able to further deprecate some of the ops after we deprecate subclasses and refactor
smoothquant etc. in the future

Also moved TORCH_VERSION_AFTER_{2_2/2_3/2_4} from torchao.quantization.utils to torchao.utils

Test Plan:
python test/integration/test_integration.py
python test/quantization/test_quant_api.py
python test/quantization/test_quant_primitives.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Jun 7, 2024
1 parent d97ae74 commit 4b9ed66
Show file tree
Hide file tree
Showing 21 changed files with 460 additions and 878 deletions.
4 changes: 2 additions & 2 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype):
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@parametrize("shape", [(16, 16), (32, 16)])
Expand Down Expand Up @@ -430,7 +430,7 @@ def test_to_cpu(self):
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
inner_tensor = getattr(nf4_tensor, attr)
self.assertEqual(inner_tensor.device.type, "cpu")


instantiate_parametrized_tests(TestNF4Linear)
instantiate_parametrized_tests(TestFSDPOps)
Expand Down
299 changes: 9 additions & 290 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,18 @@
_replace_with_custom_fn_if_matches_filter,
)
from torchao.quantization.quant_primitives import (
safe_int_mm,
choose_qparams_affine,
quantize_affine,
dequantize_affine,
MappingType,
)
from torchao.quantization.utils import (
dequantize_per_channel,
dequantize_per_tensor,
dynamically_quantize_per_channel,
dynamically_quantize_per_tensor,
quant_int8_dynamic_linear,
quant_int8_dynamic_per_token_linear,
quantize_activation_per_token_absmax,
safe_int_mm,
dequantize_affine,
)

from torchao.quantization.smoothquant import (
Expand Down Expand Up @@ -70,7 +73,7 @@
from parameterized import parameterized
import itertools
import logging
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4

logger = logging.getLogger("INFO")

Expand Down Expand Up @@ -369,167 +372,7 @@ def test_debug_x_absmax(self):
y1 = m(x0)


class PythonQuantPrimitivesUnitTest(unittest.TestCase):
def _test_dynamic_quant_per_tensor_numerics_impl(
self, qmin, qmax, int_dtype, qint_dtype, float_dtype, device, qscheme
):
x = torch.randn(256, dtype=float_dtype, device=device)
y_vals, y_scale, y_zero_point = dynamically_quantize_per_tensor(
x, qmin, qmax, int_dtype, qscheme
)

# reference
# quantize_per_tensor_dynamic doesn't work for half, so we cast there and back
x_for_ref = x.half().float() if float_dtype == torch.float16 else x

# quantize_per_tensor_dynamic doesn't support qscheme, so we just do dynamic
# quant manually with observers + static quant
obs = MinMaxObserver(
dtype=qint_dtype, qscheme=qscheme, quant_min=qmin, quant_max=qmax
).to(device)
obs(x_for_ref)
ref_scale, ref_zero_point = obs.calculate_qparams()
y_ref = torch.quantize_per_tensor(
x_for_ref, ref_scale, ref_zero_point, qint_dtype
)

# y_ref = torch.quantize_per_tensor_dynamic(x_for_ref, qint_dtype, False)
# print(y_ref)
if float_dtype == torch.float:
assert torch.equal(y_vals, y_ref.int_repr())
else:
# numerics are not exactly aligned yet, off-by-one probably due
# to rounding
assert torch.max(torch.abs(y_vals - y_ref.int_repr())).item() <= 1
torch.testing.assert_close(
y_scale, torch.tensor(y_ref.q_scale(), device=device, dtype=float_dtype)
)
if y_zero_point is not None:
assert torch.equal(
y_zero_point, torch.tensor(y_ref.q_zero_point(), device=device)
)
else:
self.assertTrue(y_ref.q_zero_point() == 0)

# dequantize and check again
x_dq = dequantize_per_tensor(y_vals, y_scale, y_zero_point, float_dtype)
y_ref_dq = y_ref.dequantize().to(float_dtype)
if float_dtype == torch.float:
torch.testing.assert_close(x_dq, y_ref_dq)
else:
sqnr = compute_error(x_dq, y_ref_dq)
self.assertTrue(sqnr.item() > 45.0)

def test_dynamic_quant_per_tensor_numerics_cpu(self):
# verifies that dynamic quant per tensor in plain pytorch matches
# numerics of production AO code
# TODO(future): test this on cpu-half, need to first make
# torch.aminmax support half on cpu
test_cases = (
(
0,
255,
torch.uint8,
torch.quint8,
torch.float32,
"cpu",
torch.per_tensor_affine,
),
(
-128,
127,
torch.int8,
torch.qint8,
torch.float32,
"cpu",
torch.per_tensor_affine,
),
(
-128,
127,
torch.int8,
torch.qint8,
torch.float32,
"cpu",
torch.per_tensor_symmetric,
),
(
-127,
127,
torch.int8,
torch.qint8,
torch.float32,
"cpu",
torch.per_tensor_symmetric,
),
)
for row in test_cases:
self._test_dynamic_quant_per_tensor_numerics_impl(*row)

@unittest.skip("test case incorrect on A10G")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_dynamic_quant_per_tensor_numerics_cuda(self):
# verifies that dynamic quant per tensor in plain pytorch matches
# numerics of production AO code
test_cases = (
(
-128,
127,
torch.int8,
torch.qint8,
torch.float32,
"cuda",
torch.per_tensor_affine,
),
(
-128,
127,
torch.int8,
torch.qint8,
torch.float16,
"cuda",
torch.per_tensor_affine,
),
(
-128,
127,
torch.int8,
torch.qint8,
torch.float32,
"cuda",
torch.per_tensor_symmetric,
),
(
-128,
127,
torch.int8,
torch.qint8,
torch.float16,
"cuda",
torch.per_tensor_symmetric,
),
(
-127,
127,
torch.int8,
torch.qint8,
torch.float32,
"cuda",
torch.per_tensor_symmetric,
),
(
-127,
127,
torch.int8,
torch.qint8,
torch.float16,
"cuda",
torch.per_tensor_symmetric,
),
)
for row in test_cases:
self._test_dynamic_quant_per_tensor_numerics_impl(*row)

class PythonQuantUtilOpUnitTest(unittest.TestCase):
def _test_dynamic_quant_per_channel_numerics_impl(
self, qmin, qmax, int_dtype, qint_dtype, float_dtype, device
):
Expand Down Expand Up @@ -705,130 +548,6 @@ def wrap_torch_int_mm(x, w):
torch.testing.assert_close(z_ref, z_eager, atol=0, rtol=0)
torch.testing.assert_close(z_ref, z_torch_compile, atol=0, rtol=0)

def _test_qlinear_per_channel_numerics(
self, x_shape, lin_shape, qmin, qmax, int_dtype, qint_dtype, float_dtype, device
):
qconfig = torch.ao.quantization.per_channel_dynamic_qconfig

x = torch.randn(*x_shape, device=device, dtype=float_dtype)

# TODO: test bias true and false
# Note: reference path only works on float because lack of aten quant primitives
# support of half, so we cast back and forth to emulate
lin_ref = (
nn.Sequential(nn.Linear(*lin_shape))
.eval()
.to(float_dtype)
.float()
.to(device)
)
y_ref = lin_ref(x.float())
weight = lin_ref[0].weight
bias = lin_ref[0].bias

qconfig_mapping = QConfigMapping().set_global(qconfig)
lin_ref_p = prepare_fx(lin_ref, qconfig_mapping, (torch.randn(1, 1),))
lin_ref_q = convert_to_reference_fx(lin_ref_p)
y_q_ref = lin_ref_q(x.float())

# scale, zp of weight (get from reference model)
w_obs = qconfig.weight()
w_obs(weight)
lin_ref_w_scale, lin_ref_w_zp = w_obs.calculate_qparams()
lin_ref_w_scale = lin_ref_w_scale.to(device).to(float_dtype)
# print('lin_ref_w', 'scale', lin_ref_w_scale, 'zp', lin_ref_w_zp)

w_vals, _s, _z = dynamically_quantize_per_channel(
getattr(lin_ref_q, "0").weight.to(float_dtype), -128, 127, torch.int8
)
w_vals = w_vals.t().contiguous()
w_vals_sums = w_vals.sum(dim=0)

# do our version of the quantized linear operator
y = quant_int8_dynamic_linear(
x,
qmin,
qmax,
int_dtype,
w_vals,
lin_ref_w_scale,
w_vals_sums,
bias,
float_dtype,
)

# print('y', y)
# print('y_q_ref', y_q_ref)
# print('y_ref', y_ref)

sqnr_ref = compute_error(y_ref, y_q_ref)
sqnr_our = compute_error(y_ref, y)
# print('sqnr_ref', sqnr_ref, 'sqnr_our', sqnr_our)
# for large shapes, sqnr can be in the high 30s for float32 and float16
self.assertTrue(sqnr_our.item() >= 37.5)

def test_qlinear_per_channel_numerics_cpu(self):
# Note: the AO codebase doesn't easily support qint8 activations,
# so the test cases below are for the quant primitives defined in
# this file only. The AO reference is using quint8 here.
test_cases = (
((2, 3), (3, 4), 0, 255, torch.uint8, torch.quint8, torch.float32, "cpu"),
((2, 3), (3, 4), -128, 127, torch.int8, torch.qint8, torch.float32, "cpu"),
)
for test_case in test_cases:
self._test_qlinear_per_channel_numerics(*test_case)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_qlinear_per_channel_numerics_cuda(self):
test_cases = (
# Note: torch._int_mm needs int8 activations, so we don't test uint8
# activations on CUDA at all
(
(32, 32),
(32, 16),
-128,
127,
torch.int8,
torch.qint8,
torch.float32,
"cuda",
),
(
(32, 32),
(32, 16),
-128,
127,
torch.int8,
torch.qint8,
torch.float16,
"cuda",
),
# a large shape from LLaMa 1.5B - currently fails for float16
(
(17, 4096),
(4096, 1536),
-128,
127,
torch.int8,
torch.qint8,
torch.float32,
"cuda",
),
(
(17, 4096),
(4096, 1536),
-128,
127,
torch.int8,
torch.qint8,
torch.float16,
"cuda",
),
)
for test_case in test_cases:
self._test_qlinear_per_channel_numerics(*test_case)


class TestSubclass(unittest.TestCase):
@run_supported_device_dtype
def _test_dequantize_impl(
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from torchao.quantization.utils import find_multiple
from torchao.utils import find_multiple

def prepare_inputs_for_model(inps, max_new_tokens=1):
# this is because input from lm-eval is 2d
Expand Down
4 changes: 2 additions & 2 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
fake_quantize_per_channel_group,
fake_quantize_per_token,
)
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.quantization.utils import get_group_qparams_symmetric
from torchao.utils import TORCH_VERSION_AFTER_2_4


# TODO: put this in a common test utils file
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
get_apply_int8wo_quant,
get_apply_int8dyn_quant,
)
from torchao.quantization.utils import (
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
)
Expand Down
Loading

0 comments on commit 4b9ed66

Please sign in to comment.