Skip to content

Commit 433cd14

Browse files
authored
Add support for using AffineQuantizedTensor with weights_only=True (#630)
Summary: `torch.load(file, weights_only=True)` is safer so ideally we can use that, by default it does not work with tensor subclasses, since now we have https://pytorch.org/docs/main/notes/serialization.html#torch.serialization.add_safe_globals we can add all tensor subclass classes and special types to globals so that these can work with `weights_only=True` Test Plan: python test/dtypes/test_affine_quantized.py Reviewers: Subscribers: Tasks: Tags:
1 parent 32f421b commit 433cd14

File tree

6 files changed

+88
-33
lines changed

6 files changed

+88
-33
lines changed

test/dtypes/test_affine_quantized.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,23 @@
22
TestCase,
33
run_tests,
44
)
5-
from torchao.quantization.quant_api import int4_weight_only
5+
from torchao.quantization.quant_api import (
6+
int4_weight_only,
7+
int8_weight_only,
8+
int8_dynamic_activation_int4_weight,
9+
int8_dynamic_activation_int8_weight,
10+
int8_dynamic_activation_int8_semi_sparse_weight,
11+
)
612
import torch
713
import unittest
14+
import tempfile
815
from torchao.utils import (
916
TORCH_VERSION_AFTER_2_5,
1017
)
1118

1219

1320
class TestAffineQuantized(TestCase):
1421
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
15-
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
1622
def test_tensor_core_layout_transpose(self):
1723
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
1824
t = l.weight
@@ -31,5 +37,20 @@ def test_tensor_core_layout_transpose(self):
3137
aqt_shape = aqt.shape
3238
self.assertEqual(aqt_shape, shape)
3339

40+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
41+
def test_weights_only(self):
42+
for apply_quant in [int4_weight_only(group_size=32), int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int8_semi_sparse_weight()]:
43+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
44+
ql = apply_quant(l)
45+
with tempfile.NamedTemporaryFile() as f:
46+
torch.save(ql.state_dict(), f)
47+
f.seek(0)
48+
# `weights_only=True` is enabled for torch 2.5+
49+
if TORCH_VERSION_AFTER_2_5:
50+
_ = torch.load(f, weights_only=True)
51+
else:
52+
_ = torch.load(f, weights_only=False)
53+
54+
3455
if __name__ == "__main__":
3556
run_tests()

torchao/dtypes/affine_quantized_tensor.py

+4
Original file line numberDiff line numberDiff line change
@@ -920,3 +920,7 @@ def _(func, types, args, kwargs):
920920

921921
to_affine_quantized = AffineQuantizedTensor.from_float
922922
to_affine_quantized_static = AffineQuantizedTensor.from_float_static
923+
924+
if TORCH_VERSION_AFTER_2_5:
925+
# Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True`
926+
torch.serialization.add_safe_globals([AffineQuantizedTensor])

torchao/dtypes/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections import defaultdict
44
import functools
55
from dataclasses import dataclass
6+
from torchao.utils import TORCH_VERSION_AFTER_2_5
67

78
"""
89
Helper function for implementing aten op or torch function dispatch
@@ -94,7 +95,6 @@ def extra_repr(self) -> str:
9495
class PlainLayoutType(LayoutType):
9596
pass
9697

97-
9898
"""
9999
layout tensor constructor registration for different tensor subclassesa
100100
@@ -117,6 +117,9 @@ def _register_layout_cls(cls: Callable, layout_type_class: type(LayoutType)):
117117
"""
118118
def decorator(layout_cls):
119119
_LAYOUT_CONSTRUCTOR_TABLE[cls][layout_type_class] = layout_cls.from_plain
120+
if TORCH_VERSION_AFTER_2_5:
121+
# Allow serialization to work for models uses this layout tensor subclass
122+
torch.serialization.add_safe_globals([layout_type_class, layout_cls])
120123
return layout_cls
121124
return decorator
122125

torchao/quantization/linear_activation_quantized_tensor.py

+5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
)
77
from typing import Callable
88
from torch.utils._python_dispatch import return_and_correct_aliasing
9+
from torchao.utils import TORCH_VERSION_AFTER_2_5
910

1011
__all__ = [
1112
"LinearActivationQuantizedTensor",
@@ -175,3 +176,7 @@ def _(func, types, args, kwargs):
175176
)
176177

177178
to_linear_activation_quantized = LinearActivationQuantizedTensor.from_float
179+
180+
if TORCH_VERSION_AFTER_2_5:
181+
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
182+
torch.serialization.add_safe_globals([LinearActivationQuantizedTensor])

torchao/quantization/quant_api.py

+49-30
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from .utils import _get_per_token_block_size
4949
import logging
5050
from .autoquant import autoquant, AutoQuantizableLinearWeight
51+
from torchao.utils import TORCH_VERSION_AFTER_2_5
5152

5253

5354
__all__ = [
@@ -326,6 +327,35 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
326327
_is_linear if filter_fn is None else filter_fn,
327328
)
328329

330+
def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
331+
# avoid circular dep
332+
from torchao.dtypes import to_affine_quantized
333+
334+
mapping_type = MappingType.ASYMMETRIC
335+
target_dtype = torch.int8
336+
return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype)
337+
338+
def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
339+
if weight.shape[-1] % group_size != 0:
340+
return weight
341+
342+
# avoid circular dep
343+
from torchao.dtypes import to_affine_quantized
344+
345+
# weight settings
346+
mapping_type = MappingType.SYMMETRIC
347+
block_size = (1, group_size)
348+
target_dtype = torch.int8
349+
eps = torch.finfo(torch.float32).eps
350+
quant_min = -8
351+
quant_max = 7
352+
353+
# input settings
354+
input_quant_func = _int8_asymm_per_token_quant
355+
356+
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
357+
weight = to_linear_activation_quantized(weight, input_quant_func)
358+
return weight
329359

330360
def int8_dynamic_activation_int4_weight(group_size=32):
331361
"""Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear
@@ -336,31 +366,11 @@ def int8_dynamic_activation_int4_weight(group_size=32):
336366
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
337367
size is more fine grained
338368
"""
339-
def apply_int8_dynamic_activation_int4_weight_quant(weight):
340-
if weight.shape[-1] % group_size != 0:
341-
return weight
342-
343-
# avoid circular dep
344-
from torchao.dtypes import to_affine_quantized
345-
346-
# weight settings
347-
mapping_type = MappingType.SYMMETRIC
348-
block_size = (1, group_size)
349-
target_dtype = torch.int8
350-
eps = torch.finfo(torch.float32).eps
351-
quant_min = -8
352-
quant_max = 7
353-
354-
# input settings
355-
input_mapping_type = MappingType.ASYMMETRIC
356-
input_target_dtype = torch.int8
357-
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype)
358-
359-
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)
360-
weight = to_linear_activation_quantized(weight, input_quant_func)
361-
return weight
369+
def insert_subclass(lin):
370+
lin.weight = torch.nn.Parameter(apply_int8_dynamic_activation_int4_weight_quant(lin.weight, group_size), requires_grad=False)
371+
return lin
362372

363-
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant)
373+
return insert_subclass
364374

365375

366376
def int4_weight_only(group_size=128, inner_k_tiles=8):
@@ -421,6 +431,16 @@ def apply_int8wo_quant(weight):
421431

422432
return _get_linear_subclass_inserter(apply_int8wo_quant)
423433

434+
def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
435+
# avoid circular dep
436+
from torchao.dtypes import to_affine_quantized
437+
mapping_type = MappingType.SYMMETRIC
438+
target_dtype = torch.int8
439+
eps = 1e-5
440+
quant_min = -127
441+
quant_max = 127
442+
return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype, eps=eps, quant_min=quant_min, quant_max=quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
443+
424444

425445
def int8_dynamic_activation_int8_weight(layout_type=PlainLayoutType()):
426446
"""
@@ -444,12 +464,7 @@ def get_weight_block_size(x):
444464
zero_point_dtype = torch.int64
445465

446466
# input settings
447-
input_mapping_type = MappingType.SYMMETRIC
448-
input_target_dtype = torch.int8
449-
input_eps = 1e-5
450-
input_quant_min = -127
451-
input_quant_max = 127
452-
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, _get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
467+
input_quant_func = _int8_symm_per_token_reduced_range_quant
453468

454469
block_size = get_weight_block_size(weight)
455470
weight = to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, layout_type=layout_type)
@@ -466,3 +481,7 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
466481
"""
467482
from torchao.dtypes import SemiSparseLayoutType
468483
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())
484+
485+
486+
if TORCH_VERSION_AFTER_2_5:
487+
torch.serialization.add_safe_globals([_int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant])

torchao/quantization/quant_primitives.py

+3
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class ZeroPointDomain(Enum):
5252
INT = auto()
5353
FLOAT = auto()
5454

55+
if TORCH_VERSION_AFTER_2_5:
56+
torch.serialization.add_safe_globals([MappingType, ZeroPointDomain])
57+
5558
"""
5659
Map from dtype to the bound value of integers
5760
TODO: maybe can replace this with call to torch.iinfo

0 commit comments

Comments
 (0)