diff --git a/docs/source/en/quantization/hqq.md b/docs/source/en/quantization/hqq.md old mode 100644 new mode 100755 index 11489808aecb..34608cd64fd8 --- a/docs/source/en/quantization/hqq.md +++ b/docs/source/en/quantization/hqq.md @@ -30,13 +30,13 @@ To quantize a model, you need to create an [`HqqConfig`]. There are two ways of from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig # Method 1: all linear layers will use the same quantization config -quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default +quant_config = HqqConfig(nbits=8, group_size=64) ``` ``` Python # Method 2: each linear layer with the same tag will use a dedicated quantization config -q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False} -q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False} +q4_config = {'nbits':4, 'group_size':64} +q3_config = {'nbits':3, 'group_size':32} quant_config = HqqConfig(dynamic_config={ 'self_attn.q_proj':q4_config, 'self_attn.k_proj':q4_config, diff --git a/src/transformers/integrations/hqq.py b/src/transformers/integrations/hqq.py index 10a6d06a3f9f..162b365668a0 100755 --- a/src/transformers/integrations/hqq.py +++ b/src/transformers/integrations/hqq.py @@ -66,6 +66,10 @@ def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_ has_been_replaced = True + # Add these fake parameters to avoid loading fail + for att in ["W_q", "meta"]: + setattr(module, att, None) + if len(list(module.children())) > 0: _, has_been_replaced = _prepare_for_hqq_linear( module, @@ -97,7 +101,7 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve # Convert quantization_config to layer-wise config skip_modules = quantization_config.skip_modules - quant_config = quantization_config.to_dict() + quant_config = quantization_config.quant_config linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert)) if any(key in linear_tags for key in quant_config.keys()): @@ -113,7 +117,11 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve ) # We store quantization config as linear_tag -> hqq quant config - model.config.quantization_config = patch_params + model.config.quantization_config = { + "quant_config": quant_config, + "quant_method": quantization_config.quant_method, + "skip_modules": skip_modules, + } if not has_been_replaced: logger.warning("No linear modules were found in your model for quantization.") diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fc0d6748cd1d..df0519566766 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -934,12 +934,17 @@ def _load_state_dict_into_meta_model( # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 + old_param = model splits = param_name.split(".") for split in splits: old_param = getattr(old_param, split) + # Not all the attributes of a module are Parameters/Tensor + if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): + old_param = None if old_param is None: break + if old_param is not None: if dtype is None: param = param.to(old_param.dtype) @@ -3819,6 +3824,7 @@ def from_pretrained( from_pt = not (from_tf | from_flax) # load pt weights early so that we know which dtype to init the model under + if from_pt: if not is_sharded and state_dict is None: # Time to load the checkpoint @@ -4176,6 +4182,9 @@ def _load_pretrained_model( expected_keys = list(model_state_dict.keys()) prefix = model.base_model_prefix + if hf_quantizer is not None: + expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys) + def _fix_key(key): if "beta" in key: return key.replace("beta", "bias") @@ -4290,7 +4299,7 @@ def _fix_key(key): value = torch.empty(*param.size(), dtype=target_dtype) if ( not is_quantized - or getattr(hf_quantizer, "requires_parameters_quantization", False) + or (getattr(hf_quantizer, "requires_parameters_quantization", False)) or not hf_quantizer.check_quantized_param( model, param_value=value, param_name=key, state_dict={} ) diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py old mode 100644 new mode 100755 index 73b3dbd8b259..015c0015cf7e --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -109,6 +109,18 @@ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> Li """ return missing_keys + def update_expected_keys(self, model, expected_keys: List[str], loaded_keys: List[str]) -> List[str]: + """ + Override this method if you want to adjust the `update_expected_keys`. + + Args: + expected_keys (`List[str]`, *optional*): + The list of the expected keys in the initialized model. + loaded_keys (`List[str]`, *optional*): + The list of the loaded keys in the checkpoint. + """ + return expected_keys + def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]: """ returns dtypes for modules that are not quantized - used for the computation of the device_map in case diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index cd32a99c00ac..775fea8f4901 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -62,7 +62,7 @@ def __init__(self, quantization_config, **kwargs): def validate_environment(self, *args, **kwargs): if not (is_hqq_available()): raise ImportError( - "HQQ is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`" + "A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`." ) if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): @@ -91,6 +91,65 @@ def validate_environment(self, *args, **kwargs): else: self.using_multi_gpu = len(set(device_map.values())) > 1 + def update_missing_keys( + self, model: "PreTrainedModel", missing_keys: List[str], prefix: str, **kwargs + ) -> List[str]: + if self.pre_quantized: + return [key for key in missing_keys if ("weight" not in key)] + else: + return missing_keys + + # Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear + def update_expected_keys( + self, model: "PreTrainedModel", expected_keys: List[str], loaded_keys: List[str] + ) -> List[str]: + if not self.pre_quantized: + return expected_keys + + # Collects all quantizable (linear) layers + def _find_hqq_quantizable_layers(model, layers): + for name, module in model.named_children(): + if isinstance(module, (torch.nn.Linear)): + layers.add(module.name) + _find_hqq_quantizable_layers(module, layers) + + new_keys = set(expected_keys) + if is_hqq_available(): + from hqq.core.quantize import HQQLinear + + # Name modules + for name, module in model.named_modules(): + module.name = name + + # valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params + _valid_modules = set() + _find_hqq_quantizable_layers(model, _valid_modules) + _valid_modules -= set(model.config.quantization_config["skip_modules"]) + + # Append new expected layers based on _ref_keys + _ref_keys = HQQLinear( + linear_layer=None, quant_config=None, compute_dtype=torch.float16, device="cpu" + ).state_dict_keys() - {"bias"} + + # Clean-up + _rm_keys = set() + for key in new_keys: + if any(_module in key for _module in _valid_modules): + _rm_keys.add(key) + new_keys -= _rm_keys + # At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear + + # Re-populate Linear/HQQLinear + for _module in _valid_modules: + if _module + ".weight" in loaded_keys: + new_keys.add(_module + ".weight") + else: + new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys}) + if _module + ".bias" in loaded_keys: + new_keys.add(_module + ".bias") + + return list(new_keys) + def check_quantized_param( self, model: "PreTrainedModel", @@ -99,9 +158,18 @@ def check_quantized_param( state_dict: Dict[str, Any], **kwargs, ) -> bool: + if is_hqq_available(): + from hqq.core.quantize import HQQLinear module, tensor_name = get_module_from_name(model, param_name) - return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") + if self.pre_quantized: + return ( + (isinstance(module, torch.nn.Linear) or isinstance(module, HQQLinear)) + and tensor_name != "weight" + and tensor_name != "bias" + ) + else: + return isinstance(module, torch.nn.Linear) and tensor_name == "weight" def create_quantized_param( self, @@ -122,13 +190,43 @@ def create_quantized_param( from hqq.core.quantize import HQQLinear module, tensor_name = get_module_from_name(model, param_name) - - layer_name = param_name.replace(".weight", "").replace(".bias", "") + layer_name = ".".join(param_name.split(".")[:-1]) parent_module = find_parent(model, layer_name) node = layer_name.split(".")[-1] - # Step 0: set module state_dict - module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key} + # set module state_dict + module_state_dict = {} + for k, v in state_dict.items(): + if layer_name + "." in k: + module_state_dict[k.split(".")[-1]] = v + if unexpected_keys is not None and k in unexpected_keys: + unexpected_keys.remove(k) + + if self.pre_quantized: + if isinstance(module, HQQLinear): + return + else: + hqq_layer = HQQLinear( + linear_layer=None, + quant_config=None, + compute_dtype=self.torch_dtype, + device=target_device, + ) + + hqq_layer.load_state_dict(module_state_dict) + + if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor): + hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias) + + if self.using_multi_gpu: + hqq_layer = self._patch_layer_for_multigpu(hqq_layer) + + setattr(parent_module, node, hqq_layer) + + # cleanup + del module.__dict__, module + torch.cuda.empty_cache() + return # Step 1: populate module with weight/bias from module state dict for key in module_state_dict: @@ -136,7 +234,6 @@ def create_quantized_param( # Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module # directly doesn't work. - if hasattr(module, "quant_config"): hqq_layer = HQQLinear( module, @@ -192,7 +289,7 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs return model def is_serializable(self, safe_serialization=None): - return False + return True @property def is_trainable(self) -> bool: diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 169d3491053e..a98b17e4bd57 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -92,6 +92,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ FSDP_MIN_VERSION = "1.12.0" GGUF_MIN_VERSION = "0.10.0" XLA_FSDPV2_MIN_VERSION = "2.2.0" +HQQ_MIN_VERSION = "0.2.1" _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) @@ -181,7 +182,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _torchdistx_available = _is_package_available("torchdistx") _torchvision_available = _is_package_available("torchvision") _mlx_available = _is_package_available("mlx") -_hqq_available = _is_package_available("hqq") +_hqq_available, _hqq_version = _is_package_available("hqq", return_version=True) _tiktoken_available = _is_package_available("tiktoken") _blobfile_available = _is_package_available("blobfile") _liger_kernel_available = _is_package_available("liger_kernel") @@ -323,8 +324,8 @@ def is_torch_deterministic(): return True -def is_hqq_available(): - return _hqq_available +def is_hqq_available(min_version: str = HQQ_MIN_VERSION): + return _hqq_available and version.parse(_hqq_version) >= version.parse(min_version) def is_pygments_available(): diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 19166f9ed92a..8be0bb672e51 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -193,15 +193,9 @@ class HqqConfig(QuantizationConfigMixin): Number of bits. Supported values are (8, 4, 3, 2, 1). group_size (`int`, *optional*, defaults to 64): Group-size value. Supported values are any value that is divisble by weight.shape[axis]). - quant_zero (`bool`, *optional*, defaults to `True`): - Quantize the zero-point if set to `True`. - quant_scale (`bool`, *optional*, defaults to `False`): - Quantize the scaling if set to `True`. - offload_meta (`bool`, *optional*, defaults to `False`): - Offload the meta-data to the CPU if set to `True`. view_as_float (`bool`, *optional*, defaults to `False`): View the quantized weight as float (used in distributed training) if set to `True`. - axis (`int`, *optional*, defaults to 0): + axis (`Optional[int]`, *optional*): Axis along which grouping is performed. Supported values are 0 or 1. dynamic_config (dict, *optional*): Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config. @@ -216,11 +210,8 @@ def __init__( self, nbits: int = 4, group_size: int = 64, - quant_zero: bool = True, - quant_scale: bool = False, - offload_meta: bool = False, view_as_float: bool = False, - axis: int = 0, + axis: Optional[int] = None, dynamic_config: Optional[dict] = None, skip_modules: List[str] = ["lm_head"], **kwargs, @@ -228,6 +219,16 @@ def __init__( if is_hqq_available(): from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig + for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]: + if deprecated_key in kwargs: + logger.info( + deprecated_key + " is deprecated. This parameter will be ignored in quantization settings." + ) + + if axis is None: + axis = 1 + logger.info("Setting axis=1 as faster backends such as TorchAO or BitBlas are only compatible with it.") + if axis not in [0, 1]: raise ValueError("Invalid axis value. Only 0 and 1 are allowed.") @@ -240,9 +241,6 @@ def __init__( **{ "nbits": nbits, "group_size": group_size, - "quant_zero": quant_zero, - "quant_scale": quant_scale, - "offload_meta": offload_meta, "view_as_float": view_as_float, "axis": axis, } @@ -259,12 +257,26 @@ def post_init(self): """ pass + @classmethod + def from_dict(cls, config: Dict[str, Any]): + """ + Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py + """ + instance = cls() + instance.quant_config = config["quant_config"] + instance.skip_modules = config["skip_modules"] + return instance + def to_dict(self) -> Dict[str, Any]: """ Serializes this instance to a Python dictionary. Returns: `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. """ - return self.quant_config + return { + "quant_config": self.quant_config, + "quant_method": self.quant_method, + "skip_modules": self.skip_modules, + } def __repr__(self): config_dict = self.to_dict() diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index 45c64676a7e4..6d08a0f0e669 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -94,8 +94,7 @@ def test_to_dict(self): quantization_config = HqqConfig() hqq_orig_config = quantization_config.to_dict() - for key in hqq_orig_config: - self.assertEqual(quantization_config.quant_config[key], hqq_orig_config[key]) + self.assertEqual(quantization_config.quant_config, hqq_orig_config["quant_config"]) @slow @@ -109,7 +108,7 @@ def test_fp16_quantized_model(self): """ Simple LLM model testing fp16 """ - quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) + quant_config = HqqConfig(nbits=8, group_size=64) hqq_runner = HQQLLMRunner( model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device @@ -118,26 +117,24 @@ def test_fp16_quantized_model(self): check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) check_forward(self, hqq_runner.model) - def test_f16_quantized_model_with_offloading(self): + +@slow +@require_torch_gpu +@require_torch_multi_gpu +@require_accelerate +class HQQTestMultiGPU(unittest.TestCase): + def tearDown(self): + cleanup() + + def test_fp16_quantized_model_multipgpu(self): """ - Simple LLM model testing bfp16 with meta-data offloading + Simple LLM model testing fp16 with multi-gpu """ - q4_config = {"nbits": 4, "group_size": 64, "quant_zero": False, "quant_scale": False} - q3_config = {"nbits": 3, "group_size": 32, "quant_zero": False, "quant_scale": False, "offload_meta": True} - quant_config = HqqConfig( - dynamic_config={ - "self_attn.q_proj": q4_config, - "self_attn.k_proj": q4_config, - "self_attn.v_proj": q4_config, - "self_attn.o_proj": q4_config, - "mlp.gate_proj": q3_config, - "mlp.up_proj": q3_config, - "mlp.down_proj": q3_config, - } - ) + + quant_config = HqqConfig(nbits=8, group_size=64) hqq_runner = HQQLLMRunner( - model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device + model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device="auto" ) check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) @@ -146,22 +143,40 @@ def test_f16_quantized_model_with_offloading(self): @slow @require_torch_gpu -@require_torch_multi_gpu @require_accelerate -class HQQTestMultiGPU(unittest.TestCase): +class HQQSerializationTest(unittest.TestCase): def tearDown(self): cleanup() - def test_fp16_quantized_model_multipgpu(self): + def test_model_serialization(self): """ - Simple LLM model testing fp16 with multi-gpu + Simple HQQ LLM save/load test """ - - quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) + quant_config = HqqConfig(nbits=4, group_size=64) hqq_runner = HQQLLMRunner( - model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device="auto" + model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device ) - check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj) - check_forward(self, hqq_runner.model) + input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device) + + with torch.no_grad(): + logits_ref = hqq_runner.model.forward(input_tensor).logits + + # Save + saved_model_id = "quant_model" + hqq_runner.model.save_pretrained(saved_model_id) + + # Remove old model + del hqq_runner.model + torch.cuda.empty_cache() + + # Load and check if the logits match + model_loaded = AutoModelForCausalLM.from_pretrained( + "quant_model", torch_dtype=torch.float16, device_map=torch_device, low_cpu_mem_usage=True + ) + + with torch.no_grad(): + logits_loaded = model_loaded.forward(input_tensor).logits + + self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)