Skip to content
Merged
140 changes: 89 additions & 51 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from transformers.pytorch_utils import Conv1D
from transformers.utils.quantization_config import QuantizationMethod
from transformers.utils.quantization_config import QuantizationMethod, GPTQConfig

from ..utils import is_accelerate_available, is_auto_gptq_available, is_gptqmodel_available
from ..utils.modeling_utils import recurse_getattr
from .constants import GPTQ_CONFIG
from .data import get_dataset, prepare_dataset
from .utils import get_block_name_with_pattern, get_device, get_layers, get_preceding_modules, get_seqlen


if is_accelerate_available():
from accelerate import (
cpu_offload_with_hook,
Expand All @@ -47,16 +46,14 @@
from auto_gptq.quantization import GPTQ
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear as hf_select_quant_linear


if is_gptqmodel_available():
from gptqmodel import exllama_set_max_input_length
from gptqmodel import exllama_set_max_input_length, BACKEND
from gptqmodel.quantization import GPTQ
from gptqmodel.utils.importer import hf_select_quant_linear
from gptqmodel.utils.model import gptqmodel_post_init as gptq_post_init

logger = getLogger(__name__)


if not is_gptqmodel_available():
logger.warning("auto_gptq will be deprecated in the future, please `pip install gptqmodel` for gptq model")

Expand Down Expand Up @@ -85,6 +82,7 @@ def __init__(
sym: bool = True,
true_sequential: bool = True,
use_cuda_fp16: bool = False,
checkpoint_format: str = "gptq",
model_seqlen: Optional[int] = None,
block_name_to_quantize: Optional[str] = None,
module_name_preceding_first_block: Optional[List[str]] = None,
Expand Down Expand Up @@ -167,6 +165,7 @@ def __init__(
self.quant_method = QuantizationMethod.GPTQ
self.cache_block_outputs = cache_block_outputs
self.modules_in_block_to_quantize = modules_in_block_to_quantize
self.checkpoint_format = checkpoint_format

self.serialization_keys = [
"bits",
Expand All @@ -178,6 +177,7 @@ def __init__(
"true_sequential",
"quant_method",
"modules_in_block_to_quantize",
"checkpoint_format",
]

if self.bits not in [2, 3, 4, 8]:
Expand All @@ -199,13 +199,34 @@ def __init__(
)
self.exllama_version = self.exllama_config["version"]

def select_quant_linear(self, pack: bool):
if is_gptqmodel_available():
self.quant_linear = hf_select_quant_linear(
bits=self.bits,
group_size=self.group_size,
desc_act=self.desc_act,
sym=self.sym,
pack=pack,
backend=self.backend,
)
else:
self.quant_linear = hf_select_quant_linear(
use_triton=False,
desc_act=self.desc_act,
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE,
disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO,
)

def to_dict(self):
"""
Returns the args in dict format.
"""
gptq_dict = {}
for key in self.serialization_keys:
gptq_dict[key] = getattr(self, key)

return gptq_dict

@classmethod
Expand All @@ -222,7 +243,7 @@ def from_dict(cls, config_dict: Dict[str, Any]):
"""
return cls(**config_dict)

def convert_model(self, model: nn.Module):
def convert_model(self, model: nn.Module, **kwargs):
"""
Convert the model to a GPTQ model by getting and replacing the layers.

Expand All @@ -231,6 +252,8 @@ def convert_model(self, model: nn.Module):
Model to be converted

"""
self.select_gptqmodel_backend(kwargs)

if self.block_name_to_quantize is None:
self.block_name_to_quantize = get_block_name_with_pattern(model)
block_name = self.block_name_to_quantize
Expand All @@ -243,9 +266,25 @@ def convert_model(self, model: nn.Module):
f"Quantization disabled for {name} (only modules_in_block_to_quantize={self.modules_in_block_to_quantize} are quantized)"
)
del layers_to_be_replaced[name]

self.select_quant_linear(pack=False)

self._replace_by_quant_layers(model, layers_to_be_replaced)

return model

def quantize_preprocess(self, model, **kwargs):
self.select_gptqmodel_backend(kwargs)

def select_gptqmodel_backend(self, kwargs):
if is_gptqmodel_available():
self.backend = BACKEND.AUTO
if kwargs.get("device_map") is not None and kwargs.get("device_map") != "auto":
device_map = kwargs.get("device_map")
devices = [device_map] if isinstance(device_map, str) else list(device_map.values())
if "cpu" in devices or torch.device("cpu") in devices:
self.backend = BACKEND.IPEX

def get_no_split_module_classes(self, model):
"""
Get the modules that should not be split across multiple devices.
Expand All @@ -270,20 +309,7 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st
name (`str`, defaults to `""`):
To keep track of the name of the current module
"""
if is_gptqmodel_available():
QuantLinear = hf_select_quant_linear(
bits=self.bits, group_size=self.group_size, desc_act=self.desc_act, sym=self.sym
)
else:
QuantLinear = hf_select_quant_linear(
use_triton=False,
desc_act=self.desc_act,
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE,
disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO,
)
if isinstance(module, QuantLinear):
if isinstance(module, self.quant_linear):
return
for attr in dir(module):
layer = getattr(module, attr)
Expand All @@ -302,7 +328,7 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st
out_features = layer.weight.shape[1]
bias = layer.bias is not None
if is_gptqmodel_available():
new_layer = QuantLinear(
new_layer = self.quant_linear(
self.bits,
self.group_size,
self.desc_act,
Expand All @@ -314,7 +340,7 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st
)
else:
if not (self.desc_act) or self.group_size == -1:
new_layer = QuantLinear(
new_layer = self.quant_linear(
self.bits,
self.group_size,
in_features,
Expand All @@ -324,7 +350,7 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st
weight_dtype=layer.weight.dtype,
)
else:
new_layer = QuantLinear(
new_layer = self.quant_linear(
self.bits,
self.group_size,
in_features,
Expand Down Expand Up @@ -359,18 +385,29 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None):

if not is_auto_gptq_available() and not is_gptqmodel_available():
raise RuntimeError(
"gptqmodel or auto-gptq is required in order to perform quantzation : `pip install gptqmodel` or `pip install auto-gptq`"
"gptqmodel or auto-gptq is required in order to perform gptq quantzation: `pip install gptqmodel` or `pip install auto-gptq`"
)

gptq_supports_cpu = (
is_auto_gptq_available()
and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
) or is_gptqmodel_available()
is_auto_gptq_available()
and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
) or is_gptqmodel_available()

if not gptq_supports_cpu and not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed to quantize model.")
raise RuntimeError("No cuda gpu or cpu support using Intel/IPEX found. A gpu or avx512 enabled cpu with Intel/IPEX is required for quantization.")

if self.sym == False and not is_gptqmodel_available():
raise ValueError("Asymmetric sym=False quantization is not supported with auto-gptq. Please use gptqmodel: `pip install gptqmodel`")

if self.checkpoint_format == "gptq_v2" and not is_gptqmodel_available():
raise ValueError("gptq_v2 format only supported with gptqmodel. Please install gptqmodel: `pip install gptqmodel`")

model.eval()

# gptqmodel internal is gptq_v2 for asym support, gptq(v1) can only support sym=True
if is_gptqmodel_available() and self.checkpoint_format != "gptq_v2" and self.backend != BACKEND.IPEX:
self.checkpoint_format = "gptq_v2"

# For Transformer model
has_config = False
has_device_map = False
Expand Down Expand Up @@ -618,7 +655,7 @@ def tmp(_, input, output):
)
self.disable_exllama = True
# Step 4: Pack the model at the end (Replacing the layers)
self.pack_model(model=model, quantizers=quantizers)
self.pack_model(model=model, device=device, quantizers=quantizers)

model.is_quantized = True
model.quantization_method = QuantizationMethod.GPTQ
Expand All @@ -629,12 +666,18 @@ def tmp(_, input, output):
# Step 5: Any post-initialization that require device information, for example buffers initialization on device.
model = self.post_init_model(model)

# convert gptqmodel internal gptq_v2 format to v1 for saving/compat
if self.checkpoint_format == "gptq_v2":
from gptqmodel.utils.model import convert_gptq_v2_to_v1_format
model = convert_gptq_v2_to_v1_format(model, self.bits, self.quant_linear)
self.checkpoint_format = "gptq"

torch.cuda.empty_cache()
if hasattr(torch, "xpu"):
torch.xpu.empty_cache()
return model

def post_init_model(self, model):
def post_init_model(self, model, **kwargs):
"""
Post-initialization that require device information, for example buffers initialization on device.

Expand All @@ -655,21 +698,26 @@ def post_init_model(self, model):
class StoreAttr(object):
pass

if is_gptqmodel_available() and self.checkpoint_format == "gptq" and self.backend != BACKEND.IPEX:
from gptqmodel.utils.model import convert_gptq_v1_to_v2_format
model = convert_gptq_v1_to_v2_format(model, self.bits, self.quant_linear)

model.quantize_config = StoreAttr()
model.quantize_config.desc_act = self.desc_act
model = gptq_post_init(model, use_act_order=self.desc_act)
if (
self.desc_act
and (not self.disable_exllama and self.exllama_version == ExllamaVersion.ONE)
and self.max_input_length is not None
self.desc_act
and (not self.disable_exllama and self.exllama_version == ExllamaVersion.ONE)
and self.max_input_length is not None
):
model = exllama_set_max_input_length(model, self.max_input_length)
return model

def pack_model(
self,
model: nn.Module,
quantizers: Dict[str, Tuple],
self,
model: nn.Module,
device: torch.device,
quantizers: Dict[str, Tuple],
):
"""
Pack the model by replacing the layers by quantized layers
Expand All @@ -680,24 +728,14 @@ def pack_model(
quantizers (`Dict[str,Tuple]`):
A mapping of the layer name and the data needed to pack the layer
"""
if is_gptqmodel_available():
QuantLinear = hf_select_quant_linear(
bits=self.bits, group_size=self.group_size, desc_act=self.desc_act, sym=self.sym
)
else:
QuantLinear = hf_select_quant_linear(
use_triton=False,
desc_act=self.desc_act,
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE,
disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO,
)
logger.info("Packing model...")
layers = get_layers(model)
layers = {n: layers[n] for n in quantizers}

self.select_quant_linear(pack=True)

self._replace_by_quant_layers(model, quantizers)
qlayers = get_layers(model, [QuantLinear])
qlayers = get_layers(model, [self.quant_linear])
for name in qlayers:
logger.info(name)
quantizers[name], scale, zero, g_idx = quantizers[name]
Expand Down Expand Up @@ -838,7 +876,7 @@ def load_quantized_model(
quantizer.exllama_version = quantizer.exllama_config["version"]
quantizer.max_input_length = max_input_length

model = quantizer.convert_model(model)
model = quantizer.convert_model(model, device_map=device_map)

if no_split_module_classes is None:
no_split_module_classes = quantizer.get_no_split_module_classes(model)
Expand Down