From 1648e6940aecc93464448b0a2c9273de80384205 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 11 Jun 2024 16:09:50 -0700 Subject: [PATCH] Deprecate top level quantization APIs Summary: This PR deprecates a few quantization APIs and here are the bc-breaking notes: 1. int8 weight only quantization int8 weight only quant module swap API ``` apply_weight_only_int8_quant(model) ``` and int8 weight only tensor subclass API ``` change_linear_weights_to_int8_woqtensors(model) ``` --> unified tensor subclass API ``` quantize(model, get_apply_int8wo_quant())) ``` 2. int8 dynamic quantization ``` apply_dynamic_quant(model) ``` or ``` change_linear_weights_to_int8_dqtensors(model) ``` --> unified tensor subclass API ``` quantize(model, get_apply_int8dyn_quant())) ``` 3. int4 weight only quantization ``` change_linear_weights_to_int4_wotensors(model) ``` --> unified tensor subclass API ``` quantize(model, get_apply_int4wo_quant())) ``` Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_aq.py | 4 +- test/integration/test_integration.py | 92 +++++-- test/prototype/mx_formats/test_mx_linear.py | 2 +- test/quantization/test_quant_api.py | 64 +++-- torchao/dtypes/__init__.py | 4 +- torchao/dtypes/aqt.py | 6 +- torchao/quantization/README.md | 257 +++++++++--------- torchao/quantization/quant_api.py | 275 +++++++++++--------- torchao/quantization/subclass.py | 4 +- torchao/utils.py | 5 + tutorials/quantize_vit/run_vit_b_quant.py | 12 +- 11 files changed, 399 insertions(+), 326 deletions(-) diff --git a/test/dtypes/test_aq.py b/test/dtypes/test_aq.py index 6967e5f310..f4211c8921 100644 --- a/test/dtypes/test_aq.py +++ b/test/dtypes/test_aq.py @@ -2,7 +2,7 @@ TestCase, run_tests, ) -from torchao.quantization.quant_api import get_apply_int4wo_quant +from torchao.quantization.quant_api import int4wo import torch import unittest @@ -12,7 +12,7 @@ class TestAQ(TestCase): def test_tensor_core_layout_transpose(self): t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda") shape = t.shape - apply_int4wo_quant = get_apply_int4wo_quant(groupsize=32) + apply_int4wo_quant = int4wo(groupsize=32) aqt = apply_int4wo_quant(t) aqt_shape = aqt.shape self.assertEqual(aqt_shape, shape) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index c3e483fe5b..b853b0589d 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -20,12 +20,17 @@ DynamicallyPerAxisQuantizedLinear, ) from torchao.quantization.quant_api import ( - apply_dynamic_quant, - apply_weight_only_int8_quant, + int4wo, + int8wo, + int8da_int8w, + quantize, + _replace_with_custom_fn_if_matches_filter, +) +# APIs to be deprecated (used for torch 2.2.2 and 2.3) +from torchao.quantization.quant_api import ( change_linear_weights_to_int8_dqtensors, change_linear_weights_to_int8_woqtensors, change_linear_weights_to_int4_woqtensors, - _replace_with_custom_fn_if_matches_filter, ) from torchao.quantization.quant_primitives import ( safe_int_mm, @@ -73,26 +78,53 @@ from parameterized import parameterized import itertools import logging -from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4, is_fbcode +from torchao.utils import ( + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_4, + unwrap_tensor_subclass, + is_fbcode, +) logger = logging.getLogger("INFO") torch.manual_seed(0) config.cache_size_limit = 100 -# TODO: use this to reduce the number of tests -TENSOR_SUBCLASS_APIS = [ - change_linear_weights_to_int8_dqtensors, - change_linear_weights_to_int8_woqtensors, - change_linear_weights_to_int4_woqtensors, -] - COMMON_DEVICES = ["cpu", "cuda"] COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() +def _int8wo_api(mod): + if TORCH_VERSION_AFTER_2_4: + quantize(mod, int8wo()) + unwrap_tensor_subclass(mod) + else: + change_linear_weights_to_int8_woqtensors(mod) + +def _int8da_int8w_api(mod): + if TORCH_VERSION_AFTER_2_4: + quantize(mod, int8da_int8w()) + unwrap_tensor_subclass(mod) + else: + change_linear_weights_to_int8_dqtensors(mod) + +def _int4wo_api(mod): + if TORCH_VERSION_AFTER_2_4: + quantize(mod, int4wo()) + unwrap_tensor_subclass(mod) + else: + change_linear_weights_to_int4_woqtensors(mod) + +# TODO: use this to reduce the number of tests +TENSOR_SUBCLASS_APIS = [ + _int8wo_api, + _int8da_int8w_api, + _int4wo_api, +] + + def combine_parameters(a, b): new_tuples = [] for (tuple1, tuple2) in itertools.product(a, b): @@ -756,14 +788,14 @@ def _test_lin_weight_subclass_api_impl( @unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen") def test_int8_dynamic_quant_subclass_api(self, device, dtype): self._test_lin_weight_subclass_api_impl( - change_linear_weights_to_int8_dqtensors, device, 35, test_dtype=dtype + _int8da_int8w_api, device, 35, test_dtype=dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_int8_weight_only_quant_subclass_api(self, device, dtype): self._test_lin_weight_subclass_api_impl( - change_linear_weights_to_int8_woqtensors, device, 40, test_dtype=dtype + _int8wo_api, device, 40, test_dtype=dtype ) @parameterized.expand(COMMON_DEVICE_DTYPE) @@ -773,7 +805,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype): self.skipTest(f"Fails for {dtype}") for test_shape in ([(16, 1024, 16)] + ([(1, 1024, 256)] if device=='cuda' else [])): self._test_lin_weight_subclass_api_impl( - change_linear_weights_to_int4_woqtensors, + _int4wo_api, device, 15, test_shape=test_shape, @@ -789,8 +821,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): for groupsize in [64, 32]: for inner_k_tiles in [4, 2]: kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles} + + def api(mod): + if TORCH_VERSION_AFTER_2_4: + quantize(mod, int4wo(**kwargs)) + unwrap_tensor_subclass(mod) + else: + change_linear_weights_to_int4_woqtensors(mod, **kwargs) + self._test_lin_weight_subclass_api_impl( - lambda mod: change_linear_weights_to_int4_woqtensors(mod, **kwargs), + api, device, 15, test_shape=test_shape, @@ -805,7 +845,7 @@ def test_dynamic_quant(self): m = nn.Sequential(nn.Linear(K, N)) y_ref = m(x) - apply_dynamic_quant(m) + quantize(m, int8da_int8w()) y_test = m(x) sqnr = compute_error(y_ref, y_test) @@ -819,7 +859,7 @@ def test_weight_only_quant(self): x = torch.randn(*x_shape) m = nn.Sequential(nn.Linear(4, 5)) y_ref = m(x) - apply_weight_only_int8_quant(m) + _int8wo_api(m) y_wo = m(x) sqnr = compute_error(y_ref, y_wo) self.assertGreater(sqnr, 44.0) @@ -842,7 +882,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): x = torch.randn(*x_shape).to(device).to(dtype) m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype) y_ref = m(x) - apply_weight_only_int8_quant(m) + _int8wo_api(m) m(x) m_c = torch.compile(m, mode="max-autotune") y_wo, (code,) = run_and_get_code(m_c, x) @@ -869,7 +909,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype): x = torch.randn(*x_shape).to(device).to(dtype) m = nn.Sequential(nn.Linear(4, 5)).to(device).to(dtype) y_ref = m(x) - apply_weight_only_int8_quant(m) + _int8wo_api(m) m_c = torch.compile(m, mode="max-autotune") y_wo, (code,) = run_and_get_code(m_c, x) sqnr = compute_error(y_ref, y_wo) @@ -910,6 +950,7 @@ def forward(self, x): # save quantized state_dict api(model) + torch.save(model.state_dict(), "test.pth") # get quantized reference model_qc = torch.compile(model, mode="max-autotune") @@ -925,6 +966,7 @@ def forward(self, x): # load quantized state_dict state_dict = torch.load("test.pth", mmap=True) os.remove("test.pth") + model.load_state_dict(state_dict, assign=True) model = model.to(device=test_device, dtype=test_dtype).eval() @@ -941,13 +983,13 @@ def forward(self, x): def test_save_load_dqtensors(self, device, dtype): if device == "cpu": self.skipTest(f"indcutor failed for cpu right now") - self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_dqtensors, device, test_dtype=dtype) + self._test_handle_save_load_meta_impl(_int8da_int8w_api, device, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) @torch.no_grad() @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_save_load_int8woqtensors(self, device, dtype): - self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_woqtensors, device, test_dtype=dtype) + self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.") @@ -955,7 +997,7 @@ def test_save_load_int8woqtensors(self, device, dtype): def test_save_load_int4woqtensors(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") - self._test_handle_save_load_meta_impl(change_linear_weights_to_int4_woqtensors, device, 20, test_dtype=dtype) + self._test_handle_save_load_meta_impl(_int4wo_api, device, 20, test_dtype=dtype) class TorchCompileUnitTest(unittest.TestCase): @@ -1275,8 +1317,7 @@ def forward(self, x): model = test_model().to(dtype=test_dtype, device=test_device).eval() ref_f = model(x) - kwargs = {"dtype": test_dtype} - api(model, **kwargs) + api(model) # running model model(x) @@ -1321,8 +1362,7 @@ def forward(self, x): model = test_model().to(dtype=test_dtype, device=test_device).eval() ref_f = model(x) - kwargs = {"dtype": test_dtype} - api(model, **kwargs) + api(model) # running model ref = model(x) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index c453b0fe38..05e8a0f32e 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -189,7 +189,7 @@ def test_inference_compile_simple(elem_dtype): if elem_dtype is torch.float8_e4m3fn: assert sqnr >= 20.0 else: - assert sqnr >= 14.0 + assert sqnr >= 13.5 def test_filter_fn(): diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e0ead9ad12..e71f67767d 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -20,7 +20,6 @@ import torchao from torchao.dtypes import ( - to_aq, AffineQuantizedTensor, ) from torchao.quantization.quant_primitives import ( @@ -28,22 +27,19 @@ ZeroPointDomain, ) from torchao.quantization.subclass import ( - to_laq, LinearActQuantizedTensor, Int8WeightOnlyQuantizedLinearWeight, Int4WeightOnlyQuantizedLinearWeight, ) from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, - apply_dynamic_quant, - apply_weight_only_int8_quant, Quantizer, TwoStepQuantizer, quantize, - get_apply_8da4w_quant, - get_apply_int4wo_quant, - get_apply_int8wo_quant, - get_apply_int8dyn_quant, + int8da_int4w, + int4wo, + int8wo, + int8da_int8w, ) from torchao.utils import ( TORCH_VERSION_AFTER_2_3, @@ -52,7 +48,9 @@ from pathlib import Path from torchao._models.llama.tokenizer import get_tokenizer from torchao._models.llama.model import Transformer, prepare_inputs_for_model +from torchao.utils import unwrap_tensor_subclass import copy +import tempfile def dynamic_quant(model, example_inputs): @@ -62,20 +60,6 @@ def dynamic_quant(model, example_inputs): m = convert_pt2e(m) return m -def _apply_dynamic_quant(model): - """ - Applies dynamic symmetric per-token activation and per-channel weight - quantization to all linear layers in the given model using - module swaps. - """ - _replace_with_custom_fn_if_matches_filter( - model, - lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features),)), - lambda mod, fqn: isinstance(mod, torch.nn.Linear), - ) - return model - - def capture_and_prepare(model, example_inputs): m = torch.export.export(model, example_inputs) quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) @@ -104,7 +88,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module: class TorchCompileDynamicQuantizer(Quantizer): def quantize(self, model: torch.nn.Module) -> torch.nn.Module: - apply_dynamic_quant(model) + quantize(model, int8da_int8w()) return model class ToyLinearModel(torch.nn.Module): @@ -167,7 +151,7 @@ class TestQuantFlow(unittest.TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() example_inputs = m.example_inputs() - m = _apply_dynamic_quant(m) + m = quantize(m, int8da_int8w()) quantized = m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) @@ -203,18 +187,28 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): torch.testing.assert_close(quantized, compiled, atol=0, rtol=0) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "only works for torch 2.4+") def test_int8_wo_quant_save_load(self): + from torchao.quantization.quant_api import ( + change_linear_weights_to_int8_woqtensors, + ) m = ToyLinearModel().eval().cpu() - apply_weight_only_int8_quant(m) + def api(model): + model = quantize(model, int8wo()) + unwrap_tensor_subclass(model) + + api(m) + example_inputs = m.example_inputs() ref = m(*example_inputs) - _TMP_FN = "_test.pt" - torch.save(m.state_dict(), _TMP_FN) + with tempfile.NamedTemporaryFile() as f: + torch.save(m.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + + m2 = ToyLinearModel().eval().cpu() + api(m2) - state_dict = torch.load(_TMP_FN) - os.remove(_TMP_FN) - m2 = ToyLinearModel().eval() - apply_weight_only_int8_quant(m2) m2.load_state_dict(state_dict) m2 = m2.to(device="cuda") example_inputs = map(lambda x: x.cuda(), example_inputs) @@ -508,7 +502,7 @@ def test_quantized_tensor_subclass_8da4w(self): m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - m = quantize(m, get_apply_8da4w_quant(groupsize=groupsize)) + m = quantize(m, int8da_int4w(groupsize=groupsize)) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) @@ -537,7 +531,7 @@ def test_quantized_tensor_subclass_int4(self): example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") groupsize = 32 - m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize)) + m = quantize(m, int4wo(groupsize=groupsize)) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -557,7 +551,7 @@ def test_quantized_tensor_subclass_int8_wo(self): m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - m = quantize(m, get_apply_int8wo_quant()) + m = quantize(m, int8wo()) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -580,7 +574,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") - m = quantize(m, get_apply_int8dyn_quant()) + m = quantize(m, int8da_int8w()) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index dccd22f3d4..87d129bc1b 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,11 +1,11 @@ from .nf4tensor import NF4Tensor, to_nf4 from .uint4 import UInt4Tensor -from .aqt import AffineQuantizedTensor, to_aq +from .aqt import AffineQuantizedTensor, to_affine_quantized __all__ = [ "NF4Tensor", "to_nf4", "UInt4Tensor" "AffineQuantizedTensor", - "to_aq", + "to_affine_quantized", ] diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index ae05720b05..8091946f77 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -200,7 +200,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): if func is aten.t.default: tensor = args[0] new = tensor.__class__( - tensor.int_data.view(tenssor.shape[::-1]), tenssor.scale, tenssor.zero_point + tensor.int_data.view(tensor.shape[::-1]), tensor.scale, tensor.zero_point ) return return_and_correct_aliasing(func, args, kwargs, new) @@ -394,8 +394,6 @@ def __new__( kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) else layout_tensor.layout ) - if dtype is None: - dtype = scale.dtype kwargs["dtype"] = dtype if strides is not None: kwargs["strides"] = strides @@ -800,4 +798,4 @@ def t(func, *args, **kwargs): ) return return_and_correct_aliasing(func, args, kwargs, new) -to_aq = AffineQuantizedTensor.from_float +to_affine_quantized = AffineQuantizedTensor.from_float diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 04efddd0ca..bdfae98f50 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -56,38 +56,157 @@ with open("quantization-cache.pkl", "wb") as f: with open("quantization-cache.pkl", "rb") as f: torchao.quantization.AUTOQUANT_CACHE.update(pickle.load(f)) ``` +## Affine Quantization +Affine quantization refers to the type of quantization that maps from floating point numbers to quantized numbers (typically integer) with an affine transformation, i.e.: `quantized_val = float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data. + +### Quantization Primitives +We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass. + +### Quantized Tensor Subclass +We also have a unified quantized tensor subclass that implements how to get a quantized tensor from floating point tensor and what does it mean to call linear ops on an instance of the tensor, e.g. `F.linear` and `aten.addmm`, with this we could dispatch to different operators (e.g. `int4mm` op) based on device (cpu, cuda) and quantization settings (`int4`, `int8`) and also packing formats (e.g. format optimized for cpu int4 mm kernel) + +### Quantization Flow Example +Let's use int4 weight only quantization that's targeting tinygemm int4 weight only quantized matmul +as an example: +```python +import torch +from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain +from torchao.dtypes import to_affine_quantized +from torch._inductor.runtime.runtime_utils import do_bench_gpu +import copy +from torchao.quantization.quant_api import ( + quantize, + int4wo, +) + +class ToyLinearModel(torch.nn.Module): + def __init__(self, m=64, n=32, k=64): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): + return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + +dtype = torch.bfloat16 +m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda") +m_bf16 = copy.deepcopy(m) +example_inputs = m.example_inputs(dtype=dtype, device="cuda") + +m_bf16 = torch.compile(m_bf16, mode='max-autotune') +# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) +groupsize = 32 +m = quantize(m, int4wo(groupsize=groupsize)) + +torch._inductor.config.force_fuse_int_mm_with_mul = True +torch._inductor.config.use_mixed_mm = True + +# temporary workaround for tensor subclass + torch.compile +from torchao.quantization.utils import unwrap_tensor_subclass +m = unwrap_tensor_subclass(m) +# compile the model to improve performance +m = torch.compile(m, mode='max-autotune') + +# benchmark to see the speedup +from torchao.utils import benchmark_model + +num_runs = 100 +torch._dynamo.reset() +bf16_time = benchmark_model(m_bf16, num_runs, example_inputs[0]) +print(f"bf16 mean time: {bf16_time}") +int4_time = benchmark_model(m, num_runs, example_inputs[0]) +print(f"int4 weight only quantized mean time: {int4_time}") +print(f"speedup: {bf16_time / int4_time}") +# output (1xA100 GPU machine) +bf16 mean time: 71.457685546875 +int4 weight only quantized mean time: 31.4580908203125 +speedup: 2.2715200981216173 +``` -## A8W8 Dynamic Quantization +What we do underlying the APIs are roughly the following: +``` +from torchao.dtypes import to_affine_quantized +def int8wo_quant(weight): + return to_affine_quantized(weight, MappingType.SYMMETRIC, (1, weight.shape[1]), torch.int8, eps=torch.finfo(torch.float32).eps, zero_point_dtype=torch.int64) + +for n, m in model.named_modules(): + if isinstance(m, torch.nn.Linear): + # optional filtering for module name, shape etc. + m.weight = nn.Parameter(int8wo_quant(m.weight)) + + # note: quantization for activation need to be applied after the weight quantization + # quantization activation (needed by dynamic quantization) + input_quant_func = int8wo_quant # specify how input activation is quantized + m.weight = nn.Parameter(to_linear_act_quantized(m.weight, input_quant_func)) +``` +The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support +`torch.export.export` and `torch.aot_compile` with the following workaround: +``` +from torchao.quantization.utils import unwrap_tensor_subclass +m_unwrapped = unwrap_tensor_subclass(m) + + +# export +m = torch.export.export(m_unwrapped, example_inputs).module() + +# aot_compile +torch._export.aot_compile(m_unwrapped, example_inputs) +``` + +### Other Available Quantization Techniques +#### A8W8 Dynamic Quantization ```python # Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor torch._inductor.config.force_fuse_int_mm_with_mul = True from torchao.quantization import quant_api -# convert linear modules to quantized tensor subclasses -quant_api.change_linear_weights_to_int8_dqtensors(model) + +# for torch 2.4+ +from torchao.quantization.quant_api import quantize +quantize(model, "int8_dynamic_quant") + +# for torch 2.2.2 and 2.3 +from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors +change_linear_weights_to_int8_dqtensors(model) ``` -## A16W8 WeightOnly Quantization +#### A16W8 WeightOnly Quantization ```python -from torchao.quantization import quant_api -quant_api.change_linear_weights_to_int8_woqtensors(model) +# for torch 2.4+ +from torchao.quantization.quant_api import quantize +from torchao.quantization.quant_api import int8wo +quantize(model, "int8_weight_only") + +# for torch 2.2.2 and 2.3 +from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors +change_linear_weights_to_int8_woqtensors(model) ``` This technique works best when the torch._inductor.config.use_mixed_mm option is enabled. This avoids dequantizing the weight tensor before the matmul, instead fusing the dequantization into the matmul, thereby avoiding materialization of a large floating point weight tensor. -## A16W4 WeightOnly Quantization +#### A16W4 WeightOnly Quantization ```python -from torchao.quantization import quant_api -quant_api.change_linear_weights_to_int4_woqtensors(model) +# for torch 2.4+ +from torchao.quantization.quant_api import quantize +quantize(model, "int4_weight_only") + +# for torch 2.2.2 and 2.3 +from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors +change_linear_weights_to_int4_woqtensors(model) ``` Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model. -## A16W4 WeightOnly Quantization with GPTQ +## (To be moved to prototype) A16W4 WeightOnly Quantization with GPTQ ```python from torchao._models._eval import InputRecorder, TransformerEvalWrapper @@ -137,17 +256,17 @@ model = quantizer.quantize(model, inputs).cuda() ``` -## A8W8 Dynamic Quantization +## (To be deprecated) A8W8 Dynamic Quantization ```Python from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer -quantizer = Int8DynActInt4WeightQuantizer(groupsize=32) +quantizer = Int8DynActInt4WeightQuantizer(groupsize=128) model = quantizer.quantize(model) ``` This is used in [ExecuTorch](https://github.com/pytorch/executorch) to quantize llama model right now. -## A8W8 Dynamic Quantization with Smoothquant +## (To be moved to prototype) A8W8 Dynamic Quantization with Smoothquant We've also implemented a version of [smoothquant](https://arxiv.org/abs/2211.10438) with the same GEMM format as above. Due to requiring calibration, the API is more complicated. @@ -186,118 +305,6 @@ model = torch.compile(model, mode='max-autotune') model(input) ``` -## Affine Quantization -Affine quantization refers to the type of quantization that maps from floating point numbers to quantized numbers (typically integer) with an affine transformation, i.e.: `quantized_val = float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data. - -### Quantization Primitives -We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass. - -### Quantized Tensor Subclass -We also have a unified quantized tensor subclass that implements how to get a quantized tensor from floating point tensor and what does it mean to call linear ops on an instance of the tensor, e.g. `F.linear` and `aten.addmm`, with this we could dispatch to different operators (e.g. `int4mm` op) based on device (cpu, cuda) and quantization settings (`int4`, `int8`) and also packing formats (e.g. format optimized for cpu int4 mm kernel) - -### Quantization Flow -What we need to do afterwards is roughly the following - -``` -from torchao.dtypes.aqt import to_aq -def apply_int8wo_quant(weight): - mapping_type = MappingType.SYMMETRIC - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - block_size = (1, weight.shape[1]) - return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) - -for n, m in model.named_modules(): - if isinstance(m, torch.nn.Linear): - # optional filtering for module name, shape etc. - m.weight = nn.Parameter(apply_int8wo_quant(m.weight)) - # note: quantization for activation need to be applied after the weight quantization - # quantization activation (needed by dynamic quantization) - # input_quant_func = apply_int8wo_quant # specify how input activation is quantized - # m.weight = nn.Parameter(to_laq(m.weight, input_quant_func)) -``` -The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support -`torch.export.export` and `torch.aot_compile` with the following workaround: -``` -from torchao.quantization.utils import unwrap_tensor_subclass -m_unwrapped = unwrap_tensor_subclass(m) - - -# export -m = torch.export.export(m_unwrapped, example_inputs).module() - -# aot_compile -torch._export.aot_compile(m_unwrapped, example_inputs) -``` - -But we expect this will be integrated into the export path by default in the future. - - -### Example -Let's use int4 weight only quantization that's targeting tinygemm int4 weight only quantized matmul -as an example: -```python -import torch -from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain -from torchao.dtypes import to_aq -from torch._inductor.runtime.runtime_utils import do_bench_gpu -import copy -from torchao.quantization.quant_api import ( - quantize, - get_apply_int4wo_quant, -) - -class ToyLinearModel(torch.nn.Module): - def __init__(self, m=64, n=32, k=64): - super().__init__() - self.linear1 = torch.nn.Linear(m, n, bias=False) - self.linear2 = torch.nn.Linear(n, k, bias=False) - - def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): - return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x - -dtype = torch.bfloat16 -m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda") -m_bf16 = copy.deepcopy(m) -example_inputs = m.example_inputs(dtype=dtype, device="cuda") - -m_bf16 = torch.compile(m_bf16, mode='max-autotune') -# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) -groupsize = 32 -m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize)) - -torch._inductor.config.force_fuse_int_mm_with_mul = True -torch._inductor.config.use_mixed_mm = True - -# temporary workaround for tensor subclass + torch.compile -from torchao.quantization.utils import unwrap_tensor_subclass -m = unwrap_tensor_subclass(m) -# compile the model to improve performance -m = torch.compile(m, mode='max-autotune') - -# benchmark to see the speedup -from torchao.utils import benchmark_model - -num_runs = 100 -torch._dynamo.reset() -bf16_time = benchmark_model(m_bf16, num_runs, example_inputs[0]) -print(f"bf16 mean time: {bf16_time}") -int4_time = benchmark_model(m, num_runs, example_inputs[0]) -print(f"int4 weight only quantized mean time: {int4_time}") -print(f"speedup: {bf16_time / int4_time}") - -# output (1xA100 GPU machine) -bf16 mean time: 71.457685546875 -int4 weight only quantized mean time: 31.4580908203125 -speedup: 2.2715200981216173 -``` - ## Notes 1. APIs have been hardware tested on A100 and T4(colab) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 510db85512..e97d7e8ec0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -13,29 +13,22 @@ both because primitives were designed based on the fusions that come along with it and because that is how we access the intended quantized and mixed GEMM kernels - -TODO: There are 2 different approaches to quantizing a model. The first and more historically -popular approach is to use module swaps which explicitly change the linear modules and the second -approach is to instead use subclasses to change the interpretation of the linear module """ import torch import torch.nn as nn import torch.nn.functional as F -from typing import Any, Callable +from typing import Any, Callable, Union, Dict -from .dynamic_quant import DynamicallyPerAxisQuantizedLinear from torchao.utils import ( TORCH_VERSION_AFTER_2_4, unwrap_tensor_subclass, ) from .subclass import ( - Int4WeightOnlyQuantizedLinearWeight, - Int8DynamicallyQuantizedLinearWeight, - Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, - to_laq, + LinearActQuantizedTensor, + to_linear_act_quantized, ) from .quant_primitives import ( @@ -52,11 +45,6 @@ __all__ = [ - "apply_weight_only_int8_quant", - "apply_dynamic_quant", - "change_linear_weights_to_int8_dqtensors", - "change_linear_weights_to_int8_woqtensors", - "change_linear_weights_to_int4_woqtensors", "swap_conv2d_1x1_to_linear", "Quantizer", "TwoStepQuantizer", @@ -65,10 +53,10 @@ "quantize", "autoquant", "_get_subclass_inserter", - "get_apply_8da4w_quant", - "get_apply_int4wo_quant", - "get_apply_int8wo_quant", - "get_apply_int8dyn_quant", + "int8da_int4w", + "int8da_int8w", + "int4wo", + "int8wo", ] from .GPTQ import ( @@ -81,6 +69,77 @@ "Int8DynActInt4WeightGPTQQuantizer", ] +### TO BE DEPRECATED START +from .subclass import ( + Int4WeightOnlyQuantizedLinearWeight, + Int8DynamicallyQuantizedLinearWeight, + Int8WeightOnlyQuantizedLinearWeight, +) + +def _in_features_greater_than_16(mod, *args): + return hasattr(mod, "in_features") and mod.in_features > 16 + +def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): + """ + Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight` + Tensor subclass, effectively applying the same form of quantization + as apply_dynamic_quant while not modifying the linear modules. + """ + if TORCH_VERSION_AFTER_2_4: + raise ImportError("This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs") + + if filter_fn is None: + filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( + *args + ) + + _replace_with_custom_fn_if_matches_filter( + model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn + ) + + +def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): + """ + Converts all linear weight tensors to the + `Int8WeightOnlyQuantizedLinearWeight` tensor subclass, + effectively applying the same form of quantization + as apply_weight_only_int8_quant while not modifying the linear modules. + """ + if TORCH_VERSION_AFTER_2_4: + raise ImportError("This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs") + + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=False, **kwargs), + _is_linear if filter_fn is None else filter_fn, + ) + +def change_linear_weights_to_int4_woqtensors(model, groupsize=128, inner_k_tiles=8, filter_fn=None): + """ + Converts all linear weight tensors to the + `Int4WeightOnlyQuantizedLinearWeight` tensor subclass, + effectively applying the same form of quantization + as apply_dynamic_quant while not modifying the linear modules. + Args: + `groupsize`: parameter for quantization, controls the granularity of quantization, smaller + size is more fine grained, choices are [256, 128, 64, 32] + `inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2] + """ + if TORCH_VERSION_AFTER_2_4: + raise ImportError("This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs") + + if filter_fn is None: + filter_fn = _is_linear + + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=False, groupsize=groupsize, inner_k_tiles=inner_k_tiles), + filter_fn, + ) + +### TO BE DEPRECATED END + + def _replace_with_custom_fn_if_matches_filter( model, @@ -115,39 +174,20 @@ def _replace_with_custom_fn_if_matches_filter( def _is_linear(mod, *args): + # avoid circular dep + from torchao.dtypes import AffineQuantizedTensor + + # adding weight tensor subclass isinstance check to make sure the weight is only quantized once + # when it is shared by multiple linear modules return ( isinstance(mod, torch.nn.Linear) and hasattr(mod, "weight") and not isinstance(mod.weight, QuantizedLinearWeightBase) and not isinstance(mod.weight, AutoQuantizableLinearWeight) + and not isinstance(mod.weight, AffineQuantizedTensor) + and not isinstance(mod.weight, LinearActQuantizedTensor) ) - -def _in_features_greater_than_16(mod, *args): - return hasattr(mod, "in_features") and mod.in_features > 16 - - -def apply_weight_only_int8_quant(model, filter_fn=None): - """ - Applies weight-only symmetric per-channel int8 quantization to all linear layers - in the given model using module swaps. - """ - _replace_with_custom_fn_if_matches_filter( - model, - WeightOnlyInt8QuantLinear.from_float, - _is_linear if filter_fn is None else filter_fn, - ) - - -def apply_dynamic_quant(model, filter_fn=None): - """ - Applies dynamic symmetric per-token activation and per-channel weight - quantization to all linear layers by converting all linear weight - tensors to the `Int8DynamicallyQuantizedLinearWeight` Tensor subclass. - """ - change_linear_weights_to_int8_dqtensors(model, filter_fn) - - import torch.nn.utils.parametrize as parametrize def _get_subclass_inserter(cls, enable_parametrization=False, **kwargs): @@ -178,70 +218,6 @@ def insert_subclass(lin): return insert_subclass -def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): - """ - Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight` - Tensor subclass, effectively applying the same form of quantization - as apply_dynamic_quant while not modifying the linear modules. - """ - if filter_fn is None: - filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16( - *args - ) - - if TORCH_VERSION_AFTER_2_4: - quantize(model, get_apply_int8dyn_quant(), filter_fn) - unwrap_tensor_subclass(model, filter_fn) - else: - _replace_with_custom_fn_if_matches_filter( - model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn - ) - - -def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): - """ - Converts all linear weight tensors to the - `Int8WeightOnlyQuantizedLinearWeight` tensor subclass, - effectively applying the same form of quantization - as apply_weight_only_int8_quant while not modifying the linear modules. - """ - - if TORCH_VERSION_AFTER_2_4: - quantize(model, get_apply_int8wo_quant(), filter_fn) - unwrap_tensor_subclass(model, filter_fn) - else: - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter(Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=False, **kwargs), - _is_linear if filter_fn is None else filter_fn, - ) - - -def change_linear_weights_to_int4_woqtensors(model, groupsize=128, inner_k_tiles=8, filter_fn=None): - """ - Converts all linear weight tensors to the - `Int4WeightOnlyQuantizedLinearWeight` tensor subclass, - effectively applying the same form of quantization - as apply_dynamic_quant while not modifying the linear modules. - - Args: - `groupsize`: parameter for quantization, controls the granularity of quantization, smaller - size is more fine grained, choices are [256, 128, 64, 32] - `inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2] - """ - if filter_fn is None: - filter_fn = _is_linear - - if TORCH_VERSION_AFTER_2_4: - quantize(model, get_apply_int4wo_quant(groupsize=groupsize, inner_k_tiles=inner_k_tiles), filter_fn) - unwrap_tensor_subclass(model, filter_fn) - else: - _replace_with_custom_fn_if_matches_filter( - model, - _get_subclass_inserter(Int4WeightOnlyQuantizedLinearWeight, enable_parametrization=False, groupsize=groupsize, inner_k_tiles=inner_k_tiles), - filter_fn, - ) - def swap_conv2d_1x1_to_linear(model, filter_fn=None): """ Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized. @@ -281,12 +257,13 @@ def insert_subclass(lin): return insert_subclass -def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn=None) -> torch.nn.Module: +def quantize(model: torch.nn.Module, apply_tensor_subclass: Union[str, Callable[[torch.Tensor], torch.Tensor]], filter_fn=None) -> torch.nn.Module: """Convert the weight of linear modules in the model with `apply_tensor_subclass` Args: model: input model apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance + or a string filter_fn: used to filter out the modules that we don't want to apply tenosr subclass Example:: @@ -303,7 +280,9 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - apply_weight_quant = lambda x: to_aq(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) + apply_weight_quant = lambda x: to_affine_quantized( + x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, + zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) # apply to modules under block0 submodule def filter_fn(module, fqn): @@ -312,6 +291,12 @@ def filter_fn(module, fqn): m = MyModel(...) m = quantize(m, apply_weight_quant, filter_fn) """ + if isinstance(apply_tensor_subclass, str): + assert apply_tensor_subclass in _APPLY_TS_TABLE, f"{apply_tensor_subclass} not supported: {_APPLY_TS_TABLE.keys()}" + apply_tensor_subclass = _APPLY_TS_TABLE[apply_tensor_subclass] + + assert not isinstance(apply_tensor_subclass, str) + _replace_with_custom_fn_if_matches_filter( model, _get_linear_subclass_inserter(apply_tensor_subclass), @@ -319,11 +304,18 @@ def filter_fn(module, fqn): ) return model -def get_apply_8da4w_quant(groupsize=32): +def int8da_int4w(groupsize=32): + """Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear + This is used to produce a model for executorch backend, but currently executorch did not + support lowering for the quantized model from this flow yet + Args: + `groupsize`: parameter for quantization, controls the granularity of quantization, smaller + size is more fine grained + """ def apply_8da4w_quant(weight): # avoid circular dep - from torchao.dtypes.aqt import to_aq + from torchao.dtypes import to_affine_quantized # weight settings mapping_type = MappingType.SYMMETRIC @@ -345,19 +337,28 @@ def get_per_token_block_size(x): # input settings input_mapping_type = MappingType.ASYMMETRIC input_target_dtype = torch.int8 - input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) + input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) - weight = to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps) - weight = to_laq(weight, input_quant_func) + weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps) + weight = to_linear_act_quantized(weight, input_quant_func) return weight return apply_8da4w_quant -def get_apply_int4wo_quant(groupsize=32, inner_k_tiles=8): +def int4wo(groupsize=128, inner_k_tiles=8): + """ + Applies uint4 weight-only asymmetric per-group quantization to linear layers, using + "tensor_core_tiled" layout for speedup with tinygemm kernel + + Args: + `groupsize`: parameter for quantization, controls the granularity of quantization, smaller + size is more fine grained, choices are [256, 128, 64, 32] + `inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2] + """ def apply_int4wo_quant(weight): # avoid circular dep - from torchao.dtypes.aqt import to_aq + from torchao.dtypes import to_affine_quantized mapping_type = MappingType.ASYMMETRIC block_size = (1, groupsize) @@ -368,28 +369,40 @@ def apply_int4wo_quant(weight): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled", inner_k_tiles=inner_k_tiles) + return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled", inner_k_tiles=inner_k_tiles) return apply_int4wo_quant -def get_apply_int8wo_quant(): +def int8wo(): + """ + Applies int8 weight-only symmetric per-channel quantization to linear layers. + """ def apply_int8wo_quant(weight): # avoid circular dep - from torchao.dtypes.aqt import to_aq + from torchao.dtypes import to_affine_quantized mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 block_size = (1, weight.shape[1]) - return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) return apply_int8wo_quant -def get_apply_int8dyn_quant(): +def int8da_int8w(): + """ + Applies int8 dynamic symmetric per-token activation and int8 per-channel weight + quantization to linear layers + """ def apply_int8dyn_quant(weight): + in_features = weight.shape[1] + # int8 dynamic quantization only has benefit when in_feature > 16 + if in_features <= 16: + return weight + # avoid circular dep - from torchao.dtypes.aqt import to_aq + from torchao.dtypes import to_affine_quantized # weight settings mapping_type = MappingType.SYMMETRIC def get_weight_block_size(x): @@ -410,10 +423,18 @@ def get_per_token_block_size(x): input_eps = 1e-5 input_quant_min = -127 input_quant_max = 127 - input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) + input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) block_size = get_weight_block_size(weight) - weight = to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) - weight = to_laq(weight, input_quant_func) + weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + weight = to_linear_act_quantized(weight, input_quant_func) return weight return apply_int8dyn_quant + +# shortcut string to apply_tensor_subclass with a specific setting +# to simplify common use cases +_APPLY_TS_TABLE: Dict[str, Callable] = { + "int4_weight_only": int4wo(), + "int8_weight_only": int8wo(), + "int8_dynamic": int8da_int8w(), +} diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index c299a5834f..a2801a622f 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -31,7 +31,7 @@ "Int8WeightOnlyQuantizedLinearWeight", "Int4WeightOnlyQuantizedLinearWeight", "LinearActQuantizedTensor", - "to_laq", + "to_linear_act_quantized", ] @@ -751,4 +751,4 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported" ) -to_laq = LinearActQuantizedTensor.from_float +to_linear_act_quantized = LinearActQuantizedTensor.from_float diff --git a/torchao/utils.py b/torchao/utils.py index 2a19993e4d..381a302645 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -126,6 +126,11 @@ def right_inverse(self, tensor): return plain_tensors def unwrap_tensor_subclass(model, filter_fn=None): + """Unwraps (nested) tensor subclass in the model to plain tensors + This is a workaround to make a model with tensor subclass to work with `torch.export.export` + and `torch.aot_compile`, we hope this can be integrated into compile stack soon + tracking issue: https://github.com/pytorch/ao/issues/345 + """ for name, child in model.named_children(): # make sure child.weight is a tensor subclass if ( diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index 918284ae1e..6115e5f21d 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -15,8 +15,12 @@ input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda') ## Quantization code - start -# int8 act, int8 weight dynamic quantization, see README for other APIs -torchao.apply_dynamic_quant(model) +# int8 dynamic quantization act, int8 weight, see ao/torchao/quantization/README.md +# for APIs for earlier torch version and other quantization techniques + +# for torch 2.4+ +from torchao.quantization.quant_api import quantize +quantize(model, "int8_dynamic") ## Quantization code - end ## compilation configs @@ -25,6 +29,10 @@ torch._inductor.config.use_mixed_mm = True ## compilation configs end +# temporary workaround for the API to work with torch.compile +from torchao.utils import unwrap_tensor_subclass +unwrap_tensor_subclass(model) + model = torch.compile(model, mode='max-autotune') # Must run with no_grad when optimizing for inference