Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class HfQuantizer(ABC):

def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
self.quantization_config = quantization_config
self.metadata = {}

# -- Handle extra kwargs below --
self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
Expand Down Expand Up @@ -392,10 +393,6 @@ def get_state_dict_and_metadata(self, model, safe_serialization=False):
"""Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization"""
return None, {}

def update_state_dict_with_metadata(self, state_dict, metadata):
"""Update state dict with metadata. Default behaviour returns state_dict"""
return state_dict

@abstractmethod
def is_serializable(self, safe_serialization=None): ...

Expand Down
98 changes: 42 additions & 56 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import importlib
import re
import types
from collections import defaultdict
from typing import TYPE_CHECKING, Optional, Union

from packaging import version
Expand All @@ -38,14 +37,14 @@
if is_torchao_available():
import torchao

if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"):
from torchao.prototype.awq import AWQConfig
from torchao.prototype.safetensors.safetensors_support import (
flatten_tensor_state_dict,
unflatten_tensor_state_dict,
)
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -87,6 +86,11 @@ def _linear_extra_repr(self):
SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [
torchao.quantization.Float8WeightOnlyConfig,
torchao.quantization.Float8DynamicActivationFloat8WeightConfig,
torchao.quantization.Int4WeightOnlyConfig,
torchao.quantization.IntxWeightOnlyConfig,
torchao.quantization.Int8DynamicActivationIntxWeightConfig,
torchao.quantization.ModuleFqnToConfig,
AWQConfig,
]

TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
Expand All @@ -104,20 +108,6 @@ class TorchAoHfQuantizer(HfQuantizer):
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)

if isinstance(self.quantization_config.quant_type, str):
is_int_4 = "int4" in self.quantization_config.quant_type
else:
config_name = self.quantization_config.quant_type.__class__.__name__
is_int_4 = fuzzy_match_size(config_name) == "4"

# TODO: better way to get the serialized key names? Hard to read from torchao codebase
if is_int_4:
self.weight_ao_keys = ["qdata", "scale", "zero_point"]
else:
self.weight_ao_keys = ["qdata", "scale"]
# Instead of serializing the simple torch.Tensor like usual, torchao adds a `:_data` suffix so we need this
self.full_ao_keys = self.weight_ao_keys + ["_data"]

def validate_environment(self, *args, **kwargs):
if not is_torchao_available():
raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")
Expand Down Expand Up @@ -168,11 +158,11 @@ def get_state_dict_and_metadata(self, model, safe_serialization: Optional[bool]
the safetensors format.
"""
if type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization:
if TORCHAO_VERSION >= version.parse("0.14.0"):
if TORCHAO_VERSION >= version.parse("0.15.0"):
return flatten_tensor_state_dict(model.state_dict())
else:
raise RuntimeError(
f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}"
f"In order to use safetensors with torchao, please use torchao version >= 0.15.0. Current version: {TORCHAO_VERSION}"
)
else:
return None, {}
Expand Down Expand Up @@ -234,7 +224,7 @@ def _process_model_before_weight_loading(
return

def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.full_ao_keys)]
return [k for k in unexpected_keys if "_weight_" not in k]

def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
if self.quantization_config.quant_type == "autoquant":
Expand All @@ -243,7 +233,7 @@ def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **
# check if the param_name is not in self.modules_to_not_convert
if any(key + "." in param_name or key == param_name for key in self.modules_to_not_convert):
return False
elif any(param_name.endswith(f":{x}") for x in self.full_ao_keys):
elif "_weight_" in param_name:
return True
else:
# we only quantize the weight of nn.Linear and nn.Embedding
Expand Down Expand Up @@ -284,42 +274,12 @@ def create_quantized_param(
"""
from torchao.quantization import quantize_

full_name = param_name
# Those are the pre quantized weights
if ":" in param_name:
param_name = param_name.rsplit(":", 1)[0]
module, tensor_name = get_module_from_name(model, param_name)

if self.pre_quantized:
# If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
# already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
is_unsafe_serialization = ":" not in full_name
if tensor_name == "bias" or is_unsafe_serialization:
module._parameters[tensor_name] = torch.nn.Parameter(
param_value.to(target_device), requires_grad=param_value.requires_grad
)
return
# Sanity check for the new serialization format
elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.metadata)):
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")

# Save the states for later quantization when they are all gathered
if not hasattr(self, "ao_params"):
self.ao_params = defaultdict(dict)
self.ao_params[param_name].update({full_name: param_value})

# We are ready for quantization in this case (we retrieved all the needed keys)
if len(self.ao_params[param_name]) == len(self.weight_ao_keys):
new_param = unflatten_tensor_state_dict(self.ao_params[param_name], self.metadata)[param_name]
# Set it
module._parameters[tensor_name] = torch.nn.Parameter(
new_param.to(target_device), requires_grad=new_param.requires_grad
)

# Free memory
del self.ao_params[param_name]
module._parameters[tensor_name] = torch.nn.Parameter(
param_value.to(target_device), requires_grad=param_value.requires_grad
)

# Add repr to the module
if isinstance(module, nn.Linear):
module.extra_repr = types.MethodType(_linear_extra_repr, module)
else:
Expand Down Expand Up @@ -430,6 +390,32 @@ def preprocess_model(self, model: "PreTrainedModel", config, dtype=None, checkpo

def _process_model_after_weight_loading(self, model, **kwargs):
"""No process required for torchao quantized model"""
if TORCHAO_VERSION >= version.parse("0.15.0") and is_metadata_torchao(self.metadata):
_, updated_state_dict = unflatten_tensor_state_dict(model.state_dict(), self.metadata)

weights_to_register = set(updated_state_dict.keys())

for name, param in list(model.named_parameters()):
module_fqn, weight_name = name.rsplit(".", 1)
module = model.get_submodule(module_fqn)
weight = getattr(module, weight_name)

device = weight.device
requires_grad = weight.requires_grad

if "_weight_" in weight_name:
delattr(module, weight_name)

if name in weights_to_register:
new_param_value = updated_state_dict[name]
new_param = torch.nn.Parameter(new_param_value.to(device), requires_grad=requires_grad)
module.register_parameter(weight_name, new_param)

weights_to_register.remove(name)

model.load_state_dict(updated_state_dict, strict=False)
return

if self.quantization_config.quant_type == "autoquant":
from torchao import autoquant
from torchao.quantization import ALL_AUTOQUANT_CLASS_LIST
Expand All @@ -448,11 +434,11 @@ def is_serializable(self, safe_serialization=None) -> bool:
if safe_serialization:
_is_torchao_serializable = type(
self.quantization_config.quant_type
) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0")
) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.15.0")
if not _is_torchao_serializable:
logger.warning(
f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \
and torchao version >= 0.14.0, please set `safe_serialization` to False for \
and torchao version >= 0.15.0, please set `safe_serialization` to False for \
{type(self.quantization_config.quant_type)} and {TORCHAO_VERSION}."
)
return _is_torchao_serializable
Expand Down
14 changes: 12 additions & 2 deletions tests/quantization/torchao_integration/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,6 @@ def tearDown(self):
def test_original_model_expected_output(self):
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)

self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

def check_serialization_expected_output(self, device, expected_output, safe_serialization=False):
Expand All @@ -723,11 +722,12 @@ def test_serialization_expected_output(self):


@require_torchao
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoSafeSerializationTest(TorchAoSerializationTest):
# called only once for all test in this class
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
# placeholder
Expand All @@ -748,6 +748,16 @@ def tearDown(self):
"What are we having for dinner?\n\nJess: (smiling) I",
),
(torchao.quantization.Float8WeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"),
(Int4WeightOnlyConfig(), "What are we having for dinner?"),
(
Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"),
"What are we having for dinner?\nRed, white, and green beans,",
),
(
torchao.quantization.Int8DynamicActivationIntxWeightConfig(),
"What are we having for dinner?\n\nJessica: (smiling)",
),
(torchao.quantization.IntxWeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"),
]
if is_torchao_available()
else []
Expand Down