Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move quant ops to utils.py #331

Merged
merged 1 commit into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype):
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@parametrize("shape", [(16, 16), (32, 16)])
Expand Down Expand Up @@ -430,7 +430,7 @@ def test_to_cpu(self):
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
inner_tensor = getattr(nf4_tensor, attr)
self.assertEqual(inner_tensor.device.type, "cpu")


instantiate_parametrized_tests(TestNF4Linear)
instantiate_parametrized_tests(TestFSDPOps)
Expand Down
297 changes: 8 additions & 289 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,18 @@
_replace_with_custom_fn_if_matches_filter,
)
from torchao.quantization.quant_primitives import (
safe_int_mm,
choose_qparams_affine,
quantize_affine,
dequantize_affine,
MappingType,
)
from torchao.quantization.utils import (
dequantize_per_channel,
dequantize_per_tensor,
dynamically_quantize_per_channel,
dynamically_quantize_per_tensor,
quant_int8_dynamic_linear,
quant_int8_dynamic_per_token_linear,
quantize_activation_per_token_absmax,
safe_int_mm,
dequantize_affine,
)

from torchao.quantization.smoothquant import (
Expand Down Expand Up @@ -369,167 +372,7 @@ def test_debug_x_absmax(self):
y1 = m(x0)


class PythonQuantPrimitivesUnitTest(unittest.TestCase):
def _test_dynamic_quant_per_tensor_numerics_impl(
self, qmin, qmax, int_dtype, qint_dtype, float_dtype, device, qscheme
):
x = torch.randn(256, dtype=float_dtype, device=device)
y_vals, y_scale, y_zero_point = dynamically_quantize_per_tensor(
x, qmin, qmax, int_dtype, qscheme
)

# reference
# quantize_per_tensor_dynamic doesn't work for half, so we cast there and back
x_for_ref = x.half().float() if float_dtype == torch.float16 else x

# quantize_per_tensor_dynamic doesn't support qscheme, so we just do dynamic
# quant manually with observers + static quant
obs = MinMaxObserver(
dtype=qint_dtype, qscheme=qscheme, quant_min=qmin, quant_max=qmax
).to(device)
obs(x_for_ref)
ref_scale, ref_zero_point = obs.calculate_qparams()
y_ref = torch.quantize_per_tensor(
x_for_ref, ref_scale, ref_zero_point, qint_dtype
)

# y_ref = torch.quantize_per_tensor_dynamic(x_for_ref, qint_dtype, False)
# print(y_ref)
if float_dtype == torch.float:
assert torch.equal(y_vals, y_ref.int_repr())
else:
# numerics are not exactly aligned yet, off-by-one probably due
# to rounding
assert torch.max(torch.abs(y_vals - y_ref.int_repr())).item() <= 1
torch.testing.assert_close(
y_scale, torch.tensor(y_ref.q_scale(), device=device, dtype=float_dtype)
)
if y_zero_point is not None:
assert torch.equal(
y_zero_point, torch.tensor(y_ref.q_zero_point(), device=device)
)
else:
self.assertTrue(y_ref.q_zero_point() == 0)

# dequantize and check again
x_dq = dequantize_per_tensor(y_vals, y_scale, y_zero_point, float_dtype)
y_ref_dq = y_ref.dequantize().to(float_dtype)
if float_dtype == torch.float:
torch.testing.assert_close(x_dq, y_ref_dq)
else:
sqnr = compute_error(x_dq, y_ref_dq)
self.assertTrue(sqnr.item() > 45.0)

def test_dynamic_quant_per_tensor_numerics_cpu(self):
# verifies that dynamic quant per tensor in plain pytorch matches
# numerics of production AO code
# TODO(future): test this on cpu-half, need to first make
# torch.aminmax support half on cpu
test_cases = (
(
0,
255,
torch.uint8,
torch.quint8,
torch.float32,
"cpu",
torch.per_tensor_affine,
),
(
-128,
127,
torch.int8,
torch.qint8,
torch.float32,
"cpu",
torch.per_tensor_affine,
),
(
-128,
127,
torch.int8,
torch.qint8,
torch.float32,
"cpu",
torch.per_tensor_symmetric,
),
(
-127,
127,
torch.int8,
torch.qint8,
torch.float32,
"cpu",
torch.per_tensor_symmetric,
),
)
for row in test_cases:
self._test_dynamic_quant_per_tensor_numerics_impl(*row)

@unittest.skip("test case incorrect on A10G")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_dynamic_quant_per_tensor_numerics_cuda(self):
# verifies that dynamic quant per tensor in plain pytorch matches
# numerics of production AO code
test_cases = (
(
-128,
127,
torch.int8,
torch.qint8,
torch.float32,
"cuda",
torch.per_tensor_affine,
),
(
-128,
127,
torch.int8,
torch.qint8,
torch.float16,
"cuda",
torch.per_tensor_affine,
),
(
-128,
127,
torch.int8,
torch.qint8,
torch.float32,
"cuda",
torch.per_tensor_symmetric,
),
(
-128,
127,
torch.int8,
torch.qint8,
torch.float16,
"cuda",
torch.per_tensor_symmetric,
),
(
-127,
127,
torch.int8,
torch.qint8,
torch.float32,
"cuda",
torch.per_tensor_symmetric,
),
(
-127,
127,
torch.int8,
torch.qint8,
torch.float16,
"cuda",
torch.per_tensor_symmetric,
),
)
for row in test_cases:
self._test_dynamic_quant_per_tensor_numerics_impl(*row)

class PythonQuantUtilOpUnitTest(unittest.TestCase):
def _test_dynamic_quant_per_channel_numerics_impl(
self, qmin, qmax, int_dtype, qint_dtype, float_dtype, device
):
Expand Down Expand Up @@ -705,130 +548,6 @@ def wrap_torch_int_mm(x, w):
torch.testing.assert_close(z_ref, z_eager, atol=0, rtol=0)
torch.testing.assert_close(z_ref, z_torch_compile, atol=0, rtol=0)

def _test_qlinear_per_channel_numerics(
self, x_shape, lin_shape, qmin, qmax, int_dtype, qint_dtype, float_dtype, device
):
qconfig = torch.ao.quantization.per_channel_dynamic_qconfig

x = torch.randn(*x_shape, device=device, dtype=float_dtype)

# TODO: test bias true and false
# Note: reference path only works on float because lack of aten quant primitives
# support of half, so we cast back and forth to emulate
lin_ref = (
nn.Sequential(nn.Linear(*lin_shape))
.eval()
.to(float_dtype)
.float()
.to(device)
)
y_ref = lin_ref(x.float())
weight = lin_ref[0].weight
bias = lin_ref[0].bias

qconfig_mapping = QConfigMapping().set_global(qconfig)
lin_ref_p = prepare_fx(lin_ref, qconfig_mapping, (torch.randn(1, 1),))
lin_ref_q = convert_to_reference_fx(lin_ref_p)
y_q_ref = lin_ref_q(x.float())

# scale, zp of weight (get from reference model)
w_obs = qconfig.weight()
w_obs(weight)
lin_ref_w_scale, lin_ref_w_zp = w_obs.calculate_qparams()
lin_ref_w_scale = lin_ref_w_scale.to(device).to(float_dtype)
# print('lin_ref_w', 'scale', lin_ref_w_scale, 'zp', lin_ref_w_zp)

w_vals, _s, _z = dynamically_quantize_per_channel(
getattr(lin_ref_q, "0").weight.to(float_dtype), -128, 127, torch.int8
)
w_vals = w_vals.t().contiguous()
w_vals_sums = w_vals.sum(dim=0)

# do our version of the quantized linear operator
y = quant_int8_dynamic_linear(
x,
qmin,
qmax,
int_dtype,
w_vals,
lin_ref_w_scale,
w_vals_sums,
bias,
float_dtype,
)

# print('y', y)
# print('y_q_ref', y_q_ref)
# print('y_ref', y_ref)

sqnr_ref = compute_error(y_ref, y_q_ref)
sqnr_our = compute_error(y_ref, y)
# print('sqnr_ref', sqnr_ref, 'sqnr_our', sqnr_our)
# for large shapes, sqnr can be in the high 30s for float32 and float16
self.assertTrue(sqnr_our.item() >= 37.5)

def test_qlinear_per_channel_numerics_cpu(self):
# Note: the AO codebase doesn't easily support qint8 activations,
# so the test cases below are for the quant primitives defined in
# this file only. The AO reference is using quint8 here.
test_cases = (
((2, 3), (3, 4), 0, 255, torch.uint8, torch.quint8, torch.float32, "cpu"),
((2, 3), (3, 4), -128, 127, torch.int8, torch.qint8, torch.float32, "cpu"),
)
for test_case in test_cases:
self._test_qlinear_per_channel_numerics(*test_case)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_qlinear_per_channel_numerics_cuda(self):
test_cases = (
# Note: torch._int_mm needs int8 activations, so we don't test uint8
# activations on CUDA at all
(
(32, 32),
(32, 16),
-128,
127,
torch.int8,
torch.qint8,
torch.float32,
"cuda",
),
(
(32, 32),
(32, 16),
-128,
127,
torch.int8,
torch.qint8,
torch.float16,
"cuda",
),
# a large shape from LLaMa 1.5B - currently fails for float16
(
(17, 4096),
(4096, 1536),
-128,
127,
torch.int8,
torch.qint8,
torch.float32,
"cuda",
),
(
(17, 4096),
(4096, 1536),
-128,
127,
torch.int8,
torch.qint8,
torch.float16,
"cuda",
),
)
for test_case in test_cases:
self._test_qlinear_per_channel_numerics(*test_case)


class TestSubclass(unittest.TestCase):
@run_supported_device_dtype
def _test_dequantize_impl(
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
fake_quantize_per_channel_group,
fake_quantize_per_token,
)
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.utils import get_group_qparams_symmetric
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having trouble figuring out what should be in quant_primitives vs utils

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

functions in utils are mostly helper functions that calls the quant_primitive ops with some fixed parameters, e.g.

quantize_affine can support: symmetric/asymmetric, per tensor/group/channel etc., int8/int4/int3

helper function in utils can be: int8_symmetric_per_tensor_quant that calls quantize_affine op with fixed settings

from torchao.utils import TORCH_VERSION_AFTER_2_4


Expand Down
Loading
Loading