Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/transformers/integrations/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve

# Convert quantization_config to layer-wise config
skip_modules = quantization_config.skip_modules
quant_config = quantization_config.to_dict()
quant_config = quantization_config.quant_config
linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))

if any(key in linear_tags for key in quant_config.keys()):
Expand All @@ -113,7 +113,11 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
)

# We store quantization config as linear_tag -> hqq quant config
model.config.quantization_config = patch_params
model.config.quantization_config = {
"quant_config": quant_config,
"quant_method": quantization_config.quant_method,
"skip_modules": skip_modules,
}

if not has_been_replaced:
logger.warning("No linear modules were found in your model for quantization.")
Expand Down
38 changes: 29 additions & 9 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
prune_linear_layer,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizer_hqq import HqqHfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
from .safetensors_conversion import auto_conversion
from .utils import (
Expand Down Expand Up @@ -851,8 +852,9 @@ def _load_state_dict_into_meta_model(
state_dict[new_key] = state_dict.pop(old_key)

for param_name, param in state_dict.items():
# print('param_name', param_name, param_name in loaded_state_dict_keys, param_name in expected_keys)
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
if param_name not in loaded_state_dict_keys: # or param_name not in expected_keys: #TODO @mobicham
continue

if param_name.startswith(start_prefix):
Expand Down Expand Up @@ -883,12 +885,20 @@ def _load_state_dict_into_meta_model(
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if old_param is None:
break

# TODO @mobicham: We need this for Hqq Quantizer otherwise it would break because state_dict fields (W_q, etc.) are not in nn.Linear
check_old_param = True
if is_quantized:
if isinstance(hf_quantizer, HqqHfQuantizer):
check_old_param, old_param = False, None

if check_old_param:
old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
if old_param is None:
break

if old_param is not None:
if dtype is None:
Expand Down Expand Up @@ -925,6 +935,10 @@ def _load_state_dict_into_meta_model(
)
)
):
# TODO @mobicham: skip module to device for HQQLinear since it's already on device
if is_quantized:
if isinstance(hf_quantizer, HqqHfQuantizer) and hf_quantizer.pre_quantized:
continue
# For backward compatibility with older versions of `accelerate` and for non-quantized params
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else:
Expand Down Expand Up @@ -3679,6 +3693,7 @@ def from_pretrained(
from_pt = not (from_tf | from_flax)

# load pt weights early so that we know which dtype to init the model under

if from_pt:
if not is_sharded and state_dict is None:
# Time to load the checkpoint
Expand Down Expand Up @@ -3947,7 +3962,12 @@ def from_pretrained(
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
):
device_map_kwargs["force_hooks"] = True
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():

# TODO @mobicham: HQQLinear breaks with dispatch_model() when loading
do_dispatch_model = True
# if pre_quantized:
# do_dispatch_model = not isinstance(hf_quantizer, HqqHfQuantizer)
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled() and do_dispatch_model:
dispatch_model(model, **device_map_kwargs)

if hf_quantizer is not None:
Expand Down Expand Up @@ -4128,7 +4148,7 @@ def _fix_key(key):
value = torch.empty(*param.size(), dtype=target_dtype)
if (
not is_quantized
or getattr(hf_quantizer, "requires_parameters_quantization", False)
or (getattr(hf_quantizer, "requires_parameters_quantization", False))
or not hf_quantizer.check_quantized_param(
model, param_value=value, param_name=key, state_dict={}
)
Expand Down
49 changes: 43 additions & 6 deletions src/transformers/quantizers/quantizer_hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ def validate_environment(self, *args, **kwargs):
else:
self.using_multi_gpu = len(set(device_map.values())) > 1

def update_missing_keys(
self, model: "PreTrainedModel", missing_keys: List[str], prefix: str, **kwargs
) -> List[str]:
if self.pre_quantized:
return [key for key in missing_keys if ("weight" not in key)]
else:
return missing_keys

def check_quantized_param(
self,
model: "PreTrainedModel",
Expand All @@ -100,8 +108,11 @@ def check_quantized_param(
**kwargs,
) -> bool:
module, tensor_name = get_module_from_name(model, param_name)
layer_name = ".".join(param_name.split(".")[:-1])
if "lm_head" in layer_name:
return False # TODO @mobicham: get 'lm_head' from skip_modules in the quantization config

return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
return isinstance(module, torch.nn.Linear) and (True if self.pre_quantized else tensor_name == "weight")

def create_quantized_param(
self,
Expand All @@ -122,21 +133,47 @@ def create_quantized_param(
from hqq.core.quantize import HQQLinear

module, tensor_name = get_module_from_name(model, param_name)

layer_name = param_name.replace(".weight", "").replace(".bias", "")
layer_name = ".".join(param_name.split(".")[:-1])
parent_module = find_parent(model, layer_name)
node = layer_name.split(".")[-1]

# Step 0: set module state_dict
# print("create_quantized_param | ", 'layer_name', layer_name, type(module), hasattr(module, "quant_config")) #model.layers.0.mlp.down_proj

# set module state_dict
module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key}

if self.pre_quantized:
hqq_layer = HQQLinear(
linear_layer=None,
quant_config=None, # module.quant_config
compute_dtype=self.torch_dtype,
device=target_device,
)

try:
hqq_layer.load_state_dict(module_state_dict)
except Exception:
# TODO @mobicham: Llama3 break with model.layers.28.mlp.down_proj because its parameters are split across 2 safetensors. How to fix this?
# Currently setting a fake layer so that loading doesn't break
print("Error loading, setting a fake layer for", layer_name, module_state_dict.keys())
hqq_layer = HQQLinear(
torch.nn.Linear(in_features=module.in_features, out_features=module.out_features, bias=False),
module.quant_config,
compute_dtype=self.torch_dtype,
device=target_device,
del_orig=True,
)

setattr(parent_module, node, hqq_layer)
torch.cuda.empty_cache()
return

# Step 1: populate module with weight/bias from module state dict
for key in module_state_dict:
setattr(module, key, torch.nn.Parameter(module_state_dict[key]))

# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
# directly doesn't work.

if hasattr(module, "quant_config"):
hqq_layer = HQQLinear(
module,
Expand Down Expand Up @@ -193,7 +230,7 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs

@property
def is_serializable(self):
return False
return True

@property
def is_trainable(self) -> bool:
Expand Down
16 changes: 15 additions & 1 deletion src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,26 @@ def post_init(self):
"""
pass

@classmethod
def from_dict(cls, config: Dict[str, Any]):
"""
Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py
"""
instance = cls()
instance.quant_config = config["quant_config"]
instance.skip_modules = config["skip_modules"]
return instance

def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
return self.quant_config
return {
"quant_config": self.quant_config,
"quant_method": self.quant_method,
"skip_modules": self.skip_modules,
}

def __repr__(self):
config_dict = self.to_dict()
Expand Down