-
Notifications
You must be signed in to change notification settings - Fork 31.8k
Efficient Inference Kernel for SpQR #34976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7f4aa05
ac4b142
8c3f5f1
ff61b8e
23c3a24
0cb5ba7
c980e66
163983b
f51d3d1
24ca92f
913fbcb
d433165
5582beb
3d64f88
81237de
1dacd50
c1a4304
53c53c0
c21c412
dc89200
64929f7
fada970
4694339
92ea493
1494453
9e8f470
525dcdf
1a54d86
68afc89
0eff944
82e7f4e
274d368
a630d6d
96b2613
17d1c72
55b50c7
c4273e2
5aac5e3
95d2e74
3fdc0c3
14f21c1
cdefeaf
8da4a66
afff70e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| <!--Copyright 2025 The HuggingFace Team. All rights reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
| the License. You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations under the License. | ||
|
|
||
| ⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
| rendered properly in your Markdown viewer. | ||
|
|
||
| --> | ||
|
|
||
| # SpQR | ||
|
|
||
| [SpQR](https://github.com/Vahe1994/SpQR) quantization algorithm involves a 16x16 tiled bi-level group 3-bit quantization structure, with sparse outliers as detailed in [SpQR: A Sparse-Quantized Representation for Near-Lossless LLM Weight Compression](https://arxiv.org/abs/2306.03078). | ||
|
|
||
| To SpQR-quantize a model, refer to the [Vahe1994/SpQR](https://github.com/Vahe1994/SpQR) repository. | ||
|
|
||
| Load a pre-SpQR-quantized model in [`~PreTrainedModel.from_pretrained`]. | ||
|
|
||
| ```python | ||
| from transformers import AutoTokenizer, AutoModelForCausalLM | ||
| import torch | ||
|
|
||
| quantized_model = AutoModelForCausalLM.from_pretrained( | ||
| "elvircrn/Llama-2-7b-SPQR-3Bit-16x16-red_pajama-hf", | ||
| torch_dtype=torch.half, | ||
| device_map="auto" | ||
| ) | ||
| tokenizer = AutoTokenizer.from_pretrained("elvircrn/Llama-2-7b-SPQR-3Bit-16x16-red_pajama-hf") | ||
| ``` | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| # Copyright 2024 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| "SpQR (Sparse-Quantized Representation) integration file" | ||
|
|
||
| from ..utils import is_accelerate_available, is_spqr_available, is_torch_available | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| import torch.nn as nn | ||
|
|
||
|
|
||
| def replace_with_spqr_linear( | ||
| model, | ||
| quantization_config=None, | ||
| modules_to_not_convert=None, | ||
| current_key_name=None, | ||
| has_been_replaced=False, | ||
| ): | ||
MekkCyber marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| Public method that recursively replaces the Linear layers of the given model with SpQR quantized layers. | ||
| `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the | ||
| conversion has been successful or not. | ||
|
|
||
| Args: | ||
| model (`torch.nn.Module`): | ||
| The model to convert, can be any `torch.nn.Module` instance. | ||
| quantization_config (`SpQRConfig`): | ||
| The quantization config object that contains the quantization parameters. | ||
| modules_to_not_convert (`list[str]`, *optional*): | ||
| A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be | ||
| converted. | ||
| current_key_name (`list`, *optional*): | ||
| A list that contains the current key name. This is used for recursion and should not be passed by the user. | ||
| has_been_replaced (`bool`, *optional*): | ||
| A boolean that indicates if the conversion has been successful or not. This is used for recursion and | ||
| should not be passed by the user. | ||
| """ | ||
| if modules_to_not_convert is None: | ||
| modules_to_not_convert = [] | ||
|
|
||
| if is_accelerate_available(): | ||
| from accelerate import init_empty_weights | ||
| if is_spqr_available(): | ||
| from spqr_quant import QuantizedLinear | ||
|
||
|
|
||
| for name, module in model.named_children(): | ||
| if current_key_name is None: | ||
| current_key_name = [] | ||
| current_key_name.append(name) | ||
|
|
||
| if isinstance(module, nn.Linear): | ||
| # Check if the current key is not in the `modules_to_not_convert` | ||
| if ".".join(current_key_name) + ".weight" not in modules_to_not_convert: | ||
| with init_empty_weights(): | ||
| tensor_name = ".".join(current_key_name) | ||
|
|
||
| shapes = quantization_config.shapes | ||
| shapes_keys = shapes.keys() | ||
|
|
||
| shapes_valid = ( | ||
| f"{tensor_name}.dense_weights.shape" in shapes_keys | ||
| and f"{tensor_name}.row_offsets.shape" in shapes_keys | ||
| and f"{tensor_name}.col_vals.shape" in shapes_keys | ||
| and f"{tensor_name}.in_perm.shape" in shapes_keys | ||
| ) | ||
|
|
||
| if not shapes_valid: | ||
| raise ValueError( | ||
| f"The SpQR quantization config does not contain the shape " | ||
| f"configuration for {tensor_name}. This indicates that the " | ||
| f"configuration is either invalid or corrupted." | ||
| ) | ||
|
|
||
| dense_weights_shape = shapes[f"{tensor_name}.dense_weights.shape"] | ||
| row_offsets_shape = shapes[f"{tensor_name}.row_offsets.shape"] | ||
| col_vals_shape = shapes[f"{tensor_name}.col_vals.shape"] | ||
| in_perm_shape = shapes[f"{tensor_name}.in_perm.shape"] | ||
|
|
||
| in_features = module.in_features | ||
| out_features = module.out_features | ||
|
|
||
| model._modules[name] = QuantizedLinear.create_placehodler( | ||
| rows=out_features, | ||
| cols=in_features, | ||
| bits=quantization_config.bits, | ||
| beta1=quantization_config.beta1, | ||
| beta2=quantization_config.beta2, | ||
| dense_weights_shape=dense_weights_shape, | ||
| row_offsets_shape=row_offsets_shape, | ||
| col_vals_shape=col_vals_shape, | ||
| in_perm_shape=in_perm_shape, | ||
| ) | ||
| has_been_replaced = True | ||
|
|
||
|
||
| # Store the module class in case we need to transpose the weight later | ||
| model._modules[name].source_cls = type(module) | ||
| # Force requires grad to False to avoid unexpected errors | ||
| model._modules[name].requires_grad_(False) | ||
| else: | ||
| pass | ||
| if len(list(module.children())) > 0: | ||
| _, has_been_replaced = replace_with_spqr_linear( | ||
| module, | ||
| quantization_config=quantization_config, | ||
| modules_to_not_convert=modules_to_not_convert, | ||
| current_key_name=current_key_name, | ||
| has_been_replaced=has_been_replaced, | ||
| ) | ||
| # Remove the last key for recursion | ||
| current_key_name.pop(-1) | ||
| return model, has_been_replaced | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| # Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/lic enses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import TYPE_CHECKING, Optional | ||
|
|
||
| from .base import HfQuantizer | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from ..modeling_utils import PreTrainedModel | ||
|
|
||
| from ..integrations import replace_with_spqr_linear | ||
| from ..utils import is_accelerate_available, is_spqr_available, is_torch_available, logging | ||
| from ..utils.quantization_config import QuantizationConfigMixin | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| import torch | ||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| class SpQRHfQuantizer(HfQuantizer): | ||
| """ | ||
| Quantizer of the SpQR method. Enables the loading of prequantized models. | ||
| """ | ||
|
|
||
| def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): | ||
| super().__init__(quantization_config, **kwargs) | ||
| self.quantization_config = quantization_config | ||
|
|
||
| def validate_environment(self, *args, **kwargs): | ||
| if not torch.cuda.is_available(): | ||
| raise RuntimeError("GPU is required to run SpQR quantized model.") | ||
|
|
||
| if not is_accelerate_available(): | ||
| raise ImportError("Using `spqr` quantization requires Accelerate: `pip install accelerate`") | ||
|
|
||
| if not is_spqr_available(): | ||
| raise ImportError("Using `spqr` quantization requires SpQR: `pip install spqr_quant[gpu]`") | ||
|
|
||
| def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": | ||
| if torch_dtype is None: | ||
| torch_dtype = torch.float16 | ||
| logger.info("Assuming SpQR inference on GPU and loading the model in `torch.float16`.") | ||
| elif torch_dtype != torch.float16: | ||
| raise ValueError( | ||
| "You cannot use any type other than torch.float16 for SpQR. Please either leave it None or set it to" | ||
| "torch.float16 explicitly." | ||
| ) | ||
| return torch_dtype | ||
|
|
||
| def _process_model_before_weight_loading( | ||
| self, | ||
| model: "PreTrainedModel", | ||
| **kwargs, | ||
| ): | ||
| replace_with_spqr_linear( | ||
| model, | ||
| quantization_config=self.quantization_config, | ||
| modules_to_not_convert=self.quantization_config.modules_to_not_convert, | ||
| ) | ||
| model.config.quantization_config = self.quantization_config | ||
|
|
||
| def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): | ||
| return model | ||
|
|
||
| @property | ||
| def is_trainable(self, model: Optional["PreTrainedModel"] = None): | ||
| return False | ||
|
|
||
| def is_serializable(self, safe_serialization=None): | ||
| return True |

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can't quantize a model on the fly with
quantization_config=...usingspqras the quantize type?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't, this is not supported as of yet. This PR only adds Inference support.