Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into fp6_quant
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst authored May 16, 2024
2 parents 8bf081c + cda787c commit 558f4e4
Show file tree
Hide file tree
Showing 7 changed files with 368 additions and 94 deletions.
2 changes: 1 addition & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
m_c = torch.compile(m, mode="max-autotune")
y_wo, (code,) = run_and_get_code(m_c, x)
sqnr = compute_error(y_ref, y_wo)
self.assertGreater(sqnr, 43.0)
self.assertGreaterEqual(sqnr, 42.75)
if device == "cuda":
self.assertTrue("mixed_mm" in code)

Expand Down
47 changes: 35 additions & 12 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torchao.quantization.prototype.qat import (
_choose_qparams_per_token_asymmetric,
_GenericFakeQuantize,
fake_quantize_per_channel_group,
fake_quantize_per_token,
)
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4


# TODO: put this in a common test utils file
Expand Down Expand Up @@ -58,7 +59,7 @@ def _get_qmin_qmax(self, n_bit: int):
qmax = 2 ** (n_bit - 1) - 1
return (qmin, qmax)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantize_per_channel_group(self):
n_bit = 4
(qmin, qmax) = self._get_qmin_qmax(n_bit)
Expand All @@ -67,6 +68,7 @@ def test_fake_quantize_per_channel_group(self):
torch.manual_seed(self.SEED)
x = torch.randn(100, 256).requires_grad_()
(s, zp) = get_group_qparams_symmetric(x, n_bit, group_size)
zp = zp.to(torch.int32)
x2 = copy.deepcopy(x)

# fake quant op
Expand All @@ -84,18 +86,15 @@ def test_fake_quantize_per_channel_group(self):
)
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantize_per_token(self):
(qmin, qmax) = self._get_qmin_qmax(8)

torch.manual_seed(self.SEED)
x = torch.randn(100, 256).requires_grad_()
x2 = copy.deepcopy(x)
# TODO: use torch.ops.aten.quantized_decomposed version instead
(s, zp) = _choose_qparams_per_token_asymmetric(
x,
torch.int8, # not used
)
(s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32)

# fake quant op
out = fake_quantize_per_token(x, s, zp, qmin, qmax)
Expand Down Expand Up @@ -130,7 +129,7 @@ def _set_ptq_weight(
ptq_linear.scales = s
ptq_linear.zeros = zp

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_linear(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
Expand All @@ -155,7 +154,7 @@ def test_qat_8da4w_linear(self):
ptq_out = ptq_linear(x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
Expand Down Expand Up @@ -189,7 +188,7 @@ def test_qat_8da4w_quantizer(self):
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer

Expand All @@ -201,7 +200,7 @@ def test_qat_8da4w_quantizer_meta_weights(self):
qat_model = qat_quantizer.prepare(m)
self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values()))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
Expand Down Expand Up @@ -254,7 +253,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
qat_out2 = qat_model2(*x2)
torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
Expand Down Expand Up @@ -299,6 +298,30 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_generic_fake_quantize(self):
"""
Test that the generic fake quantize used in 8da4w QAT matches
the numerics of existing fake quantize ops in Pytorch in both
the forward and the backward passes.
"""
(qmin, qmax) = self._get_qmin_qmax(4)
py_input = torch.randn(16, 64).float().requires_grad_()
py_s = torch.randn(16).float()
py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32)
py_out = torch.fake_quantize_per_channel_affine(py_input, py_s, py_zp, 0, qmin, qmax)
py_out.sum().backward()

ao_input = copy.deepcopy(py_input)
ao_input.grad.data.zero_()
ao_s = copy.deepcopy(py_s).reshape(-1, 1)
ao_zp = copy.deepcopy(py_zp).reshape(-1, 1)
ao_out = _GenericFakeQuantize.apply(ao_input, ao_s, ao_zp, qmin, qmax)
ao_out.sum().backward()

torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)
torch.testing.assert_close(py_input.grad, ao_input.grad, atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
86 changes: 73 additions & 13 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,10 @@ def test_eval_wrapper(self):
# TODO: move to a separate test file
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
def test_quantized_tensor_subclass_8da4w(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.subclass import (
AffineQuantizedTensor,
LinearActQuantizedTensor,
)
from torchao.quantization.quant_primitives import MappingType
import copy

Expand All @@ -409,6 +412,7 @@ def test_quantized_tensor_subclass_8da4w(self):
quant_max = 7

# TODO: make a general helper function?
# input settings
def get_per_token_block_size(x):
block_size = []
for i in range(len(x.shape)-1):
Expand All @@ -421,13 +425,18 @@ def get_per_token_block_size(x):
input_target_dtype = torch.int8
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

def dynamic_quant(linear):
# note: order is important
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)

m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
m.linear1.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear1.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear2.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
dynamic_quant(m.linear1)
dynamic_quant(m.linear2)
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)

# reference
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
Expand Down Expand Up @@ -461,9 +470,6 @@ def test_quantized_tensor_subclass_int4(self):
preserve_zero = False
zero_point_dtype = torch.bfloat16

# weight only quantization
input_quant_func = None

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
Expand All @@ -475,7 +481,6 @@ def to_quantized(weight):
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=ZeroPointDomain.FLOAT,
input_quant_func=input_quant_func,
)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
Expand Down Expand Up @@ -506,16 +511,13 @@ def test_quantized_tensor_subclass_int8(self):
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# weight only quantization
input_quant_func = None

m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

def to_quantized(weight):
block_size = (1, weight.shape[1])
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, input_quant_func=input_quant_func)
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
Expand All @@ -532,5 +534,63 @@ def to_quantized(weight):
torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8_dyn_quant(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.subclass import LinearActQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import ZeroPointDomain
import copy

# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float)

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def dynamic_quant(linear):
# note: order is important
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)

dynamic_quant(m.linear1)
dynamic_quant(m.linear2)
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
change_linear_weights_to_int8_dqtensors(m_copy)

res = m(*example_inputs)
ref = m_copy(*example_inputs)

self.assertTrue(torch.equal(res, ref))


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_quantize_activation_per_token_abs_max_zero_input(self):
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_quantize_dequantize_group_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
Expand Down
Loading

0 comments on commit 558f4e4

Please sign in to comment.