-
Notifications
You must be signed in to change notification settings - Fork 375
Add Int4PlainInt32Tensor #2845
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add Int4PlainInt32Tensor #2845
Changes from 14 commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
ec3e065
Add Int4XPUTensorIntZP
liangan1 1dc5b2c
Add int4_xpu_tensor
liangan1 e63b100
Update int4_xpu_tensor.py
liangan1 5ef1ca2
Fix typo
liangan1 a28dd89
Fix code format issue
liangan1 8a0f124
fix bug
liangan1 a0ff36f
Fix code format
liangan1 5e9c476
Merge branch 'main' into liangan1/int4_xpu_int_zp
liangan1 2c4c2ce
Update int4_xpu_tensor.py
liangan1 e48ea0b
change the pack format to plain
liangan1 c4e5b9d
fix typo
liangan1 7063e56
Update quant_api.py
liangan1 5b87d8b
merge main branch
liangan1 6076877
Merge branch 'main' into liangan1/int4_xpu_int_zp
liangan1 8d2acd2
Update __init__.py
liangan1 43acd66
Update __init__.py
liangan1 a047c00
change Int4XPUTensorIntZP to Int4PlainInt32
liangan1 3f70b2b
Update __init__.py
liangan1 402dd72
Refine code
liangan1 282f1a8
Refine code
liangan1 cd781fc
Update __init__.py
liangan1 afadf69
Update __init__.py
liangan1 b68beef
Add more comments about the original weight dtype
liangan1 66e05ff
Merge branch 'main' into liangan1/int4_xpu_int_zp
liangan1 105b4b9
fix code format issue
liangan1 b24ff1a
fix code format issue
liangan1 77868bc
skip ut if no xpu
liangan1 970aa17
Update test_int4_plain_int32_tensor.py
liangan1 78f6bb2
Add assert for the original weight data type
liangan1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
89 changes: 89 additions & 0 deletions
89
test/quantization/quantize_/workflows/int4/test_int4_xpu.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import tempfile | ||
| import unittest | ||
|
|
||
| import torch | ||
| from torch.testing._internal.common_utils import ( | ||
| TestCase, | ||
| instantiate_parametrized_tests, | ||
| parametrize, | ||
| run_tests, | ||
| ) | ||
|
|
||
| from torchao.quantization import ( | ||
| Int4WeightOnlyConfig, | ||
| quantize_, | ||
| ) | ||
| from torchao.quantization.quant_primitives import ( | ||
| ZeroPointDomain, | ||
| ) | ||
| from torchao.quantization.utils import compute_error | ||
| from torchao.utils import ( | ||
| torch_version_at_least, | ||
| ) | ||
|
|
||
|
|
||
| def get_config(group_size): | ||
| return Int4WeightOnlyConfig( | ||
| group_size=group_size, | ||
| packing_format="plain_int32", | ||
| zero_point_domain=ZeroPointDomain.INT, | ||
| version=2, | ||
| ) | ||
|
|
||
|
|
||
| @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") | ||
| class Int4XPUTensorIntZP(TestCase): | ||
| @parametrize( | ||
| "sizes", | ||
| [ | ||
| ((128,), 256, 128), | ||
| ((32, 128), 512, 128), | ||
| ((2, 32, 128), 256, 12), | ||
| ], | ||
| ) | ||
| @parametrize("dtype", [torch.bfloat16, torch.half]) | ||
| @parametrize("group_size", [32, 64, 128]) | ||
| def test_linear(self, sizes, dtype, group_size): | ||
| device = "xpu" | ||
| M, N, K = sizes | ||
| input = torch.randn(*M, K, dtype=dtype, device=device) | ||
| linear = torch.nn.Linear(K, N, dtype=dtype, device=device) | ||
| original = linear(input) | ||
| quantize_(linear, get_config(group_size)) | ||
| quantized = linear(input) | ||
| self.assertTrue(compute_error(original, quantized) > 20) | ||
|
|
||
| compiled_linear = torch.compile(linear) | ||
| quantized_and_compiled = compiled_linear(input) | ||
| self.assertTrue(compute_error(original, quantized_and_compiled) > 20) | ||
|
|
||
| @parametrize("dtype", [torch.bfloat16, torch.half]) | ||
| def test_module_path(self, dtype): | ||
| linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu") | ||
| quantize_(linear, get_config(group_size=128)) | ||
| self.assertEqual( | ||
| str(type(linear.weight)), | ||
| "<class 'torchao.quantization.Int4XPUTensorIntZP'>", | ||
| ) | ||
|
|
||
| with tempfile.NamedTemporaryFile() as f: | ||
| torch.save(linear.state_dict(), f) | ||
| f.seek(0) | ||
| state_dict = torch.load(f) | ||
| self.assertEqual( | ||
| str(type(state_dict["weight"])), | ||
| "<class 'torchao.quantization.Int4XPUTensorIntZP'>", | ||
| ) | ||
|
|
||
|
|
||
| instantiate_parametrized_tests(Int4XPUTensorIntZP) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_tests() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,9 @@ | ||
| from .int4_preshuffled_tensor import Int4PreshuffledTensor | ||
| from .int4_tensor import Int4Tensor | ||
| from .int4_xpu_tensor import Int4XPUTensorIntZP | ||
|
|
||
| __all__ = [ | ||
| "Int4PreshuffledTensor", | ||
| "Int4Tensor", | ||
| "Int4XPUTensorIntZP", | ||
| ] |
181 changes: 181 additions & 0 deletions
181
torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,181 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
|
|
||
| from typing import List | ||
|
|
||
| import torch | ||
|
|
||
| from torchao.quantization.quant_primitives import ( | ||
| MappingType, | ||
| _choose_qparams_affine, | ||
| _quantize_affine, | ||
| ) | ||
| from torchao.utils import ( | ||
| TorchAOBaseTensor, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "Int4XPUTensorIntZP", | ||
| ] | ||
|
|
||
| aten = torch.ops.aten | ||
|
|
||
|
|
||
| class Int4XPUTensorIntZP(TorchAOBaseTensor): | ||
| """ | ||
| int4 weight-only quantization on XPU with oneDNN as backend (groupwise quantization only) | ||
| Tensor Attributes: | ||
| qdata: packed int4 weigh, always viewed as a 2D (N, K/2) tensor | ||
| scale: (K/group_size, N), dtype is the same as the original Tensor dtype | ||
| zero_point: (K/group_size, N) | ||
| Non-Tensor Attributes: | ||
| block_size: the block size for quantization, representing the granularity, for groupwise quantization, will have block_size (1, group_size). | ||
| we only support group_size = 32/64/128. | ||
| shape: shape of the original Tensor | ||
| """ | ||
|
|
||
| tensor_data_names = ["qdata", "scale", "zero_point"] | ||
| tensor_attribute_names = ["block_size", "shape"] | ||
|
|
||
| def __new__( | ||
| cls, | ||
| qdata, | ||
| scale, | ||
| zero_point, | ||
| block_size, | ||
| shape, | ||
| ): | ||
| kwargs = {} | ||
| kwargs["device"] = qdata.device | ||
| kwargs["dtype"] = scale.dtype | ||
| kwargs["requires_grad"] = False | ||
| return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] | ||
|
|
||
| def __init__(self, qdata, scale, zero_point, block_size, shape): | ||
| self.qdata = qdata | ||
| self.scale = scale | ||
| self.zero_point = zero_point | ||
| self.block_size = block_size | ||
|
|
||
| def _quantization_type(self): | ||
| return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" | ||
|
|
||
| @classmethod | ||
| def from_hp( | ||
| cls, | ||
| w: torch.Tensor, | ||
| block_size: List[int], | ||
| ): | ||
| assert w.ndim == 2 and w.device.type == "xpu", ( | ||
| f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}" | ||
| ) | ||
| assert len(block_size) == w.ndim | ||
|
|
||
| original_shape = w.shape | ||
| mapping_type = MappingType.ASYMMETRIC | ||
| target_dtype = torch.int32 | ||
| quant_min = 0 | ||
| quant_max = 15 | ||
| eps = 1e-6 | ||
| scale_dtype = None | ||
| zero_point_dtype = torch.int32 | ||
| scale, zero_point = _choose_qparams_affine( | ||
| w, | ||
| mapping_type.name, | ||
| block_size, | ||
| target_dtype, | ||
| quant_min, | ||
| quant_max, | ||
| eps, | ||
| scale_dtype, | ||
| zero_point_dtype, | ||
| ) | ||
| int_data = _quantize_affine( | ||
| w, | ||
| block_size, | ||
| scale, | ||
| zero_point, | ||
| target_dtype, | ||
| quant_min, | ||
| quant_max, | ||
| ) | ||
| assert int_data.dtype == torch.int32, ( | ||
| "torch.ops.aten._convert_weight_to_int4pack expects `int32` dtype" | ||
| ) | ||
| packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) | ||
| packed_weight = torch.ops.aten._convert_weight_to_int4pack( | ||
jerryzh168 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| packed_weight.contiguous(), 8 | ||
| ) | ||
| scale = scale.reshape(int_data.shape[0], -1) | ||
| zero_point = zero_point.reshape(int_data.shape[0], -1) | ||
| return Int4XPUTensorIntZP( | ||
| packed_weight, | ||
| scale.transpose(0, 1).contiguous(), | ||
| zero_point.transpose(0, 1).contiguous().to(torch.int8), | ||
| block_size, | ||
| original_shape, | ||
| ) | ||
|
|
||
|
|
||
| implements = Int4XPUTensorIntZP.implements | ||
|
|
||
|
|
||
| @implements([torch.nn.functional.linear, aten.linear.default]) | ||
| def _(func, types, args, kwargs): | ||
| input_tensor, weight_tensor, bias = ( | ||
| args[0], | ||
| args[1], | ||
| args[2] if len(args) > 2 else None, | ||
| ) | ||
| assert input_tensor.device.type == "xpu", ( | ||
| f"For XPU device only but got: {input_tensor.device}" | ||
| ) | ||
| assert isinstance(weight_tensor, Int4XPUTensorIntZP), ( | ||
| f"Expected weight_tensor to be Int4XPUTensorIntZP, got: {type(weight_tensor)}" | ||
| ) | ||
| assert weight_tensor.block_size[0] == 1, ( | ||
| f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" | ||
| ) | ||
| assert input_tensor.shape[-1] == weight_tensor.shape[1], ( | ||
| f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" | ||
| ) | ||
|
|
||
| act_mat = input_tensor | ||
| packed_weight = weight_tensor.qdata | ||
| scale = weight_tensor.scale | ||
| zero_point = weight_tensor.zero_point | ||
|
|
||
| orig_act_size = act_mat.size() | ||
| orig_dtype = act_mat.dtype | ||
|
|
||
| # reshape to 2D | ||
| act_mat = act_mat.reshape(-1, act_mat.shape[-1]) | ||
|
|
||
| # groupwise int4 quantization | ||
| groupsize = weight_tensor.block_size[1] | ||
| y = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( | ||
| act_mat, packed_weight, groupsize, scale, zero_point | ||
| ) | ||
|
|
||
| # remove out_feature padding | ||
| assert weight_tensor.ndim == 2 | ||
| orig_out_features = weight_tensor.shape[-2] | ||
| y = y[:, :orig_out_features] | ||
| y = y.reshape(*orig_act_size[:-1], orig_out_features) | ||
|
|
||
| if bias is not None: | ||
| y += bias | ||
| return y.to(orig_dtype) | ||
|
|
||
|
|
||
| Int4XPUTensorIntZP.__module__ = "torchao.quantization" | ||
|
|
||
| # Allow a model with Int4XPUTensorIntZP weights to be loaded with `weights_only=True` | ||
| torch.serialization.add_safe_globals([Int4XPUTensorIntZP]) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.