diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 83667f54fa..2eb90d94d3 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -25,7 +25,7 @@ from tqdm.auto import tqdm from transformers import AutoTokenizer from transformers.pytorch_utils import Conv1D -from transformers.utils.quantization_config import QuantizationMethod +from transformers.utils.quantization_config import QuantizationMethod, GPTQConfig from ..utils import is_accelerate_available, is_auto_gptq_available, is_gptqmodel_available from ..utils.modeling_utils import recurse_getattr @@ -33,7 +33,6 @@ from .data import get_dataset, prepare_dataset from .utils import get_block_name_with_pattern, get_device, get_layers, get_preceding_modules, get_seqlen - if is_accelerate_available(): from accelerate import ( cpu_offload_with_hook, @@ -47,16 +46,14 @@ from auto_gptq.quantization import GPTQ from auto_gptq.utils.import_utils import dynamically_import_QuantLinear as hf_select_quant_linear - if is_gptqmodel_available(): - from gptqmodel import exllama_set_max_input_length + from gptqmodel import exllama_set_max_input_length, BACKEND from gptqmodel.quantization import GPTQ from gptqmodel.utils.importer import hf_select_quant_linear from gptqmodel.utils.model import gptqmodel_post_init as gptq_post_init logger = getLogger(__name__) - if not is_gptqmodel_available(): logger.warning("auto_gptq will be deprecated in the future, please `pip install gptqmodel` for gptq model") @@ -85,6 +82,7 @@ def __init__( sym: bool = True, true_sequential: bool = True, use_cuda_fp16: bool = False, + checkpoint_format: str = "gptq", model_seqlen: Optional[int] = None, block_name_to_quantize: Optional[str] = None, module_name_preceding_first_block: Optional[List[str]] = None, @@ -167,6 +165,7 @@ def __init__( self.quant_method = QuantizationMethod.GPTQ self.cache_block_outputs = cache_block_outputs self.modules_in_block_to_quantize = modules_in_block_to_quantize + self.checkpoint_format = checkpoint_format self.serialization_keys = [ "bits", @@ -178,6 +177,7 @@ def __init__( "true_sequential", "quant_method", "modules_in_block_to_quantize", + "checkpoint_format", ] if self.bits not in [2, 3, 4, 8]: @@ -199,6 +199,26 @@ def __init__( ) self.exllama_version = self.exllama_config["version"] + def select_quant_linear(self, pack: bool): + if is_gptqmodel_available(): + self.quant_linear = hf_select_quant_linear( + bits=self.bits, + group_size=self.group_size, + desc_act=self.desc_act, + sym=self.sym, + pack=pack, + backend=self.backend, + ) + else: + self.quant_linear = hf_select_quant_linear( + use_triton=False, + desc_act=self.desc_act, + group_size=self.group_size, + bits=self.bits, + disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE, + disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO, + ) + def to_dict(self): """ Returns the args in dict format. @@ -206,6 +226,7 @@ def to_dict(self): gptq_dict = {} for key in self.serialization_keys: gptq_dict[key] = getattr(self, key) + return gptq_dict @classmethod @@ -222,7 +243,7 @@ def from_dict(cls, config_dict: Dict[str, Any]): """ return cls(**config_dict) - def convert_model(self, model: nn.Module): + def convert_model(self, model: nn.Module, **kwargs): """ Convert the model to a GPTQ model by getting and replacing the layers. @@ -231,6 +252,8 @@ def convert_model(self, model: nn.Module): Model to be converted """ + self.select_gptqmodel_backend(kwargs) + if self.block_name_to_quantize is None: self.block_name_to_quantize = get_block_name_with_pattern(model) block_name = self.block_name_to_quantize @@ -243,9 +266,25 @@ def convert_model(self, model: nn.Module): f"Quantization disabled for {name} (only modules_in_block_to_quantize={self.modules_in_block_to_quantize} are quantized)" ) del layers_to_be_replaced[name] + + self.select_quant_linear(pack=False) + self._replace_by_quant_layers(model, layers_to_be_replaced) + return model + def quantize_preprocess(self, model, **kwargs): + self.select_gptqmodel_backend(kwargs) + + def select_gptqmodel_backend(self, kwargs): + if is_gptqmodel_available(): + self.backend = BACKEND.AUTO + if kwargs.get("device_map") is not None and kwargs.get("device_map") != "auto": + device_map = kwargs.get("device_map") + devices = [device_map] if isinstance(device_map, str) else list(device_map.values()) + if "cpu" in devices or torch.device("cpu") in devices: + self.backend = BACKEND.IPEX + def get_no_split_module_classes(self, model): """ Get the modules that should not be split across multiple devices. @@ -270,20 +309,7 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st name (`str`, defaults to `""`): To keep track of the name of the current module """ - if is_gptqmodel_available(): - QuantLinear = hf_select_quant_linear( - bits=self.bits, group_size=self.group_size, desc_act=self.desc_act, sym=self.sym - ) - else: - QuantLinear = hf_select_quant_linear( - use_triton=False, - desc_act=self.desc_act, - group_size=self.group_size, - bits=self.bits, - disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE, - disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO, - ) - if isinstance(module, QuantLinear): + if isinstance(module, self.quant_linear): return for attr in dir(module): layer = getattr(module, attr) @@ -302,7 +328,7 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st out_features = layer.weight.shape[1] bias = layer.bias is not None if is_gptqmodel_available(): - new_layer = QuantLinear( + new_layer = self.quant_linear( self.bits, self.group_size, self.desc_act, @@ -314,7 +340,7 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st ) else: if not (self.desc_act) or self.group_size == -1: - new_layer = QuantLinear( + new_layer = self.quant_linear( self.bits, self.group_size, in_features, @@ -324,7 +350,7 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st weight_dtype=layer.weight.dtype, ) else: - new_layer = QuantLinear( + new_layer = self.quant_linear( self.bits, self.group_size, in_features, @@ -359,18 +385,29 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None): if not is_auto_gptq_available() and not is_gptqmodel_available(): raise RuntimeError( - "gptqmodel or auto-gptq is required in order to perform quantzation : `pip install gptqmodel` or `pip install auto-gptq`" + "gptqmodel or auto-gptq is required in order to perform gptq quantzation: `pip install gptqmodel` or `pip install auto-gptq`" ) gptq_supports_cpu = ( - is_auto_gptq_available() - and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2") - ) or is_gptqmodel_available() + is_auto_gptq_available() + and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2") + ) or is_gptqmodel_available() + if not gptq_supports_cpu and not torch.cuda.is_available(): - raise RuntimeError("No GPU found. A GPU is needed to quantize model.") + raise RuntimeError("No cuda gpu or cpu support using Intel/IPEX found. A gpu or avx512 enabled cpu with Intel/IPEX is required for quantization.") + + if self.sym == False and not is_gptqmodel_available(): + raise ValueError("Asymmetric sym=False quantization is not supported with auto-gptq. Please use gptqmodel: `pip install gptqmodel`") + + if self.checkpoint_format == "gptq_v2" and not is_gptqmodel_available(): + raise ValueError("gptq_v2 format only supported with gptqmodel. Please install gptqmodel: `pip install gptqmodel`") model.eval() + # gptqmodel internal is gptq_v2 for asym support, gptq(v1) can only support sym=True + if is_gptqmodel_available() and self.checkpoint_format != "gptq_v2" and self.backend != BACKEND.IPEX: + self.checkpoint_format = "gptq_v2" + # For Transformer model has_config = False has_device_map = False @@ -618,7 +655,7 @@ def tmp(_, input, output): ) self.disable_exllama = True # Step 4: Pack the model at the end (Replacing the layers) - self.pack_model(model=model, quantizers=quantizers) + self.pack_model(model=model, device=device, quantizers=quantizers) model.is_quantized = True model.quantization_method = QuantizationMethod.GPTQ @@ -629,12 +666,18 @@ def tmp(_, input, output): # Step 5: Any post-initialization that require device information, for example buffers initialization on device. model = self.post_init_model(model) + # convert gptqmodel internal gptq_v2 format to v1 for saving/compat + if self.checkpoint_format == "gptq_v2": + from gptqmodel.utils.model import convert_gptq_v2_to_v1_format + model = convert_gptq_v2_to_v1_format(model, self.bits, self.quant_linear) + self.checkpoint_format = "gptq" + torch.cuda.empty_cache() if hasattr(torch, "xpu"): torch.xpu.empty_cache() return model - def post_init_model(self, model): + def post_init_model(self, model, **kwargs): """ Post-initialization that require device information, for example buffers initialization on device. @@ -655,21 +698,26 @@ def post_init_model(self, model): class StoreAttr(object): pass + if is_gptqmodel_available() and self.checkpoint_format == "gptq" and self.backend != BACKEND.IPEX: + from gptqmodel.utils.model import convert_gptq_v1_to_v2_format + model = convert_gptq_v1_to_v2_format(model, self.bits, self.quant_linear) + model.quantize_config = StoreAttr() model.quantize_config.desc_act = self.desc_act model = gptq_post_init(model, use_act_order=self.desc_act) if ( - self.desc_act - and (not self.disable_exllama and self.exllama_version == ExllamaVersion.ONE) - and self.max_input_length is not None + self.desc_act + and (not self.disable_exllama and self.exllama_version == ExllamaVersion.ONE) + and self.max_input_length is not None ): model = exllama_set_max_input_length(model, self.max_input_length) return model def pack_model( - self, - model: nn.Module, - quantizers: Dict[str, Tuple], + self, + model: nn.Module, + device: torch.device, + quantizers: Dict[str, Tuple], ): """ Pack the model by replacing the layers by quantized layers @@ -680,24 +728,14 @@ def pack_model( quantizers (`Dict[str,Tuple]`): A mapping of the layer name and the data needed to pack the layer """ - if is_gptqmodel_available(): - QuantLinear = hf_select_quant_linear( - bits=self.bits, group_size=self.group_size, desc_act=self.desc_act, sym=self.sym - ) - else: - QuantLinear = hf_select_quant_linear( - use_triton=False, - desc_act=self.desc_act, - group_size=self.group_size, - bits=self.bits, - disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE, - disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO, - ) logger.info("Packing model...") layers = get_layers(model) layers = {n: layers[n] for n in quantizers} + + self.select_quant_linear(pack=True) + self._replace_by_quant_layers(model, quantizers) - qlayers = get_layers(model, [QuantLinear]) + qlayers = get_layers(model, [self.quant_linear]) for name in qlayers: logger.info(name) quantizers[name], scale, zero, g_idx = quantizers[name] @@ -838,7 +876,7 @@ def load_quantized_model( quantizer.exllama_version = quantizer.exllama_config["version"] quantizer.max_input_length = max_input_length - model = quantizer.convert_model(model) + model = quantizer.convert_model(model, device_map=device_map) if no_split_module_classes is None: no_split_module_classes = quantizer.get_no_split_module_classes(model)