-
Notifications
You must be signed in to change notification settings - Fork 33.8k
Hqq serialization #33141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hqq serialization #33141
Changes from 10 commits
ff40f1a
fa8a9f5
75dfe0a
f2ea032
bc9cb55
a8704d2
5cb7d81
7f1b85d
71cccd4
ff982b3
d35ea7c
cbe219f
2bb974c
9f7c235
7f15b49
383e028
cf5a05c
4682a72
813ed62
d0c594c
7e019b3
0dd1152
9053ad5
2b6e7df
e68110a
433c3a0
4db1991
3b56533
a8843cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,6 +56,7 @@ | |
| prune_linear_layer, | ||
| ) | ||
| from .quantizers import AutoHfQuantizer, HfQuantizer | ||
| from .quantizers.quantizer_hqq import HqqHfQuantizer | ||
| from .quantizers.quantizers_utils import get_module_from_name | ||
| from .safetensors_conversion import auto_conversion | ||
| from .utils import ( | ||
|
|
@@ -857,9 +858,14 @@ def _load_state_dict_into_meta_model( | |
|
|
||
| is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") | ||
|
|
||
| # We add this because HQQLinear dict has a very large state_dict (19 params/per module), which makes loading extremely slow | ||
| run_expected_keys_check = True | ||
| if isinstance(hf_quantizer, HqqHfQuantizer): | ||
| run_expected_keys_check = False | ||
|
|
||
| for param_name, param in state_dict.items(): | ||
| # First part of the test is always true as load_state_dict_keys always contains state_dict keys. | ||
| if param_name not in loaded_state_dict_keys or param_name not in expected_keys: | ||
| if param_name not in loaded_state_dict_keys or ((param_name not in expected_keys) and run_expected_keys_check): | ||
| continue | ||
|
|
||
| if param_name.startswith(start_prefix): | ||
|
|
@@ -891,12 +897,17 @@ def _load_state_dict_into_meta_model( | |
| # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which | ||
| # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. | ||
| # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 | ||
|
|
||
| old_param = model | ||
| splits = param_name.split(".") | ||
| for split in splits: | ||
| old_param = getattr(old_param, split) | ||
| # Not all the attributes of a module are Parameters/Tensor | ||
| if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): | ||
| old_param = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this for potentially ints ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, some parameters are strings (packing format, etc.), booleans or integers. They are necessary meta-data to dequantize |
||
| if old_param is None: | ||
| break | ||
|
|
||
| if old_param is not None: | ||
| if dtype is None: | ||
| param = param.to(old_param.dtype) | ||
|
|
@@ -3725,6 +3736,7 @@ def from_pretrained( | |
| from_pt = not (from_tf | from_flax) | ||
|
|
||
| # load pt weights early so that we know which dtype to init the model under | ||
|
|
||
| if from_pt: | ||
| if not is_sharded and state_dict is None: | ||
| # Time to load the checkpoint | ||
|
|
@@ -4181,7 +4193,7 @@ def _fix_key(key): | |
| value = torch.empty(*param.size(), dtype=target_dtype) | ||
| if ( | ||
| not is_quantized | ||
| or getattr(hf_quantizer, "requires_parameters_quantization", False) | ||
| or (getattr(hf_quantizer, "requires_parameters_quantization", False)) | ||
| or not hf_quantizer.check_quantized_param( | ||
| model, param_value=value, param_name=key, state_dict={} | ||
| ) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.