From e5df48e1033611d6433dc3fd1de7923e0f555ebc Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 24 Jul 2024 14:55:35 -0400 Subject: [PATCH 1/5] [test] Fix regression tests (#537) * Fix regression tests Summary: Torch 2.2 is compiled with numpy 1.x, but when we `pip install -r requirements-dev.txt` we download a higher version of numpy (2.0) This causes an error with the .numpy() calls and importing torch in general. I don't think we want to pin the versions in requirements-dev.txt, so instead I added a pin to the numpy version in the specific torch spec, so it'll only run for 2.2. PT 2.3+ support numpy 2.0+ which is why those test don't fail. Test Plan: Reviewers: Subscribers: Tasks: Tags: * use == and specify for CPU as well * update --- .github/workflows/regression_test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 191fb6fe6d..119d228085 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -23,7 +23,7 @@ jobs: include: - name: CUDA 2.2.2 runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: 'torch==2.2.2' + torch-spec: 'torch==2.2.2 "numpy<2" ' gpu-arch-type: "cuda" gpu-arch-version: "12.1" - name: CUDA 2.3 @@ -38,7 +38,7 @@ jobs: gpu-arch-version: "12.1" - name: CPU 2.2.2 runs-on: linux.4xlarge - torch-spec: 'torch==2.2.2 --index-url https://download.pytorch.org/whl/cpu' + torch-spec: 'torch==2.2.2 --index-url https://download.pytorch.org/whl/cpu "numpy<2" ' gpu-arch-type: "cpu" gpu-arch-version: "" - name: CPU 2.3 From e8662e0e10dbb1511f8e673a4a5cee26d24879e4 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 25 Jul 2024 17:43:06 -0700 Subject: [PATCH 2/5] Fixing cuda device check (#536) Summary: Previous cuda device check is not general enough, this adds a better check that works for more cases like "cuda:0" Test Plan: python test/quantization/test_quant_api.py -k test_int4wo_quantized_model_to_device Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 24 ++++++++++++----------- torchao/dtypes/affine_quantized_tensor.py | 3 ++- torchao/dtypes/utils.py | 5 ++++- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index ab24fc981c..c19dc2660b 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -642,17 +642,19 @@ def test_int8wo_quantized_model_to_device(self): @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+") def test_int4wo_quantized_model_to_device(self): # TODO: change initial model to "cpu" - m = ToyLinearModel().eval().to(torch.bfloat16).to("cuda") - m_copy = copy.deepcopy(m) - example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") - - quantize_(m, int4_weight_only()) - ref = m(*example_inputs) - - example_inputs_cuda = (example_inputs[0].to("cuda"),) - m.to(device="cuda") - cuda_res = m(*example_inputs_cuda) - self.assertEqual(cuda_res.cpu(), ref) + devices = ["cuda", "cuda:0"] + for device in devices: + m = ToyLinearModel().eval().to(torch.bfloat16).to(device) + m_copy = copy.deepcopy(m) + example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device) + + quantize_(m, int4_weight_only()) + ref = m(*example_inputs) + + example_inputs_cuda = (example_inputs[0].to(device),) + m.to(device=device) + cuda_res = m(*example_inputs_cuda) + self.assertEqual(cuda_res.cpu(), ref) @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index da5cc7d28b..4142937905 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -21,6 +21,7 @@ _register_layout_cls, _get_layout_tensor_constructor, LayoutType, + is_device, ) from typing import ClassVar from dataclasses import dataclass @@ -544,7 +545,7 @@ def from_plain( def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] - if device != "cuda" and (isinstance(device, torch.device) and device.type != "cuda"): + if not is_device("cuda", device): raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}") return self.__class__( self.packed_weight.to(device), diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 3a437b4745..656c4873ab 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -1,5 +1,5 @@ import torch -from typing import Dict, Callable +from typing import Dict, Callable, Union from collections import defaultdict import functools from dataclasses import dataclass @@ -89,3 +89,6 @@ def _get_layout_tensor_constructor(cls: Callable, layout_type_class: type(Layout raise ValueError(f"layout_name: {layout_type_class} is not supported yet for {cls}") return _LAYOUT_CONSTRUCTOR_TABLE[cls][layout_type_class] + +def is_device(target_device_str: str, device: Union[str, torch.device]): + return torch.device(device).type == target_device_str From 428084356ace4ea94c22a3a9b3d74cff8ee41db3 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 26 Jul 2024 10:43:58 +0800 Subject: [PATCH 3/5] Improve FSDP support for low-bit optimizers (#538) --- test/prototype/test_low_bit_optim.py | 9 +++ torchao/prototype/low_bit_optim/adam.py | 10 ++-- torchao/prototype/low_bit_optim/adamw.py | 10 ++-- .../prototype/low_bit_optim/subclass_4bit.py | 57 ++++++++++++++++--- .../prototype/low_bit_optim/subclass_8bit.py | 51 ++++++++++++++--- .../prototype/low_bit_optim/subclass_fp8.py | 45 +++++++++++++-- 6 files changed, 153 insertions(+), 29 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 94cfe34096..5eb0a54b62 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -226,6 +226,15 @@ def _test_fsdp2(self, optim_cls): base_optim.step() self.assertEqual(fsdp_loss, base_loss) + base_param = base_optim.param_groups[0]["params"][0] + base_exp_avg = base_optim.state[base_param]["exp_avg"] + + fsdp_param = fsdp_optim.param_groups[0]["params"][0] + fsdp_exp_avg = fsdp_optim.state[fsdp_param]["exp_avg"] + full_fsdp_exp_avg = fsdp_exp_avg.full_tensor() + + self.assertEqual(base_exp_avg.dequantize(), full_fsdp_exp_avg.dequantize()) + instantiate_parametrized_tests(TestQuantize) instantiate_parametrized_tests(TestOptim) diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index b3b7eeb6f3..47a99c06dc 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -39,11 +39,11 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int): def _new_buffer(self, p: Tensor, signed: bool): if p.numel() >= 4096 and p.numel() % self.block_size == 0: if isinstance(p, DTensor): - out = torch.empty_like(p) - out._local_tensor = self._subclass_zeros( - out._local_tensor, - signed, - self.block_size, + out = DTensor.from_local( + local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size), + device_mesh=p.device_mesh, + placements=p.placements, + run_check=False, ) else: out = self._subclass_zeros(p, signed, self.block_size) diff --git a/torchao/prototype/low_bit_optim/adamw.py b/torchao/prototype/low_bit_optim/adamw.py index ad60caa435..dbde91fdd2 100644 --- a/torchao/prototype/low_bit_optim/adamw.py +++ b/torchao/prototype/low_bit_optim/adamw.py @@ -39,11 +39,11 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int): def _new_buffer(self, p: Tensor, signed: bool): if p.numel() >= 4096 and p.numel() % self.block_size == 0: if isinstance(p, DTensor): - out = torch.empty_like(p) - out._local_tensor = self._subclass_zeros( - out._local_tensor, - signed, - self.block_size, + out = DTensor.from_local( + local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size), + device_mesh=p.device_mesh, + placements=p.placements, + run_check=False, ) else: out = self._subclass_zeros(p, signed, self.block_size) diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index 9550b3d51c..a24cf8b1d5 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -8,7 +8,8 @@ aten = torch.ops.aten - +c10d_functional = torch.ops.c10d_functional +_c10d_functional = torch.ops._c10d_functional # https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/configs/2nd_moment_group_128.yml # NOTE: power-1 is linear @@ -31,17 +32,29 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape ) def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape): + """Create quantized 4-bit optimizer state as proposed in https://arxiv.org/abs/2309.01507 + + Args + codes: quantized and packed 4-bit data stored as uint8. + scale: scale data for block-wise quantization. + qmap: lookup table that maps between quantized value (code) and float value. + signed: whether the tensor is signed or unsigned. + shape: shape of original float tensor. + + NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size). + Thus, the last dimension of the original float tensor is not necessarily divisible by block size. + Given `codes` and `scale`, `block_size` is calculated as `codes.numel() * 2 // scale.numel()`. + The extra `* 2` is because `codes` is 4-bit data packed in 8-bit storage. + """ assert codes.dtype is torch.uint8 assert codes.ndim == 1 # flattened buffer + assert scale.ndim == 1 self.codes = codes self.scale = scale self.qmap = qmap self.signed = signed self._shape = shape - - @property - def block_size(self): - return self.codes.numel() * 2 // self.scale.numel() + self.block_size = codes.numel() * 2 // scale.numel() def __tensor_flatten__(self): return self.tensor_attrs, [self.signed, self._shape] @@ -113,9 +126,37 @@ def _(func, *args, **kwargs): return func(*args, **kwargs) +# this is needed for DTensor.from_local() and for flattening tensor @OptimState4bit.implements(aten.view.default) def _(func, *args, **kwargs): x, shape = args - if len(shape) > 1 or shape[0] != -1: - raise ValueError(f"{x.__class__.__name__} only supports .view() with shape=[-1]") - return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),)) + + if tuple(x.shape) == tuple(shape): + return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, x._shape) + + if len(shape) == 1 and shape[0] == -1: + return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),)) + + raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]") + + +# this is needed for DTensor.full_tensor() +@OptimState4bit.implements([ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, +]) +def _(func, *args, **kwargs): + x = args[0] + if not isinstance(x, OptimState4bit): + raise ValueError(f"expecting a OptimState4bit but found {type(x)}") + + codes = func(x.codes, *args[1:], **kwargs) + scale = func(x.scale, *args[1:], **kwargs) + + # adjust the first dim + shape = (x._shape[0] * codes.numel() // x.codes.numel(),) + x._shape[1:] + + # assume tensors from all ranks have the same signedness + return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape) diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 5b16f6363f..1e2067963a 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -6,14 +6,13 @@ aten = torch.ops.aten +c10d_functional = torch.ops.c10d_functional +_c10d_functional = torch.ops._c10d_functional QMAP_SIGNED = create_dynamic_map(signed=True) QMAP_UNSIGNED = create_dynamic_map(signed=False) -# dynamic tree quantization -# https://arxiv.org/pdf/1511.04561 -# https://arxiv.org/abs/2110.02861 class OptimState8bit(Tensor): implements = classmethod(_implements) tensor_attrs = ["codes", "scale", "qmap"] @@ -28,15 +27,25 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): ) def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): + """Create quantized 8-bit optimizer state as proposed in https://arxiv.org/abs/2110.02861 + + Args + codes: quantized 8-bit data stored as uint8. Has the same shape as the original float tensor. + scale: scale data for block-wise quantization. + qmap: lookup table that maps between quantized value (code) and float value. + signed: whether the tensor is signed or unsigned. + + NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size). + Thus, the last dimension of the original float tensor is not necessarily divisible by block size. + Given `codes` and `scale`, `block_size` is calculated as `codes.numel() // scale.numel()`. + """ assert codes.dtype is torch.uint8 + assert scale.ndim == 1 self.codes = codes self.scale = scale self.qmap = qmap self.signed = signed - - @property - def block_size(self): - return self.codes.numel() // self.scale.numel() + self.block_size = codes.numel() // scale.numel() def __tensor_flatten__(self): return self.tensor_attrs, [self.signed] @@ -97,3 +106,31 @@ def _(func, *args, **kwargs): def _(func, *args, **kwargs): args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args] return func(*args, **kwargs) + + +# this is needed for DTensor.from_local() +@OptimState8bit.implements(aten.view.default) +def _(func, *args, **kwargs): + x, shape = args + return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed) + + +# this is needed for DTensor.full_tensor() +@OptimState8bit.implements([ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, +]) +def _(func, *args, **kwargs): + x = args[0] + if not isinstance(x, OptimState8bit): + raise ValueError(f"expecting a OptimState8bit but found {type(x)}") + + # assume tensors from all ranks have the same signedness + return OptimState8bit( + func(x.codes, *args[1:], **kwargs), + func(x.scale, *args[1:], **kwargs), + x.qmap.clone(), + x.signed, + ) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index e3116e20f8..b78638cd01 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -4,6 +4,9 @@ aten = torch.ops.aten +c10d_functional = torch.ops.c10d_functional +_c10d_functional = torch.ops._c10d_functional + DTYPE = torch.float8_e4m3fn @@ -32,13 +35,21 @@ def __new__(cls, codes: Tensor, scale: Tensor): ) def __init__(self, codes: Tensor, scale: Tensor): + """Create quantized FP8 optimizer state. + + Args + codes: quantized FP8 E4M3FN data. Has the same shape as the original float tensor. + scale: scale data for block-wise quantization. + + NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size). + Thus, the last dimension of the original float tensor is not necessarily divisible by block size. + Given `codes` and `scale`, `block_size` is calculated as `codes.numel() // scale.numel()`. + """ assert codes.dtype is DTYPE + assert scale.ndim == 1 self.codes = codes self.scale = scale - - @property - def block_size(self): - return self.codes.numel() // self.scale.numel() + self.block_size = codes.numel() // scale.numel() def __tensor_flatten__(self): return self.tensor_attrs, [] @@ -99,3 +110,29 @@ def _(func, *args, **kwargs): def _(func, *args, **kwargs): args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args] return func(*args, **kwargs) + + +# this is needed for DTensor.from_local() +@OptimStateFp8.implements(aten.view.default) +def _(func, *args, **kwargs): + x, shape = args + return OptimStateFp8(x.codes.view(shape), x.scale) + + +# this is needed for DTensor.full_tensor() +@OptimStateFp8.implements([ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, +]) +def _(func, *args, **kwargs): + x = args[0] + if not isinstance(x, OptimStateFp8): + raise ValueError(f"expecting a OptimStateFp8 but found {type(x)}") + + # assume tensors from all ranks have the same signedness + return OptimStateFp8( + func(x.codes, *args[1:], **kwargs), + func(x.scale, *args[1:], **kwargs), + ) From c9f79bea2af11ce69125ac6902fbafda1d039f30 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 26 Jul 2024 10:07:00 -0400 Subject: [PATCH 4/5] Implement sparsity as a AQT Layout (#498) Summary: This PR adds in sparsity as an AQTLayout, previously it was implemented using the QuantizedLinearBase subclass that will be deprecated shortly. I also added renamed `sparsify` to `sparsify_` and added in a `semi_sparse_weight()` function to be in line with our other APIs. The main code changes are in `torchao/dtypes/affine_quantized_tensor.py`, for the semi-structured cusparselt representation, we can reuse a lot of the existing PlainLayout implementation, since the compressed representation is stored in a single tensor like `int_data`. Test Plan: ``` python test/sparsity/test_sparse_api ``` --- README.md | 12 +- scripts/sam/benchmark.sh | 1 - scripts/sam/eval_combo.py | 44 +-- scripts/sam/results.csv | 10 +- test/sparsity/test_sparse_api.py | 27 +- torchao/dtypes/__init__.py | 2 + torchao/dtypes/affine_quantized_tensor.py | 77 +++++ torchao/quantization/__init__.py | 1 + torchao/quantization/quant_api.py | 26 +- torchao/sparsity/__init__.py | 11 +- .../prototype/dynamic_quant_sparse.py | 314 ------------------ torchao/sparsity/sparse_api.py | 26 +- 12 files changed, 168 insertions(+), 383 deletions(-) delete mode 100644 torchao/sparsity/prototype/dynamic_quant_sparse.py diff --git a/README.md b/README.md index 1dd2a72340..e31dc63a8f 100644 --- a/README.md +++ b/README.md @@ -49,20 +49,18 @@ And a quick crash course on inference quantization to help parse the above table Sparsifying your model is also a 1 liner that should work on any model with an `nn.Linear`. We find that sparsity works best on compute bound models like SAM, specifically the MLP layers. ```python -from torchao.sparsity import sparsify -from torch.sparse import to_sparse_semi_structured +from torchao.sparsity import sparsify, semi_sparse_weight() -m = sparsify(m, to_sparse_semi_structured) +m = sparsify_(m, semi_sparse_weight()) ``` Sparsity can also be composed with int8 dynamic quantization for further speedups: ```python -from torchao.sparsity import sparsify -from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight +from torchao.sparsity import sparsify, int8_dynamic_activation_int8_semi_sparse_weight -m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight()) +m = sparsify_(m, int8_dynamic_activation_int8_semi_sparse_weight()) ``` -We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration. +We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + semi sparse (2:4) sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration. We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**. The following benchmarks were ran for [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast) ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and `bfloat16` dtype, with `torch.compile="max_autotune"`: diff --git a/scripts/sam/benchmark.sh b/scripts/sam/benchmark.sh index 5c1262f9cc..c52ce33151 100755 --- a/scripts/sam/benchmark.sh +++ b/scripts/sam/benchmark.sh @@ -8,4 +8,3 @@ python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse # int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse) python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse - diff --git a/scripts/sam/eval_combo.py b/scripts/sam/eval_combo.py index e83ec25300..b9733bd98b 100644 --- a/scripts/sam/eval_combo.py +++ b/scripts/sam/eval_combo.py @@ -9,6 +9,10 @@ import time import resource +from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only +from torchao.sparsity import sparsify_, apply_fake_sparsity, int8_dynamic_activation_int8_semi_sparse_weight, semi_sparse_weight +from torchao.utils import unwrap_tensor_subclass + torch._dynamo.config.cache_size_limit = 50000 def unbind_jagged(device, data, sizes, offsets): @@ -279,30 +283,17 @@ def run( block.attn.use_rel_pos = use_rel_pos if compress == "int8_dynamic_quant": - from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight - from torchao.utils import unwrap_tensor_subclass quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight()) predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) elif compress == "sparse_mlp_only": def mlp_only(mod, name): return isinstance(mod, torch.nn.Linear) and 'mlp' in name - from torchao.sparsity import sparsify - from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) - predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only) + sparsify_(predictor.model.image_encoder, semi_sparse_weight(), filter_fn=mlp_only) elif compress == "sparse": - from torchao.sparsity import sparsify - from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity apply_fake_sparsity(predictor.model.image_encoder) - predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured) + sparsify_(predictor.model.image_encoder, semi_sparse_weight()) elif compress == "int8_dynamic_quant_sparse": - from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor - SparseSemiStructuredTensor._FORCE_CUTLASS = False - from torchao.sparsity import sparsify, apply_fake_sparsity - from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight - from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight - from torchao.utils import unwrap_tensor_subclass - def attn_only(mod, name): return isinstance(mod, torch.nn.Linear) and 'attn' in name def mlp_lin1_only(mod, name): @@ -316,20 +307,17 @@ def mlp_only(mod, name): apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only) - quantize_( - predictor.model.image_encoder, - int8_dynamic_activation_int8_weight(), - attn_only - ) + quantize_(predictor.model.image_encoder, + int8_dynamic_activation_int8_weight(), + attn_only) + quantize_(predictor.model.image_encoder, + int8_dynamic_activation_int8_semi_sparse_weight(), + mlp_lin1_only) + sparsify_(predictor.model.image_encoder, + semi_sparse_weight(), + mlp_lin2_only) predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder) - predictor.model.image_encoder = sparsify(predictor.model.image_encoder, - int8_dynamic_activation_int8_2x4_sparse_weight(), - mlp_lin1_only, prune=False) - - predictor.model.image_encoder = sparsify(predictor.model.image_encoder, - to_sparse_semi_structured, - mlp_lin2_only, prune=False) else: assert compress is None, f"Unsupported compress mode {compress}" @@ -413,6 +401,6 @@ def mlp_only(mod, name): vals = ",".join(map(str, [device, sam_model_type, batch_size, max_memory_allocated_bytes, max_memory_allocated_percentage, img_s, batch_ms_batch_size, mIoU, use_compile, use_half, compress, use_compile_decoder, use_rel_pos, pad_input_image_batch, num_workers, num_batches, num_images, profile_path, memory_path])) f.write(vals+"\n") - + if __name__ == '__main__': fire.Fire(run) diff --git a/scripts/sam/results.csv b/scripts/sam/results.csv index 01aad5c022..0be02c7f37 100644 --- a/scripts/sam/results.csv +++ b/scripts/sam/results.csv @@ -1,6 +1,6 @@ device,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,use_compile_decoder,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path -cuda,vit_h,32,15172,18,22.74609667033727,43.96358700541707,0.5811068585673369,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None -cuda,vit_h,32,15154,18,24.908711866303545,40.14659631407106,0.5822020528694204,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None -cuda,vit_h,32,15632,19,24.806623549763994,40.311814221468836,0.5671732654673084,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None -cuda,vit_h,32,13429,16,24.299052218005198,41.15386851422198,0.5305645705002248,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None -cuda,vit_h,32,14865,18,26.46342281926203,37.7880067453756,0.5668329259098808,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,15172,18,22.533401716616083,44.37856354651513,0.5812715827356921,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None +cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None +cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None +cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None +cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py index 3e566732bb..b846afa454 100644 --- a/test/sparsity/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -1,18 +1,24 @@ +import copy import logging import unittest import torch from torch import nn -from torch.sparse import to_sparse_semi_structured -from torchao.sparsity import apply_fake_sparsity, sparsify -from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight +from torchao.sparsity import ( + apply_fake_sparsity, + sparsify_, + int8_dynamic_activation_int8_semi_sparse_weight, + semi_sparse_weight, +) from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, _get_subclass_inserter, _is_linear, + int8_dynamic_activation_int8_weight, + quantize_, ) -from torchao.utils import TORCH_VERSION_AFTER_2_3 +from torchao.utils import TORCH_VERSION_AFTER_2_3, unwrap_tensor_subclass from torch.testing._internal.common_utils import TestCase @@ -38,12 +44,11 @@ def test_sparse(self): apply_fake_sparsity(model) dense_result = model(input) - model = sparsify(model, to_sparse_semi_structured) + sparsify_(model, semi_sparse_weight()) sparse_result = model(input) assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) - class TestQuantSemiSparse(TestCase): @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature") @@ -58,15 +63,15 @@ def test_quant_semi_sparse(self): .half() .cuda() ) - apply_fake_sparsity(model) - dense_result = model(input) + model_copy = copy.deepcopy(model) + quantize_(model_copy, int8_dynamic_activation_int8_weight()) + dense_result = model_copy(input) - sparsify(model, int8_dynamic_activation_int8_2x4_sparse_weight()) + quantize_(model, int8_dynamic_activation_int8_semi_sparse_weight()) sparse_result = model(input) - assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1) - + assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2) if __name__ == "__main__": unittest.main() diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 39372fe27f..e4b47b8229 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -7,6 +7,7 @@ to_affine_quantized_static, LayoutType, PlainLayoutType, + SemiSparseLayoutType, TensorCoreTiledLayoutType, ) @@ -19,5 +20,6 @@ "to_affine_quantized_static", "LayoutType", "PlainLayoutType", + "SemiSparseLayoutType", "TensorCoreTiledLayoutType", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 4142937905..5c762231e2 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -32,6 +32,17 @@ class PlainLayoutType(LayoutType): pass +@dataclass(frozen=True) +class SemiSparseLayoutType(LayoutType): + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + # prune to 2:4 if not already + temp = input.detach() + pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2] + temp.view(-1, 4).scatter_(1, pruning_inds, value=0) + return temp + + @dataclass(frozen=True) class TensorCoreTiledLayoutType(LayoutType): inner_k_tiles: int = 8 @@ -473,6 +484,47 @@ def from_plain( assert isinstance(layout_type, PlainLayoutType) return cls(int_data, scale, zero_point, layout_type) +@register_layout_cls(SemiSparseLayoutType) +class SemiSparseAQTLayout(PlainAQTLayout): + """ + Layout storage class for semi_sparse_cusparselt layout for affine quantized tensor + """ + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"SparseAQTLayout dispatch: attempting to run {func}, this is not supported" + ) + + def get_plain(self): + # Currently we don't have cuSPARSELt expansion routines, so we matmul by + # the identity matrix to get the original dense matrix. This is slow though. + cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0]) + int_data_expanded = torch._cslt_sparse_mm(self.int_data, + torch.eye(cols, + dtype=self.int_data.dtype, + device=self.int_data.device).t()) + return int_data_expanded, self.scale, self.zero_point + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout_type: LayoutType, + ): + assert isinstance(layout_type, SemiSparseLayoutType) + int_data_compressed = torch._cslt_compress(int_data) + return cls(int_data_compressed, scale, zero_point, layout_type) + + @register_layout_cls(TensorCoreTiledLayoutType) class TensorCoreTiledAQTLayout(AQTLayout): """ @@ -669,6 +721,31 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): if bias is not None: y += bias return y + # handle int8 dynamic_quant + semi_structured_sparse + elif( + is_cuda and + input_is_int8 and + input_tensor.dtype == weight_qtensor.dtype and + isinstance(input_tensor.layout_type, PlainLayoutType) and + isinstance(weight_qtensor.layout_type, SemiSparseLayoutType) + ): + x_vals_int8 = input_tensor.layout_tensor.int_data + x_scales = input_tensor.layout_tensor.scale + w_vals_int8 = weight_qtensor.layout_tensor.int_data + w_scales = weight_qtensor.layout_tensor.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm + y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( + w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16 + ).t() + y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( + *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] + ) + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y else: input_tensor = input_tensor.dequantize() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index a1cf1bf034..6bf37f0080 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -32,6 +32,7 @@ "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", + "int8_dynamic_activation_int8_semi_sparse_weight", "int4_weight_only", "int8_weight_only", ] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3b02930c3c..161a84c4e4 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -14,13 +14,14 @@ come along with it and because that is how we access the intended quantized and mixed GEMM kernels """ - +from functools import partial import torch import torchao import torch.nn as nn import torch.nn.functional as F from typing import Any, Callable, Union, Dict, Optional +from torchao.dtypes import PlainLayoutType from torchao.utils import ( TORCH_VERSION_AFTER_2_4, unwrap_tensor_subclass, @@ -57,6 +58,7 @@ "quantize_", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", + "int8_dynamic_activation_int8_semi_sparse_weight", "int4_weight_only", "int8_weight_only", ] @@ -410,7 +412,8 @@ def apply_int8wo_quant(weight): return _get_linear_subclass_inserter(apply_int8wo_quant) -def int8_dynamic_activation_int8_weight(): + +def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()): """ Applies int8 dynamic symmetric per-token activation and int8 per-channel weight quantization to linear layers @@ -432,16 +435,31 @@ def get_weight_block_size(x): zero_point_dtype = torch.int64 # input settings + def get_per_token_block_size(x): + block_size = list(x.shape) + for i in range(len(block_size)-1): + block_size[i] = 1 + return block_size + input_mapping_type = MappingType.SYMMETRIC input_target_dtype = torch.int8 input_eps = 1e-5 input_quant_min = -127 input_quant_max = 127 - input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) + input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) block_size = get_weight_block_size(weight) - weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type) weight = to_linear_act_quantized(weight, input_quant_func) return weight return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant) + + +def int8_dynamic_activation_int8_semi_sparse_weight(): + """ + Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight + quantization + 2:4 sparsity to linear layers. + """ + from torchao.dtypes import SemiSparseLayoutType + return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()) diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index 9b288c07f9..c3b10f949a 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -6,11 +6,18 @@ from .wanda import WandaSparsifier # noqa: F403 from .utils import PerChannelNormObserver # noqa: F403 -from .sparse_api import apply_fake_sparsity, sparsify +from .sparse_api import ( + apply_fake_sparsity, + sparsify_, + semi_sparse_weight, + int8_dynamic_activation_int8_semi_sparse_weight +) __all__ = [ "WandaSparsifier", "PerChannelNormObserver", "apply_fake_sparsity", - "sparsify" + "sparsify_" + "semi_sparse_weight", + "int8_dynamic_activation_int8_semi_sparse_weight" ] diff --git a/torchao/sparsity/prototype/dynamic_quant_sparse.py b/torchao/sparsity/prototype/dynamic_quant_sparse.py deleted file mode 100644 index 2f2a198278..0000000000 --- a/torchao/sparsity/prototype/dynamic_quant_sparse.py +++ /dev/null @@ -1,314 +0,0 @@ -import torch -import torch.nn as nn -from typing import Tuple, Optional - -from torchao.quantization.utils import ( - dynamically_quantize_per_channel, - quant_int8_dynamic_per_token_linear, - quantize_activation_per_token_absmax, - dequantize_per_channel, -) - -from torchao.quantization.subclass import ( - Int8DynamicallyQuantizedLinearWeight, - QuantizedLinearWeightBase, -) - -from torch.sparse import to_sparse_semi_structured - -# Quant + Sparse helper functinos -def sparse_quant_int8_dynamic_linear( - x : torch.Tensor, - w_vals_int8_packed : torch.Tensor, - w_meta_int32 : Optional[torch.Tensor], - w_scales : torch.Tensor, - bias : Optional[torch.Tensor], - out_dtype : torch.dtype, - fuse_mul=False, -): - x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) - # w_meta_int32 is either None or meta tensor - if w_meta_int32 is None: - if fuse_mul: - mm_out = sparse_quant_int8_cslt_matmul_fuse_mul( - x_vals_int8, x_scales, w_vals_int8_packed, w_scales, out_dtype, - ) - else: - mm_out = sparse_quant_int8_cslt_matmul( - x_vals_int8, x_scales, w_vals_int8_packed, w_scales, out_dtype, - ) - else: - mm_out = sparse_quant_int8_cutlass_matmul( - x_vals_int8, x_scales, w_vals_int8_packed, w_meta_int32, w_scales, out_dtype, - ) - - if bias is not None: - mm_out += bias - return mm_out - -def sparse_quant_int8_cslt_matmul_fuse_mul( - x_vals_int8, - x_scales, - w_vals_int8, - w_scales, - out_dtype, -): - - assert ( - x_vals_int8.dtype == torch.int8 - ), f"x dtype {x_vals_int8.dtype} not yet supported" - assert ( - w_vals_int8.dtype == torch.int8 - ), f"w dtype {w_vals_int8.dtype} not yet supported" - # assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' - - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() - - assert x_scales.dtype in [ - torch.float, - torch.bfloat16, - ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" - - y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16 - ).t() - y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( - *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] - ) - y = y.to(out_dtype) - - return y - -def sparse_quant_int8_cslt_matmul( - x_vals_int8, - x_scales, - w_vals_int8, - w_scales, - out_dtype, -): - - assert ( - x_vals_int8.dtype == torch.int8 - ), f"x dtype {x_vals_int8.dtype} not yet supported" - assert ( - w_vals_int8.dtype == torch.int8 - ), f"w dtype {w_vals_int8.dtype} not yet supported" - # assert w_scales.dtype == out_dtype, f'{w_scales.dtype} does not match {out_dtype}' - - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() - - assert x_scales.dtype in [ - torch.float, - torch.bfloat16, - ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" - - y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, tmp.t(), out_dtype=torch.bfloat16 - ).t() - y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1) * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] - ) - y = y.to(out_dtype) - - return y - - -def sparse_quant_int8_cutlass_matmul( - x_vals_int8, - x_scales, - w_vals_int8, - w_meta_int32, - w_scales, - out_dtype, -): - assert ( - x_vals_int8.dtype == torch.int8 - ), f"x dtype {x_vals_int8.dtype} not yet supported" - assert ( - w_vals_int8.dtype == torch.int8 - ), f"w dtype {w_vals_int8.dtype} not yet supported" - assert w_scales.dtype == out_dtype, f"{w_scales.dtype} does not match {out_dtype}" - assert w_meta_int32.dtype == torch.int32, f"{w_meta_int32.dtype} not yet supported" - - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() - - assert x_scales.dtype in [ - torch.float, - torch.bfloat16, - ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" - - y_dot_int32 = torch._sparse_semi_structured_linear( - tmp, w_vals_int8, w_meta_int32.view(torch.int32), out_dtype=torch.int32 - ) - y = (y_dot_int32 * x_scales.reshape(-1, 1) * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_int32.shape[-1] - ) - y = y.to(out_dtype) - return y - -class Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight( - Int8DynamicallyQuantizedLinearWeight -): - def dequantize(self, dtype=None): - # overload dequantize op for __repr__ - zero_points = torch.zeros(self.q_scales.shape, device=self.q_scales.device, dtype=self.q_scales.dtype) - int_data_expanded = torch._cslt_sparse_mm(self.int_data, torch.eye(self.shape[1], - dtype=self.int_data.dtype, - device=self.int_data.device)) - dq_t = dequantize_per_channel( - int_data_expanded, self.q_scales, zero_points, self.dtype if dtype is None else dtype - ).to(self.dtype) - - return dq_t if not self.transposed else dq_t.t() - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - return sparse_quant_int8_dynamic_linear( - act_mat, w_qtensor.int_data, None, w_qtensor.q_scales, bias, act_mat.dtype, - fuse_mul=True - ) - - @classmethod - def from_float(cls, input_float, qmin=-128, qmax=127): - - assert input_float.is_cuda - - w_int_repr, w_scales, _ = dynamically_quantize_per_channel( - input_float, qmin, qmax, torch.int8 - ) - - int_data = w_int_repr.contiguous() - int_data = torch._cslt_compress(int_data) - - return cls( - int_data, - w_scales, - False, - input_float.shape, - dtype=input_float.dtype, - ) - - -class Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight(QuantizedLinearWeightBase): - - @staticmethod - def __new__(cls, int_data, mask_meta, q_scales, transposed, shape, **kwargs): - kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype) - return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] - - def __init__(self, int_data, mask_meta, q_scales, transposed, shape, **kwargs): - self.q_scales = q_scales - self.mask_meta = mask_meta - super().__init__(int_data, transposed) - - def dequantize(self, dtype=None): - """ - Obtain the dequantized version of the quantized tensor subclass - """ - dq_t = dequantize_per_channel( - self.int_data, self.q_scales, 0, self.dtype if dtype is None else dtype - ).to(self.dtype) - # data was transposed to dequantize so make sure shape is correct - return dq_t if not self.transposed else dq_t.t() - - def int_repr(self): - """ - Get the internal integer representation of the quantized tensor - """ - return self.int_data if self.transposed else self.int_data.t() - - def q_params(self): - """ - Get the quantization scales for the quantized tensor - """ - return {"q_scales": self.q_scales} - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.int_data.to(kwargs["device"]), - self.mask_meta.to(kwargs["device"]), - self.q_scales.to(kwargs["device"]), - self.transposed, - self.shape, - **kwargs, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.int_data), - fn(self.mask_meta), - fn(self.q_scales), - self.transposed, - self.shape, - dtype=self.dtype, - ) - - def _change_shape(self, shape): - return self.__class__( - self.int_data, - self.mask_meta, - self.q_scales, - self.transposed, - shape, - dtype=self.dtype, - ) - - def __tensor_flatten__(self): - return ["int_data", "mask_meta", "q_scales"], [ - self.transposed, - self.dtype, - self.shape, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None - ): - int_data, q_scales = tensor_data_dict["int_data"], tensor_data_dict["q_scales"] - mask_meta = tensor_data_dict["mask_meta"] - transposed, dtype, shape = tensor_attributes - return cls( - int_data, - mask_meta, - q_scales, - transposed, - shape if outer_size is None else outer_size, - dtype=dtype, - strides=outer_stride, - ) - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - return sparse_quant_int8_dynamic_linear( - act_mat, - w_qtensor.int_data, - w_qtensor.mask_meta, - w_qtensor.q_scales, - bias, - act_mat.dtype, - ) - - @classmethod - def from_float(cls, input_float, qmin=-128, qmax=127): - - assert input_float.is_cuda - - w_int_repr, w_scales, _ = dynamically_quantize_per_channel( - input_float, qmin, qmax, torch.int8 - ) - - int_data = w_int_repr.contiguous() - sparse_tensor = to_sparse_semi_structured(int_data) - - return cls( - sparse_tensor.packed, - sparse_tensor.meta, - w_scales, - False, - input_float.shape, - dtype=input_float.dtype, - ) - -def int8_dynamic_activation_int8_2x4_sparse_weight(): - return Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight.from_float diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 8f8ca24a39..a12d954422 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -7,6 +7,7 @@ _is_linear, _replace_with_custom_fn_if_matches_filter, _get_linear_subclass_inserter, + int8_dynamic_activation_int8_semi_sparse_weight, ) # Sparsity helper functions @@ -29,16 +30,21 @@ def apply_fake_sparsity(model, **kwargs): sparsifier.step() sparsifier.squash_mask() +def semi_sparse_weight(): + """ + Convert the weight of linear moduels to semi-structured (2:4) sparsity + """ + return _get_linear_subclass_inserter(to_sparse_semi_structured) -def sparsify(model: torch.nn.Module, +def sparsify_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None) -> torch.nn.Module: """Convert the weight of linear modules in the model with `apply_tensor_subclass` This function is essentially the same as quantize, put for sparsity subclasses. Currently, we support two options for sparsity: - - semi-structured (2:4) sparsity with `to_sparse_semi_structured` - - int8 dynamic quantization + 2:4 sparsity with `int8_dynamic_activation_int8_2x4_sparse_weight`, which is also available via the quantize API + - semi-structured (2:4) sparsity with `semi_sparse_weight` + - int8 dynamic quantization + 2:4 sparsity with `int8_dynamic_activation_int8_semi_sparse_weight`, which is also available via the quantize API Args: model (torch.nn.Module): input model @@ -49,7 +55,7 @@ def sparsify(model: torch.nn.Module, Example:: import torch import torch.nn as nn - from torchao.sparsity import sparsify + from torchao.sparsity import sparsify_ def filter_fn(module: nn.Module, fqn: str) -> bool: return isinstance(module, nn.Linear) @@ -57,17 +63,15 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) # for 2:4 sparsity - from torch.sparse import to_sparse_semi_structured - m = sparsify(m, to_sparse_semi_structured, filter_fn) + from torchao.sparse_api import semi_sparse_weight + m = sparsify_(m, semi_sparse_weight(), filter_fn) # for int8 dynamic quantization + 2:4 sparsity - from torchao.sparsity.prototype import int8_dynamic_activation_int8_2x4_sparse_weight - m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight(), filter_fn) + from torchao.sparsity.prototype import int8_dynamic_activation_int8_semi_sparse_weight + m = sparsify_(m, int8_dynamic_activation_int8_semi_sparse_weight(), filter_fn) """ _replace_with_custom_fn_if_matches_filter( model, - _get_linear_subclass_inserter(apply_tensor_subclass), + apply_tensor_subclass, _is_linear if filter_fn is None else filter_fn, ) - - return model From afde1755d906ad644e04835675e7856d72c3c87b Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 26 Jul 2024 10:10:32 -0700 Subject: [PATCH 5/5] Refactor LinearActQuantizedTensor (#542) Summary: * rename to LinearActivationQuantizedTensor * using `implements` util to implement torch function and torch dispatch overwrites Test Plan: CI Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 12 +- torchao/dtypes/affine_quantized_tensor.py | 60 +++---- torchao/dtypes/utils.py | 56 +++++- .../prototype/low_bit_optim/subclass_4bit.py | 17 +- .../prototype/low_bit_optim/subclass_8bit.py | 17 +- .../prototype/low_bit_optim/subclass_fp8.py | 17 +- torchao/prototype/quant_llm/quant_llm.py | 26 +-- torchao/quantization/README.md | 2 +- torchao/quantization/__init__.py | 6 + .../linear_activation_quantized_tensor.py | 170 ++++++++++++++++++ torchao/quantization/quant_api.py | 15 +- torchao/quantization/subclass.py | 155 ---------------- tutorials/calibration_flow/static_quant.py | 4 +- 13 files changed, 288 insertions(+), 269 deletions(-) create mode 100644 torchao/quantization/linear_activation_quantized_tensor.py diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index c19dc2660b..155a232c3e 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -22,12 +22,14 @@ from torchao.dtypes import ( AffineQuantizedTensor, ) +from torchao.quantization import ( + LinearActivationQuantizedTensor, +) from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, ) from torchao.quantization.subclass import ( - LinearActQuantizedTensor, Int8WeightOnlyQuantizedLinearWeight, Int4WeightOnlyQuantizedLinearWeight, ) @@ -504,8 +506,8 @@ def test_quantized_tensor_subclass_8da4w(self): example_inputs = m.example_inputs() quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size)) - assert isinstance(m.linear1.weight, LinearActQuantizedTensor) - assert isinstance(m.linear2.weight, LinearActQuantizedTensor) + assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor) + assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor) assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) @@ -577,8 +579,8 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") quantize_(m, int8_dynamic_activation_int8_weight()) - assert isinstance(m.linear1.weight, LinearActQuantizedTensor) - assert isinstance(m.linear2.weight, LinearActQuantizedTensor) + assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor) + assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor) assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 5c762231e2..b71e48a3a0 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -17,7 +17,8 @@ from torchao.utils import find_multiple from torchao.dtypes.utils import ( _implements, - _ATEN_OP_OR_TORCH_FN_TABLE, + _dispatch__torch_function__, + _dispatch__torch_dispatch__, _register_layout_cls, _get_layout_tensor_constructor, LayoutType, @@ -295,17 +296,6 @@ def from_float_static( def layout_type(self) -> LayoutType: return self.layout_tensor.layout_type - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - kwargs = {} if kwargs is None else kwargs - - if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: - return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs) - - with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - - def _get_to_kwargs(self, *args, **kwargs): device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) device = self.device if device is None else device @@ -347,29 +337,23 @@ def _apply_fn_to_data(self, fn): strides=self.stride(), ) - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - # Note: we only added cpu path here for 8da4w, this is for executorch, in the future - # 1. we'll add cpu/cuda version (int4mm etc.) - # 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like - # cpu device + et laytout --> gives current 8da4w executorch representation - # cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc. - # cuda device + some layout --> gives cuda kernel - - # two scenarios where we currently fall back to vanilla mm: - # 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized - # kernels in CPU as well, see the note above - # 2 - we're given non-floats - quantizing long to int8 is crazy - if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: - return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs) + implements = classmethod(_implements) + # Note: we only added cpu path here for 8da4w, this is for executorch, in the future + # 1. we'll add cpu/cuda version (int4mm etc.) + # 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like + # cpu device + et laytout --> gives current 8da4w executorch representation + # cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc. + # cuda device + some layout --> gives cuda kernel - raise NotImplementedError( - f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" - ) + # two scenarios where we currently fall back to vanilla mm: + # 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized + # kernels in CPU as well, see the note above + # 2 - we're given non-floats - quantizing long to int8 is crazy + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) + __torch_function__ = classmethod(_dispatch__torch_function__) -def implements(aten_ops_or_torch_fn): - return _implements(AffineQuantizedTensor, aten_ops_or_torch_fn) +implements = AffineQuantizedTensor.implements def register_layout_cls(layout_type_class: type(LayoutType)): return _register_layout_cls(AffineQuantizedTensor, layout_type_class) @@ -827,7 +811,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): @implements(torch.nn.functional.linear) -def functional_linear(*args, **kwargs): +def _(func, types, *args, **kwargs): input_tensor, weight_tensor, bias = ( args[0], args[1], @@ -846,7 +830,7 @@ def functional_linear(*args, **kwargs): return torch.nn.functional.linear(input_tensor, weight_tensor, bias) @implements([aten.mm.default, aten.addmm.default]) -def aten_mm(func, *args, **kwargs): +def _(func, types, *args, **kwargs): if not args[0].is_floating_point(): raise NotImplementedError(f"{func} is not implemented for non floating point input") @@ -885,21 +869,21 @@ def aten_mm(func, *args, **kwargs): return func(input_tensor, weight_tensor) @implements([aten.detach.default]) -def detach(func, *args, **kwargs): +def _(func, types, *args, **kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) @implements([aten.clone.default]) -def clone(func, *args, **kwargs): +def _(func, types, *args, **kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) @implements([aten._to_copy.default]) -def _to_copy(func, *args, **kwargs): +def _(func, types, *args, **kwargs): return return_and_correct_aliasing( func, args, @@ -908,7 +892,7 @@ def _to_copy(func, *args, **kwargs): ) @implements([aten.t.default]) -def t(func, *args, **kwargs): +def _(func, types, *args, **kwargs): block_size = args[0].block_size assert len(block_size) == 2 transposed_block_size = (block_size[1], block_size[0]) diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 656c4873ab..9d49809d47 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -5,19 +5,28 @@ from dataclasses import dataclass """ -torch_function and torch_dispatch operator dispatch registrations - -first key is a tensor subclass type like AffineQuantizedTensor, -second key is a `func` in __torhc_function__ or __torch_dispatch__, -value is a function that implements the dispatch +Helper function for implementing aten op or torch function dispatch +and dispatching to these implementations. """ -_ATEN_OP_OR_TORCH_FN_TABLE: Dict[Callable, Dict[Callable, Callable]] = defaultdict(dict) - def _implements(cls, aten_ops_or_torch_fns): """Use this decorator to implement a function for an aten ops in __torch_dispatch__ (if user passed in a list of ops) or torch function in __torch_function__ (if user passed in a single object) + + class MyTensor(torch.Tensor): + ... + implements = classmethod(_implements) + + implements = MyTensor.implements + + @implements(torch.nn.functional.linear): + def _(func, types, args, kwargs): + ... + """ + if not hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE"): + cls._ATEN_OP_OR_TORCH_FN_TABLE = {} + if not isinstance(aten_ops_or_torch_fns, (list, tuple)): aten_ops_or_torch_fns = [aten_ops_or_torch_fns] def decorator(func): @@ -26,10 +35,41 @@ def decorator(func): def wrapper(*args, **kwargs): return func(*args, **kwargs) - _ATEN_OP_OR_TORCH_FN_TABLE[cls][op] = wrapper + cls._ATEN_OP_OR_TORCH_FN_TABLE[op] = wrapper return func return decorator +def _dispatch__torch_function__(cls, func, types, args=(), kwargs=None): + """Use this util function for a common `__torch_function__` implementation + that dispatches to ops/functions registered with `_implements` + + class MyTensor(torch.Tensor): + ... + __torch_function__ = classmethod(_dispatch__torch_function__) + """ + kwargs = {} if kwargs is None else kwargs + if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \ + func in cls._ATEN_OP_OR_TORCH_FN_TABLE: + return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs) + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + +def _dispatch__torch_dispatch__(cls, func, types, args, kwargs): + """Use this util function for a common `__torch_dispatch__` implementation + that dispatches to ops/functions registered with `_implements` + + class MyTensor(torch.Tensor): + ... + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) + """ + if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \ + func in cls._ATEN_OP_OR_TORCH_FN_TABLE: + return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs) + + raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func}") + + """ Base class for different LayoutType, should not be instantiated directly """ diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index a24cf8b1d5..087c9912b4 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -2,7 +2,7 @@ import torch from torch import Tensor -from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE +from torchao.dtypes.utils import _implements, _dispatch__torch_dispatch__ from .quant_utils import create_dynamic_map, scale_tensor, quantize_4bit_with_qmap, dequant_with_qmap @@ -85,16 +85,11 @@ def __repr__(self): f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})" ) - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: - return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs) - - raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported") + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) @OptimState4bit.implements(aten.copy_.default) -def _(func, *args, **kwargs): +def _(func, types, *args, **kwargs): dst = args[0] src = args[1] @@ -121,14 +116,14 @@ def _(func, *args, **kwargs): @OptimState4bit.implements(aten.lerp.Scalar) -def _(func, *args, **kwargs): +def _(func, types, *args, **kwargs): args = [x.dequantize() if isinstance(x, OptimState4bit) else x for x in args] return func(*args, **kwargs) # this is needed for DTensor.from_local() and for flattening tensor @OptimState4bit.implements(aten.view.default) -def _(func, *args, **kwargs): +def _(func, types, *args, **kwargs): x, shape = args if tuple(x.shape) == tuple(shape): @@ -147,7 +142,7 @@ def _(func, *args, **kwargs): c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default, ]) -def _(func, *args, **kwargs): +def _(func, types, *args, **kwargs): x = args[0] if not isinstance(x, OptimState4bit): raise ValueError(f"expecting a OptimState4bit but found {type(x)}") diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 1e2067963a..128a020b66 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -1,6 +1,6 @@ import torch from torch import Tensor -from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE +from torchao.dtypes.utils import _implements, _dispatch__torch_dispatch__ from .quant_utils import create_dynamic_map, scale_tensor, quantize_8bit_with_qmap, dequant_with_qmap @@ -71,16 +71,11 @@ def __repr__(self): f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})" ) - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: - return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs) - - raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported") + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) @OptimState8bit.implements(aten.copy_.default) -def _(func, *args, **kwargs): +def _(func, types, *args, **kwargs): dst = args[0] src = args[1] @@ -103,14 +98,14 @@ def _(func, *args, **kwargs): @OptimState8bit.implements(aten.lerp.Scalar) -def _(func, *args, **kwargs): +def _(func, types, *args, **kwargs): args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args] return func(*args, **kwargs) # this is needed for DTensor.from_local() @OptimState8bit.implements(aten.view.default) -def _(func, *args, **kwargs): +def _(func, types, *args, **kwargs): x, shape = args return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed) @@ -122,7 +117,7 @@ def _(func, *args, **kwargs): c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default, ]) -def _(func, *args, **kwargs): +def _(func, types, *args, **kwargs): x = args[0] if not isinstance(x, OptimState8bit): raise ValueError(f"expecting a OptimState8bit but found {type(x)}") diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index b78638cd01..de1d629fcc 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -1,6 +1,6 @@ import torch from torch import Tensor -from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE +from torchao.dtypes.utils import _implements, _dispatch__torch_dispatch__ aten = torch.ops.aten @@ -77,16 +77,11 @@ def __repr__(self): f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})" ) - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: - return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs) - - raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported") + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) @OptimStateFp8.implements(aten.copy_.default) -def _(func, *args, **kwargs): +def _(func, types, *args, **kwargs): dst = args[0] src = args[1] @@ -107,14 +102,14 @@ def _(func, *args, **kwargs): @OptimStateFp8.implements(aten.lerp.Scalar) -def _(func, *args, **kwargs): +def _(func, types, *args, **kwargs): args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args] return func(*args, **kwargs) # this is needed for DTensor.from_local() @OptimStateFp8.implements(aten.view.default) -def _(func, *args, **kwargs): +def _(func, types, *args, **kwargs): x, shape = args return OptimStateFp8(x.codes.view(shape), x.scale) @@ -126,7 +121,7 @@ def _(func, *args, **kwargs): c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default, ]) -def _(func, *args, **kwargs): +def _(func, types, *args, **kwargs): x = args[0] if not isinstance(x, OptimStateFp8): raise ValueError(f"expecting a OptimStateFp8 but found {type(x)}") diff --git a/torchao/prototype/quant_llm/quant_llm.py b/torchao/prototype/quant_llm/quant_llm.py index 3a5dafb52a..38eed6dd5e 100644 --- a/torchao/prototype/quant_llm/quant_llm.py +++ b/torchao/prototype/quant_llm/quant_llm.py @@ -6,7 +6,7 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones from torchao.ops import quant_llm_linear -from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE +from torchao.dtypes.utils import _implements, _dispatch__torch_function__, _dispatch__torch_dispatch__ from torchao.quantization.quant_api import _get_linear_subclass_inserter @@ -348,6 +348,8 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te class QuantLlmLinearWeight(Tensor): implements = classmethod(_implements) + __torch_function__ = classmethod(_dispatch__torch_function__) + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) @staticmethod def __new__(cls, fpx_data: Tensor, scale: Tensor, ebits: int, mbits: int): @@ -399,26 +401,8 @@ def _apply_fn_to_data(self, fn): self.mbits, ) - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - kwargs = {} if kwargs is None else kwargs - - if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: - return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs) - - with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: - return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs) - - raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported") - - @QuantLlmLinearWeight.implements(torch.nn.functional.linear) -def _(*args, **kwargs): +def _(func, types, *args, **kwargs): act = args[0] weight = args[1] bias = args[2] if len(args) >= 3 else None @@ -447,7 +431,7 @@ def _(*args, **kwargs): @QuantLlmLinearWeight.implements(torch.ops.aten.detach.default) -def _(func, *args, **kwargs): +def _(func, types, *args, **kwargs): return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 5ab14bd297..bc4f6e4f35 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -145,7 +145,7 @@ for n, m in model.named_modules(): # note: quantization for activation need to be applied after the weight quantization # quantization activation (needed by dynamic quantization) input_quant_func = int8wo_quant # specify how input activation is quantized - m.weight = nn.Parameter(to_linear_act_quantized(m.weight, input_quant_func)) + m.weight = nn.Parameter(to_linear_activation_quantized(m.weight, input_quant_func)) ``` The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support `torch.export.export` and `torch.aot_compile` with the following workaround: diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 6bf37f0080..2ac4a0c285 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -12,6 +12,10 @@ from .weight_only import * # noqa: F403 from .unified import * from .autoquant import * +from .linear_activation_quantized_tensor import ( # noqat: F403 + LinearActivationQuantizedTensor, + to_linear_activation_quantized, +) __all__ = [ "swap_conv2d_1x1_to_linear" @@ -35,4 +39,6 @@ "int8_dynamic_activation_int8_semi_sparse_weight", "int4_weight_only", "int8_weight_only", + "LinearActivationQuantizedTensor", + "to_linear_activation_quantized", ] diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py new file mode 100644 index 0000000000..e4e4fedc45 --- /dev/null +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -0,0 +1,170 @@ +import torch +from torchao.dtypes.utils import ( + _implements, + _dispatch__torch_function__, + _dispatch__torch_dispatch__, +) +from typing import Callable +from torch.utils._python_dispatch import return_and_correct_aliasing + +__all__ = [ + "LinearActivationQuantizedTensor", + "to_linear_activation_quantized", +] + +aten = torch.ops.aten + +class LinearActivationQuantizedTensor(torch.Tensor): + """ + Applies activation quantization for linear operator + """ + def __new__( + cls, + original_weight_tensor: torch.Tensor, + input_quant_func: Callable, + ): + kwargs = {} + dtype = original_weight_tensor.dtype + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + kwargs["device"] = original_weight_tensor.device + shape = original_weight_tensor.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + original_weight_tensor: torch.Tensor, + input_quant_func: Callable, + ): + self.original_weight_tensor = original_weight_tensor + self.input_quant_func = input_quant_func + + def __tensor_flatten__(self): + return ["original_weight_tensor"], [self.input_quant_func] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + original_weight_tensor = tensor_data_dict["original_weight_tensor"] + input_quant_func, = tensor_attributes + return cls( + original_weight_tensor, + input_quant_func, + ) + + @classmethod + def from_float(cls, input_float, input_quant_func): + return cls(input_float, input_quant_func) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.original_weight_tensor), + self.input_quant_func, + ) + + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + } + return kwargs + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.original_weight_tensor.to(**kwargs), + self.input_quant_func, + ) + + implements = classmethod(_implements) + __torch_function__ = classmethod(_dispatch__torch_function__) + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) + +implements = LinearActivationQuantizedTensor.implements + +@implements(torch.nn.functional.linear) +def _(func, types, *args, **kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if isinstance(weight_tensor, LinearActivationQuantizedTensor): + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + aqt = input_quant_func(input_tensor) + return torch.nn.functional.linear(aqt, original_weight_tensor, bias) + + raise NotImplementedError("LinearActivationQuantizedTensor: No specialized dispatch found for linear op") + +@implements([aten.mm.default, aten.addmm.default]) +def _(func, types, *args, **kwargs): + if not args[0].is_floating_point(): + raise NotImplementedError(f"LinearActivationQuantizedTensor: expecting a floating point input") + + if func == aten.addmm.default: + assert args[1].shape[-1] == args[2].shape[0], ( + f"need mat1 shape: {args[1].shape} final" + f"dim to match mat2 shape: {args[2].shape} first dim " + ) + input_tensor, weight_tensor, bias = ( + args[1], + args[2], + args[0], + ) + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + aqt = input_quant_func(input_tensor) + return func(bias, aqt, original_weight_tensor) + else: + # aten.mm.default + assert args[0].shape[-1] == args[1].shape[0], ( + f"need mat1 shape: {args[0].shape} final dim" + f"to match mat2 shape: {args[1].shape} first dim" + ) + input_tensor, weight_tensor = ( + args[0], + args[1], + ) + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + aqt = input_quant_func(input_tensor) + return func(aqt, original_weight_tensor) + + +@implements(aten.detach.default) +def _(func, types, *args, **kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + +@implements(aten.clone.default) +def _(func, types, *args, **kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + +@implements(aten._to_copy.default) +def _(func, types, *args, **kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + +@implements(aten.t.default) +def _(func, types, *args, **kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.t) + ) + +to_linear_activation_quantized = LinearActivationQuantizedTensor.from_float diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 161a84c4e4..a3f6daca30 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -28,8 +28,11 @@ ) from .subclass import ( QuantizedLinearWeightBase, - LinearActQuantizedTensor, - to_linear_act_quantized, +) + +from .linear_activation_quantized_tensor import ( + LinearActivationQuantizedTensor, + to_linear_activation_quantized, ) from .quant_primitives import ( @@ -189,7 +192,7 @@ def _is_linear(mod, *args): and not isinstance(mod.weight, QuantizedLinearWeightBase) and not isinstance(mod.weight, AutoQuantizableLinearWeight) and not isinstance(mod.weight, AffineQuantizedTensor) - and not isinstance(mod.weight, LinearActQuantizedTensor) + and not isinstance(mod.weight, LinearActivationQuantizedTensor) ) import torch.nn.utils.parametrize as parametrize @@ -351,7 +354,7 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight): input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype) weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps) - weight = to_linear_act_quantized(weight, input_quant_func) + weight = to_linear_activation_quantized(weight, input_quant_func) return weight return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant) @@ -450,7 +453,7 @@ def get_per_token_block_size(x): block_size = get_weight_block_size(weight) weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type) - weight = to_linear_act_quantized(weight, input_quant_func) + weight = to_linear_activation_quantized(weight, input_quant_func) return weight return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant) @@ -458,7 +461,7 @@ def get_per_token_block_size(x): def int8_dynamic_activation_int8_semi_sparse_weight(): """ - Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight + Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight quantization + 2:4 sparsity to linear layers. """ from torchao.dtypes import SemiSparseLayoutType diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index a2801a622f..8978cb7ce4 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -30,8 +30,6 @@ "Int8DynamicallyQuantizedLinearWeight", "Int8WeightOnlyQuantizedLinearWeight", "Int4WeightOnlyQuantizedLinearWeight", - "LinearActQuantizedTensor", - "to_linear_act_quantized", ] @@ -599,156 +597,3 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): ) int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) return int_data, scales_and_zeros, False, groupsize, inner_k_tiles - - -class LinearActQuantizedTensor(torch.Tensor): - """ - Applies activation quantization for linear operator - """ - def __new__( - cls, - original_weight_tensor: torch.Tensor, - input_quant_func: Callable, - ): - kwargs = {} - dtype = original_weight_tensor.dtype - kwargs["dtype"] = dtype - kwargs["requires_grad"] = False - kwargs["device"] = original_weight_tensor.device - shape = original_weight_tensor.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - original_weight_tensor: torch.Tensor, - input_quant_func: Callable, - ): - self.original_weight_tensor = original_weight_tensor - self.input_quant_func = input_quant_func - - def __tensor_flatten__(self): - return ["original_weight_tensor"], [self.input_quant_func] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - original_weight_tensor = tensor_data_dict["original_weight_tensor"] - input_quant_func, = tensor_attributes - return cls( - original_weight_tensor, - input_quant_func, - ) - - @classmethod - def from_float(cls, input_float, input_quant_func): - return cls(input_float, input_quant_func) - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - kwargs = {} if kwargs is None else kwargs - - if func is torch.nn.functional.linear: - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - if isinstance(weight_tensor, LinearActQuantizedTensor): - input_quant_func = weight_tensor.input_quant_func - original_weight_tensor = weight_tensor.original_weight_tensor - aqt = input_quant_func(input_tensor) - return torch.nn.functional.linear(aqt, original_weight_tensor, bias) - - with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.original_weight_tensor), - self.input_quant_func, - ) - - def _get_to_kwargs(self, *args, **kwargs): - device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) - device = self.device if device is None else device - dtype = self.dtype if dtype is None else dtype - memory_format = ( - memory_format if memory_format is not None else torch.preserve_format - ) - kwargs = { - "device": device, - "dtype": dtype, - "memory_format": memory_format, - } - return kwargs - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.original_weight_tensor.to(**kwargs), - self.input_quant_func, - ) - - def __torch_dispatch__(cls, func, types, args, kwargs): - if ( - func in [aten.mm.default, aten.addmm.default] - and args[0].is_floating_point() - ): - if func == aten.addmm.default: - assert args[1].shape[-1] == args[2].shape[0], ( - f"need mat1 shape: {args[1].shape} final" - f"dim to match mat2 shape: {args[2].shape} first dim " - ) - input_tensor, weight_tensor, bias = ( - args[1], - args[2], - args[0], - ) - input_quant_func = weight_tensor.input_quant_func - original_weight_tensor = weight_tensor.original_weight_tensor - aqt = input_quant_func(input_tensor) - return func(bias, aqt, original_weight_tensor) - else: - # aten.mm.default - assert args[0].shape[-1] == args[1].shape[0], ( - f"need mat1 shape: {args[0].shape} final dim" - f"to match mat2 shape: {args[1].shape} first dim" - ) - input_tensor, weight_tensor = ( - args[0], - args[1], - ) - input_quant_func = weight_tensor.input_quant_func - original_weight_tensor = weight_tensor.original_weight_tensor - aqt = input_quant_func(input_tensor) - return func(aqt, original_weight_tensor) - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - if func is aten._to_copy.default: - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), - ) - - if func is aten.t.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.t) - ) - - raise NotImplementedError( - f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported" - ) - -to_linear_act_quantized = LinearActQuantizedTensor.from_float diff --git a/tutorials/calibration_flow/static_quant.py b/tutorials/calibration_flow/static_quant.py index a546c5ab89..7911f645e1 100644 --- a/tutorials/calibration_flow/static_quant.py +++ b/tutorials/calibration_flow/static_quant.py @@ -11,7 +11,7 @@ from torchao.dtypes import to_affine_quantized_static from torchao.quantization.utils import compute_error from torchao.quantization import quantize_ -from torchao.quantization.subclass import to_linear_act_quantized +from torchao.quantization import to_linear_activation_quantized from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter @@ -60,7 +60,7 @@ def weight_quant_func(weight): # activation quantization act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams() input_quant_func = lambda x: to_affine_quantized_static(x, act_scale, act_zero_point, x.shape, target_dtype) - linear.weight = torch.nn.Parameter(to_linear_act_quantized(linear.weight, input_quant_func), requires_grad=False) + linear.weight = torch.nn.Parameter(to_linear_activation_quantized(linear.weight, input_quant_func), requires_grad=False) return linear