Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ec3e065
Add Int4XPUTensorIntZP
liangan1 Aug 22, 2025
1dc5b2c
Add int4_xpu_tensor
liangan1 Aug 22, 2025
e63b100
Update int4_xpu_tensor.py
liangan1 Aug 25, 2025
5ef1ca2
Fix typo
liangan1 Aug 25, 2025
a28dd89
Fix code format issue
liangan1 Aug 25, 2025
8a0f124
fix bug
liangan1 Aug 25, 2025
a0ff36f
Fix code format
liangan1 Aug 25, 2025
5e9c476
Merge branch 'main' into liangan1/int4_xpu_int_zp
liangan1 Aug 26, 2025
2c4c2ce
Update int4_xpu_tensor.py
liangan1 Aug 26, 2025
e48ea0b
change the pack format to plain
liangan1 Aug 26, 2025
c4e5b9d
fix typo
liangan1 Aug 26, 2025
7063e56
Update quant_api.py
liangan1 Aug 26, 2025
5b87d8b
merge main branch
liangan1 Aug 28, 2025
6076877
Merge branch 'main' into liangan1/int4_xpu_int_zp
liangan1 Aug 27, 2025
8d2acd2
Update __init__.py
liangan1 Aug 29, 2025
43acd66
Update __init__.py
liangan1 Aug 29, 2025
a047c00
change Int4XPUTensorIntZP to Int4PlainInt32
liangan1 Aug 29, 2025
3f70b2b
Update __init__.py
liangan1 Aug 29, 2025
402dd72
Refine code
liangan1 Aug 29, 2025
282f1a8
Refine code
liangan1 Aug 29, 2025
cd781fc
Update __init__.py
liangan1 Sep 1, 2025
afadf69
Update __init__.py
liangan1 Sep 1, 2025
b68beef
Add more comments about the original weight dtype
liangan1 Sep 1, 2025
66e05ff
Merge branch 'main' into liangan1/int4_xpu_int_zp
liangan1 Sep 1, 2025
105b4b9
fix code format issue
liangan1 Sep 1, 2025
b24ff1a
fix code format issue
liangan1 Sep 1, 2025
77868bc
skip ut if no xpu
liangan1 Sep 1, 2025
970aa17
Update test_int4_plain_int32_tensor.py
liangan1 Sep 1, 2025
78f6bb2
Add assert for the original weight data type
liangan1 Sep 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions test/quantization/quantize_/workflows/int4/test_int4_xpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)

from torchao.quantization import (
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
torch_version_at_least,
)


def get_config(group_size):
return Int4WeightOnlyConfig(
group_size=group_size,
packing_format="int4_xpu_int_zp",
version=2,
)


@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
class Int4XPUTensorIntZP(TestCase):
@parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 512, 128),
((2, 32, 128), 256, 12),
],
)
@parametrize("dtype", [torch.bfloat16, torch.half])
@parametrize("group_size", [32, 64, 128])
def test_linear(self, sizes, dtype, group_size):
device = "xpu"
M, N, K = sizes
input = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(group_size))
quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)

compiled_linear = torch.compile(linear)
quantized_and_compiled = compiled_linear(input)
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)

@parametrize("dtype", [torch.bfloat16, torch.half])
def test_module_path(self, dtype):
linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu")
quantize_(linear, get_config(group_size=128))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4XPUTensorIntZP'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(linear.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Int4XPUTensorIntZP'>",
)


instantiate_parametrized_tests(Int4XPUTensorIntZP)


if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
Int4MarlinSparseTensor,
Int4PreshuffledTensor,
Int4Tensor,
Int4XPUTensorIntZP,
IntxUnpackedTensor,
)
from .smoothquant import (
Expand Down Expand Up @@ -162,6 +163,7 @@
"Int4Tensor",
"Int4PreshuffledTensor",
"Int4MarlinSparseTensor",
"Int4XPUTensorIntZP",
"IntxUnpackedTensor",
"Float8Tensor",
# smooth quant - subject to change
Expand Down
5 changes: 4 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
Int4PreshuffledTensor,
Int4Tensor,
IntxUnpackedTensor,
Int4XPUTensorIntZP,
QuantizeTensorToFloat8Kwargs,
)
from torchao.quantization.transform_module import (
Expand Down Expand Up @@ -518,7 +519,6 @@ def quantize_(
torch._C._log_api_usage_once("torchao.quantization.quantize_")

filter_fn = _is_linear if filter_fn is None else filter_fn

if isinstance(config, ModuleFqnToConfig):
_replace_with_custom_fn_if_matches_filter_with_name(
model,
Expand Down Expand Up @@ -1080,6 +1080,9 @@ def _int4_weight_only_quantize_tensor(weight, config):
block_size,
)
return new_weight
elif packing_format == PackingFormat.INT4_XPU_INT_ZP:
new_weight = Int4XPUTensorIntZP.from_hp(weight, block_size)
return new_weight
else:
raise ValueError(f"Unsupported packing format: {packing_format}")

Expand Down
3 changes: 3 additions & 0 deletions torchao/quantization/quantize_/common/packing_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ class PackingFormat(str, Enum):
Unpacked means the subbyte quantized data is stored as int8
"""
UNPACKED_TO_INT8 = "unpacked_to_int8"

"int4_xpu_int_zp is referring to the format used by int4 weight-only quantization on XPU with int zero point, which is a groupwise quantization format."
INT4_XPU_INT_ZP = "int4_xpu_int_zp"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't include int4 and xpu in the name, can you name this in terms of of how the quantized data is packed?

Copy link
Collaborator Author

@liangan1 liangan1 Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The int4 weight xpu is a plain format tensor according to this doc, it just pack 2 int4 weight elements in a byte and then store the 4*int4 as int32. So I change it to the plain.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, we have plain that stores 2*int4 as int8, can you reuse it or would need a new one? https://github.com/pytorch/ao/blob/main/torchao/quantization/quantize_/workflows/int4/int4_tensor.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@liangan1 can you use PLAIN_INT32 for packing_format, and rename things accordingly (tensor subclass, files etc.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jerryzh168. I have added PLAIN_INT32 to be used by the xpu int4. Per my understanding, the packing format should be a dispatch policy to select the right tensor subclassing and a tensor subclass should cover a specific quantization recipe. So I suppose I should keep the current tensor name for int4 xpu.
In this PR, we just want to enable the int xpu with int zp domain. The current oneDNN backend can not support the float zp as CUDA/CPU backend and the feature is WIP. I plain to reuse this packing format in the future and dispatch the tensor with the zero point domain information.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can reuse the packing format and the tensor for float32 zero_point as well in the future I think, but today we structure tensor subclass by: dtype + packing_format, so Int4PlainInt32 might be better

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. change it to Int4PlainInt32. pls help to review again.

4 changes: 4 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from .int4.int4_tensor import (
Int4Tensor,
)
from .int4.int4_xpu_tensor import (
Int4XPUTensorIntZP,
)
from .intx.intx_unpacked_tensor import (
IntxUnpackedTensor,
)
Expand All @@ -19,6 +22,7 @@
"Int4Tensor",
"Int4PreshuffledTensor",
"Int4MarlinSparseTensor",
"Int4XPUTensorIntZP",
"Float8Tensor",
"QuantizeTensorToFloat8Kwargs",
"IntxUnpackedTensor",
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/quantize_/workflows/int4/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .int4_preshuffled_tensor import Int4PreshuffledTensor
from .int4_tensor import Int4Tensor
from .int4_xpu_tensor import Int4XPUTensorIntZP

__all__ = [
"Int4PreshuffledTensor",
"Int4Tensor",
"Int4XPUTensorIntZP",
]
182 changes: 182 additions & 0 deletions torchao/quantization/quantize_/workflows/int4/int4_xpu_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.


from typing import List

import torch

from torchao.quantization.quant_primitives import (
MappingType,
_choose_qparams_affine,
_quantize_affine,
)
from torchao.utils import (
TorchAOBaseTensor,
)

__all__ = [
"Int4XPUTensorIntZP",
]

aten = torch.ops.aten


class Int4XPUTensorIntZP(TorchAOBaseTensor):
"""
int4 weight-only quantization on XPU with oneDNN as backend (groupwise quantization only)
Tensor Attributes:
qdata: packed int4 weigh, always viewed as a 2D (N, K/2) tensor, last dimension is packed
preshuffling is specific to CPU kernels, see Note below.
scale: (K/group_size, N), dtype is the same as the original Tensor dtype
zero_point: (K/group_size, N)
Non-Tensor Attributes:
block_size: the block size for quantization, representing the granularity, for groupwise quantization, will have block_size (1, group_size).
we only support group_size = 32/64/128.
shape: shape of the original Tensor
"""

tensor_data_names = ["qdata", "scale", "zero_point"]
tensor_attribute_names = ["block_size", "shape"]

def __new__(
cls,
qdata,
scale,
zero_point,
block_size,
shape,
):
kwargs = {}
kwargs["device"] = qdata.device
kwargs["dtype"] = scale.dtype
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, qdata, scale, zero_point, block_size, shape):
self.qdata = qdata
self.scale = scale
self.zero_point = zero_point
self.block_size = block_size

def _quantization_type(self):
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}"

@classmethod
def from_hp(
cls,
w: torch.Tensor,
block_size: List[int],
):
assert w.ndim == 2 and w.device.type == "xpu", (
f"Expecting 2D tensor on XPU, but got: {w.shape} on {w.device.type}"
)
assert len(block_size) == w.ndim

original_shape = w.shape
mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
scale_dtype = None
zero_point_dtype = torch.int32
scale, zero_point = _choose_qparams_affine(
w,
mapping_type.name,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
)
int_data = _quantize_affine(
w,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
)
assert int_data.dtype == torch.int32, (
"torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
)
packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
packed_weight.contiguous(), 8
)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
return Int4XPUTensorIntZP(
packed_weight,
scale.transpose(0, 1).contiguous(),
zero_point.transpose(0, 1).contiguous().to(torch.int8),
block_size,
original_shape,
)


implements = Int4XPUTensorIntZP.implements


@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
assert input_tensor.device.type == "xpu", (
f"For XPU device only but got: {input_tensor.device}"
)
assert isinstance(weight_tensor, Int4XPUTensorIntZP), (
f"Expected weight_tensor to be Int4XPUTensorIntZP, got: {type(weight_tensor)}"
)
assert weight_tensor.block_size[0] == 1, (
f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
)
assert input_tensor.shape[-1] == weight_tensor.shape[1], (
f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}"
)

act_mat = input_tensor
packed_weight = weight_tensor.qdata
scale = weight_tensor.scale
zero_point = weight_tensor.zero_point

orig_act_size = act_mat.size()
orig_dtype = act_mat.dtype

# reshape to 2D
act_mat = act_mat.reshape(-1, act_mat.shape[-1])

# groupwise int4 quantization
groupsize = weight_tensor.block_size[1]
y = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
act_mat, packed_weight, groupsize, scale, zero_point
)

# remove out_feature padding
assert weight_tensor.ndim == 2
orig_out_features = weight_tensor.shape[-2]
y = y[:, :orig_out_features]
y = y.reshape(*orig_act_size[:-1], orig_out_features)

if bias is not None:
y += bias
return y.to(orig_dtype)


Int4XPUTensorIntZP.__module__ = "torchao.quantization"

# Allow a model with Int4XPUTensorIntZP weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([Int4XPUTensorIntZP])
Loading