Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deduplicate code for some torchao q/dq ops #173

Merged
merged 2 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
quant_int8_dynamic_per_token_linear,
quantize_activation_per_token_absmax,
safe_int_mm,
dequantize_affine,
)

from torchao.quantization.smoothquant import (
Expand Down Expand Up @@ -385,11 +386,11 @@ def _test_dynamic_quant_per_tensor_numerics_impl(
# to rounding
assert torch.max(torch.abs(y_vals - y_ref.int_repr())).item() <= 1
torch.testing.assert_close(
y_scale, torch.tensor([y_ref.q_scale()], device=device, dtype=float_dtype)
y_scale, torch.tensor(y_ref.q_scale(), device=device, dtype=float_dtype)
)
if y_zero_point is not None:
assert torch.equal(
y_zero_point, torch.tensor([y_ref.q_zero_point()], device=device)
y_zero_point, torch.tensor(y_ref.q_zero_point(), device=device)
)
else:
self.assertTrue(y_ref.q_zero_point() == 0)
Expand Down Expand Up @@ -559,8 +560,8 @@ def _test_dynamic_quant_per_channel_numerics_impl(
assert torch.max(torch.abs(y_vals - y_ref.int_repr())) <= 1

# dequantize
x_dq = dequantize_per_channel(y_vals, y_scale, y_zero_point)
x_ref_dq = y_ref.dequantize()
x_dq = dequantize_per_channel(y_vals, y_scale, y_zero_point, out_dtype=float_dtype)
x_ref_dq = y_ref.dequantize().to(float_dtype)
# off-by-one for scale is okay
torch.testing.assert_close(
x_dq, x_ref_dq, atol=torch.max(y_scale).item() * 1.01, rtol=0.0001
Expand All @@ -583,7 +584,8 @@ def test_dynamic_quant_per_channel_numerics_cuda(self):
def _test_quantize_per_token_impl(self, device, dtype):
x = torch.randn(3, 3, 3, device=device, dtype=dtype)
xq, scales = quantize_activation_per_token_absmax(x)
x_dq = dequantize_per_tensor(xq, scales, None).to(x.dtype)
block_size = (1, 1, 3)
x_dq = dequantize_affine(xq, block_size, scales, None, torch.int8, output_dtype=x.dtype)
sqnr = compute_error(x, x_dq)
self.assertTrue(sqnr >= 45.0)

Expand Down Expand Up @@ -1174,7 +1176,7 @@ def forward(self, x):
model_qc = torch.compile(model, mode="max-autotune")
ref_q = model_qc(x).detach()

assert SQNR(ref_f, ref_q) > min_sqnr
assert SQNR(ref_f, ref_q) > min_sqnr, f"got sqnr: {SQNR(ref_f, ref_q)}, expected: {min_sqnr}"

# load model structure
with torch.device('meta'):
Expand All @@ -1191,7 +1193,7 @@ def forward(self, x):
model_qc = torch.compile(model, mode="max-autotune")
test = model_qc(x).detach()

assert SQNR(ref_f, test) > min_sqnr
assert SQNR(ref_f, test) > min_sqnr, f"got sqnr: {SQNR(ref_f, ref_q)}, expected: {min_sqnr}"
self.assertTrue(torch.equal(ref_q, test))

@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand Down
6 changes: 4 additions & 2 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def test_choose_qparams_group_sym(self):
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 2)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
eps = torch.finfo(torch.float32).eps
precision = torch.float32
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision)

scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2)
scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision)

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))
Expand Down
159 changes: 57 additions & 102 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
"groupwise_affine_dequantize_tensor_from_qparams",
"groupwise_affine_quantize_tensor",
"groupwise_affine_dequantize_tensor",
"choose_qparams_affine",
"quantize_affine",
"dequantize_affine",
# TODO: need to clean up above functions
] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else [])

Expand Down Expand Up @@ -219,10 +222,8 @@ def dequantize_affine(
if zero_point is not None:
zero_point = zero_point.view(shape_after_reduction)

dequant = input.to(torch.float32)
scale = scale.to(torch.float32)
dequant = input.to(output_dtype)
if zero_point is not None:
zero_point = zero_point.to(torch.float32)
dequant -= zero_point
dequant *= scale
dequant = dequant.view(original_shape)
Expand Down Expand Up @@ -260,9 +261,9 @@ def choose_qparams_affine(
"""
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
if scale_dtype is None:
scale_dtype = torch.float32
scale_dtype = input.dtype
if zero_point_dtype is None:
zero_point_dtype = torch.float32
zero_point_dtype = input.dtype

assert len(block_size) == input.dim()
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
Expand Down Expand Up @@ -301,47 +302,18 @@ def dynamically_quantize_per_tensor(
target_dtype,
qscheme=torch.per_tensor_affine, # for now, reuse existing qscheme enum
):
# assumes affine quantization

# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps

if qscheme == torch.per_tensor_affine:
# get min and max
# TODO(future): make torch.aminmax work on cpu-half
# min_val, max_val = torch.aminmax(x)
min_val = torch.min(x)
max_val = torch.max(x)

# calculate scale and zero point based on min and max
# reference: https://fburl.com/code/srbiybme
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))

scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
# TODO(future): make torch.clamp with scalar work on cpu-half
scale = torch.clamp(scale, min=eps).reshape(1)
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
zero_point = torch.clamp(zero_point, quant_min, quant_max)

# quantize based on qmin/qmax/scale/zp
# reference: torch/ao/quantization/fx/_decomposed.py?lines=63
quant = torch.clamp(
torch.round(x / scale) + zero_point, quant_min, quant_max
).to(target_dtype)

else:
assert qscheme == torch.per_tensor_symmetric, f"unsupported qscheme {qscheme}"
# assert quant_min == -1 * quant_max, "unsupported quant_min/quant_max"
amax = torch.max(torch.abs(x))
scale = amax / (float(quant_max - quant_min) / 2)
scale = torch.clamp(scale, min=eps).reshape(1)
quant = torch.clamp(torch.round(x / scale), quant_min, quant_max).to(
target_dtype
)
# do not create a tensor for zero_point as this is expensive
zero_point = None

block_size = x.shape
zero_point_dtype = torch.int32

qscheme_to_mapping_type = {
torch.per_tensor_affine: MappingType.ASYMMETRIC,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very, very nit: Hm, I wondering if MappingType is the right name... - We can definitely do this in a follow up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so MappingType means how we map from floating point to quantized values. I'm open to other suggestions as well. although we may remove this and just split the function into two in the future, so we could discuss this a little bit later (after we verified this with executorch)

torch.per_tensor_symmetric: MappingType.SYMMETRIC,
}
assert qscheme in qscheme_to_mapping_type, f"unsupported qscheme {qscheme}"
mapping_type = qscheme_to_mapping_type[qscheme]
scale, zero_point = choose_qparams_affine(x, mapping_type, block_size, target_dtype=target_dtype, quant_min=quant_min, quant_max=quant_max, eps=eps, zero_point_dtype=zero_point_dtype)
quant = quantize_affine(x, block_size, scale, zero_point, target_dtype, quant_min, quant_max)
return quant, scale, zero_point


Expand Down Expand Up @@ -374,57 +346,45 @@ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
# assumes dense memory format
# TODO(future): relax ^ as needed

# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps
assert x.dim() == 2, "only support 2d Tensors"

# get min and max
min_val, max_val = torch.aminmax(x, dim=1)

# calculate scale and zero point based on min and max
# reference: https://fburl.com/code/srbiybme
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device

# reference: https://fburl.com/code/4wll53rk
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
# ensure scale is the same dtype as the original tensor
scale = torch.clamp(scale, min=eps).to(x.dtype)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)

# quantize based on qmin/qmax/scale/zp
# reference: torch/ao/quantization/fx/_decomposed.py?lines=63
x_div = x.transpose(0, 1) / scale
x_round = torch.round(x_div)
x_zp = x_round + zero_point
x_zp = x_zp.transpose(0, 1)
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
eps = torch.finfo(torch.float32).eps
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
block_size = (1, x.shape[1])
zero_point_dtype = torch.int64

mapping_type = MappingType.SYMMETRIC
scale, zero_point = choose_qparams_affine(x, mapping_type, block_size, target_dtype=target_dtype, quant_min=quant_min, quant_max=quant_max, eps=eps, zero_point_dtype=zero_point_dtype)
quant = quantize_affine(x, block_size, scale, zero_point, target_dtype, quant_min, quant_max)
return quant, scale, zero_point


# reference: https://fburl.com/code/vfsygwd0


def dequantize_per_tensor(int_repr, scale, zero_point, out_dtype=torch.float32):
y = int_repr.to(out_dtype)
if zero_point is not None:
y -= zero_point
return y * scale
block_size = int_repr.shape
input_dtype = int_repr.dtype
assert scale.numel() == 1, f"scale size: {scale.numel()}"
dequantized = dequantize_affine(int_repr, block_size, scale, zero_point, input_dtype, output_dtype=out_dtype)
return dequantized


# reference: https://fburl.com/code/org0fmi3


def dequantize_per_channel(int_repr, scales, zero_points, out_dtype=torch.float32):
# assumes axis is 0
y = int_repr.transpose(0, 1)
y = y.to(out_dtype)
y = y - zero_points
y = y * scales
y = y.transpose(0, 1)
return y
assert int_repr.dim() == 2, "only support 2d Tensors"
# channel axis == 0
# block_size before transpose should be (1, int_repr.shape[1]) for axis == 0 per channel quant

# TODO: transpose is for perf reasons for torch.compile, we should separate this to lowering step
int_repr = int_repr.t()
# transpose for block_size as well
block_size = (int_repr.shape[0], 1)
input_dtype = int_repr.dtype
dequantized = dequantize_affine(int_repr, block_size, scales, zero_points, input_dtype, output_dtype=out_dtype)
dequantized = dequantized.t()
return dequantized


def quant_int8_dynamic_linear(
Expand Down Expand Up @@ -595,7 +555,7 @@ def quant_int8_per_token_matmul(


def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128):
""" """
"""This is tinygemm specific, we'll keep this for now"""
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
Expand Down Expand Up @@ -644,6 +604,7 @@ def groupwise_affine_quantize_tensor_from_qparams(
n_bit=4,
groupsize=128,
):
"""This is tinygemm specific, we'll keep this for now"""
assert groupsize > 1
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
Expand Down Expand Up @@ -679,6 +640,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
n_bit=4,
groupsize=128,
):
"""This is tinygemm specific, we'll keep this for now"""
assert groupsize > 1
# needed for GPTQ single column dequantize
if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1:
Expand Down Expand Up @@ -728,26 +690,19 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
assert n_bit <= 8, f"unsupported n_bit: {n_bit}"

to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0

max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))

max_val_abs = torch.max(-min_val_neg, max_val_pos)
max_int = 2 ** (n_bit - 1) - 1
min_int = -(2 ** (n_bit - 1))

scales = max_val_abs / (float(max_int - min_int) / 2)
scales = torch.max(scales, torch.full_like(scales, torch.finfo(torch.float32).eps))
# TODO: make sure abs(scales) is not too small?
zeros = torch.full_like(scales, 0)
return scales.to(precision).reshape(w.shape[0], -1), zeros.to(precision).reshape(
w.shape[0], -1
)
mapping_type = MappingType.SYMMETRIC
block_size = (1, groupsize)
eps = torch.finfo(torch.float32).eps
ranges = {}
ranges[1] = (-1, 0)
# generating ranges for bit 2 to 8
for i in range(2, 9):
ranges[i] = (-(2 ** (i - 1)), 2 ** (i - 1) - 1)
quant_min, quant_max = ranges[n_bit]
scale, zero_point = choose_qparams_affine(w, mapping_type, block_size, target_dtype=torch.int8, quant_min=quant_min, quant_max=quant_max, eps=eps, scale_dtype=precision, zero_point_dtype=precision)
return scale.reshape(w.shape[0], -1), zero_point.reshape(w.shape[0], -1)


if TORCH_VERSION_AFTER_2_3:
Expand Down Expand Up @@ -796,7 +751,7 @@ def pack_int4_from_int8(int8_data: torch.Tensor) -> torch.Tensor:

@impl(quantized_decomposed_lib, "unpack_int4_to_int8", "CompositeExplicitAutograd")
def unpack_int4_to_int8(int8_data: torch.Tensor) -> torch.Tensor:
"""Get the original weight from the normalized float weight format"""
""" Get the original weight from the normalized float weight format"""
# since we are using int8 we will decode 2 entries per byte
# Shift elements down 4 and select out the bottom 4 bits
shape = int8_data.shape
Expand Down
6 changes: 5 additions & 1 deletion torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,11 @@ def dequantize(self, dtype=None):
"""
Obtain the dequantized version of the quantized tensor subclass
"""
zero_points = torch.zeros(self.q_scales.shape, device=self.q_scales.device, dtype=self.q_scales.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised this didn't cause a regression. Seems like a big change.

# zero_points = 0
# TODO: fix dtype here? `to(self.dtype)` is not overwritten by `dtype` arg?
dq_t = dequantize_per_channel(
self.int_data.t(), self.q_scales, 0, self.dtype if dtype is None else dtype
self.int_data.t(), self.q_scales, zero_points, self.dtype if dtype is None else dtype
).to(self.dtype)
# data was transposed to dequantize so make sure shape is correct
return dq_t if not self.transposed else dq_t.t()
Expand Down Expand Up @@ -292,6 +295,7 @@ def from_float(cls, input_float, qmin=-128, qmax=127):
Int8DynamicallyQuantizedLinearWeight.from_float(model.lin_mod.weight)
)
"""
# because we call transpose in dequantization
w_int_repr, w_scales, _ = dynamically_quantize_per_channel(
input_float, qmin, qmax, torch.int8
)
Expand Down
Loading