diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 20f324b35e..583f15019c 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -22,71 +22,17 @@ _register_layout_cls, _get_layout_tensor_constructor, LayoutType, + PlainLayoutType, is_device, ) -from typing import ClassVar from dataclasses import dataclass from torchao.utils import TORCH_VERSION_AFTER_2_5 aten = torch.ops.aten -@dataclass(frozen=True) -class PlainLayoutType(LayoutType): - pass - -@dataclass(frozen=True) -class SemiSparseLayoutType(LayoutType): - - def pre_process(self, input: torch.Tensor) -> torch.Tensor: - # prune to 2:4 if not already - temp = input.detach() - pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2] - temp.view(-1, 4).scatter_(1, pruning_inds, value=0) - return temp - - -@dataclass(frozen=True) -class TensorCoreTiledLayoutType(LayoutType): - inner_k_tiles: int = 8 - - def pre_process(self, input: torch.Tensor) -> torch.Tensor: - orig_out_features, orig_in_features = input.shape - in_features = find_multiple(orig_in_features, 1024) - out_features = find_multiple(orig_out_features, 8) - input = torch.nn.functional.pad( - input, - (0, in_features - orig_in_features, 0, out_features - orig_out_features), - ) - return input - - def extra_repr(self): - return f"inner_k_tiles={self.inner_k_tiles}" - - -def _aqt_is_int8(aqt): - """Check if an AffineQuantizedTensor is int8 quantized Tensor""" - return ( - aqt.layout_tensor.dtype == torch.int8 and - aqt.quant_min is None or aqt.quant_min == -128 and - aqt.quant_max is None or aqt.quant_max == 127 - ) - -def _aqt_is_int8_reduced_range(aqt): - return ( - aqt.layout_tensor.dtype == torch.int8 and - aqt.quant_min == -127 and - aqt.quant_max is None or aqt.quant_max == 127 - ) - -def _aqt_is_uint4(aqt): - """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" - # TODO: use torch.uint4 - return ( - aqt.layout_tensor.dtype == torch.int32 and - aqt.quant_min is None or aqt.quant_min == 0 and - aqt.quant_max is None or aqt.quant_max == 15 - ) - +############################### +# Base Layout Tensor Subclass # +############################### class AQTLayout(torch.Tensor): """ Base class for the layout tensor for `AffineQuantizedTensor` @@ -126,6 +72,10 @@ def _get_to_kwargs(self, *args, **kwargs): } return kwargs +############################## +# Tensor Subclass Definition # +############################## + class AffineQuantizedTensor(torch.Tensor): """ Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation: @@ -337,7 +287,6 @@ def _apply_fn_to_data(self, fn): strides=self.stride(), ) - implements = classmethod(_implements) # Note: we only added cpu path here for 8da4w, this is for executorch, in the future # 1. we'll add cpu/cuda version (int4mm etc.) @@ -353,7 +302,10 @@ def _apply_fn_to_data(self, fn): __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) __torch_function__ = classmethod(_dispatch__torch_function__) -implements = AffineQuantizedTensor.implements + +###################################################### +# LayoutType and Layout Tensor Subclass Registration # +###################################################### def register_layout_cls(layout_type_class: type(LayoutType)): return _register_layout_cls(AffineQuantizedTensor, layout_type_class) @@ -361,6 +313,35 @@ def register_layout_cls(layout_type_class: type(LayoutType)): def get_layout_tensor_constructor(layout_type_class: type(LayoutType)): return _get_layout_tensor_constructor(AffineQuantizedTensor, layout_type_class) +@dataclass(frozen=True) +class SemiSparseLayoutType(LayoutType): + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + # prune to 2:4 if not already + temp = input.detach() + pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2] + temp.view(-1, 4).scatter_(1, pruning_inds, value=0) + return temp + + +@dataclass(frozen=True) +class TensorCoreTiledLayoutType(LayoutType): + inner_k_tiles: int = 8 + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + orig_out_features, orig_in_features = input.shape + in_features = find_multiple(orig_in_features, 1024) + out_features = find_multiple(orig_out_features, 8) + input = torch.nn.functional.pad( + input, + (0, in_features - orig_in_features, 0, out_features - orig_out_features), + ) + return input + + def extra_repr(self): + return f"inner_k_tiles={self.inner_k_tiles}" + + @register_layout_cls(PlainLayoutType) class PlainAQTLayout(AQTLayout): """ @@ -487,7 +468,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) def get_plain(self): - # Currently we don't have cuSPARSELt expansion routines, so we matmul by + # Currently we don't have cuSPARSELt expansion routines, so we matmul by # the identity matrix to get the original dense matrix. This is slow though. cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0]) int_data_expanded = torch._cslt_sparse_mm(self.int_data, @@ -507,7 +488,7 @@ def from_plain( assert isinstance(layout_type, SemiSparseLayoutType) int_data_compressed = torch._cslt_compress(int_data) return cls(int_data_compressed, scale, zero_point, layout_type) - + @register_layout_cls(TensorCoreTiledLayoutType) class TensorCoreTiledAQTLayout(AQTLayout): @@ -654,6 +635,34 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def get_layout_type(self) -> LayoutType: return self.layout_type +##################################################### +# torch functional and aten operator implementation # +##################################################### + +def _aqt_is_int8(aqt): + """Check if an AffineQuantizedTensor is int8 quantized Tensor""" + return ( + aqt.layout_tensor.dtype == torch.int8 and + aqt.quant_min is None or aqt.quant_min == -128 and + aqt.quant_max is None or aqt.quant_max == 127 + ) + +def _aqt_is_int8_reduced_range(aqt): + return ( + aqt.layout_tensor.dtype == torch.int8 and + aqt.quant_min == -127 and + aqt.quant_max is None or aqt.quant_max == 127 + ) + +def _aqt_is_uint4(aqt): + """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" + # TODO: use torch.uint4 + return ( + aqt.layout_tensor.dtype == torch.int32 and + aqt.quant_min is None or aqt.quant_min == 0 and + aqt.quant_max is None or aqt.quant_max == 15 + ) + def _quantized_linear_op(input_tensor, weight_qtensor, bias): """ Quantized version of F.linear operator @@ -811,8 +820,10 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): raise NotImplementedError("No specialized dispatch found for quantized linear op") +implements = AffineQuantizedTensor.implements + @implements(torch.nn.functional.linear) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( args[0], args[1], @@ -831,7 +842,7 @@ def _(func, types, *args, **kwargs): return torch.nn.functional.linear(input_tensor, weight_tensor, bias) @implements([aten.mm.default, aten.addmm.default]) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): if not args[0].is_floating_point(): raise NotImplementedError(f"{func} is not implemented for non floating point input") @@ -870,21 +881,21 @@ def _(func, types, *args, **kwargs): return func(input_tensor, weight_tensor) @implements([aten.detach.default]) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) @implements([aten.clone.default]) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) @implements([aten._to_copy.default]) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, @@ -893,7 +904,7 @@ def _(func, types, *args, **kwargs): ) @implements([aten.t.default]) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): block_size = args[0].block_size assert len(block_size) == 2 transposed_block_size = (block_size[1], block_size[0]) diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 9d49809d47..56dfdaf4db 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -32,8 +32,8 @@ def _(func, types, args, kwargs): def decorator(func): for op in aten_ops_or_torch_fns: @functools.wraps(op) - def wrapper(*args, **kwargs): - return func(*args, **kwargs) + def wrapper(f, types, args, kwargs): + return func(f, types, args, kwargs) cls._ATEN_OP_OR_TORCH_FN_TABLE[op] = wrapper return func @@ -50,7 +50,7 @@ class MyTensor(torch.Tensor): kwargs = {} if kwargs is None else kwargs if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \ func in cls._ATEN_OP_OR_TORCH_FN_TABLE: - return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs) + return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs) with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) @@ -65,7 +65,7 @@ class MyTensor(torch.Tensor): """ if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \ func in cls._ATEN_OP_OR_TORCH_FN_TABLE: - return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs) + return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs) raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func}") @@ -87,6 +87,14 @@ def __repr__(self): def extra_repr(self) -> str: return "" +""" +Plain LayoutType, the most basic LayoutType, also has no extra metadata, will typically be the default +""" +@dataclass(frozen=True) +class PlainLayoutType(LayoutType): + pass + + """ layout tensor constructor registration for different tensor subclassesa diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index 087c9912b4..5e02a5e045 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -89,7 +89,7 @@ def __repr__(self): @OptimState4bit.implements(aten.copy_.default) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): dst = args[0] src = args[1] @@ -116,14 +116,14 @@ def _(func, types, *args, **kwargs): @OptimState4bit.implements(aten.lerp.Scalar) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): args = [x.dequantize() if isinstance(x, OptimState4bit) else x for x in args] return func(*args, **kwargs) # this is needed for DTensor.from_local() and for flattening tensor @OptimState4bit.implements(aten.view.default) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): x, shape = args if tuple(x.shape) == tuple(shape): @@ -142,7 +142,7 @@ def _(func, types, *args, **kwargs): c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default, ]) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): x = args[0] if not isinstance(x, OptimState4bit): raise ValueError(f"expecting a OptimState4bit but found {type(x)}") diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 128a020b66..77459a2a3d 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -75,7 +75,7 @@ def __repr__(self): @OptimState8bit.implements(aten.copy_.default) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): dst = args[0] src = args[1] @@ -98,14 +98,14 @@ def _(func, types, *args, **kwargs): @OptimState8bit.implements(aten.lerp.Scalar) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args] return func(*args, **kwargs) # this is needed for DTensor.from_local() @OptimState8bit.implements(aten.view.default) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): x, shape = args return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed) @@ -117,7 +117,7 @@ def _(func, types, *args, **kwargs): c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default, ]) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): x = args[0] if not isinstance(x, OptimState8bit): raise ValueError(f"expecting a OptimState8bit but found {type(x)}") diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index de1d629fcc..ee97fffc71 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -81,7 +81,7 @@ def __repr__(self): @OptimStateFp8.implements(aten.copy_.default) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): dst = args[0] src = args[1] @@ -102,14 +102,14 @@ def _(func, types, *args, **kwargs): @OptimStateFp8.implements(aten.lerp.Scalar) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args] return func(*args, **kwargs) # this is needed for DTensor.from_local() @OptimStateFp8.implements(aten.view.default) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): x, shape = args return OptimStateFp8(x.codes.view(shape), x.scale) @@ -121,7 +121,7 @@ def _(func, types, *args, **kwargs): c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default, ]) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): x = args[0] if not isinstance(x, OptimStateFp8): raise ValueError(f"expecting a OptimStateFp8 but found {type(x)}") diff --git a/torchao/prototype/quant_llm/quant_llm.py b/torchao/prototype/quant_llm/quant_llm.py index 38eed6dd5e..bbcc978e77 100644 --- a/torchao/prototype/quant_llm/quant_llm.py +++ b/torchao/prototype/quant_llm/quant_llm.py @@ -402,7 +402,7 @@ def _apply_fn_to_data(self, fn): ) @QuantLlmLinearWeight.implements(torch.nn.functional.linear) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): act = args[0] weight = args[1] bias = args[2] if len(args) >= 3 else None @@ -431,7 +431,7 @@ def _(func, types, *args, **kwargs): @QuantLlmLinearWeight.implements(torch.ops.aten.detach.default) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) diff --git a/torchao/quantization/linear_activation_quantized_tensor.py b/torchao/quantization/linear_activation_quantized_tensor.py index e4e4fedc45..dfe1f62de7 100644 --- a/torchao/quantization/linear_activation_quantized_tensor.py +++ b/torchao/quantization/linear_activation_quantized_tensor.py @@ -91,7 +91,7 @@ def to(self, *args, **kwargs): implements = LinearActivationQuantizedTensor.implements @implements(torch.nn.functional.linear) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( args[0], args[1], @@ -106,7 +106,7 @@ def _(func, types, *args, **kwargs): raise NotImplementedError("LinearActivationQuantizedTensor: No specialized dispatch found for linear op") @implements([aten.mm.default, aten.addmm.default]) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): if not args[0].is_floating_point(): raise NotImplementedError(f"LinearActivationQuantizedTensor: expecting a floating point input") @@ -141,19 +141,19 @@ def _(func, types, *args, **kwargs): @implements(aten.detach.default) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) @implements(aten.clone.default) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) @implements(aten._to_copy.default) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, @@ -162,7 +162,7 @@ def _(func, types, *args, **kwargs): ) @implements(aten.t.default) -def _(func, types, *args, **kwargs): +def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.t) ) diff --git a/tutorials/developer_api_guide.py b/tutorials/developer_api_guide.py new file mode 100644 index 0000000000..9c670e14b9 --- /dev/null +++ b/tutorials/developer_api_guide.py @@ -0,0 +1,372 @@ +# Following is a example for a simple dtype implemented with tensor subclass +# it shows +# * the basic structure of a new dtype tensor subclass (__new__, __init__, __tensor_flatten__, __tensor_unflatten__) +# * two types of dispatch that people can overwrite (__torch_function__, __torch_dispatch__) +# * how to abstract away packing format with layout +# * how the tensor subclass composes with torch.compile to get speedup + + +import functools +from collections import defaultdict +from typing import Any, Callable, Dict, Optional, Tuple + +import torch + +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.quantization.quant_primitives import choose_qparams_affine, MappingType +from torchao.dtypes.utils import ( + _implements, + _dispatch__torch_function__, + _dispatch__torch_dispatch__, + _register_layout_cls, + _get_layout_tensor_constructor, + LayoutType, + PlainLayoutType, +) + +aten = torch.ops.aten + +############################### +# Base Layout Tensor Subclass # +############################### +class MyDTypeLayout(torch.Tensor): + """ + Base class for the layout tensor for `MyDTypeTensor` + """ + # get the original unpacked Tensors + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: + return self.int_data, self.scale + + def get_layout_type(self) -> LayoutType: + return self.layout_type + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + layout_type: LayoutType, + ): + """Construct a layout tensor from plain tensors and a layout_type, which main contain + extra metadata for packing etc. + """ + pass + + def __repr__(self): + int_data, scale = self.get_plain() + layout_type = self.get_layout_type() + return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, layout_type={layout_type})" + + __torch_function__ = torch._C._disabled_torch_function_impl + +############################## +# Tensor Subclass Definition # +############################## + +class MyDTypeTensor(torch.Tensor): + """We need to define __new__ for constructing a new tensor subclass instance and __init__ for initialize + the instance. There is no requirement on what the argument list should look like here, only requirement is + that `__new__` must return a Tensor instance with `torch.Tensor._make_wrapper_subclass(cls, shape, ...)` call + """ + + @staticmethod + def __new__( + cls, + layout_tensor: MyDTypeLayout, + shape: torch.Size, + dtype: Optional[torch.dtype] = None, + ): + kwargs = {} + kwargs["device"] = layout_tensor.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else layout_tensor.layout + ) + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + layout_tensor: MyDTypeLayout, + shape: torch.Size, + dtype: Optional[torch.dtype] = None, + ): + self.layout_tensor = layout_tensor + + """__tensor_flatten__ and __tensor_unflatten__ are used to desugar the tensor into native Tensors/attributes and + reconstruct the tensor subclass instance from the desugared tensor and attributes, these are required to define + a Tensor subclass for torch.compile support + """ + + def __tensor_flatten__(self): + """ + Given the class, returns the fields of the class as two lists + The first one contains any tensor fields such as int_data and scale as keys to a dictionary + The second one contains all other non tensor type fields as values of a list + """ + return ["layout_tensor"], [self.shape, self.dtype] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + """ + Given the flattened data from above, returns a class instance + tensor_data_dict contains the tensor fields of the class as a dictionary + tensor_attributes contains all other non tensor type fields + """ + layout_tensor = tensor_data_dict["layout_tensor"] + shape, dtype = tensor_attributes + return cls( + layout_tensor, + shape if outer_size is None else outer_size, + dtype=dtype, + ) + + """classmethod that converts from a floating point Tensor (fp32/fp16/bf16) to the current dtype + """ + + @classmethod + def from_float( + cls, + input_float: torch.Tensor, + layout_type: LayoutType = PlainLayoutType(), + ): + mapping_type = MappingType.SYMMETRIC + block_size = input_float.shape + dtype = torch.int16 + scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype) + int_data = (input_float / scale).to(torch.int8) + layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) + layout_tensor = layout_tensor_ctr(int_data, scale, layout_type) + return cls(layout_tensor, input_float.shape) + + """[Optional] We can overwrite layout property of the Tensor to represent different packing formats + """ + + @property + def layout_type(self) -> LayoutType: + return self.layout_tensor.layout_type + + def dequantize(self, output_dtype=None): + """We can define a dequantize method to convert the quantized tensor to a floating point tensor""" + if output_dtype is None: + output_dtype = torch.get_default_dtype() + int_data, scale = self.layout_tensor.get_plain() + return int_data.to(output_dtype) * scale + + def __repr__(self): + return ( + f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, " + f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" + ) + + def _apply_fn_to_data(self, fn): + """ + Used for implementing aten ops by applying them only to the relevant tensor atributes + In this case we only want to call things like to() or view() on the layout tensor + """ + return self.__class__( + fn(self.layout_tensor), + self.shape, + self.dtype, + ) + + implements = classmethod(_implements) + + """There are two entry points that we can modify the behavior of a pytorch op: torch_function and torch_dispatch: + + __torch_function__: will be called whenever a torch level function is called on the Tensor object, for example: torch.nn.functional.linear, + tensor.detach, tensor.reshape, tensor.t etc. + + __torch_dispatch__: will be called in the C++ dispatcher, when an aten operator is called on the Tensor object, for example: + aten.mm, aten.addmm, aten.detach.default, aten.t.default etc. + + We have some helper functions that can dispatch to the functions registered with MyDTypeTensor.implements, but if the default implementation does not work for your use case, please feel free to customize it + """ + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) + __torch_function__ = classmethod(_dispatch__torch_function__) + +###################################################### +# LayoutType and Layout Tensor Subclass Registration # +###################################################### + +def register_layout_cls(layout_type_class: type(LayoutType)): + return _register_layout_cls(MyDTypeTensor, layout_type_class) + +def get_layout_tensor_constructor(layout_type_class: type(LayoutType)): + return _get_layout_tensor_constructor(MyDTypeTensor, layout_type_class) + + +@register_layout_cls(PlainLayoutType) +class PlainMyDTypeLayout(MyDTypeLayout): + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + layout_type: LayoutType, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + layout_type: LayoutType, + ): + self.int_data = int_data + self.scale = scale + self.layout_type = layout_type + + def __tensor_flatten__(self): + return ["int_data", "scale"], [self.layout_type] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data, scale = tensor_data_dict["int_data"], tensor_data_dict["scale"] + layout_type, = tensor_attributes + return cls(int_data, scale, layout_type) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + layout_type: LayoutType, + ): + """Construct a layout tensor from plain tensors and a layout_type, which main contain + extra metadata for packing etc. + """ + assert isinstance(layout_type, PlainLayoutType) + return cls(int_data, scale, layout_type) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.int_data), + fn(self.scale), + self.layout_type, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"MyDTypeLayout dispatch: attempting to run {func}, this is not supported" + ) + +##################################################### +# torch functional and aten operator implementation # +##################################################### + +implements = MyDTypeTensor.implements + +def _quantized_linear_op(input_tensor, weight_tensor, bias): + if isinstance(input_tensor, MyDTypeTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, MyDTypeTensor): + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +@implements(torch.nn.functional.linear) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + # using try/except here so that we can have a general fallback when input_tensor/weight_tensor + # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to + # make the branches easier to understand in `_quantized_linear_op` + try: + return _quantized_linear_op(input_tensor, weight_tensor, bias) + except NotImplementedError: + if isinstance(input_tensor, MyDTypeTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, MyDTypeTensor): + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +@implements(aten.detach.default) +def _(func, types, args, kwargs): + # `return_and_correct_aliasing` should be used by wrapper tensor ``__torch_dispatch__`` subclasses that would like to + # work with torch.compile. It ensures that the subclass properly implements the aliasing behavior of every op, + # which is needed for correctness in AOTAutograd. + + # `_apply_fn_to_data` just applies the function to the tensor data in `args[0]`, `args[0]` is a tensor subclass + # of `my_dtype` + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +class M(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.linear = torch.nn.Linear(1024, 1024) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + +##################### +# Factory functions # +##################### +to_my_dtype = MyDTypeTensor.from_float + + +######## +# Test # +######## +from torchao.utils import benchmark_model + +m = M() +example_inputs = (100 * torch.randn(1024, 1024),) +NUM_WARMUPS = 10 +NUM_RUNS = 100 + +for _ in range(NUM_WARMUPS): + m(*example_inputs) +print("before quantization:", benchmark_model(m, NUM_RUNS, example_inputs[0])) + +compiled = torch.compile(m, mode="max-autotune") +for _ in range(NUM_WARMUPS): + compiled(*example_inputs) +print("after compile:", benchmark_model(compiled, NUM_RUNS, example_inputs[0])) + +# convert weights to quantized weights +m.linear.weight = torch.nn.Parameter( + to_my_dtype(m.linear.weight), requires_grad=False +) + +for _ in range(NUM_WARMUPS): + m(*example_inputs) + +print("after quantization:", benchmark_model(m, NUM_RUNS, example_inputs[0])) + +m = torch.compile(m, mode="max-autotune") + +for _ in range(NUM_WARMUPS): + m(*example_inputs) + +# NOTE: currently there is no speedup because we just dequantize the weight in the _quantized_linear op +# we plan to add custom op example in the future and that will help us to get speedup +print("after quantization and compile:", benchmark_model(m, NUM_RUNS, example_inputs[0]))