diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index a4f5010981..5f2d1153df 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -87,6 +87,44 @@ def test_to_device(self, apply_quant): ql = apply_quant(l) ql.cuda() + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_register_new_dispatch(self): + from torchao.dtypes.affine_quantized_tensor import ( + register_aqt_quantized_linear_dispatch, + deregister_aqt_quantized_linear_dispatch, + ) + from torchao.dtypes import to_affine_quantized_intx + from torchao.dtypes import AffineQuantizedTensor + from torchao.quantization.quant_primitives import MappingType + + def dispatch_condition(input_tensor, weight_tensor, bias): + return ( + isinstance(weight_tensor, AffineQuantizedTensor) and + weight_tensor.quant_min == 0 and + weight_tensor.quant_max == 2**6-1 + ) + + def impl(input_tensor, weight_tensor, bias): + # this is just for testing, normally people will call into uint6 weight only + # quantized linear operator here + assert False, "dispatching to my impl for uint6 weight only quant" + + register_aqt_quantized_linear_dispatch(dispatch_condition, impl) + + def apply_uint6_weight_only_quant(linear): + linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight, MappingType.ASYMMETRIC, (1, linear.weight.shape[-1]), torch.uint8, 0, 2**6-1), requires_grad=False) + return linear + + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + apply_uint6_weight_only_quant(l) + + example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") + with self.assertRaisesRegex(AssertionError, "dispatching to my impl for uint6 weight only quant"): + l(example_input) + + deregister_aqt_quantized_linear_dispatch(dispatch_condition) + + common_utils.instantiate_parametrized_tests(TestAffineQuantized) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 06bb8aeff9..11b9356adf 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -39,6 +39,9 @@ TORCH_VERSION_AT_LEAST_2_5, _is_float8_type ) +import logging + +logger = logging.getLogger(__name__) from torchao.float8.float8_tensor import ScaledMMConfig aten = torch.ops.aten @@ -88,9 +91,28 @@ class QuantizedLinearNotImplementedError(NotImplementedError): pass -_QLINEAR_DISPATCH_TABLE = {} -def _register_quantized_linear_dispatch(dispatch_condition, impl): - _QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl +_AQT_QLINEAR_DISPATCH_TABLE = {} +def register_aqt_quantized_linear_dispatch(dispatch_condition, impl): + """Register a dispatch for quantized linear op with dispatch_condition function and impl function + both takes three arguments: + input_tensor: dimension is (M1, M2, ..., in_features) + weight_tensor: dimension is (out_features, in_features) + bias: dimension is (out_features,) + so that these can be shared by F.linear, aten.mm, aten.addmm dispatches + + Args: + `dispatch_condition` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], bool]: the dispatch + condition for a specialized quantized linear implementation, e.g. bfloat16 activation + uint4 weight + `impl` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: the specialized + quantized linear implementation + """ + _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl + +def deregister_aqt_quantized_linear_dispatch(dispatch_condition): + if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE: + del _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] + else: + logger.warn(f"Attempting to remove non-existant dispatch condition {dispatch_condition}") class AffineQuantizedTensor(TorchAOBaseTensor): """ @@ -189,7 +211,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor @staticmethod def _quantized_linear_op(input_tensor, weight_tensor, bias): - for dispatch_condition, impl in _QLINEAR_DISPATCH_TABLE.items(): + for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items(): if dispatch_condition(input_tensor, weight_tensor, bias): return impl(input_tensor, weight_tensor, bias) raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op") @@ -440,7 +462,7 @@ def extra_repr(self): @dataclass(frozen=True) class Float8LayoutType(LayoutType): - mm_config: Optional[ScaledMMConfig] + mm_config: Optional[ScaledMMConfig] = None @register_layout_cls(PlainLayoutType) @@ -598,13 +620,13 @@ def from_plain( @register_layout_cls(Float8LayoutType) class Float8AQTLayout(AQTLayout): - """ + """ Layout storage class for float8 layout for affine quantized tensor """ float8_data: torch.Tensor scale: torch.Tensor transposed: bool - + def __new__( cls, float8_data: torch.Tensor, @@ -639,7 +661,7 @@ def _apply_fn_to_data(self, fn): fn(self.float8_data) fn(self.scale) return self - + def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) return self.__class__( @@ -976,21 +998,6 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh y += bias return y -# this is for the case when linear activation is quantized, but is not caught by the previous -# conditions that expects a quantized activation, we just dequantize the activation so that -# it can continue with the weight only quantization dispatches -# NOTE: this is a fallback path that must be registered after all the implementations that expects -# input tensor to be quantized -def _linear_quantized_act_fallback_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) - ) - -def _linear_quantized_act_fallback_impl(input_tensor, weight_tensor, bias): - input_tensor = input_tensor.dequantize() - # dequantize activation and redispatch to F.linear - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): return ( # input is native bfloat16 tensor @@ -1187,19 +1194,18 @@ def _linear_fp_act_fp8_weight_impl( ).reshape(out_shape) -def _register_quantized_linear_dispatches(): +def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), (_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl), - (_linear_quantized_act_fallback_check, _linear_quantized_act_fallback_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), (_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl), ]: - _register_quantized_linear_dispatch(dispatch_condition, impl) + register_aqt_quantized_linear_dispatch(dispatch_condition, impl) -_register_quantized_linear_dispatches() +_register_aqt_quantized_linear_dispatches() @implements(torch.nn.functional.linear) def _(func, types, args, kwargs): @@ -1216,7 +1222,11 @@ def _(func, types, args, kwargs): # make the branches easier to understand in `_quantized_linear_op` try: return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - except QuantizedLinearNotImplementedError: + except QuantizedLinearNotImplementedError as e: + # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl` + if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None: + raise e + if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() if isinstance(weight_tensor, AffineQuantizedTensor): @@ -1239,7 +1249,11 @@ def _(func, types, args, kwargs): try: weight_tensor = weight_tensor.t() return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - except QuantizedLinearNotImplementedError: + except QuantizedLinearNotImplementedError as e: + # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl` + if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None: + raise e + if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() if isinstance(weight_tensor, AffineQuantizedTensor): @@ -1259,7 +1273,11 @@ def _(func, types, args, kwargs): try: weight_tensor = weight_tensor.t() return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - except QuantizedLinearNotImplementedError: + except QuantizedLinearNotImplementedError as e: + # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl` + if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None: + raise e + if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() if isinstance(weight_tensor, AffineQuantizedTensor): diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 036a5ca929..3a197da05e 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -1,5 +1,5 @@ import torch -from typing import Dict, Callable, Union, Tuple +from typing import Dict, Callable, Union, Tuple, Optional from collections import defaultdict import functools from dataclasses import dataclass @@ -73,6 +73,12 @@ class MyTensor(torch.Tensor): """ Base class for different LayoutType, should not be instantiated directly +used to allow users to pass around configurations for the layout tensor, e.g. inner_k_tiles +for int4 tensor core tiled layout + +Note: layout is an abstraction not only for custom data representation, it is also used for how the +layout interacts with different operators, e.g. the same data representation can have different +behaviors when running the same operator, e.g. transpose, quantized_linear. """ @dataclass(frozen=True) class LayoutType: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index aa2d3b3f93..ba670f23b7 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -498,7 +498,7 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): """ Applies float8 weight-only symmetric per-channel quantization to linear layers. - + Args: weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.