diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile index 44d1ceb2bfdd..700df877d10f 100755 --- a/docker/transformers-quantization-latest-gpu/Dockerfile +++ b/docker/transformers-quantization-latest-gpu/Dockerfile @@ -53,6 +53,9 @@ RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2 # Add vptq for quantization testing RUN python3 -m pip install --no-cache-dir vptq +# Add spqr for quantization testing +RUN python3 -m pip install --no-cache-dir spqr_quant[gpu] + # Add hqq for quantization testing RUN python3 -m pip install --no-cache-dir hqq diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 5540bd1826a9..17a1bb1b3b33 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -166,6 +166,8 @@ - local: quantization/aqlm title: AQLM - local: quantization/vptq + title: SpQR + - local: quantization/spqr title: VPTQ - local: quantization/quanto title: Quanto diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md index cd3e2705ab34..6da5b8ce69b5 100755 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -81,6 +81,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide. [[autodoc]] BitNetConfig +## SpQRConfig + +[[autodoc]] SpQRConfig + ## FineGrainedFP8Config [[autodoc]] FineGrainedFP8Config diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index caebebe81547..94696e300a57 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -61,6 +61,7 @@ Use the table below to help you decide which quantization method to use. | [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM | | [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | 🟡 5 | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao | | [VPTQ](./vptq.md) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ | +| [SpQR](./spqr.md) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ | | [FINEGRAINED_FP8](./finegrained_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | | diff --git a/docs/source/en/quantization/spqr.md b/docs/source/en/quantization/spqr.md new file mode 100644 index 000000000000..b9ebb99b69cb --- /dev/null +++ b/docs/source/en/quantization/spqr.md @@ -0,0 +1,35 @@ + + +# SpQR + +[SpQR](https://github.com/Vahe1994/SpQR) quantization algorithm involves a 16x16 tiled bi-level group 3-bit quantization structure, with sparse outliers as detailed in [SpQR: A Sparse-Quantized Representation for Near-Lossless LLM Weight Compression](https://arxiv.org/abs/2306.03078). + +To SpQR-quantize a model, refer to the [Vahe1994/SpQR](https://github.com/Vahe1994/SpQR) repository. + +Load a pre-SpQR-quantized model in [`~PreTrainedModel.from_pretrained`]. + +```python +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + +quantized_model = AutoModelForCausalLM.from_pretrained( + "elvircrn/Llama-2-7b-SPQR-3Bit-16x16-red_pajama-hf", + torch_dtype=torch.half, + device_map="auto" +) +tokenizer = AutoTokenizer.from_pretrained("elvircrn/Llama-2-7b-SPQR-3Bit-16x16-red_pajama-hf") +``` diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b6412cf59360..8b97168ecf5f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1029,6 +1029,7 @@ "HiggsConfig", "HqqConfig", "QuantoConfig", + "SpQRConfig", "TorchAoConfig", "VptqConfig", ], @@ -6202,6 +6203,7 @@ HiggsConfig, HqqConfig, QuantoConfig, + SpQRConfig, TorchAoConfig, VptqConfig, ) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index c78564dcba8e..b545c5da50a5 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -106,6 +106,7 @@ ], "peft": ["PeftAdapterMixin"], "quanto": ["replace_with_quanto_layers"], + "spqr": ["replace_with_spqr_linear"], "vptq": ["replace_with_vptq_linear"], } @@ -210,6 +211,7 @@ ) from .peft import PeftAdapterMixin from .quanto import replace_with_quanto_layers + from .spqr import replace_with_spqr_linear from .vptq import replace_with_vptq_linear try: diff --git a/src/transformers/integrations/spqr.py b/src/transformers/integrations/spqr.py new file mode 100644 index 000000000000..58b71740d37c --- /dev/null +++ b/src/transformers/integrations/spqr.py @@ -0,0 +1,122 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"SpQR (Sparse-Quantized Representation) integration file" + +from ..utils import is_accelerate_available, is_spqr_available, is_torch_available + + +if is_torch_available(): + import torch.nn as nn + + +def replace_with_spqr_linear( + model, + quantization_config=None, + modules_to_not_convert=None, + current_key_name=None, + has_been_replaced=False, +): + """ + Public method that recursively replaces the Linear layers of the given model with SpQR quantized layers. + `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the + conversion has been successful or not. + + Args: + model (`torch.nn.Module`): + The model to convert, can be any `torch.nn.Module` instance. + quantization_config (`SpQRConfig`): + The quantization config object that contains the quantization parameters. + modules_to_not_convert (`list[str]`, *optional*): + A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be + converted. + current_key_name (`list`, *optional*): + A list that contains the current key name. This is used for recursion and should not be passed by the user. + has_been_replaced (`bool`, *optional*): + A boolean that indicates if the conversion has been successful or not. This is used for recursion and + should not be passed by the user. + """ + if modules_to_not_convert is None: + modules_to_not_convert = [] + + if is_accelerate_available(): + from accelerate import init_empty_weights + if is_spqr_available(): + from spqr_quant import QuantizedLinear + + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, nn.Linear): + # Check if the current key is not in the `modules_to_not_convert` + if ".".join(current_key_name) + ".weight" not in modules_to_not_convert: + with init_empty_weights(): + tensor_name = ".".join(current_key_name) + + shapes = quantization_config.shapes + shapes_keys = shapes.keys() + + shapes_valid = ( + f"{tensor_name}.dense_weights.shape" in shapes_keys + and f"{tensor_name}.row_offsets.shape" in shapes_keys + and f"{tensor_name}.col_vals.shape" in shapes_keys + and f"{tensor_name}.in_perm.shape" in shapes_keys + ) + + if not shapes_valid: + raise ValueError( + f"The SpQR quantization config does not contain the shape " + f"configuration for {tensor_name}. This indicates that the " + f"configuration is either invalid or corrupted." + ) + + dense_weights_shape = shapes[f"{tensor_name}.dense_weights.shape"] + row_offsets_shape = shapes[f"{tensor_name}.row_offsets.shape"] + col_vals_shape = shapes[f"{tensor_name}.col_vals.shape"] + in_perm_shape = shapes[f"{tensor_name}.in_perm.shape"] + + in_features = module.in_features + out_features = module.out_features + + model._modules[name] = QuantizedLinear.create_placehodler( + rows=out_features, + cols=in_features, + bits=quantization_config.bits, + beta1=quantization_config.beta1, + beta2=quantization_config.beta2, + dense_weights_shape=dense_weights_shape, + row_offsets_shape=row_offsets_shape, + col_vals_shape=col_vals_shape, + in_perm_shape=in_perm_shape, + ) + has_been_replaced = True + + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + else: + pass + if len(list(module.children())) > 0: + _, has_been_replaced = replace_with_spqr_linear( + module, + quantization_config=quantization_config, + modules_to_not_convert=modules_to_not_convert, + current_key_name=current_key_name, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 0f279498f7ec..ee7c832b1de1 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -31,6 +31,7 @@ QuantizationConfigMixin, QuantizationMethod, QuantoConfig, + SpQRConfig, TorchAoConfig, VptqConfig, ) @@ -47,6 +48,7 @@ from .quantizer_higgs import HiggsHfQuantizer from .quantizer_hqq import HqqHfQuantizer from .quantizer_quanto import QuantoHfQuantizer +from .quantizer_spqr import SpQRHfQuantizer from .quantizer_torchao import TorchAoHfQuantizer from .quantizer_vptq import VptqHfQuantizer @@ -66,6 +68,7 @@ "torchao": TorchAoHfQuantizer, "bitnet": BitNetHfQuantizer, "vptq": VptqHfQuantizer, + "spqr": SpQRHfQuantizer, "fp8": FineGrainedFP8HfQuantizer, } @@ -84,6 +87,7 @@ "torchao": TorchAoConfig, "bitnet": BitNetConfig, "vptq": VptqConfig, + "spqr": SpQRConfig, "fp8": FineGrainedFP8Config, } diff --git a/src/transformers/quantizers/quantizer_spqr.py b/src/transformers/quantizers/quantizer_spqr.py new file mode 100644 index 000000000000..60cc1bca9b27 --- /dev/null +++ b/src/transformers/quantizers/quantizer_spqr.py @@ -0,0 +1,83 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/lic enses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, Optional + +from .base import HfQuantizer + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + +from ..integrations import replace_with_spqr_linear +from ..utils import is_accelerate_available, is_spqr_available, is_torch_available, logging +from ..utils.quantization_config import QuantizationConfigMixin + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class SpQRHfQuantizer(HfQuantizer): + """ + Quantizer of the SpQR method. Enables the loading of prequantized models. + """ + + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): + super().__init__(quantization_config, **kwargs) + self.quantization_config = quantization_config + + def validate_environment(self, *args, **kwargs): + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run SpQR quantized model.") + + if not is_accelerate_available(): + raise ImportError("Using `spqr` quantization requires Accelerate: `pip install accelerate`") + + if not is_spqr_available(): + raise ImportError("Using `spqr` quantization requires SpQR: `pip install spqr_quant[gpu]`") + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + torch_dtype = torch.float16 + logger.info("Assuming SpQR inference on GPU and loading the model in `torch.float16`.") + elif torch_dtype != torch.float16: + raise ValueError( + "You cannot use any type other than torch.float16 for SpQR. Please either leave it None or set it to" + "torch.float16 explicitly." + ) + return torch_dtype + + def _process_model_before_weight_loading( + self, + model: "PreTrainedModel", + **kwargs, + ): + replace_with_spqr_linear( + model, + quantization_config=self.quantization_config, + modules_to_not_convert=self.quantization_config.modules_to_not_convert, + ) + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + return model + + @property + def is_trainable(self, model: Optional["PreTrainedModel"] = None): + return False + + def is_serializable(self, safe_serialization=None): + return True diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index c71619e4e8f9..14fef2988488 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -121,6 +121,7 @@ is_seqio_available, is_soundfile_available, is_spacy_available, + is_spqr_available, is_sudachi_available, is_sudachi_projection_available, is_tensorflow_probability_available, @@ -1191,6 +1192,13 @@ def require_vptq(test_case): return unittest.skipUnless(is_vptq_available(), "test requires vptq")(test_case) +def require_spqr(test_case): + """ + Decorator marking a test that requires spqr + """ + return unittest.skipUnless(is_spqr_available(), "test requires spqr")(test_case) + + def require_eetq(test_case): """ Decorator marking a test that requires eetq diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index bf56a584469f..cf13060ee307 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -193,6 +193,7 @@ is_soundfile_available, is_spacy_available, is_speech_available, + is_spqr_available, is_sudachi_available, is_sudachi_projection_available, is_tensorflow_probability_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index c14ff2124aa0..bd95b6f282c0 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -201,6 +201,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _blobfile_available = _is_package_available("blobfile") _liger_kernel_available = _is_package_available("liger_kernel") _triton_available = _is_package_available("triton") +_spqr_available = _is_package_available("spqr_quant") _torch_version = "N/A" _torch_available = False @@ -1213,6 +1214,10 @@ def is_speech_available(): return _torchaudio_available +def is_spqr_available(): + return _spqr_available + + def is_phonemizer_available(): return _phonemizer_available diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 329123e0f4a4..11415e895d91 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -56,6 +56,7 @@ class QuantizationMethod(str, Enum): FBGEMM_FP8 = "fbgemm_fp8" TORCHAO = "torchao" BITNET = "bitnet" + SPQR = "spqr" FP8 = "fp8" @@ -1551,6 +1552,75 @@ def post_init(self): pass +@dataclass +class SpQRConfig(QuantizationConfigMixin): + """ + This is a wrapper class about `spqr` parameters. Refer to the original publication for more details. + + Args: + bits (`int`, *optional*, defaults to 3): + Specifies the bit count for the weights and first order zero-points and scales. + Currently only bits = 3 is supported. + beta1 (`int`, *optional*, defaults to 16): + SpQR tile width. Currently only beta1 = 16 is supported. + beta2 (`int`, *optional*, defaults to 16): + SpQR tile height. Currently only beta2 = 16 is supported. + shapes (`Optional`, *optional*): + A dictionary holding the shape of each object. We need this because it's impossible + to deduce the exact size of the parameters just from bits, beta1, beta2. + modules_to_not_convert (`Optional[List[str]]`, *optional*): + Optionally, provides a list of full paths of `nn.Linear` weight parameters that shall not be quantized. + Defaults to None. + kwargs (`Dict[str, Any]`, *optional*): + Additional parameters from which to initialize the configuration object. + """ + + def __init__( + self, + bits: int = 3, + beta1: int = 16, + beta2: int = 16, + shapes: Optional[Dict[str, int]] = None, + modules_to_not_convert: Optional[List[str]] = None, + **kwargs, + ): + if shapes is None: + shapes = {} + self.shapes = shapes + self.quant_method = QuantizationMethod.SPQR + self.bits = bits + self.beta1 = beta1 + self.beta2 = beta2 + if modules_to_not_convert is None: + modules_to_not_convert = [] + self.modules_to_not_convert = modules_to_not_convert + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + if not isinstance(self.bits, int): + raise TypeError("bits must be an int") + if not isinstance(self.beta1, int): + raise TypeError("beta1 must be an int") + if not isinstance(self.beta2, int): + raise TypeError("beta2 must be an int") + + if self.bits != 3: + raise ValueError("SpQR currently only supports bits = 3") + if self.beta1 != 16: + raise ValueError("SpQR currently only supports beta1 = 16") + if self.beta2 != 16: + raise ValueError("SpQR currently only supports beta2 = 16") + + if self.modules_to_not_convert is not None and not isinstance(self.modules_to_not_convert, list): + raise ValueError("modules_to_not_convert must be a list of strings") + + if not isinstance(self.shapes, dict): + raise TypeError("shapes must be a dict") + + @dataclass class FineGrainedFP8Config(QuantizationConfigMixin): """ diff --git a/tests/quantization/spqr_integration/__init__.py b/tests/quantization/spqr_integration/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/quantization/spqr_integration/test_spqr.py b/tests/quantization/spqr_integration/test_spqr.py new file mode 100644 index 000000000000..134e57af5de1 --- /dev/null +++ b/tests/quantization/spqr_integration/test_spqr.py @@ -0,0 +1,249 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, SpQRConfig, StaticCache +from transformers.testing_utils import ( + require_accelerate, + require_spqr, + require_torch_gpu, + require_torch_multi_gpu, + slow, + torch_device, +) +from transformers.utils import is_accelerate_available, is_torch_available + + +if is_torch_available(): + import torch + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +@require_torch_gpu +class SpQRConfigTest(unittest.TestCase): + def test_to_dict(self): + """ + Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object + """ + quantization_config = SpQRConfig() + config_to_dict = quantization_config.to_dict() + + for key in config_to_dict: + self.assertEqual(getattr(quantization_config, key), config_to_dict[key]) + + def test_from_dict(self): + """ + Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict + """ + dict = { + "beta1": 16, + "beta2": 16, + "bits": 3, + "modules_to_not_convert": ["lm_head.weight"], + "shapes": {"model.layers.0.self_attn.q_proj.dense_weights.shape": 16}, + } + quantization_config = SpQRConfig.from_dict(dict) + + self.assertEqual(dict["beta1"], quantization_config.beta1) + self.assertEqual(dict["beta2"], quantization_config.beta2) + self.assertEqual(dict["bits"], quantization_config.bits) + self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert) + self.assertEqual(dict["shapes"], quantization_config.shapes) + + +@slow +@require_torch_gpu +@require_spqr +@require_accelerate +class SpQRTest(unittest.TestCase): + model_name = "elvircrn/Llama-2-7b-SPQR-3Bit-16x16-red_pajama-hf" + + input_text = "Hello my name is" + max_new_tokens = 32 + + EXPECTED_OUTPUT = ( + "Hello my name is Jesse. (I'm also known as Jesse) I'm a 25 year old male from United States. I'm looking for" + ) + EXPECTED_OUTPUT_COMPILE = "Hello my name is Jake and I am a 20 year old student at the University of North Texas. (Go Mean Green!) I am a huge fan of the Dallas" + + device_map = "cuda" + + # called only once for all test in this class + @classmethod + def setUpClass(cls): + """ + Setup quantized model + """ + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.quantized_model = AutoModelForCausalLM.from_pretrained( + cls.model_name, + device_map=cls.device_map, + ) + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def test_quantized_model_conversion(self): + """ + Simple test that checks if the quantized model has been converted properly + """ + from spqr_quant import QuantizedLinear + + from transformers.integrations import replace_with_spqr_linear + + model_id = "meta-llama/Llama-2-7b-hf" + config = AutoConfig.from_pretrained(model_id) + quantization_config = AutoConfig.from_pretrained(self.model_name, return_dict=False).quantization_config + quantization_config = SpQRConfig.from_dict(quantization_config) + + with init_empty_weights(): + model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id, config=config) + + nb_linears = 0 + for module in model.modules(): + if isinstance(module, torch.nn.Linear): + nb_linears += 1 + + model, _ = replace_with_spqr_linear( + model, + quantization_config=quantization_config, + modules_to_not_convert=quantization_config.modules_to_not_convert, + ) + + nb_spqr_linear = 0 + for module in model.modules(): + if isinstance(module, QuantizedLinear): + nb_spqr_linear += 1 + + self.assertEqual(nb_linears - 1, nb_spqr_linear) + + def test_quantized_model(self): + """ + Simple test that checks if the quantized model is working properly + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_raise_if_non_quantized(self): + model_id = "meta-llama/Llama-2-7b-hf" + quantization_config = SpQRConfig() + + with self.assertRaises(ValueError): + _ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config) + + @unittest.skip + def test_save_pretrained(self): + """ + Simple test that checks if the quantized model is working properly after being saved and loaded + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @require_torch_multi_gpu + def test_quantized_model_multi_gpu(self): + """ + Simple test that checks if the quantized model is working properly with multiple GPUs + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto") + + self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_quantized_model_compile(self): + """ + Simple test that checks if the quantized model is working properly + """ + + # Sample tokens greedily + def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values): + logits = model( + cur_token, + position_ids=input_pos, + cache_position=cache_position, + past_key_values=past_key_values, + return_dict=False, + use_cache=True, + )[0] + new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) + + return new_token + + # Tokenize the test input + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)["input_ids"] + seq_length = input_ids.shape[1] + + # Setup static KV cache for generation + past_key_values = StaticCache( + config=self.quantized_model.config, + batch_size=1, + max_cache_len=seq_length + self.max_new_tokens + 1, + device=torch_device, + dtype=self.quantized_model.config._pre_quantization_dtype, + ) + + # Allocate token ids to be generated and copy prefix ids + cache_position = torch.arange(seq_length, device=torch_device) + generated_ids = torch.zeros(1, seq_length + self.max_new_tokens, dtype=torch.int, device=torch_device) + generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int) + + # Do a forward pass to fill the prefix cache and compile the kernels if necessary + logits = self.quantized_model( + input_ids, + cache_position=cache_position, + past_key_values=past_key_values, + return_dict=False, + use_cache=True, + )[0] + next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) + generated_ids[:, [seq_length]] = next_token + + with torch.no_grad(): + # Compile the CUDA graph + decode_one_tokens = torch.compile(decode_one_tokens, mode="default", backend="inductor", fullgraph=True) + + # Generate tokens one by one + cache_position = torch.tensor([seq_length + 1], device=torch_device) + for _ in range(1, self.max_new_tokens): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): + next_token = decode_one_tokens( + self.quantized_model, next_token.clone(), None, cache_position, past_key_values + ) + generated_ids.index_copy_(1, cache_position, next_token) + cache_position += 1 + + # Check generated text + self.assertEqual( + self.tokenizer.decode(generated_ids[0], skip_special_tokens=True), self.EXPECTED_OUTPUT_COMPILE + )