Skip to content

Commit

Permalink
Allow quantized linear registration in a different file (#783)
Browse files Browse the repository at this point in the history
* Allow quantized linear registration in a different file

Summary:

Previously there was some ordering that we need to maintain for quantized linear dispatch table in AffineQuantizedTensor,
the reason is there is a fallback entry that dequantizes the input: https://github.com/pytorch/ao/blob/ba2d3b1333b90ccd0186216649a1c58c6a17ce56/torchao/dtypes/affine_quantized_tensor.py#L1195

so the dispatches with two inputs quantized (static or dynamic quantization) must come before this entry and dispatches with weight only quantization, however the fallback is not
really used/needed in practice, since people typically just want to call into a very specific kernel.

From offline discussions with @drisspg and @HDCharles, it might be useful to have a "quantized_linear_impl" for `LayoutType`, this allows people to specify and check which quantized_linear_impl they want to use to make sure they can call into
the specific kernel, when this field is set, we'll not run the fallback path for quantized linear either (dequantize all activation and weight tensors and run the floating point linear op)
I think this can be added for a specific layout type if people want to and we don't have to enforce this in the base `LayoutType`

Test Plan:
python test/dtypes/test_affine_quantized.py -k test_register_new_dispatch

Reviewers:

Subscribers:

Tasks:

Tags:

* fix error

* de-register dispatch

* make register/deregister fn public

* rebase and fix error
  • Loading branch information
jerryzh168 authored Sep 3, 2024
1 parent e2dad4a commit e15e509
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 32 deletions.
38 changes: 38 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,44 @@ def test_to_device(self, apply_quant):
ql = apply_quant(l)
ql.cuda()

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_register_new_dispatch(self):
from torchao.dtypes.affine_quantized_tensor import (
register_aqt_quantized_linear_dispatch,
deregister_aqt_quantized_linear_dispatch,
)
from torchao.dtypes import to_affine_quantized_intx
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType

def dispatch_condition(input_tensor, weight_tensor, bias):
return (
isinstance(weight_tensor, AffineQuantizedTensor) and
weight_tensor.quant_min == 0 and
weight_tensor.quant_max == 2**6-1
)

def impl(input_tensor, weight_tensor, bias):
# this is just for testing, normally people will call into uint6 weight only
# quantized linear operator here
assert False, "dispatching to my impl for uint6 weight only quant"

register_aqt_quantized_linear_dispatch(dispatch_condition, impl)

def apply_uint6_weight_only_quant(linear):
linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight, MappingType.ASYMMETRIC, (1, linear.weight.shape[-1]), torch.uint8, 0, 2**6-1), requires_grad=False)
return linear

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
apply_uint6_weight_only_quant(l)

example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")
with self.assertRaisesRegex(AssertionError, "dispatching to my impl for uint6 weight only quant"):
l(example_input)

deregister_aqt_quantized_linear_dispatch(dispatch_condition)



common_utils.instantiate_parametrized_tests(TestAffineQuantized)

Expand Down
78 changes: 48 additions & 30 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
TORCH_VERSION_AT_LEAST_2_5,
_is_float8_type
)
import logging

logger = logging.getLogger(__name__)

from torchao.float8.float8_tensor import ScaledMMConfig
aten = torch.ops.aten
Expand Down Expand Up @@ -88,9 +91,28 @@ class QuantizedLinearNotImplementedError(NotImplementedError):
pass


_QLINEAR_DISPATCH_TABLE = {}
def _register_quantized_linear_dispatch(dispatch_condition, impl):
_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl
_AQT_QLINEAR_DISPATCH_TABLE = {}
def register_aqt_quantized_linear_dispatch(dispatch_condition, impl):
"""Register a dispatch for quantized linear op with dispatch_condition function and impl function
both takes three arguments:
input_tensor: dimension is (M1, M2, ..., in_features)
weight_tensor: dimension is (out_features, in_features)
bias: dimension is (out_features,)
so that these can be shared by F.linear, aten.mm, aten.addmm dispatches
Args:
`dispatch_condition` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], bool]: the dispatch
condition for a specialized quantized linear implementation, e.g. bfloat16 activation + uint4 weight
`impl` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: the specialized
quantized linear implementation
"""
_AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl

def deregister_aqt_quantized_linear_dispatch(dispatch_condition):
if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE:
del _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition]
else:
logger.warn(f"Attempting to remove non-existant dispatch condition {dispatch_condition}")

class AffineQuantizedTensor(TorchAOBaseTensor):
"""
Expand Down Expand Up @@ -189,7 +211,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor

@staticmethod
def _quantized_linear_op(input_tensor, weight_tensor, bias):
for dispatch_condition, impl in _QLINEAR_DISPATCH_TABLE.items():
for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items():
if dispatch_condition(input_tensor, weight_tensor, bias):
return impl(input_tensor, weight_tensor, bias)
raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op")
Expand Down Expand Up @@ -440,7 +462,7 @@ def extra_repr(self):

@dataclass(frozen=True)
class Float8LayoutType(LayoutType):
mm_config: Optional[ScaledMMConfig]
mm_config: Optional[ScaledMMConfig] = None


@register_layout_cls(PlainLayoutType)
Expand Down Expand Up @@ -598,13 +620,13 @@ def from_plain(

@register_layout_cls(Float8LayoutType)
class Float8AQTLayout(AQTLayout):
"""
"""
Layout storage class for float8 layout for affine quantized tensor
"""
float8_data: torch.Tensor
scale: torch.Tensor
transposed: bool

def __new__(
cls,
float8_data: torch.Tensor,
Expand Down Expand Up @@ -639,7 +661,7 @@ def _apply_fn_to_data(self, fn):
fn(self.float8_data)
fn(self.scale)
return self

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self.__class__(
Expand Down Expand Up @@ -976,21 +998,6 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh
y += bias
return y

# this is for the case when linear activation is quantized, but is not caught by the previous
# conditions that expects a quantized activation, we just dequantize the activation so that
# it can continue with the weight only quantization dispatches
# NOTE: this is a fallback path that must be registered after all the implementations that expects
# input tensor to be quantized
def _linear_quantized_act_fallback_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
)

def _linear_quantized_act_fallback_impl(input_tensor, weight_tensor, bias):
input_tensor = input_tensor.dequantize()
# dequantize activation and redispatch to F.linear
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):
return (
# input is native bfloat16 tensor
Expand Down Expand Up @@ -1187,19 +1194,18 @@ def _linear_fp_act_fp8_weight_impl(
).reshape(out_shape)


def _register_quantized_linear_dispatches():
def _register_aqt_quantized_linear_dispatches():
for dispatch_condition, impl in [
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
(_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl),
(_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl),
(_linear_quantized_act_fallback_check, _linear_quantized_act_fallback_impl),
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),
(_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl),
]:
_register_quantized_linear_dispatch(dispatch_condition, impl)
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)

_register_quantized_linear_dispatches()
_register_aqt_quantized_linear_dispatches()

@implements(torch.nn.functional.linear)
def _(func, types, args, kwargs):
Expand All @@ -1216,7 +1222,11 @@ def _(func, types, args, kwargs):
# make the branches easier to understand in `_quantized_linear_op`
try:
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
except QuantizedLinearNotImplementedError:
except QuantizedLinearNotImplementedError as e:
# fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl`
if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None:
raise e

if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
Expand All @@ -1239,7 +1249,11 @@ def _(func, types, args, kwargs):
try:
weight_tensor = weight_tensor.t()
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
except QuantizedLinearNotImplementedError:
except QuantizedLinearNotImplementedError as e:
# fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl`
if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None:
raise e

if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
Expand All @@ -1259,7 +1273,11 @@ def _(func, types, args, kwargs):
try:
weight_tensor = weight_tensor.t()
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
except QuantizedLinearNotImplementedError:
except QuantizedLinearNotImplementedError as e:
# fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl`
if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None:
raise e

if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
Expand Down
8 changes: 7 additions & 1 deletion torchao/dtypes/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from typing import Dict, Callable, Union, Tuple
from typing import Dict, Callable, Union, Tuple, Optional
from collections import defaultdict
import functools
from dataclasses import dataclass
Expand Down Expand Up @@ -73,6 +73,12 @@ class MyTensor(torch.Tensor):

"""
Base class for different LayoutType, should not be instantiated directly
used to allow users to pass around configurations for the layout tensor, e.g. inner_k_tiles
for int4 tensor core tiled layout
Note: layout is an abstraction not only for custom data representation, it is also used for how the
layout interacts with different operators, e.g. the same data representation can have different
behaviors when running the same operator, e.g. transpose, quantized_linear.
"""
@dataclass(frozen=True)
class LayoutType:
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
"""
Applies float8 weight-only symmetric per-channel quantization to linear layers.
Args:
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
Expand Down

0 comments on commit e15e509

Please sign in to comment.