Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to config for Int8DynamicActivationIntxWeightConfig #1836

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/torchao_experimental_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ jobs:
pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu"
pip install numpy
pip install pytest
pip install parameterized
USE_CPP=1 pip install .
- name: Run python tests
run: |
Expand Down
7 changes: 6 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg @jerryzh168 are we ok adding tensor_impl_ctr_kwargs to from_hp_to_intx.

It can be used to propagate a bias when constructing the weight tensor subclass via from_plain.

Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def from_hp_to_intx(
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
_layout: Layout = PlainLayout(),
use_hqq: bool = False,
tensor_impl_ctr_kwargs: Optional[dict] = None,
):
"""Convert a high precision tensor to an integer affine quantized tensor."""
original_shape = input_float.shape
Expand Down Expand Up @@ -276,7 +277,11 @@ def from_hp_to_intx(

data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
if tensor_impl_ctr_kwargs is None:
tensor_impl_ctr_kwargs = {}
tensor_impl = tensor_impl_ctr(
data, scale, zero_point, _layout, **tensor_impl_ctr_kwargs
)
Comment on lines +280 to +284
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know which style AO uses, no strong pref

Suggested change
if tensor_impl_ctr_kwargs is None:
tensor_impl_ctr_kwargs = {}
tensor_impl = tensor_impl_ctr(
data, scale, zero_point, _layout, **tensor_impl_ctr_kwargs
)
tensor_impl = tensor_impl_ctr(
data, scale, zero_point, _layout, **(tensor_impl_ctr_kwargs or {})
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to hear from @drisspg or someone from torchao on this change.

Not so much on the style preference, but more so on whether they're OK adding tensor_impl_ctr_kwargs to the to_affine_quantized_intx signature.

return cls(
tensor_impl,
block_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#pragma once
#include <cpuinfo.h>
// #include <glog/logging.h>
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
#include <torchao/experimental/ops/packed_weights_header.h>

Expand Down Expand Up @@ -121,6 +120,21 @@ void check_format(PackedWeightsFormat format,
}
}

void log_registration(PackedWeightsFormat format, std::string description) {
// Logging is only supported in ATen mode
#ifdef USE_ATEN
LOG(INFO) << "Registering ukernel config for linear_8bit_act_xbit_weight" << std::endl
<< "\tDescription: " << description << std::endl
<< "\tformat.type=" << static_cast<int>(format.type) << std::endl
<< "\tformat.weight_nbit=" << format.weight_nbit << std::endl
<< "\tformat.has_weight_zeros=" << format.has_weight_zeros << std::endl
<< "\tformat.has_bias=" << format.has_bias << std::endl
<< "\tformat.nr=" << format.nr << std::endl
<< "\tformat.kr=" << format.kr << std::endl
<< "\tformat.sr=" << format.sr << std::endl;
#endif // USE_ATEN
}

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
void register_ukernel_config_universal(UKernelConfigRegistrationTable &table,
PackedWeightsFormat format,
Expand All @@ -135,6 +149,7 @@ void register_ukernel_config_universal(UKernelConfigRegistrationTable &table,
if (format.nr == 8 && format.kr == 16 && format.sr == 2) {
#if defined(TORCHAO_BUILD_CPU_AARCH64)
if (cpuinfo_has_arm_neon_dot()) {
log_registration(format, "universal");
namespace kernel = torchao::kernels::cpu::aarch64::linear::
channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot;
table.register_ukernel_config(
Expand Down Expand Up @@ -211,6 +226,7 @@ void register_ukernel_config_kleidi(UKernelConfigRegistrationTable &table,
#if defined(TORCHAO_ENABLE_ARM_I8MM)
if (cpuinfo_has_arm_i8mm()) {
constexpr int n_step = 8;
log_registration(format, "kleidiai: matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm");
table.register_ukernel_config(
format, uarch,
UKernelConfig{
Expand All @@ -228,6 +244,7 @@ void register_ukernel_config_kleidi(UKernelConfigRegistrationTable &table,

if (cpuinfo_has_arm_neon_dot()) {
constexpr int n_step = 8;
log_registration(format, "kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod");
table.register_ukernel_config(
format, uarch,
UKernelConfig{
Expand All @@ -249,6 +266,7 @@ void register_ukernel_config_kleidi(UKernelConfigRegistrationTable &table,
constexpr int sr = 2;
if (cpuinfo_has_arm_neon_dot()) {
constexpr int n_step = 4;
log_registration(format, "kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod");
table.register_ukernel_config(
format, uarch,
UKernelConfig{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,20 @@

import logging
from enum import Enum, auto
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
get_tensor_impl_constructor,
register_layout,
)
from torchao.dtypes.affine_quantized_tensor_ops import (
register_aqt_quantized_linear_dispatch,
)
from torchao.dtypes.utils import AQTTensorImpl, Layout
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
quantize_affine,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
Expand Down Expand Up @@ -58,47 +53,46 @@ def target_from_str(target: str) -> Target:


class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout):
bit_width: Optional[int]
group_size: Optional[int]
has_weight_zeros: Optional[bool]
# The target platform for the layout, 'native' or 'aten'
target: Optional[Target]

def __init__(
self,
bit_width: Optional[int] = None,
group_size: Optional[int] = None,
has_weight_zeros: Optional[bool] = None,
target: Optional[str] = "native",
target: Union[str, Target] = "native",
):
if bit_width is not None:
assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8"
if group_size is not None:
assert group_size >= 1, f"group_size must be positive, got {group_size}"

self.bit_width = bit_width
self.group_size = group_size
self.has_weight_zeros = has_weight_zeros
self.target = target_from_str(target)

if not self.has_params_set():
assert (
self.bit_width is None
and self.group_size is None
and self.has_weight_zeros is None
), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False"
if isinstance(target, str):
target = target_from_str(target)
self.target = target

self.bit_width: Optional[int] = None
self.group_size: Optional[int] = None
self.has_weight_zeros: Optional[bool] = None
# has_bias is whether the packed weights
# have bias packed with them, not whether the
# linear operator has bias
self.has_bias: Optional[bool] = None

def extra_repr(self):
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}, target={self.target}"
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}, has_bias={self.has_bias}, target={self.target}"

def has_params_set(self) -> bool:
return (
(self.bit_width is not None)
and (self.group_size is not None)
and (self.has_weight_zeros is not None)
and (self.has_bias is not None)
and (self.target is not None)
)

def set_params(
self, bit_width: int, group_size: int, has_weight_zeros: bool, has_bias: bool
):
assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8"
assert group_size >= 1, f"group_size must be positive, got {group_size}"

self.bit_width = bit_width
self.group_size = group_size
self.has_weight_zeros = has_weight_zeros
self.has_bias = has_bias
assert self.has_params_set()


@register_layout(PackedLinearInt8DynamicActivationIntxWeightLayout)
class PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl(AQTTensorImpl):
Expand Down Expand Up @@ -177,11 +171,17 @@ def from_plain(
), "aten target is requires torch version > 2.6.0"
int_data = int_data.add(8)
int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)

# If layout does not have bias packed with the weights, set bias to None
# It will be applied later in the linear function
if not layout.has_bias:
bias = None
packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(
int_data, scale, bias, layout.group_size, k, n
)
return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor)

assert not layout.has_bias, "has_bias is not supported yet"
if layout.has_weight_zeros:
args = [
int_data.to(torch.int8),
Expand Down Expand Up @@ -256,9 +256,7 @@ def __tensor_unflatten__(

def _linear_check(input_tensor, weight_tensor, bias):
layout = weight_tensor.tensor_impl.get_layout()
return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and (
bias is None or layout.target == Target.ATEN # Aten target allows bias
)
return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)


def _linear_impl(input_tensor, weight_tensor, bias):
Expand All @@ -275,6 +273,10 @@ def _impl_2d_native(input_tensor, weight_tensor):
assert n == weight_tensor.tensor_impl.n_tensor.shape[1]
assert k == weight_tensor.tensor_impl.k_tensor.shape[1]

assert (
not weight_tensor.tensor_impl.get_layout().has_bias
), "has_bias is not supported yet"

# TODO(T200095131): convert self.n, self.k, self.group_size to
# int when supported by AOTI
args = (
Expand Down Expand Up @@ -312,113 +314,35 @@ def _impl_2d_aten(input_tensor, weight_tensor):

target = weight_tensor.tensor_impl.get_layout().target

if weight_tensor.tensor_impl.get_layout().has_bias:
assert (
bias is None
), "bias should be None because it is already packed with the weights (has_bias=True)"
Comment on lines +317 to +320
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if: assert; also fine with leaving it as-is for legibility

Suggested change
if weight_tensor.tensor_impl.get_layout().has_bias:
assert (
bias is None
), "bias should be None because it is already packed with the weights (has_bias=True)"
assert (
not weight_tensor.tensor_impl.get_layout().has_bias or bias is None
), "bias should be None because it is already packed with the weights (has_bias=True)"


if target == Target.ATEN:
assert TORCH_VERSION_AT_LEAST_2_6 == 1, "Target.ATEN requires torch >= 2.6.0"
_impl_2d = _impl_2d_aten
elif target == Target.NATIVE:
_impl_2d = _impl_2d_native
assert (
bias is None
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl with target 'native' "

if input_tensor.dim() == 2:
return _impl_2d(input_tensor, weight_tensor)
res = _impl_2d(input_tensor, weight_tensor)
else:
assert input_tensor.dim() >= 3
lead_shape = input_tensor.shape[0:-2]
m, k = input_tensor.shape[-2], input_tensor.shape[-1]
n, k_ = weight_tensor.shape
assert k_ == k

assert input_tensor.dim() >= 3
lead_shape = input_tensor.shape[0:-2]
m, k = input_tensor.shape[-2], input_tensor.shape[-1]
n, k_ = weight_tensor.shape
assert k_ == k
res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor)
res = res.reshape(*lead_shape, m, n)

res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor)
res = res.reshape(*lead_shape, m, n)
if bias is not None:
res = res + bias
return res


register_aqt_quantized_linear_dispatch(
_linear_check,
_linear_impl,
)


class PackedLinearInt8DynamicActivationIntxWeightAtenTensor(AffineQuantizedTensor):
"""
PackedLinearInt8DynamicActivationIntxWeightAtenTensor quantized tensor subclass which inherits AffineQuantizedTensor class.
"""

@classmethod
def from_hp_to_intx(
cls,
input_float: torch.Tensor,
mapping_type: MappingType,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
_layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(),
use_hqq: bool = False,
bias: Optional[torch.Tensor] = None,
):
assert (
use_hqq == False
), "PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization"
assert isinstance(
_layout, PackedLinearInt8DynamicActivationIntxWeightLayout
), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}"
assert (
_layout.target == Target.ATEN
), "PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'."
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)

scale, zero_point = choose_qparams_affine(
input_float,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain,
)
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
# TODO should probably consolidate ZeroPointDomain.NONE and None
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
zero_point = None
data = quantize_affine(
input_float,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
zero_point_domain,
)
# Note: output will be uint8 tensor for sub byte tensors for now

data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout, bias)
return cls(
tensor_impl,
block_size,
original_shape,
quant_min,
quant_max,
zero_point_domain,
dtype=input_float.dtype,
)


to_packedlinearint8dynamicactivationintxweight_quantized_intx = (
PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx
)
Loading
Loading