diff --git a/src/transformers/integrations/hqq.py b/src/transformers/integrations/hqq.py index 10a6d06a3f9f..65af151a3e82 100755 --- a/src/transformers/integrations/hqq.py +++ b/src/transformers/integrations/hqq.py @@ -97,7 +97,7 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve # Convert quantization_config to layer-wise config skip_modules = quantization_config.skip_modules - quant_config = quantization_config.to_dict() + quant_config = quantization_config.quant_config linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert)) if any(key in linear_tags for key in quant_config.keys()): @@ -113,7 +113,11 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve ) # We store quantization config as linear_tag -> hqq quant config - model.config.quantization_config = patch_params + model.config.quantization_config = { + "quant_config": quant_config, + "quant_method": quantization_config.quant_method, + "skip_modules": skip_modules, + } if not has_been_replaced: logger.warning("No linear modules were found in your model for quantization.") diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 683b519a28cd..883625a7600a 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -56,6 +56,7 @@ prune_linear_layer, ) from .quantizers import AutoHfQuantizer, HfQuantizer +from .quantizers.quantizer_hqq import HqqHfQuantizer from .quantizers.quantizers_utils import get_module_from_name from .safetensors_conversion import auto_conversion from .utils import ( @@ -851,8 +852,9 @@ def _load_state_dict_into_meta_model( state_dict[new_key] = state_dict.pop(old_key) for param_name, param in state_dict.items(): + # print('param_name', param_name, param_name in loaded_state_dict_keys, param_name in expected_keys) # First part of the test is always true as load_state_dict_keys always contains state_dict keys. - if param_name not in loaded_state_dict_keys or param_name not in expected_keys: + if param_name not in loaded_state_dict_keys: # or param_name not in expected_keys: #TODO @mobicham continue if param_name.startswith(start_prefix): @@ -883,12 +885,20 @@ def _load_state_dict_into_meta_model( # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 - old_param = model - splits = param_name.split(".") - for split in splits: - old_param = getattr(old_param, split) - if old_param is None: - break + + # TODO @mobicham: We need this for Hqq Quantizer otherwise it would break because state_dict fields (W_q, etc.) are not in nn.Linear + check_old_param = True + if is_quantized: + if isinstance(hf_quantizer, HqqHfQuantizer): + check_old_param, old_param = False, None + + if check_old_param: + old_param = model + splits = param_name.split(".") + for split in splits: + old_param = getattr(old_param, split) + if old_param is None: + break if old_param is not None: if dtype is None: @@ -925,6 +935,10 @@ def _load_state_dict_into_meta_model( ) ) ): + # TODO @mobicham: skip module to device for HQQLinear since it's already on device + if is_quantized: + if isinstance(hf_quantizer, HqqHfQuantizer) and hf_quantizer.pre_quantized: + continue # For backward compatibility with older versions of `accelerate` and for non-quantized params set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) else: @@ -3679,6 +3693,7 @@ def from_pretrained( from_pt = not (from_tf | from_flax) # load pt weights early so that we know which dtype to init the model under + if from_pt: if not is_sharded and state_dict is None: # Time to load the checkpoint @@ -3947,7 +3962,12 @@ def from_pretrained( and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ ): device_map_kwargs["force_hooks"] = True - if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): + + # TODO @mobicham: HQQLinear breaks with dispatch_model() when loading + do_dispatch_model = True + # if pre_quantized: + # do_dispatch_model = not isinstance(hf_quantizer, HqqHfQuantizer) + if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled() and do_dispatch_model: dispatch_model(model, **device_map_kwargs) if hf_quantizer is not None: @@ -4128,7 +4148,7 @@ def _fix_key(key): value = torch.empty(*param.size(), dtype=target_dtype) if ( not is_quantized - or getattr(hf_quantizer, "requires_parameters_quantization", False) + or (getattr(hf_quantizer, "requires_parameters_quantization", False)) or not hf_quantizer.check_quantized_param( model, param_value=value, param_name=key, state_dict={} ) diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 14be75369dec..fd1bd3f16b69 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -91,6 +91,14 @@ def validate_environment(self, *args, **kwargs): else: self.using_multi_gpu = len(set(device_map.values())) > 1 + def update_missing_keys( + self, model: "PreTrainedModel", missing_keys: List[str], prefix: str, **kwargs + ) -> List[str]: + if self.pre_quantized: + return [key for key in missing_keys if ("weight" not in key)] + else: + return missing_keys + def check_quantized_param( self, model: "PreTrainedModel", @@ -100,8 +108,11 @@ def check_quantized_param( **kwargs, ) -> bool: module, tensor_name = get_module_from_name(model, param_name) + layer_name = ".".join(param_name.split(".")[:-1]) + if "lm_head" in layer_name: + return False # TODO @mobicham: get 'lm_head' from skip_modules in the quantization config - return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") + return isinstance(module, torch.nn.Linear) and (True if self.pre_quantized else tensor_name == "weight") def create_quantized_param( self, @@ -122,21 +133,47 @@ def create_quantized_param( from hqq.core.quantize import HQQLinear module, tensor_name = get_module_from_name(model, param_name) - - layer_name = param_name.replace(".weight", "").replace(".bias", "") + layer_name = ".".join(param_name.split(".")[:-1]) parent_module = find_parent(model, layer_name) node = layer_name.split(".")[-1] - # Step 0: set module state_dict + # print("create_quantized_param | ", 'layer_name', layer_name, type(module), hasattr(module, "quant_config")) #model.layers.0.mlp.down_proj + + # set module state_dict module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key} + if self.pre_quantized: + hqq_layer = HQQLinear( + linear_layer=None, + quant_config=None, # module.quant_config + compute_dtype=self.torch_dtype, + device=target_device, + ) + + try: + hqq_layer.load_state_dict(module_state_dict) + except Exception: + # TODO @mobicham: Llama3 break with model.layers.28.mlp.down_proj because its parameters are split across 2 safetensors. How to fix this? + # Currently setting a fake layer so that loading doesn't break + print("Error loading, setting a fake layer for", layer_name, module_state_dict.keys()) + hqq_layer = HQQLinear( + torch.nn.Linear(in_features=module.in_features, out_features=module.out_features, bias=False), + module.quant_config, + compute_dtype=self.torch_dtype, + device=target_device, + del_orig=True, + ) + + setattr(parent_module, node, hqq_layer) + torch.cuda.empty_cache() + return + # Step 1: populate module with weight/bias from module state dict for key in module_state_dict: setattr(module, key, torch.nn.Parameter(module_state_dict[key])) # Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module # directly doesn't work. - if hasattr(module, "quant_config"): hqq_layer = HQQLinear( module, @@ -193,7 +230,7 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs @property def is_serializable(self): - return False + return True @property def is_trainable(self) -> bool: diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 20d142b83f46..80295e89634f 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -257,12 +257,26 @@ def post_init(self): """ pass + @classmethod + def from_dict(cls, config: Dict[str, Any]): + """ + Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py + """ + instance = cls() + instance.quant_config = config["quant_config"] + instance.skip_modules = config["skip_modules"] + return instance + def to_dict(self) -> Dict[str, Any]: """ Serializes this instance to a Python dictionary. Returns: `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. """ - return self.quant_config + return { + "quant_config": self.quant_config, + "quant_method": self.quant_method, + "skip_modules": self.skip_modules, + } def __repr__(self): config_dict = self.to_dict()