-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate to config for Int8DynamicActivationIntxWeightConfig #1836
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -204,6 +204,7 @@ def from_hp_to_intx( | |||||||||||||||||
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, | ||||||||||||||||||
_layout: Layout = PlainLayout(), | ||||||||||||||||||
use_hqq: bool = False, | ||||||||||||||||||
tensor_impl_ctr_kwargs: Optional[dict] = None, | ||||||||||||||||||
): | ||||||||||||||||||
"""Convert a high precision tensor to an integer affine quantized tensor.""" | ||||||||||||||||||
original_shape = input_float.shape | ||||||||||||||||||
|
@@ -276,7 +277,11 @@ def from_hp_to_intx( | |||||||||||||||||
|
||||||||||||||||||
data = _layout.post_process(data) | ||||||||||||||||||
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) | ||||||||||||||||||
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) | ||||||||||||||||||
if tensor_impl_ctr_kwargs is None: | ||||||||||||||||||
tensor_impl_ctr_kwargs = {} | ||||||||||||||||||
tensor_impl = tensor_impl_ctr( | ||||||||||||||||||
data, scale, zero_point, _layout, **tensor_impl_ctr_kwargs | ||||||||||||||||||
) | ||||||||||||||||||
Comment on lines
+280
to
+284
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't know which style AO uses, no strong pref
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to hear from @drisspg or someone from torchao on this change. Not so much on the style preference, but more so on whether they're OK adding tensor_impl_ctr_kwargs to the to_affine_quantized_intx signature. |
||||||||||||||||||
return cls( | ||||||||||||||||||
tensor_impl, | ||||||||||||||||||
block_size, | ||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -6,25 +6,20 @@ | |||||||||||||||
|
||||||||||||||||
import logging | ||||||||||||||||
from enum import Enum, auto | ||||||||||||||||
from typing import Optional, Tuple | ||||||||||||||||
from typing import Optional, Tuple, Union | ||||||||||||||||
|
||||||||||||||||
import torch | ||||||||||||||||
from torch.utils._python_dispatch import return_and_correct_aliasing | ||||||||||||||||
|
||||||||||||||||
from torchao.dtypes.affine_quantized_tensor import ( | ||||||||||||||||
AffineQuantizedTensor, | ||||||||||||||||
get_tensor_impl_constructor, | ||||||||||||||||
register_layout, | ||||||||||||||||
) | ||||||||||||||||
from torchao.dtypes.affine_quantized_tensor_ops import ( | ||||||||||||||||
register_aqt_quantized_linear_dispatch, | ||||||||||||||||
) | ||||||||||||||||
from torchao.dtypes.utils import AQTTensorImpl, Layout | ||||||||||||||||
from torchao.quantization.quant_primitives import ( | ||||||||||||||||
MappingType, | ||||||||||||||||
ZeroPointDomain, | ||||||||||||||||
choose_qparams_affine, | ||||||||||||||||
quantize_affine, | ||||||||||||||||
) | ||||||||||||||||
from torchao.utils import ( | ||||||||||||||||
TORCH_VERSION_AT_LEAST_2_6, | ||||||||||||||||
|
@@ -58,47 +53,46 @@ def target_from_str(target: str) -> Target: | |||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout): | ||||||||||||||||
bit_width: Optional[int] | ||||||||||||||||
group_size: Optional[int] | ||||||||||||||||
has_weight_zeros: Optional[bool] | ||||||||||||||||
# The target platform for the layout, 'native' or 'aten' | ||||||||||||||||
target: Optional[Target] | ||||||||||||||||
|
||||||||||||||||
def __init__( | ||||||||||||||||
self, | ||||||||||||||||
bit_width: Optional[int] = None, | ||||||||||||||||
group_size: Optional[int] = None, | ||||||||||||||||
has_weight_zeros: Optional[bool] = None, | ||||||||||||||||
target: Optional[str] = "native", | ||||||||||||||||
target: Union[str, Target] = "native", | ||||||||||||||||
): | ||||||||||||||||
if bit_width is not None: | ||||||||||||||||
assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8" | ||||||||||||||||
if group_size is not None: | ||||||||||||||||
assert group_size >= 1, f"group_size must be positive, got {group_size}" | ||||||||||||||||
|
||||||||||||||||
self.bit_width = bit_width | ||||||||||||||||
self.group_size = group_size | ||||||||||||||||
self.has_weight_zeros = has_weight_zeros | ||||||||||||||||
self.target = target_from_str(target) | ||||||||||||||||
|
||||||||||||||||
if not self.has_params_set(): | ||||||||||||||||
assert ( | ||||||||||||||||
self.bit_width is None | ||||||||||||||||
and self.group_size is None | ||||||||||||||||
and self.has_weight_zeros is None | ||||||||||||||||
), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False" | ||||||||||||||||
if isinstance(target, str): | ||||||||||||||||
target = target_from_str(target) | ||||||||||||||||
self.target = target | ||||||||||||||||
|
||||||||||||||||
self.bit_width: Optional[int] = None | ||||||||||||||||
self.group_size: Optional[int] = None | ||||||||||||||||
self.has_weight_zeros: Optional[bool] = None | ||||||||||||||||
# has_bias is whether the packed weights | ||||||||||||||||
# have bias packed with them, not whether the | ||||||||||||||||
# linear operator has bias | ||||||||||||||||
self.has_bias: Optional[bool] = None | ||||||||||||||||
|
||||||||||||||||
def extra_repr(self): | ||||||||||||||||
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}, target={self.target}" | ||||||||||||||||
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}, has_bias={self.has_bias}, target={self.target}" | ||||||||||||||||
|
||||||||||||||||
def has_params_set(self) -> bool: | ||||||||||||||||
return ( | ||||||||||||||||
(self.bit_width is not None) | ||||||||||||||||
and (self.group_size is not None) | ||||||||||||||||
and (self.has_weight_zeros is not None) | ||||||||||||||||
and (self.has_bias is not None) | ||||||||||||||||
and (self.target is not None) | ||||||||||||||||
) | ||||||||||||||||
|
||||||||||||||||
def set_params( | ||||||||||||||||
self, bit_width: int, group_size: int, has_weight_zeros: bool, has_bias: bool | ||||||||||||||||
): | ||||||||||||||||
assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8" | ||||||||||||||||
assert group_size >= 1, f"group_size must be positive, got {group_size}" | ||||||||||||||||
|
||||||||||||||||
self.bit_width = bit_width | ||||||||||||||||
self.group_size = group_size | ||||||||||||||||
self.has_weight_zeros = has_weight_zeros | ||||||||||||||||
self.has_bias = has_bias | ||||||||||||||||
assert self.has_params_set() | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
@register_layout(PackedLinearInt8DynamicActivationIntxWeightLayout) | ||||||||||||||||
class PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl(AQTTensorImpl): | ||||||||||||||||
|
@@ -177,11 +171,17 @@ def from_plain( | |||||||||||||||
), "aten target is requires torch version > 2.6.0" | ||||||||||||||||
int_data = int_data.add(8) | ||||||||||||||||
int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) | ||||||||||||||||
|
||||||||||||||||
# If layout does not have bias packed with the weights, set bias to None | ||||||||||||||||
# It will be applied later in the linear function | ||||||||||||||||
if not layout.has_bias: | ||||||||||||||||
bias = None | ||||||||||||||||
packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight( | ||||||||||||||||
int_data, scale, bias, layout.group_size, k, n | ||||||||||||||||
) | ||||||||||||||||
return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor) | ||||||||||||||||
|
||||||||||||||||
assert not layout.has_bias, "has_bias is not supported yet" | ||||||||||||||||
if layout.has_weight_zeros: | ||||||||||||||||
args = [ | ||||||||||||||||
int_data.to(torch.int8), | ||||||||||||||||
|
@@ -256,9 +256,7 @@ def __tensor_unflatten__( | |||||||||||||||
|
||||||||||||||||
def _linear_check(input_tensor, weight_tensor, bias): | ||||||||||||||||
layout = weight_tensor.tensor_impl.get_layout() | ||||||||||||||||
return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and ( | ||||||||||||||||
bias is None or layout.target == Target.ATEN # Aten target allows bias | ||||||||||||||||
) | ||||||||||||||||
return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
def _linear_impl(input_tensor, weight_tensor, bias): | ||||||||||||||||
|
@@ -275,6 +273,10 @@ def _impl_2d_native(input_tensor, weight_tensor): | |||||||||||||||
assert n == weight_tensor.tensor_impl.n_tensor.shape[1] | ||||||||||||||||
assert k == weight_tensor.tensor_impl.k_tensor.shape[1] | ||||||||||||||||
|
||||||||||||||||
assert ( | ||||||||||||||||
not weight_tensor.tensor_impl.get_layout().has_bias | ||||||||||||||||
), "has_bias is not supported yet" | ||||||||||||||||
|
||||||||||||||||
# TODO(T200095131): convert self.n, self.k, self.group_size to | ||||||||||||||||
# int when supported by AOTI | ||||||||||||||||
args = ( | ||||||||||||||||
|
@@ -312,113 +314,35 @@ def _impl_2d_aten(input_tensor, weight_tensor): | |||||||||||||||
|
||||||||||||||||
target = weight_tensor.tensor_impl.get_layout().target | ||||||||||||||||
|
||||||||||||||||
if weight_tensor.tensor_impl.get_layout().has_bias: | ||||||||||||||||
assert ( | ||||||||||||||||
bias is None | ||||||||||||||||
), "bias should be None because it is already packed with the weights (has_bias=True)" | ||||||||||||||||
Comment on lines
+317
to
+320
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Suggested change
|
||||||||||||||||
|
||||||||||||||||
if target == Target.ATEN: | ||||||||||||||||
assert TORCH_VERSION_AT_LEAST_2_6 == 1, "Target.ATEN requires torch >= 2.6.0" | ||||||||||||||||
_impl_2d = _impl_2d_aten | ||||||||||||||||
elif target == Target.NATIVE: | ||||||||||||||||
_impl_2d = _impl_2d_native | ||||||||||||||||
assert ( | ||||||||||||||||
bias is None | ||||||||||||||||
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl with target 'native' " | ||||||||||||||||
|
||||||||||||||||
if input_tensor.dim() == 2: | ||||||||||||||||
return _impl_2d(input_tensor, weight_tensor) | ||||||||||||||||
res = _impl_2d(input_tensor, weight_tensor) | ||||||||||||||||
else: | ||||||||||||||||
assert input_tensor.dim() >= 3 | ||||||||||||||||
lead_shape = input_tensor.shape[0:-2] | ||||||||||||||||
m, k = input_tensor.shape[-2], input_tensor.shape[-1] | ||||||||||||||||
n, k_ = weight_tensor.shape | ||||||||||||||||
assert k_ == k | ||||||||||||||||
|
||||||||||||||||
assert input_tensor.dim() >= 3 | ||||||||||||||||
lead_shape = input_tensor.shape[0:-2] | ||||||||||||||||
m, k = input_tensor.shape[-2], input_tensor.shape[-1] | ||||||||||||||||
n, k_ = weight_tensor.shape | ||||||||||||||||
assert k_ == k | ||||||||||||||||
res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor) | ||||||||||||||||
res = res.reshape(*lead_shape, m, n) | ||||||||||||||||
|
||||||||||||||||
res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor) | ||||||||||||||||
res = res.reshape(*lead_shape, m, n) | ||||||||||||||||
if bias is not None: | ||||||||||||||||
res = res + bias | ||||||||||||||||
return res | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
register_aqt_quantized_linear_dispatch( | ||||||||||||||||
_linear_check, | ||||||||||||||||
_linear_impl, | ||||||||||||||||
) | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
class PackedLinearInt8DynamicActivationIntxWeightAtenTensor(AffineQuantizedTensor): | ||||||||||||||||
""" | ||||||||||||||||
PackedLinearInt8DynamicActivationIntxWeightAtenTensor quantized tensor subclass which inherits AffineQuantizedTensor class. | ||||||||||||||||
""" | ||||||||||||||||
|
||||||||||||||||
@classmethod | ||||||||||||||||
def from_hp_to_intx( | ||||||||||||||||
cls, | ||||||||||||||||
input_float: torch.Tensor, | ||||||||||||||||
mapping_type: MappingType, | ||||||||||||||||
block_size: Tuple[int, ...], | ||||||||||||||||
target_dtype: torch.dtype, | ||||||||||||||||
quant_min: Optional[int] = None, | ||||||||||||||||
quant_max: Optional[int] = None, | ||||||||||||||||
eps: Optional[float] = None, | ||||||||||||||||
scale_dtype: Optional[torch.dtype] = None, | ||||||||||||||||
zero_point_dtype: Optional[torch.dtype] = None, | ||||||||||||||||
preserve_zero: bool = True, | ||||||||||||||||
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, | ||||||||||||||||
_layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(), | ||||||||||||||||
use_hqq: bool = False, | ||||||||||||||||
bias: Optional[torch.Tensor] = None, | ||||||||||||||||
): | ||||||||||||||||
assert ( | ||||||||||||||||
use_hqq == False | ||||||||||||||||
), "PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization" | ||||||||||||||||
assert isinstance( | ||||||||||||||||
_layout, PackedLinearInt8DynamicActivationIntxWeightLayout | ||||||||||||||||
), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}" | ||||||||||||||||
assert ( | ||||||||||||||||
_layout.target == Target.ATEN | ||||||||||||||||
), "PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'." | ||||||||||||||||
original_shape = input_float.shape | ||||||||||||||||
input_float = _layout.pre_process(input_float) | ||||||||||||||||
|
||||||||||||||||
scale, zero_point = choose_qparams_affine( | ||||||||||||||||
input_float, | ||||||||||||||||
mapping_type, | ||||||||||||||||
block_size, | ||||||||||||||||
target_dtype, | ||||||||||||||||
quant_min, | ||||||||||||||||
quant_max, | ||||||||||||||||
eps, | ||||||||||||||||
scale_dtype, | ||||||||||||||||
zero_point_dtype, | ||||||||||||||||
preserve_zero, | ||||||||||||||||
zero_point_domain, | ||||||||||||||||
) | ||||||||||||||||
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None | ||||||||||||||||
# TODO should probably consolidate ZeroPointDomain.NONE and None | ||||||||||||||||
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE: | ||||||||||||||||
zero_point = None | ||||||||||||||||
data = quantize_affine( | ||||||||||||||||
input_float, | ||||||||||||||||
block_size, | ||||||||||||||||
scale, | ||||||||||||||||
zero_point, | ||||||||||||||||
target_dtype, | ||||||||||||||||
quant_min, | ||||||||||||||||
quant_max, | ||||||||||||||||
zero_point_domain, | ||||||||||||||||
) | ||||||||||||||||
# Note: output will be uint8 tensor for sub byte tensors for now | ||||||||||||||||
|
||||||||||||||||
data = _layout.post_process(data) | ||||||||||||||||
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) | ||||||||||||||||
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout, bias) | ||||||||||||||||
return cls( | ||||||||||||||||
tensor_impl, | ||||||||||||||||
block_size, | ||||||||||||||||
original_shape, | ||||||||||||||||
quant_min, | ||||||||||||||||
quant_max, | ||||||||||||||||
zero_point_domain, | ||||||||||||||||
dtype=input_float.dtype, | ||||||||||||||||
) | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
to_packedlinearint8dynamicactivationintxweight_quantized_intx = ( | ||||||||||||||||
PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx | ||||||||||||||||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@drisspg @jerryzh168 are we ok adding tensor_impl_ctr_kwargs to from_hp_to_intx.
It can be used to propagate a bias when constructing the weight tensor subclass via from_plain.