-
Notifications
You must be signed in to change notification settings - Fork 32.6k
Hqq serialization #33141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hqq serialization #33141
Changes from all commits
ff40f1a
fa8a9f5
75dfe0a
f2ea032
bc9cb55
a8704d2
5cb7d81
7f1b85d
71cccd4
ff982b3
d35ea7c
cbe219f
2bb974c
9f7c235
7f15b49
383e028
cf5a05c
4682a72
813ed62
d0c594c
7e019b3
0dd1152
9053ad5
2b6e7df
e68110a
433c3a0
4db1991
3b56533
a8843cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,7 +62,7 @@ def __init__(self, quantization_config, **kwargs): | |
| def validate_environment(self, *args, **kwargs): | ||
| if not (is_hqq_available()): | ||
| raise ImportError( | ||
| "HQQ is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`" | ||
| "A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`." | ||
| ) | ||
|
|
||
| if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): | ||
|
|
@@ -91,6 +91,65 @@ def validate_environment(self, *args, **kwargs): | |
| else: | ||
| self.using_multi_gpu = len(set(device_map.values())) > 1 | ||
|
|
||
| def update_missing_keys( | ||
| self, model: "PreTrainedModel", missing_keys: List[str], prefix: str, **kwargs | ||
| ) -> List[str]: | ||
| if self.pre_quantized: | ||
| return [key for key in missing_keys if ("weight" not in key)] | ||
| else: | ||
| return missing_keys | ||
|
|
||
| # Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear | ||
| def update_expected_keys( | ||
|
Comment on lines
+102
to
+103
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure I understand exactly why we need to do this post loading vs while loading + is gonna be quite annoying to maintain as it feels like there's a lot of hacks but fine for me!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately, this was the only way to make this work. The issue is that, when the model is created, it creates I think one better way of doing this in the future is to have a native |
||
| self, model: "PreTrainedModel", expected_keys: List[str], loaded_keys: List[str] | ||
| ) -> List[str]: | ||
| if not self.pre_quantized: | ||
| return expected_keys | ||
|
|
||
| # Collects all quantizable (linear) layers | ||
| def _find_hqq_quantizable_layers(model, layers): | ||
| for name, module in model.named_children(): | ||
| if isinstance(module, (torch.nn.Linear)): | ||
| layers.add(module.name) | ||
| _find_hqq_quantizable_layers(module, layers) | ||
|
|
||
| new_keys = set(expected_keys) | ||
| if is_hqq_available(): | ||
| from hqq.core.quantize import HQQLinear | ||
|
|
||
| # Name modules | ||
| for name, module in model.named_modules(): | ||
| module.name = name | ||
|
|
||
| # valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params | ||
| _valid_modules = set() | ||
| _find_hqq_quantizable_layers(model, _valid_modules) | ||
| _valid_modules -= set(model.config.quantization_config["skip_modules"]) | ||
|
|
||
| # Append new expected layers based on _ref_keys | ||
| _ref_keys = HQQLinear( | ||
| linear_layer=None, quant_config=None, compute_dtype=torch.float16, device="cpu" | ||
| ).state_dict_keys() - {"bias"} | ||
|
|
||
| # Clean-up | ||
| _rm_keys = set() | ||
| for key in new_keys: | ||
| if any(_module in key for _module in _valid_modules): | ||
| _rm_keys.add(key) | ||
| new_keys -= _rm_keys | ||
| # At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear | ||
|
|
||
| # Re-populate Linear/HQQLinear | ||
| for _module in _valid_modules: | ||
| if _module + ".weight" in loaded_keys: | ||
| new_keys.add(_module + ".weight") | ||
| else: | ||
| new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys}) | ||
| if _module + ".bias" in loaded_keys: | ||
| new_keys.add(_module + ".bias") | ||
|
|
||
| return list(new_keys) | ||
|
|
||
| def check_quantized_param( | ||
| self, | ||
| model: "PreTrainedModel", | ||
|
|
@@ -99,9 +158,18 @@ def check_quantized_param( | |
| state_dict: Dict[str, Any], | ||
| **kwargs, | ||
| ) -> bool: | ||
| if is_hqq_available(): | ||
| from hqq.core.quantize import HQQLinear | ||
| module, tensor_name = get_module_from_name(model, param_name) | ||
|
|
||
| return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") | ||
| if self.pre_quantized: | ||
| return ( | ||
| (isinstance(module, torch.nn.Linear) or isinstance(module, HQQLinear)) | ||
| and tensor_name != "weight" | ||
| and tensor_name != "bias" | ||
| ) | ||
| else: | ||
| return isinstance(module, torch.nn.Linear) and tensor_name == "weight" | ||
|
|
||
| def create_quantized_param( | ||
| self, | ||
|
|
@@ -122,21 +190,50 @@ def create_quantized_param( | |
| from hqq.core.quantize import HQQLinear | ||
|
|
||
| module, tensor_name = get_module_from_name(model, param_name) | ||
|
|
||
| layer_name = param_name.replace(".weight", "").replace(".bias", "") | ||
| layer_name = ".".join(param_name.split(".")[:-1]) | ||
| parent_module = find_parent(model, layer_name) | ||
| node = layer_name.split(".")[-1] | ||
|
|
||
| # Step 0: set module state_dict | ||
| module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key} | ||
| # set module state_dict | ||
| module_state_dict = {} | ||
| for k, v in state_dict.items(): | ||
| if layer_name + "." in k: | ||
| module_state_dict[k.split(".")[-1]] = v | ||
| if unexpected_keys is not None and k in unexpected_keys: | ||
| unexpected_keys.remove(k) | ||
|
|
||
| if self.pre_quantized: | ||
| if isinstance(module, HQQLinear): | ||
| return | ||
| else: | ||
| hqq_layer = HQQLinear( | ||
| linear_layer=None, | ||
| quant_config=None, | ||
| compute_dtype=self.torch_dtype, | ||
| device=target_device, | ||
| ) | ||
|
|
||
| hqq_layer.load_state_dict(module_state_dict) | ||
|
|
||
| if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor): | ||
| hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) | ||
|
|
||
| if self.using_multi_gpu: | ||
| hqq_layer = self._patch_layer_for_multigpu(hqq_layer) | ||
|
|
||
| setattr(parent_module, node, hqq_layer) | ||
|
|
||
| # cleanup | ||
| del module.__dict__, module | ||
| torch.cuda.empty_cache() | ||
| return | ||
|
|
||
| # Step 1: populate module with weight/bias from module state dict | ||
| for key in module_state_dict: | ||
| setattr(module, key, torch.nn.Parameter(module_state_dict[key])) | ||
|
|
||
| # Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module | ||
| # directly doesn't work. | ||
|
|
||
| if hasattr(module, "quant_config"): | ||
| hqq_layer = HQQLinear( | ||
| module, | ||
|
|
@@ -192,7 +289,7 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs | |
| return model | ||
|
|
||
| def is_serializable(self, safe_serialization=None): | ||
| return False | ||
| return True | ||
|
|
||
| @property | ||
| def is_trainable(self) -> bool: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this for potentially ints ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, some parameters are strings (packing format, etc.), booleans or integers. They are necessary meta-data to dequantize