Skip to content

Commit

Permalink
Move quant ops to utils.py
Browse files Browse the repository at this point in the history
Summary:
We had a lot of "quant primitive" ops that can be expressed with more primitive ops,
so these ops are more of a helper functions now, so we moved them to torchao.quantization.utils

we should be able to further deprecate some of the ops after we deprecate subclasses and refactor
smoothquant etc. in the future

Also moved TORCH_VERSION_AFTER_{2_2/2_3/2_4} from torchao.quantization.utils to torchao.utils

Test Plan:
python test/integration/test_integration.py
python test/quantization/test_quant_api.py
python test/quantization/test_quant_primitives.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Jun 7, 2024
1 parent d97ae74 commit 6a9a5f4
Show file tree
Hide file tree
Showing 21 changed files with 575 additions and 595 deletions.
4 changes: 2 additions & 2 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_smoketest_linear_compile(self, dtype: torch.dtype):
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@parametrize("shape", [(16, 16), (32, 16)])
Expand Down Expand Up @@ -430,7 +430,7 @@ def test_to_cpu(self):
for attr in _INNER_TENSOR_NAMES_FOR_SHARDING:
inner_tensor = getattr(nf4_tensor, attr)
self.assertEqual(inner_tensor.device.type, "cpu")


instantiate_parametrized_tests(TestNF4Linear)
instantiate_parametrized_tests(TestFSDPOps)
Expand Down
131 changes: 124 additions & 7 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,18 @@
_replace_with_custom_fn_if_matches_filter,
)
from torchao.quantization.quant_primitives import (
safe_int_mm,
choose_qparams_affine,
quantize_affine,
dequantize_affine,
MappingType,
)
from torchao.quantization.utils import (
dequantize_per_channel,
dequantize_per_tensor,
dynamically_quantize_per_channel,
dynamically_quantize_per_tensor,
quant_int8_dynamic_linear,
quant_int8_dynamic_per_token_linear,
quantize_activation_per_token_absmax,
safe_int_mm,
dequantize_affine,
)

from torchao.quantization.smoothquant import (
Expand Down Expand Up @@ -70,7 +73,7 @@
from parameterized import parameterized
import itertools
import logging
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4

logger = logging.getLogger("INFO")

Expand All @@ -90,6 +93,120 @@

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()

# deprecated quant primitive ops

def _dynamically_quantize_per_tensor(
x,
quant_min,
quant_max,
target_dtype,
qscheme=torch.per_tensor_affine, # for now, reuse existing qscheme enum
):
eps = torch.finfo(torch.float32).eps
block_size = x.shape
zero_point_dtype = torch.int32

qscheme_to_mapping_type = {
torch.per_tensor_affine: MappingType.ASYMMETRIC,
torch.per_tensor_symmetric: MappingType.SYMMETRIC,
}
assert qscheme in qscheme_to_mapping_type, f"unsupported qscheme {qscheme}"
mapping_type = qscheme_to_mapping_type[qscheme]
scale, zero_point = choose_qparams_affine(x, mapping_type, block_size, target_dtype=target_dtype, quant_min=quant_min, quant_max=quant_max, eps=eps, zero_point_dtype=zero_point_dtype)
quant = quantize_affine(x, block_size, scale, zero_point, target_dtype, quant_min, quant_max)
return quant, scale, zero_point

def _quant_int8_matmul(
x_vals_int8,
x_scale,
x_zp,
w_vals_int8_t,
w_vals_int8_t_sums_int64,
w_scales,
out_dtype=torch.float32,
):
"""
Quantized matmul of int8 operands that accumulates to int32 and returns
out_dtype. For now, this is written for approximate numerical
correctness, and things like aligning accumulation behaviors and
performance optimizations are left for a future PR.
Assumes that weight quantization is symmetric, i.e. w_zp is 0.
Assumes that weight quantization is per-channel.
see
https://github.com/google/gemmlowp/blob/master/doc/quantization.md
for an overview of quantized matmul compute
in scalar form, assuming out_dtype is fp32 and zw == 0:
Y_i_j_fp32 = sx * sw (dot(X_i, W_j) - zx * sum(W_j))
"""

assert x_vals_int8.dtype in (
torch.uint8,
torch.int8,
), f"x dtype {x_vals_int8.dtype} not yet supported"
assert (
w_vals_int8_t.dtype == torch.int8
), f"w dtype {w_vals_int8_t.dtype} not yet supported"
assert w_scales.dtype == out_dtype, f"{w_scales.dtype} does not match {out_dtype}"

#
# 1. do the matrix form of dot(X_i, W_j)
#

# TODO(before land): add test case for input with bsz
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
y_dot_int32 = safe_int_mm(tmp, w_vals_int8_t)
y_dot_int32 = y_dot_int32.reshape(*x_vals_int8.shape[:-1], -1)

# TODO(future): consider using integer arithmetic throughout, although
# TBD if that is actually faster on GPUs
# need to use 32 bits here to prevent overflow for large shapes,
# 16 bits is not enough
y_dot_float32 = y_dot_int32.to(torch.float32)

#
# 2. connect it all together
#

# mm_unscaled has to stay in float32 for the next two lines to prevent overflow
mm_unscaled_float32 = y_dot_float32 - (x_zp * w_vals_int8_t_sums_int64)
y = x_scale * w_scales * mm_unscaled_float32
# can downcast only at the very end
y = y.to(out_dtype)
return y

def _quant_int8_dynamic_linear(
x,
x_quant_min,
x_quant_max,
x_q_dtype,
w_vals_int8_t,
w_scales,
w_vals_int8_t_sums_int64,
bias,
out_dtype=torch.float32,
):
# like F.linear, but with int8 dynamic quantization of activation,
# and a quantized weight
x_vals_int8, x_scale, x_zp = _dynamically_quantize_per_tensor(
x, x_quant_min, x_quant_max, x_q_dtype
)
# w_vals_int8_t_sums_int64 = w_vals_int8_t.sum(dim=0)
mm_out = _quant_int8_matmul(
x_vals_int8,
x_scale,
x_zp,
w_vals_int8_t,
w_vals_int8_t_sums_int64,
w_scales,
out_dtype,
)
if bias is not None:
mm_out += bias
return mm_out

def combine_parameters(a, b):
new_tuples = []
for (tuple1, tuple2) in itertools.product(a, b):
Expand Down Expand Up @@ -374,7 +491,7 @@ def _test_dynamic_quant_per_tensor_numerics_impl(
self, qmin, qmax, int_dtype, qint_dtype, float_dtype, device, qscheme
):
x = torch.randn(256, dtype=float_dtype, device=device)
y_vals, y_scale, y_zero_point = dynamically_quantize_per_tensor(
y_vals, y_scale, y_zero_point = _dynamically_quantize_per_tensor(
x, qmin, qmax, int_dtype, qscheme
)

Expand Down Expand Up @@ -745,7 +862,7 @@ def _test_qlinear_per_channel_numerics(
w_vals_sums = w_vals.sum(dim=0)

# do our version of the quantized linear operator
y = quant_int8_dynamic_linear(
y = _quant_int8_dynamic_linear(
x,
qmin,
qmax,
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from torchao.quantization.utils import find_multiple
from torchao.utils import find_multiple

def prepare_inputs_for_model(inps, max_new_tokens=1):
# this is because input from lm-eval is 2d
Expand Down
4 changes: 2 additions & 2 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
fake_quantize_per_channel_group,
fake_quantize_per_token,
)
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.quantization.utils import get_group_qparams_symmetric
from torchao.utils import TORCH_VERSION_AFTER_2_4


# TODO: put this in a common test utils file
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
get_apply_int8wo_quant,
get_apply_int8dyn_quant,
)
from torchao.quantization.utils import (
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
)
Expand Down
20 changes: 10 additions & 10 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,22 @@
import unittest
import torch
from torchao.quantization.quant_primitives import (
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
groupwise_affine_quantize_tensor_from_qparams,
groupwise_affine_dequantize_tensor_from_qparams,
quantize_affine,
dequantize_affine,
choose_qparams_affine,
MappingType,
ZeroPointDomain,
)

# TODO: remove test for utils?
from torchao.quantization.utils import (
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
groupwise_affine_quantize_tensor_from_qparams,
groupwise_affine_dequantize_tensor_from_qparams,
quantize_activation_per_token_absmax,
)

from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
)
Expand Down Expand Up @@ -225,7 +230,6 @@ def test_choose_qparams_tensor_sym(self):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_activation_per_token_abs_max(self):
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
input = torch.randn(10, 10)
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)

Expand All @@ -246,15 +250,13 @@ def test_quantize_activation_per_token_abs_max(self):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_activation_per_token_abs_max_zero_input(self):
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
input = torch.zeros(10, 10)
# make sure it still works
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_quantize_activation_per_token_abs_max_dtype(self):
from torchao.quantization.quant_primitives import quantize_activation_per_token_absmax
input = torch.zeros(10, 10, dtype=torch.bfloat16)
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)
self.assertTrue(scale_ref.dtype, torch.bfloat16)
Expand Down Expand Up @@ -439,8 +441,6 @@ def test_not_preserve_zero_not_supported(self):


def test_get_groupwise_affine_qparams(self):
from torchao.quantization.quant_primitives import ZeroPointDomain

input = torch.randn(10, 256)
n_bit = 4
scale_ref, zero_point_ref = _get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)
Expand Down
4 changes: 0 additions & 4 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,11 @@
logging.info("Skipping import of cpp extensions")

from torchao.quantization import (
apply_weight_only_int8_quant,
apply_dynamic_quant,
autoquant,
)
from . import dtypes

__all__ = [
"dtypes",
"apply_dynamic_quant",
"apply_weight_only_int8_quant",
"autoquant",
]
4 changes: 3 additions & 1 deletion torchao/dtypes/aqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
dequantize_affine,
ZeroPointDomain,
MappingType,
int_scaled_matmul,
)
from torchao.quantization.utils import (
pack_tinygemm_scales_and_zeros,
)
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.kernel.intmm import int_scaled_matmul
from torchao.utils import find_multiple

aten = torch.ops.aten
Expand Down
2 changes: 1 addition & 1 deletion torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import torch

from torchao.quantization.utils import TORCH_VERSION_AFTER_2_2
from torchao.utils import TORCH_VERSION_AFTER_2_2

try:
# Only works for torch2.2 or newer.
Expand Down
4 changes: 2 additions & 2 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from typing import Any, Dict, Optional
from .unified import Quantizer

from .quant_primitives import (
from .utils import (
get_groupwise_affine_qparams,
groupwise_affine_quantize_tensor_from_qparams,
groupwise_affine_dequantize_tensor_from_qparams,
Expand Down Expand Up @@ -753,7 +753,7 @@ def quantize(self, model: torch.nn.Module, inputs: List[_MultiInput], **kwargs:
return model


from .quant_primitives import (
from .utils import (
get_group_qparams_symmetric,
group_quantize_tensor_symmetric,
per_token_dynamic_quant,
Expand Down
12 changes: 3 additions & 9 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,9 @@
"change_linear_weights_to_int4_woqtensors",
"swap_conv2d_1x1_to_linear"
"safe_int_mm",
"dynamically_quantize_per_tensor",
"quantize_activation_per_token_absmax",
"dynamically_quantize_per_channel",
"dequantize_per_tensor",
"dequantize_per_channel",
"autoquant",
"change_linears_to_autoquantizable",
"change_autoquantizable_to_quantized",
"quant_int8_dynamic_linear",
"quant_int8_matmul",
"quant_int8_dynamic_per_token_linear",
"quant_int8_per_token_matmul",
"get_scale",
"SmoothFakeDynQuantMixin",
"SmoothFakeDynamicallyQuantizedLinear",
Expand All @@ -48,4 +39,7 @@
"WeightOnlyInt8QuantLinear",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"quantize_affine",
"dequantize_affine",
"choose_qprams_affine",
]
7 changes: 5 additions & 2 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
)
from torch.utils._python_dispatch import return_and_correct_aliasing
from .quant_primitives import (
quantize_activation_per_token_absmax,
safe_int_mm,
)
from .utils import TORCH_VERSION_AFTER_2_4
from .utils import (
TORCH_VERSION_AFTER_2_4,
quantize_activation_per_token_absmax,
)

import torch.nn.functional as F
try:
from torch._inductor.utils import do_bench
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/dynamic_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn as nn

from .quant_primitives import (
from .utils import (
dynamically_quantize_per_channel,
quant_int8_dynamic_per_token_linear,
)
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/prototype/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
from torch.library import impl

from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.utils import get_group_qparams_symmetric
from torchao.quantization.unified import TwoStepQuantizer

from torchao.quantization.GPTQ import (
Expand Down
Loading

0 comments on commit 6a9a5f4

Please sign in to comment.