Skip to content

Commit

Permalink
Add support for using AffineQuantizedTensor with weights_only=True (#…
Browse files Browse the repository at this point in the history
…630)

Summary:
`torch.load(file, weights_only=True)` is safer so ideally we can use that, by default it does not work
with tensor subclasses, since now we have https://pytorch.org/docs/main/notes/serialization.html#torch.serialization.add_safe_globals
we can add all tensor subclass classes and special types to globals so that these can work with `weights_only=True`

Test Plan:
python test/dtypes/test_affine_quantized.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Aug 8, 2024
1 parent 32f421b commit 433cd14
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 33 deletions.
25 changes: 23 additions & 2 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@
TestCase,
run_tests,
)
from torchao.quantization.quant_api import int4_weight_only
from torchao.quantization.quant_api import (
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
)
import torch
import unittest
import tempfile
from torchao.utils import (
TORCH_VERSION_AFTER_2_5,
)


class TestAffineQuantized(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_tensor_core_layout_transpose(self):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
t = l.weight
Expand All @@ -31,5 +37,20 @@ def test_tensor_core_layout_transpose(self):
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_weights_only(self):
for apply_quant in [int4_weight_only(group_size=32), int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int8_semi_sparse_weight()]:
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(l)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
# `weights_only=True` is enabled for torch 2.5+
if TORCH_VERSION_AFTER_2_5:
_ = torch.load(f, weights_only=True)
else:
_ = torch.load(f, weights_only=False)


if __name__ == "__main__":
run_tests()
4 changes: 4 additions & 0 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,3 +920,7 @@ def _(func, types, args, kwargs):

to_affine_quantized = AffineQuantizedTensor.from_float
to_affine_quantized_static = AffineQuantizedTensor.from_float_static

if TORCH_VERSION_AFTER_2_5:
# Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([AffineQuantizedTensor])
5 changes: 4 additions & 1 deletion torchao/dtypes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict
import functools
from dataclasses import dataclass
from torchao.utils import TORCH_VERSION_AFTER_2_5

"""
Helper function for implementing aten op or torch function dispatch
Expand Down Expand Up @@ -94,7 +95,6 @@ def extra_repr(self) -> str:
class PlainLayoutType(LayoutType):
pass


"""
layout tensor constructor registration for different tensor subclassesa
Expand All @@ -117,6 +117,9 @@ def _register_layout_cls(cls: Callable, layout_type_class: type(LayoutType)):
"""
def decorator(layout_cls):
_LAYOUT_CONSTRUCTOR_TABLE[cls][layout_type_class] = layout_cls.from_plain
if TORCH_VERSION_AFTER_2_5:
# Allow serialization to work for models uses this layout tensor subclass
torch.serialization.add_safe_globals([layout_type_class, layout_cls])
return layout_cls
return decorator

Expand Down
5 changes: 5 additions & 0 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)
from typing import Callable
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.utils import TORCH_VERSION_AFTER_2_5

__all__ = [
"LinearActivationQuantizedTensor",
Expand Down Expand Up @@ -175,3 +176,7 @@ def _(func, types, args, kwargs):
)

to_linear_activation_quantized = LinearActivationQuantizedTensor.from_float

if TORCH_VERSION_AFTER_2_5:
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([LinearActivationQuantizedTensor])
79 changes: 49 additions & 30 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from .utils import _get_per_token_block_size
import logging
from .autoquant import autoquant, AutoQuantizableLinearWeight
from torchao.utils import TORCH_VERSION_AFTER_2_5


__all__ = [
Expand Down Expand Up @@ -326,6 +327,35 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
_is_linear if filter_fn is None else filter_fn,
)

def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
# avoid circular dep
from torchao.dtypes import to_affine_quantized

mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int8
return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype)

def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
if weight.shape[-1] % group_size != 0:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized

# weight settings
mapping_type = MappingType.SYMMETRIC
block_size = (1, group_size)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
quant_min = -8
quant_max = 7

# input settings
input_quant_func = _int8_asymm_per_token_quant

weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
weight = to_linear_activation_quantized(weight, input_quant_func)
return weight

def int8_dynamic_activation_int4_weight(group_size=32):
"""Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear
Expand All @@ -336,31 +366,11 @@ def int8_dynamic_activation_int4_weight(group_size=32):
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
size is more fine grained
"""
def apply_int8_dynamic_activation_int4_weight_quant(weight):
if weight.shape[-1] % group_size != 0:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized

# weight settings
mapping_type = MappingType.SYMMETRIC
block_size = (1, group_size)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
quant_min = -8
quant_max = 7

# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype)

weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
weight = to_linear_activation_quantized(weight, input_quant_func)
return weight
def insert_subclass(lin):
lin.weight = torch.nn.Parameter(apply_int8_dynamic_activation_int4_weight_quant(lin.weight, group_size), requires_grad=False)
return lin

return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant)
return insert_subclass


def int4_weight_only(group_size=128, inner_k_tiles=8):
Expand Down Expand Up @@ -421,6 +431,16 @@ def apply_int8wo_quant(weight):

return _get_linear_subclass_inserter(apply_int8wo_quant)

def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
# avoid circular dep
from torchao.dtypes import to_affine_quantized
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = 1e-5
quant_min = -127
quant_max = 127
return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)


def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()):
"""
Expand All @@ -444,12 +464,7 @@ def get_weight_block_size(x):
zero_point_dtype = torch.int64

# input settings
input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
input_quant_func = _int8_symm_per_token_reduced_range_quant

block_size = get_weight_block_size(weight)
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
Expand All @@ -466,3 +481,7 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
"""
from torchao.dtypes import SemiSparseLayoutType
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())


if TORCH_VERSION_AFTER_2_5:
torch.serialization.add_safe_globals([_int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant])
3 changes: 3 additions & 0 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class ZeroPointDomain(Enum):
INT = auto()
FLOAT = auto()

if TORCH_VERSION_AFTER_2_5:
torch.serialization.add_safe_globals([MappingType, ZeroPointDomain])

"""
Map from dtype to the bound value of integers
TODO: maybe can replace this with call to torch.iinfo
Expand Down

0 comments on commit 433cd14

Please sign in to comment.