diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile
index 44d1ceb2bfdd..700df877d10f 100755
--- a/docker/transformers-quantization-latest-gpu/Dockerfile
+++ b/docker/transformers-quantization-latest-gpu/Dockerfile
@@ -53,6 +53,9 @@ RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2
# Add vptq for quantization testing
RUN python3 -m pip install --no-cache-dir vptq
+# Add spqr for quantization testing
+RUN python3 -m pip install --no-cache-dir spqr_quant[gpu]
+
# Add hqq for quantization testing
RUN python3 -m pip install --no-cache-dir hqq
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 5540bd1826a9..17a1bb1b3b33 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -166,6 +166,8 @@
- local: quantization/aqlm
title: AQLM
- local: quantization/vptq
+ title: SpQR
+ - local: quantization/spqr
title: VPTQ
- local: quantization/quanto
title: Quanto
diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md
index cd3e2705ab34..6da5b8ce69b5 100755
--- a/docs/source/en/main_classes/quantization.md
+++ b/docs/source/en/main_classes/quantization.md
@@ -81,6 +81,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
[[autodoc]] BitNetConfig
+## SpQRConfig
+
+[[autodoc]] SpQRConfig
+
## FineGrainedFP8Config
[[autodoc]] FineGrainedFP8Config
diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md
index caebebe81547..94696e300a57 100644
--- a/docs/source/en/quantization/overview.md
+++ b/docs/source/en/quantization/overview.md
@@ -61,6 +61,7 @@ Use the table below to help you decide which quantization method to use.
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | 🟡 5 | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
| [VPTQ](./vptq.md) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
+| [SpQR](./spqr.md) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
| [FINEGRAINED_FP8](./finegrained_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
diff --git a/docs/source/en/quantization/spqr.md b/docs/source/en/quantization/spqr.md
new file mode 100644
index 000000000000..b9ebb99b69cb
--- /dev/null
+++ b/docs/source/en/quantization/spqr.md
@@ -0,0 +1,35 @@
+
+
+# SpQR
+
+[SpQR](https://github.com/Vahe1994/SpQR) quantization algorithm involves a 16x16 tiled bi-level group 3-bit quantization structure, with sparse outliers as detailed in [SpQR: A Sparse-Quantized Representation for Near-Lossless LLM Weight Compression](https://arxiv.org/abs/2306.03078).
+
+To SpQR-quantize a model, refer to the [Vahe1994/SpQR](https://github.com/Vahe1994/SpQR) repository.
+
+Load a pre-SpQR-quantized model in [`~PreTrainedModel.from_pretrained`].
+
+```python
+from transformers import AutoTokenizer, AutoModelForCausalLM
+import torch
+
+quantized_model = AutoModelForCausalLM.from_pretrained(
+ "elvircrn/Llama-2-7b-SPQR-3Bit-16x16-red_pajama-hf",
+ torch_dtype=torch.half,
+ device_map="auto"
+)
+tokenizer = AutoTokenizer.from_pretrained("elvircrn/Llama-2-7b-SPQR-3Bit-16x16-red_pajama-hf")
+```
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index b6412cf59360..8b97168ecf5f 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -1029,6 +1029,7 @@
"HiggsConfig",
"HqqConfig",
"QuantoConfig",
+ "SpQRConfig",
"TorchAoConfig",
"VptqConfig",
],
@@ -6202,6 +6203,7 @@
HiggsConfig,
HqqConfig,
QuantoConfig,
+ SpQRConfig,
TorchAoConfig,
VptqConfig,
)
diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py
index c78564dcba8e..b545c5da50a5 100755
--- a/src/transformers/integrations/__init__.py
+++ b/src/transformers/integrations/__init__.py
@@ -106,6 +106,7 @@
],
"peft": ["PeftAdapterMixin"],
"quanto": ["replace_with_quanto_layers"],
+ "spqr": ["replace_with_spqr_linear"],
"vptq": ["replace_with_vptq_linear"],
}
@@ -210,6 +211,7 @@
)
from .peft import PeftAdapterMixin
from .quanto import replace_with_quanto_layers
+ from .spqr import replace_with_spqr_linear
from .vptq import replace_with_vptq_linear
try:
diff --git a/src/transformers/integrations/spqr.py b/src/transformers/integrations/spqr.py
new file mode 100644
index 000000000000..58b71740d37c
--- /dev/null
+++ b/src/transformers/integrations/spqr.py
@@ -0,0 +1,122 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"SpQR (Sparse-Quantized Representation) integration file"
+
+from ..utils import is_accelerate_available, is_spqr_available, is_torch_available
+
+
+if is_torch_available():
+ import torch.nn as nn
+
+
+def replace_with_spqr_linear(
+ model,
+ quantization_config=None,
+ modules_to_not_convert=None,
+ current_key_name=None,
+ has_been_replaced=False,
+):
+ """
+ Public method that recursively replaces the Linear layers of the given model with SpQR quantized layers.
+ `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
+ conversion has been successful or not.
+
+ Args:
+ model (`torch.nn.Module`):
+ The model to convert, can be any `torch.nn.Module` instance.
+ quantization_config (`SpQRConfig`):
+ The quantization config object that contains the quantization parameters.
+ modules_to_not_convert (`list[str]`, *optional*):
+ A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
+ converted.
+ current_key_name (`list`, *optional*):
+ A list that contains the current key name. This is used for recursion and should not be passed by the user.
+ has_been_replaced (`bool`, *optional*):
+ A boolean that indicates if the conversion has been successful or not. This is used for recursion and
+ should not be passed by the user.
+ """
+ if modules_to_not_convert is None:
+ modules_to_not_convert = []
+
+ if is_accelerate_available():
+ from accelerate import init_empty_weights
+ if is_spqr_available():
+ from spqr_quant import QuantizedLinear
+
+ for name, module in model.named_children():
+ if current_key_name is None:
+ current_key_name = []
+ current_key_name.append(name)
+
+ if isinstance(module, nn.Linear):
+ # Check if the current key is not in the `modules_to_not_convert`
+ if ".".join(current_key_name) + ".weight" not in modules_to_not_convert:
+ with init_empty_weights():
+ tensor_name = ".".join(current_key_name)
+
+ shapes = quantization_config.shapes
+ shapes_keys = shapes.keys()
+
+ shapes_valid = (
+ f"{tensor_name}.dense_weights.shape" in shapes_keys
+ and f"{tensor_name}.row_offsets.shape" in shapes_keys
+ and f"{tensor_name}.col_vals.shape" in shapes_keys
+ and f"{tensor_name}.in_perm.shape" in shapes_keys
+ )
+
+ if not shapes_valid:
+ raise ValueError(
+ f"The SpQR quantization config does not contain the shape "
+ f"configuration for {tensor_name}. This indicates that the "
+ f"configuration is either invalid or corrupted."
+ )
+
+ dense_weights_shape = shapes[f"{tensor_name}.dense_weights.shape"]
+ row_offsets_shape = shapes[f"{tensor_name}.row_offsets.shape"]
+ col_vals_shape = shapes[f"{tensor_name}.col_vals.shape"]
+ in_perm_shape = shapes[f"{tensor_name}.in_perm.shape"]
+
+ in_features = module.in_features
+ out_features = module.out_features
+
+ model._modules[name] = QuantizedLinear.create_placehodler(
+ rows=out_features,
+ cols=in_features,
+ bits=quantization_config.bits,
+ beta1=quantization_config.beta1,
+ beta2=quantization_config.beta2,
+ dense_weights_shape=dense_weights_shape,
+ row_offsets_shape=row_offsets_shape,
+ col_vals_shape=col_vals_shape,
+ in_perm_shape=in_perm_shape,
+ )
+ has_been_replaced = True
+
+ # Store the module class in case we need to transpose the weight later
+ model._modules[name].source_cls = type(module)
+ # Force requires grad to False to avoid unexpected errors
+ model._modules[name].requires_grad_(False)
+ else:
+ pass
+ if len(list(module.children())) > 0:
+ _, has_been_replaced = replace_with_spqr_linear(
+ module,
+ quantization_config=quantization_config,
+ modules_to_not_convert=modules_to_not_convert,
+ current_key_name=current_key_name,
+ has_been_replaced=has_been_replaced,
+ )
+ # Remove the last key for recursion
+ current_key_name.pop(-1)
+ return model, has_been_replaced
diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py
index 0f279498f7ec..ee7c832b1de1 100755
--- a/src/transformers/quantizers/auto.py
+++ b/src/transformers/quantizers/auto.py
@@ -31,6 +31,7 @@
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
+ SpQRConfig,
TorchAoConfig,
VptqConfig,
)
@@ -47,6 +48,7 @@
from .quantizer_higgs import HiggsHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
+from .quantizer_spqr import SpQRHfQuantizer
from .quantizer_torchao import TorchAoHfQuantizer
from .quantizer_vptq import VptqHfQuantizer
@@ -66,6 +68,7 @@
"torchao": TorchAoHfQuantizer,
"bitnet": BitNetHfQuantizer,
"vptq": VptqHfQuantizer,
+ "spqr": SpQRHfQuantizer,
"fp8": FineGrainedFP8HfQuantizer,
}
@@ -84,6 +87,7 @@
"torchao": TorchAoConfig,
"bitnet": BitNetConfig,
"vptq": VptqConfig,
+ "spqr": SpQRConfig,
"fp8": FineGrainedFP8Config,
}
diff --git a/src/transformers/quantizers/quantizer_spqr.py b/src/transformers/quantizers/quantizer_spqr.py
new file mode 100644
index 000000000000..60cc1bca9b27
--- /dev/null
+++ b/src/transformers/quantizers/quantizer_spqr.py
@@ -0,0 +1,83 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/lic enses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING, Optional
+
+from .base import HfQuantizer
+
+
+if TYPE_CHECKING:
+ from ..modeling_utils import PreTrainedModel
+
+from ..integrations import replace_with_spqr_linear
+from ..utils import is_accelerate_available, is_spqr_available, is_torch_available, logging
+from ..utils.quantization_config import QuantizationConfigMixin
+
+
+if is_torch_available():
+ import torch
+
+logger = logging.get_logger(__name__)
+
+
+class SpQRHfQuantizer(HfQuantizer):
+ """
+ Quantizer of the SpQR method. Enables the loading of prequantized models.
+ """
+
+ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+ self.quantization_config = quantization_config
+
+ def validate_environment(self, *args, **kwargs):
+ if not torch.cuda.is_available():
+ raise RuntimeError("GPU is required to run SpQR quantized model.")
+
+ if not is_accelerate_available():
+ raise ImportError("Using `spqr` quantization requires Accelerate: `pip install accelerate`")
+
+ if not is_spqr_available():
+ raise ImportError("Using `spqr` quantization requires SpQR: `pip install spqr_quant[gpu]`")
+
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ if torch_dtype is None:
+ torch_dtype = torch.float16
+ logger.info("Assuming SpQR inference on GPU and loading the model in `torch.float16`.")
+ elif torch_dtype != torch.float16:
+ raise ValueError(
+ "You cannot use any type other than torch.float16 for SpQR. Please either leave it None or set it to"
+ "torch.float16 explicitly."
+ )
+ return torch_dtype
+
+ def _process_model_before_weight_loading(
+ self,
+ model: "PreTrainedModel",
+ **kwargs,
+ ):
+ replace_with_spqr_linear(
+ model,
+ quantization_config=self.quantization_config,
+ modules_to_not_convert=self.quantization_config.modules_to_not_convert,
+ )
+ model.config.quantization_config = self.quantization_config
+
+ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
+ return model
+
+ @property
+ def is_trainable(self, model: Optional["PreTrainedModel"] = None):
+ return False
+
+ def is_serializable(self, safe_serialization=None):
+ return True
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index c71619e4e8f9..14fef2988488 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -121,6 +121,7 @@
is_seqio_available,
is_soundfile_available,
is_spacy_available,
+ is_spqr_available,
is_sudachi_available,
is_sudachi_projection_available,
is_tensorflow_probability_available,
@@ -1191,6 +1192,13 @@ def require_vptq(test_case):
return unittest.skipUnless(is_vptq_available(), "test requires vptq")(test_case)
+def require_spqr(test_case):
+ """
+ Decorator marking a test that requires spqr
+ """
+ return unittest.skipUnless(is_spqr_available(), "test requires spqr")(test_case)
+
+
def require_eetq(test_case):
"""
Decorator marking a test that requires eetq
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index bf56a584469f..cf13060ee307 100755
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -193,6 +193,7 @@
is_soundfile_available,
is_spacy_available,
is_speech_available,
+ is_spqr_available,
is_sudachi_available,
is_sudachi_projection_available,
is_tensorflow_probability_available,
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index c14ff2124aa0..bd95b6f282c0 100755
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -201,6 +201,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_blobfile_available = _is_package_available("blobfile")
_liger_kernel_available = _is_package_available("liger_kernel")
_triton_available = _is_package_available("triton")
+_spqr_available = _is_package_available("spqr_quant")
_torch_version = "N/A"
_torch_available = False
@@ -1213,6 +1214,10 @@ def is_speech_available():
return _torchaudio_available
+def is_spqr_available():
+ return _spqr_available
+
+
def is_phonemizer_available():
return _phonemizer_available
diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py
index 329123e0f4a4..11415e895d91 100755
--- a/src/transformers/utils/quantization_config.py
+++ b/src/transformers/utils/quantization_config.py
@@ -56,6 +56,7 @@ class QuantizationMethod(str, Enum):
FBGEMM_FP8 = "fbgemm_fp8"
TORCHAO = "torchao"
BITNET = "bitnet"
+ SPQR = "spqr"
FP8 = "fp8"
@@ -1551,6 +1552,75 @@ def post_init(self):
pass
+@dataclass
+class SpQRConfig(QuantizationConfigMixin):
+ """
+ This is a wrapper class about `spqr` parameters. Refer to the original publication for more details.
+
+ Args:
+ bits (`int`, *optional*, defaults to 3):
+ Specifies the bit count for the weights and first order zero-points and scales.
+ Currently only bits = 3 is supported.
+ beta1 (`int`, *optional*, defaults to 16):
+ SpQR tile width. Currently only beta1 = 16 is supported.
+ beta2 (`int`, *optional*, defaults to 16):
+ SpQR tile height. Currently only beta2 = 16 is supported.
+ shapes (`Optional`, *optional*):
+ A dictionary holding the shape of each object. We need this because it's impossible
+ to deduce the exact size of the parameters just from bits, beta1, beta2.
+ modules_to_not_convert (`Optional[List[str]]`, *optional*):
+ Optionally, provides a list of full paths of `nn.Linear` weight parameters that shall not be quantized.
+ Defaults to None.
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional parameters from which to initialize the configuration object.
+ """
+
+ def __init__(
+ self,
+ bits: int = 3,
+ beta1: int = 16,
+ beta2: int = 16,
+ shapes: Optional[Dict[str, int]] = None,
+ modules_to_not_convert: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ if shapes is None:
+ shapes = {}
+ self.shapes = shapes
+ self.quant_method = QuantizationMethod.SPQR
+ self.bits = bits
+ self.beta1 = beta1
+ self.beta2 = beta2
+ if modules_to_not_convert is None:
+ modules_to_not_convert = []
+ self.modules_to_not_convert = modules_to_not_convert
+ self.post_init()
+
+ def post_init(self):
+ r"""
+ Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
+ """
+ if not isinstance(self.bits, int):
+ raise TypeError("bits must be an int")
+ if not isinstance(self.beta1, int):
+ raise TypeError("beta1 must be an int")
+ if not isinstance(self.beta2, int):
+ raise TypeError("beta2 must be an int")
+
+ if self.bits != 3:
+ raise ValueError("SpQR currently only supports bits = 3")
+ if self.beta1 != 16:
+ raise ValueError("SpQR currently only supports beta1 = 16")
+ if self.beta2 != 16:
+ raise ValueError("SpQR currently only supports beta2 = 16")
+
+ if self.modules_to_not_convert is not None and not isinstance(self.modules_to_not_convert, list):
+ raise ValueError("modules_to_not_convert must be a list of strings")
+
+ if not isinstance(self.shapes, dict):
+ raise TypeError("shapes must be a dict")
+
+
@dataclass
class FineGrainedFP8Config(QuantizationConfigMixin):
"""
diff --git a/tests/quantization/spqr_integration/__init__.py b/tests/quantization/spqr_integration/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/quantization/spqr_integration/test_spqr.py b/tests/quantization/spqr_integration/test_spqr.py
new file mode 100644
index 000000000000..134e57af5de1
--- /dev/null
+++ b/tests/quantization/spqr_integration/test_spqr.py
@@ -0,0 +1,249 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import tempfile
+import unittest
+
+from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, SpQRConfig, StaticCache
+from transformers.testing_utils import (
+ require_accelerate,
+ require_spqr,
+ require_torch_gpu,
+ require_torch_multi_gpu,
+ slow,
+ torch_device,
+)
+from transformers.utils import is_accelerate_available, is_torch_available
+
+
+if is_torch_available():
+ import torch
+
+if is_accelerate_available():
+ from accelerate import init_empty_weights
+
+
+@require_torch_gpu
+class SpQRConfigTest(unittest.TestCase):
+ def test_to_dict(self):
+ """
+ Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
+ """
+ quantization_config = SpQRConfig()
+ config_to_dict = quantization_config.to_dict()
+
+ for key in config_to_dict:
+ self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
+
+ def test_from_dict(self):
+ """
+ Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
+ """
+ dict = {
+ "beta1": 16,
+ "beta2": 16,
+ "bits": 3,
+ "modules_to_not_convert": ["lm_head.weight"],
+ "shapes": {"model.layers.0.self_attn.q_proj.dense_weights.shape": 16},
+ }
+ quantization_config = SpQRConfig.from_dict(dict)
+
+ self.assertEqual(dict["beta1"], quantization_config.beta1)
+ self.assertEqual(dict["beta2"], quantization_config.beta2)
+ self.assertEqual(dict["bits"], quantization_config.bits)
+ self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert)
+ self.assertEqual(dict["shapes"], quantization_config.shapes)
+
+
+@slow
+@require_torch_gpu
+@require_spqr
+@require_accelerate
+class SpQRTest(unittest.TestCase):
+ model_name = "elvircrn/Llama-2-7b-SPQR-3Bit-16x16-red_pajama-hf"
+
+ input_text = "Hello my name is"
+ max_new_tokens = 32
+
+ EXPECTED_OUTPUT = (
+ "Hello my name is Jesse. (I'm also known as Jesse) I'm a 25 year old male from United States. I'm looking for"
+ )
+ EXPECTED_OUTPUT_COMPILE = "Hello my name is Jake and I am a 20 year old student at the University of North Texas. (Go Mean Green!) I am a huge fan of the Dallas"
+
+ device_map = "cuda"
+
+ # called only once for all test in this class
+ @classmethod
+ def setUpClass(cls):
+ """
+ Setup quantized model
+ """
+ cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
+ cls.quantized_model = AutoModelForCausalLM.from_pretrained(
+ cls.model_name,
+ device_map=cls.device_map,
+ )
+
+ def tearDown(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def test_quantized_model_conversion(self):
+ """
+ Simple test that checks if the quantized model has been converted properly
+ """
+ from spqr_quant import QuantizedLinear
+
+ from transformers.integrations import replace_with_spqr_linear
+
+ model_id = "meta-llama/Llama-2-7b-hf"
+ config = AutoConfig.from_pretrained(model_id)
+ quantization_config = AutoConfig.from_pretrained(self.model_name, return_dict=False).quantization_config
+ quantization_config = SpQRConfig.from_dict(quantization_config)
+
+ with init_empty_weights():
+ model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id, config=config)
+
+ nb_linears = 0
+ for module in model.modules():
+ if isinstance(module, torch.nn.Linear):
+ nb_linears += 1
+
+ model, _ = replace_with_spqr_linear(
+ model,
+ quantization_config=quantization_config,
+ modules_to_not_convert=quantization_config.modules_to_not_convert,
+ )
+
+ nb_spqr_linear = 0
+ for module in model.modules():
+ if isinstance(module, QuantizedLinear):
+ nb_spqr_linear += 1
+
+ self.assertEqual(nb_linears - 1, nb_spqr_linear)
+
+ def test_quantized_model(self):
+ """
+ Simple test that checks if the quantized model is working properly
+ """
+ input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
+
+ output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
+ self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
+
+ def test_raise_if_non_quantized(self):
+ model_id = "meta-llama/Llama-2-7b-hf"
+ quantization_config = SpQRConfig()
+
+ with self.assertRaises(ValueError):
+ _ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
+
+ @unittest.skip
+ def test_save_pretrained(self):
+ """
+ Simple test that checks if the quantized model is working properly after being saved and loaded
+ """
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ self.quantized_model.save_pretrained(tmpdirname)
+ model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
+
+ input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
+
+ output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
+ self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
+
+ @require_torch_multi_gpu
+ def test_quantized_model_multi_gpu(self):
+ """
+ Simple test that checks if the quantized model is working properly with multiple GPUs
+ """
+ input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
+
+ quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto")
+
+ self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
+
+ output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
+
+ self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
+
+ def test_quantized_model_compile(self):
+ """
+ Simple test that checks if the quantized model is working properly
+ """
+
+ # Sample tokens greedily
+ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
+ logits = model(
+ cur_token,
+ position_ids=input_pos,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ return_dict=False,
+ use_cache=True,
+ )[0]
+ new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
+
+ return new_token
+
+ # Tokenize the test input
+ input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)["input_ids"]
+ seq_length = input_ids.shape[1]
+
+ # Setup static KV cache for generation
+ past_key_values = StaticCache(
+ config=self.quantized_model.config,
+ batch_size=1,
+ max_cache_len=seq_length + self.max_new_tokens + 1,
+ device=torch_device,
+ dtype=self.quantized_model.config._pre_quantization_dtype,
+ )
+
+ # Allocate token ids to be generated and copy prefix ids
+ cache_position = torch.arange(seq_length, device=torch_device)
+ generated_ids = torch.zeros(1, seq_length + self.max_new_tokens, dtype=torch.int, device=torch_device)
+ generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int)
+
+ # Do a forward pass to fill the prefix cache and compile the kernels if necessary
+ logits = self.quantized_model(
+ input_ids,
+ cache_position=cache_position,
+ past_key_values=past_key_values,
+ return_dict=False,
+ use_cache=True,
+ )[0]
+ next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int)
+ generated_ids[:, [seq_length]] = next_token
+
+ with torch.no_grad():
+ # Compile the CUDA graph
+ decode_one_tokens = torch.compile(decode_one_tokens, mode="default", backend="inductor", fullgraph=True)
+
+ # Generate tokens one by one
+ cache_position = torch.tensor([seq_length + 1], device=torch_device)
+ for _ in range(1, self.max_new_tokens):
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
+ next_token = decode_one_tokens(
+ self.quantized_model, next_token.clone(), None, cache_position, past_key_values
+ )
+ generated_ids.index_copy_(1, cache_position, next_token)
+ cache_position += 1
+
+ # Check generated text
+ self.assertEqual(
+ self.tokenizer.decode(generated_ids[0], skip_special_tokens=True), self.EXPECTED_OUTPUT_COMPILE
+ )