Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ff40f1a
HQQ model serialization attempt
mobicham Jul 18, 2024
fa8a9f5
Merge branch 'huggingface:main' into main
mobicham Jul 18, 2024
75dfe0a
fix hqq dispatch and unexpected keys
SunMarc Aug 1, 2024
f2ea032
Merge remote-tracking branch 'upstream/main' into hqq_serialization
SunMarc Aug 1, 2024
bc9cb55
Merge remote-tracking branch 'upstream/main' into hqq_serialization
SunMarc Aug 1, 2024
a8704d2
style
SunMarc Aug 1, 2024
5cb7d81
remove check_old_param
mobicham Aug 27, 2024
7f1b85d
revert to check HQQLinear in quantizer_hqq.py
mobicham Aug 28, 2024
71cccd4
revert to check HQQLinear in quantizer_hqq.py
mobicham Aug 28, 2024
ff982b3
update HqqConfig default params
mobicham Aug 28, 2024
d35ea7c
make ci happy
mobicham Aug 28, 2024
cbe219f
make ci happy
mobicham Aug 28, 2024
2bb974c
revert to HQQLinear check in quantizer_hqq.py
mobicham Aug 28, 2024
9f7c235
check hqq_min version 0.2.0
mobicham Aug 28, 2024
7f15b49
Merge branch 'main' into hqq_serialization
mobicham Aug 28, 2024
383e028
set axis=1 as default in quantization_config.py
mobicham Aug 28, 2024
cf5a05c
Merge branch 'hqq_serialization' of https://github.com/mobiusml/trans…
mobicham Aug 28, 2024
4682a72
validate_env with hqq>=0.2.0 version message
mobicham Aug 28, 2024
813ed62
deprecated hqq kwargs message
mobicham Aug 28, 2024
d0c594c
make ci happy
mobicham Aug 28, 2024
7e019b3
remove run_expected_keys_check hack + bump to 0.2.1 min hqq version
mobicham Aug 29, 2024
0dd1152
fix unexpected_keys hqq update
mobicham Aug 29, 2024
9053ad5
add pre_quantized check
mobicham Aug 29, 2024
2b6e7df
add update_expected_keys to base quantizerr
mobicham Aug 29, 2024
e68110a
ci base.py fix?
mobicham Aug 29, 2024
433c3a0
ci base.py fix?
mobicham Aug 29, 2024
4db1991
fix "quantization typo" src/transformers/utils/quantization_config.py
mobicham Sep 30, 2024
3b56533
Merge branch 'main' into hqq_serialization
mobicham Sep 30, 2024
a8843cf
fix post merge
mobicham Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/en/quantization/hqq.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ To quantize a model, you need to create an [`HqqConfig`]. There are two ways of
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig

# Method 1: all linear layers will use the same quantization config
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default
quant_config = HqqConfig(nbits=8, group_size=64)
```

``` Python
# Method 2: each linear layer with the same tag will use a dedicated quantization config
q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False}
q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False}
q4_config = {'nbits':4, 'group_size':64}
q3_config = {'nbits':3, 'group_size':32}
quant_config = HqqConfig(dynamic_config={
'self_attn.q_proj':q4_config,
'self_attn.k_proj':q4_config,
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/integrations/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_

has_been_replaced = True

# Add these fake parameters to avoid loading fail
for att in ["W_q", "meta"]:
setattr(module, att, None)

if len(list(module.children())) > 0:
_, has_been_replaced = _prepare_for_hqq_linear(
module,
Expand Down Expand Up @@ -97,7 +101,7 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve

# Convert quantization_config to layer-wise config
skip_modules = quantization_config.skip_modules
quant_config = quantization_config.to_dict()
quant_config = quantization_config.quant_config
linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))

if any(key in linear_tags for key in quant_config.keys()):
Expand All @@ -113,7 +117,11 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
)

# We store quantization config as linear_tag -> hqq quant config
model.config.quantization_config = patch_params
model.config.quantization_config = {
"quant_config": quant_config,
"quant_method": quantization_config.quant_method,
"skip_modules": skip_modules,
}

if not has_been_replaced:
logger.warning("No linear modules were found in your model for quantization.")
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,12 +934,17 @@ def _load_state_dict_into_meta_model(
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29

old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
# Not all the attributes of a module are Parameters/Tensor
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for potentially ints ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, some parameters are strings (packing format, etc.), booleans or integers. They are necessary meta-data to dequantize

if old_param is None:
break

if old_param is not None:
if dtype is None:
param = param.to(old_param.dtype)
Expand Down Expand Up @@ -3819,6 +3824,7 @@ def from_pretrained(
from_pt = not (from_tf | from_flax)

# load pt weights early so that we know which dtype to init the model under

if from_pt:
if not is_sharded and state_dict is None:
# Time to load the checkpoint
Expand Down Expand Up @@ -4176,6 +4182,9 @@ def _load_pretrained_model(
expected_keys = list(model_state_dict.keys())
prefix = model.base_model_prefix

if hf_quantizer is not None:
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)

def _fix_key(key):
if "beta" in key:
return key.replace("beta", "bias")
Expand Down Expand Up @@ -4290,7 +4299,7 @@ def _fix_key(key):
value = torch.empty(*param.size(), dtype=target_dtype)
if (
not is_quantized
or getattr(hf_quantizer, "requires_parameters_quantization", False)
or (getattr(hf_quantizer, "requires_parameters_quantization", False))
or not hf_quantizer.check_quantized_param(
model, param_value=value, param_name=key, state_dict={}
)
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/quantizers/base.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,18 @@ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> Li
"""
return missing_keys

def update_expected_keys(self, model, expected_keys: List[str], loaded_keys: List[str]) -> List[str]:
"""
Override this method if you want to adjust the `update_expected_keys`.

Args:
expected_keys (`List[str]`, *optional*):
The list of the expected keys in the initialized model.
loaded_keys (`List[str]`, *optional*):
The list of the loaded keys in the checkpoint.
"""
return expected_keys

def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
"""
returns dtypes for modules that are not quantized - used for the computation of the device_map in case
Expand Down
113 changes: 105 additions & 8 deletions src/transformers/quantizers/quantizer_hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, quantization_config, **kwargs):
def validate_environment(self, *args, **kwargs):
if not (is_hqq_available()):
raise ImportError(
"HQQ is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`"
"A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`."
)

if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
Expand Down Expand Up @@ -91,6 +91,65 @@ def validate_environment(self, *args, **kwargs):
else:
self.using_multi_gpu = len(set(device_map.values())) > 1

def update_missing_keys(
self, model: "PreTrainedModel", missing_keys: List[str], prefix: str, **kwargs
) -> List[str]:
if self.pre_quantized:
return [key for key in missing_keys if ("weight" not in key)]
else:
return missing_keys

# Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear
def update_expected_keys(
Comment on lines +102 to +103
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand exactly why we need to do this post loading vs while loading + is gonna be quite annoying to maintain as it feels like there's a lot of hacks but fine for me!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, this was the only way to make this work. The issue is that, when the model is created, it creates nn.Linear layers that have a "weight" parameter, but there's no "weight" parameter in HQQLinear, so when it tries to load the HQQLinear state-dict, it breaks. I agree it's a bit complicated, actually the whole HQQLinear support has been complicated.

I think one better way of doing this in the future is to have a native QuantLinear layer officially supported by HF transformers, which could potentially integrate any quant method. I pitched this idea to Marc as well, I think it's worth thinking about. Happy to assist with that as well.

self, model: "PreTrainedModel", expected_keys: List[str], loaded_keys: List[str]
) -> List[str]:
if not self.pre_quantized:
return expected_keys

# Collects all quantizable (linear) layers
def _find_hqq_quantizable_layers(model, layers):
for name, module in model.named_children():
if isinstance(module, (torch.nn.Linear)):
layers.add(module.name)
_find_hqq_quantizable_layers(module, layers)

new_keys = set(expected_keys)
if is_hqq_available():
from hqq.core.quantize import HQQLinear

# Name modules
for name, module in model.named_modules():
module.name = name

# valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
_valid_modules = set()
_find_hqq_quantizable_layers(model, _valid_modules)
_valid_modules -= set(model.config.quantization_config["skip_modules"])

# Append new expected layers based on _ref_keys
_ref_keys = HQQLinear(
linear_layer=None, quant_config=None, compute_dtype=torch.float16, device="cpu"
).state_dict_keys() - {"bias"}

# Clean-up
_rm_keys = set()
for key in new_keys:
if any(_module in key for _module in _valid_modules):
_rm_keys.add(key)
new_keys -= _rm_keys
# At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear

# Re-populate Linear/HQQLinear
for _module in _valid_modules:
if _module + ".weight" in loaded_keys:
new_keys.add(_module + ".weight")
else:
new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys})
if _module + ".bias" in loaded_keys:
new_keys.add(_module + ".bias")

return list(new_keys)

def check_quantized_param(
self,
model: "PreTrainedModel",
Expand All @@ -99,9 +158,18 @@ def check_quantized_param(
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
if is_hqq_available():
from hqq.core.quantize import HQQLinear
module, tensor_name = get_module_from_name(model, param_name)

return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
if self.pre_quantized:
return (
(isinstance(module, torch.nn.Linear) or isinstance(module, HQQLinear))
and tensor_name != "weight"
and tensor_name != "bias"
)
else:
return isinstance(module, torch.nn.Linear) and tensor_name == "weight"

def create_quantized_param(
self,
Expand All @@ -122,21 +190,50 @@ def create_quantized_param(
from hqq.core.quantize import HQQLinear

module, tensor_name = get_module_from_name(model, param_name)

layer_name = param_name.replace(".weight", "").replace(".bias", "")
layer_name = ".".join(param_name.split(".")[:-1])
parent_module = find_parent(model, layer_name)
node = layer_name.split(".")[-1]

# Step 0: set module state_dict
module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key}
# set module state_dict
module_state_dict = {}
for k, v in state_dict.items():
if layer_name + "." in k:
module_state_dict[k.split(".")[-1]] = v
if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k)

if self.pre_quantized:
if isinstance(module, HQQLinear):
return
else:
hqq_layer = HQQLinear(
linear_layer=None,
quant_config=None,
compute_dtype=self.torch_dtype,
device=target_device,
)

hqq_layer.load_state_dict(module_state_dict)

if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)

if self.using_multi_gpu:
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)

setattr(parent_module, node, hqq_layer)

# cleanup
del module.__dict__, module
torch.cuda.empty_cache()
return

# Step 1: populate module with weight/bias from module state dict
for key in module_state_dict:
setattr(module, key, torch.nn.Parameter(module_state_dict[key]))

# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
# directly doesn't work.

if hasattr(module, "quant_config"):
hqq_layer = HQQLinear(
module,
Expand Down Expand Up @@ -192,7 +289,7 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs
return model

def is_serializable(self, safe_serialization=None):
return False
return True

@property
def is_trainable(self) -> bool:
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
FSDP_MIN_VERSION = "1.12.0"
GGUF_MIN_VERSION = "0.10.0"
XLA_FSDPV2_MIN_VERSION = "2.2.0"
HQQ_MIN_VERSION = "0.2.1"


_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
Expand Down Expand Up @@ -181,7 +182,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_torchdistx_available = _is_package_available("torchdistx")
_torchvision_available = _is_package_available("torchvision")
_mlx_available = _is_package_available("mlx")
_hqq_available = _is_package_available("hqq")
_hqq_available, _hqq_version = _is_package_available("hqq", return_version=True)
_tiktoken_available = _is_package_available("tiktoken")
_blobfile_available = _is_package_available("blobfile")
_liger_kernel_available = _is_package_available("liger_kernel")
Expand Down Expand Up @@ -323,8 +324,8 @@ def is_torch_deterministic():
return True


def is_hqq_available():
return _hqq_available
def is_hqq_available(min_version: str = HQQ_MIN_VERSION):
return _hqq_available and version.parse(_hqq_version) >= version.parse(min_version)


def is_pygments_available():
Expand Down
42 changes: 27 additions & 15 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,9 @@ class HqqConfig(QuantizationConfigMixin):
Number of bits. Supported values are (8, 4, 3, 2, 1).
group_size (`int`, *optional*, defaults to 64):
Group-size value. Supported values are any value that is divisble by weight.shape[axis]).
quant_zero (`bool`, *optional*, defaults to `True`):
Quantize the zero-point if set to `True`.
quant_scale (`bool`, *optional*, defaults to `False`):
Quantize the scaling if set to `True`.
offload_meta (`bool`, *optional*, defaults to `False`):
Offload the meta-data to the CPU if set to `True`.
view_as_float (`bool`, *optional*, defaults to `False`):
View the quantized weight as float (used in distributed training) if set to `True`.
axis (`int`, *optional*, defaults to 0):
axis (`Optional[int]`, *optional*):
Axis along which grouping is performed. Supported values are 0 or 1.
dynamic_config (dict, *optional*):
Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config.
Expand All @@ -216,18 +210,25 @@ def __init__(
self,
nbits: int = 4,
group_size: int = 64,
quant_zero: bool = True,
quant_scale: bool = False,
offload_meta: bool = False,
view_as_float: bool = False,
axis: int = 0,
axis: Optional[int] = None,
dynamic_config: Optional[dict] = None,
skip_modules: List[str] = ["lm_head"],
**kwargs,
):
if is_hqq_available():
from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig

for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]:
if deprecated_key in kwargs:
logger.info(
deprecated_key + " is deprecated. This parameter will be ignored in quantization settings."
)

if axis is None:
axis = 1
logger.info("Setting axis=1 as faster backends such as TorchAO or BitBlas are only compatible with it.")

if axis not in [0, 1]:
raise ValueError("Invalid axis value. Only 0 and 1 are allowed.")

Expand All @@ -240,9 +241,6 @@ def __init__(
**{
"nbits": nbits,
"group_size": group_size,
"quant_zero": quant_zero,
"quant_scale": quant_scale,
"offload_meta": offload_meta,
"view_as_float": view_as_float,
"axis": axis,
}
Expand All @@ -259,12 +257,26 @@ def post_init(self):
"""
pass

@classmethod
def from_dict(cls, config: Dict[str, Any]):
"""
Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py
"""
instance = cls()
instance.quant_config = config["quant_config"]
instance.skip_modules = config["skip_modules"]
return instance

def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
return self.quant_config
return {
"quant_config": self.quant_config,
"quant_method": self.quant_method,
"skip_modules": self.skip_modules,
}

def __repr__(self):
config_dict = self.to_dict()
Expand Down
Loading