diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 4d61f97ac9..e72b89156f 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,7 +1,13 @@ from .nf4tensor import NF4Tensor, to_nf4 # from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor from .uint4 import UInt4Tensor -from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized +from .affine_quantized_tensor import ( + AffineQuantizedTensor, + to_affine_quantized, + LayoutType, + PlainLayoutType, + TensorCoreTiledLayoutType, +) __all__ = [ "NF4Tensor", @@ -9,4 +15,7 @@ "UInt4Tensor" "AffineQuantizedTensor", "to_affine_quantized", + "LayoutType", + "PlainLayoutType", + "TensorCoreTiledLayoutType", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 9c1e5e9bde..8f55176669 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -20,10 +20,35 @@ _ATEN_OP_OR_TORCH_FN_TABLE, _register_layout_cls, _get_layout_tensor_constructor, + LayoutType, ) +from typing import ClassVar +from dataclasses import dataclass aten = torch.ops.aten +@dataclass(frozen=True) +class PlainLayoutType(LayoutType): + pass + +@dataclass(frozen=True) +class TensorCoreTiledLayoutType(LayoutType): + inner_k_tiles: int = 8 + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + orig_out_features, orig_in_features = input.shape + in_features = find_multiple(orig_in_features, 1024) + out_features = find_multiple(orig_out_features, 8) + input = torch.nn.functional.pad( + input, + (0, in_features - orig_in_features, 0, out_features - orig_out_features), + ) + return input + + def extra_repr(self): + return f"inner_k_tiles={self.inner_k_tiles}" + + def _aqt_is_int8(aqt): """Check if an AffineQuantizedTensor is int8 quantized Tensor""" return ( @@ -52,10 +77,10 @@ class AQTLayout(torch.Tensor): """ Base class for the layout tensor for `AffineQuantizedTensor` """ - # this should be set for each layout class during registration - extended_layout: Optional[str] = None + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + pass - def get_plain() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def get_layout_type(self) -> LayoutType: pass @classmethod @@ -64,9 +89,15 @@ def from_plain( int_data: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, + layout_type: LayoutType, ): pass + def __repr__(self): + int_data, scale, zero_point = self.get_plain() + layout_type = self.get_layout_type() + return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})" + def _get_to_kwargs(self, *args, **kwargs): device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) device = self.device if device is None else device @@ -194,30 +225,17 @@ def from_float( zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - extended_layout: str = "plain", - # TODO: this is only for "tensor_core_tiled", need to figure out - # the proper API for this arg - inner_k_tiles: Optional[int] = None, + layout_type: LayoutType = PlainLayoutType(), ): original_shape = input_float.shape - if extended_layout == "tensor_core_tiled": - orig_out_features, orig_in_features = input_float.shape - in_features = find_multiple(orig_in_features, 1024) - out_features = find_multiple(orig_out_features, 8) - input_float = torch.nn.functional.pad( - input_float, - (0, in_features - orig_in_features, 0, out_features - orig_out_features), - ) + input_float = layout_type.pre_process(input_float) scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) + int_data = layout_type.post_process(int_data) - layout_cls_ctr = get_layout_tensor_constructor(extended_layout) - # TODO: this is temporary, need to come up with the proper UX - if extended_layout == "tensor_core_tiled": - layout_tensor = layout_cls_ctr(int_data, scale, zero_point, inner_k_tiles) - else: - layout_tensor = layout_cls_ctr(int_data, scale, zero_point) + layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) + layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type) return cls( layout_tensor, block_size, @@ -229,8 +247,8 @@ def from_float( ) @property - def extended_layout(self) -> str: - return self.layout_tensor.extended_layout + def layout_type(self) -> str: + return self.layout_tensor.layout_type @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -308,13 +326,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs): def implements(aten_ops_or_torch_fn): return _implements(AffineQuantizedTensor, aten_ops_or_torch_fn) -def register_layout_cls(extended_layout: str): - return _register_layout_cls(AffineQuantizedTensor, extended_layout) +def register_layout_cls(layout_type_class: type(LayoutType)): + return _register_layout_cls(AffineQuantizedTensor, layout_type_class) -def get_layout_tensor_constructor(extended_layout: str): - return _get_layout_tensor_constructor(AffineQuantizedTensor, extended_layout) +def get_layout_tensor_constructor(layout_type_class: type(LayoutType)): + return _get_layout_tensor_constructor(AffineQuantizedTensor, layout_type_class) -@register_layout_cls("plain") +@register_layout_cls(PlainLayoutType) class PlainAQTLayout(AQTLayout): """ Layout storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point @@ -330,6 +348,7 @@ def __new__( int_data: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, + layout_type: LayoutType, ): kwargs = {} kwargs["device"] = int_data.device @@ -346,20 +365,23 @@ def __init__( int_data: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, + layout_type: LayoutType, ): self.int_data = int_data self.scale = scale self.zero_point = zero_point + self.layout_type = layout_type def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point"], [] + return ["int_data", "scale", "zero_point"], [self.layout_type] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] - return cls(int_data, scale, zero_point) + layout_type, = tensor_attributes + return cls(int_data, scale, zero_point, layout_type) def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) @@ -367,6 +389,7 @@ def to(self, *args, **kwargs): self.int_data.to(kwargs["device"]), self.scale.to(kwargs["device"]), self.zero_point.to(kwargs["device"]), + self.layout_type, ) def _apply_fn_to_data(self, fn): @@ -374,6 +397,7 @@ def _apply_fn_to_data(self, fn): fn(self.int_data), fn(self.scale), fn(self.zero_point), + self.layout_type, ) @classmethod @@ -398,19 +422,24 @@ def __torch_dispatch__(cls, func, types, args, kwargs): __torch_function__ = torch._C._disabled_torch_function_impl - def get_plain(self): + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.int_data, self.scale, self.zero_point + def get_layout_type(self) -> LayoutType: + return self.layout_type + @classmethod def from_plain( cls, int_data: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, + layout_type: LayoutType, ): - return cls(int_data, scale, zero_point) + assert isinstance(layout_type, PlainLayoutType) + return cls(int_data, scale, zero_point, layout_type) -@register_layout_cls("tensor_core_tiled") +@register_layout_cls(TensorCoreTiledLayoutType) class TensorCoreTiledAQTLayout(AQTLayout): """ Layout storage class for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, @@ -427,6 +456,7 @@ def __new__( packed_weight: torch.Tensor, scale_and_zero: torch.Tensor, transposed: bool, + layout_type: LayoutType, ): kwargs = {} kwargs["device"] = packed_weight.device @@ -443,31 +473,40 @@ def __init__( packed_weight: torch.Tensor, scale_and_zero: torch.Tensor, transposed: bool, + layout_type: LayoutType, ): self.packed_weight = packed_weight self.scale_and_zero = scale_and_zero self.transposed = False + self.layout_type = layout_type def __tensor_flatten__(self): - return ["packed_weight", "scale_and_zero"], [self.transposed] + return ["packed_weight", "scale_and_zero"], [self.transposed, self.layout_type] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] - transposed, = tensor_attributes - return cls(packed_weight, scale_and_zero, transposed) + transposed, layout_type, = tensor_attributes + return cls(packed_weight, scale_and_zero, transposed, layout_type) @classmethod - def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8): + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout_type: LayoutType + ): + assert isinstance(layout_type, TensorCoreTiledLayoutType) # assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype" # packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles) - packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), inner_k_tiles) + packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), layout_type.inner_k_tiles) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) - return cls(packed_weight, scale_and_zero, False) + return cls(packed_weight, scale_and_zero, False, layout_type) def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) @@ -477,7 +516,8 @@ def to(self, *args, **kwargs): return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device), - self.transposed + self.transposed, + self.layout_type, ) def _apply_fn_to_data(self, fn): @@ -485,10 +525,6 @@ def _apply_fn_to_data(self, fn): self.scale_and_zero = fn(self.scale_and_zero) return self - def __repr__(self): - int_data, scale, zero_point = self.get_plain() - return f"TensorCoreTiledAQTLayout(int_data={int_data}, scale={scale}, zero_point={zero_point})" - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs @@ -511,7 +547,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): __torch_function__ = torch._C._disabled_torch_function_impl - def get_plain(self): + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from torchao.quantization.quant_primitives import ( ZeroPointDomain, quantize_affine, @@ -542,6 +578,9 @@ def get_plain(self): int_data = quantize_affine(dequantized, block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain) return int_data, scale, zero + def get_layout_type(self) -> LayoutType: + return self.layout_type + def _quantized_linear_op(input_tensor, weight_qtensor, bias): """ Quantized version of F.linear operator @@ -565,8 +604,8 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): is_cuda and input_is_int8 and input_tensor.dtype == weight_qtensor.dtype and - input_tensor.extended_layout == "plain" and - weight_qtensor.extended_layout == "plain" + isinstance(input_tensor.layout_type, PlainLayoutType) and + isinstance(weight_qtensor.layout_type, PlainLayoutType) ): # # 1. do the matrix form of dot(X_i, W_j) @@ -608,7 +647,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): weight_qtensor.dtype == torch.bfloat16 and len(weight_qtensor.shape) == 2 and weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and - weight_qtensor.extended_layout == "tensor_core_tiled" + isinstance(weight_qtensor.layout_type, TensorCoreTiledLayoutType) ): assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}" assert input_tensor.shape[-1] == weight_qtensor.shape[1], ( @@ -651,7 +690,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): weight_qtensor.block_size[0] == 1 and weight_qtensor.block_size[1] == weight_qtensor.shape[1] and weight_qtensor.zero_point_domain == ZeroPointDomain.INT and - weight_qtensor.extended_layout == "plain" + isinstance(weight_qtensor.layout_type, PlainLayoutType) ): # TODO: enable cpu and mps efficient path # per channel int8 weight only quantizated mm diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 1e4eb692a5..3a437b4745 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -1,6 +1,8 @@ +import torch from typing import Dict, Callable from collections import defaultdict import functools +from dataclasses import dataclass """ torch_function and torch_dispatch operator dispatch registrations @@ -28,6 +30,23 @@ def wrapper(*args, **kwargs): return func return decorator +""" +Base class for different LayoutType, should not be instantiated directly +""" +@dataclass(frozen=True) +class LayoutType: + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + return input + + def post_process(self, input: torch.Tensor) -> torch.Tensor: + return input + + def __repr__(self): + return f"{self.__class__.__name__}({self.extra_repr()})" + + def extra_repr(self) -> str: + return "" + """ layout tensor constructor registration for different tensor subclassesa @@ -35,31 +54,38 @@ def wrapper(*args, **kwargs): second key is an extended layout string, like tensor_core_tiled value is a constructor for the LayoutTensor class, e.g. TensorCoreTiledAQTLayout.from_plain """ -_LAYOUT_CONSTRUCTOR_TABLE: Dict[Callable, Dict[str, Callable]] = defaultdict(dict) +_LAYOUT_CONSTRUCTOR_TABLE: Dict[Callable, Dict[type(LayoutType), Callable]] = defaultdict(dict) -def _register_layout_cls(cls: Callable, extended_layout: str): +def _register_layout_cls(cls: Callable, layout_type_class: type(LayoutType)): """Helper function for layout registrations, this is used to implement register_layout_cls decorator for each tensor subclass, see aqt.py for example usage Args: cls: Tensor subclass type - extended_layout: string name for the layout type + layout_type_class: the class type of subclass of `LayoutType`, e.g. `PlainLayoutType` Returns: a decorator that registers the layout tensor constructor in the table """ def decorator(layout_cls): - layout_cls.extended_layout = extended_layout - _LAYOUT_CONSTRUCTOR_TABLE[cls][extended_layout] = layout_cls.from_plain + _LAYOUT_CONSTRUCTOR_TABLE[cls][layout_type_class] = layout_cls.from_plain return layout_cls return decorator -def _get_layout_tensor_constructor(cls: Callable, extended_layout: str) -> Callable: - """Get Layout class constructor (LayoutClass.from_plain) for `cls` based on `extended_layout` +def _get_layout_tensor_constructor(cls: Callable, layout_type_class: type(LayoutType)) -> Callable: + """Get Layout class constructor (LayoutClass.from_plain) for `cls` based on `layout_type_class` + `layout_type_class` means the class type of subclass of `LayoutType`, e.g. `PlainLayoutType` + + Args: + cls: Tensor subclass type + layout_type_class: the class type of subclass of `LayoutType`, e.g. `PlainLayoutType` + + Returns: + layout tensor subclass constructor for the layout_type_class """ if cls not in _LAYOUT_CONSTRUCTOR_TABLE: raise ValueError(f"no registered layout class constructor for: {cls}") - if extended_layout not in _LAYOUT_CONSTRUCTOR_TABLE[cls]: - raise ValueError(f"extended_layout: {extended_layout} is not supported yet for {cls}") + if layout_type_class not in _LAYOUT_CONSTRUCTOR_TABLE[cls]: + raise ValueError(f"layout_name: {layout_type_class} is not supported yet for {cls}") - return _LAYOUT_CONSTRUCTOR_TABLE[cls][extended_layout] + return _LAYOUT_CONSTRUCTOR_TABLE[cls][layout_type_class] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index d6c142476b..095dbde0b0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -380,6 +380,7 @@ def int4_weight_only(group_size=128, inner_k_tiles=8): def apply_int4_weight_only_quant(weight): # avoid circular dep from torchao.dtypes import to_affine_quantized + from torchao.dtypes import TensorCoreTiledLayoutType mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) @@ -390,7 +391,8 @@ def apply_int4_weight_only_quant(weight): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled", inner_k_tiles=inner_k_tiles) + layout_type = TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles) + return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type) return apply_int4_weight_only_quant