Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from gptqmodel import exllama_set_max_input_length
from gptqmodel.quantization import GPTQ
from gptqmodel.utils.importer import hf_select_quant_linear
from gptqmodel.utils.model import hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format
from gptqmodel.utils.model import hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format, nested_move_to
from gptqmodel.utils.model import hf_gptqmodel_post_init as gptq_post_init
from gptqmodel.version import __version__ as gptqmodel_version

Expand Down Expand Up @@ -511,9 +511,11 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None):

blocks = recurse_getattr(model, self.block_name_to_quantize)

cur_layer_device = get_device(blocks[0])

if not has_device_map:
# put modules from module_name_preceding_first_block on cuda or xpu or cpu
to_device = 0 if has_device_more_than_cpu() else "cpu"
to_device = cur_layer_device
for module_name in self.module_name_preceding_first_block:
module = recurse_getattr(model, module_name)
if module is None:
Expand All @@ -525,26 +527,22 @@ def store_input_hook(_, input, *args):
kwargs = args[0]
if input is None:
if "hidden_states" in kwargs:
input = (kwargs["hidden_states"],)
input = (nested_move_to(kwargs["hidden_states"], cur_layer_device),)
else:
raise ValueError("No input value found in the foward pass")
layer_inputs.append(input)
other_kwargs = {}
for k, v in kwargs.items(): # make sure other arguments also be captured
if k not in ["hidden_states"]:
other_kwargs[k] = v
other_kwargs[k] = nested_move_to(v, cur_layer_device)
layer_input_kwargs.append(other_kwargs)
raise ValueError

if self.cache_block_outputs:
handle = blocks[0].register_forward_pre_hook(store_input_hook, with_kwargs=True)
for data in dataset:
for k, v in data.items():
# put the data on gpu, we won't put them back to cpu
if (not has_device_map or device.type == "cpu") and has_device_more_than_cpu():
data[k] = v.to(0)
else:
data[k] = v.to(device)
data[k] = nested_move_to(v, cur_layer_device)
try:
model(**data)
except ValueError:
Expand All @@ -571,11 +569,7 @@ def store_input_hook(_, input, *args):
handle = block.register_forward_pre_hook(store_input_hook, with_kwargs=True)
for data in dataset:
for k, v in data.items():
# put the data on gpu, we won't put them back to cpu
if (not has_device_map or device.type == "cpu") and has_device_more_than_cpu():
data[k] = v.to(0)
else:
data[k] = v.to(device)
data[k] = nested_move_to(v, cur_layer_device)
try:
model(**data)
except ValueError:
Expand All @@ -587,6 +581,7 @@ def store_input_hook(_, input, *args):
if (not has_device_map or get_device(block) == torch.device("cpu")) and has_device_more_than_cpu():
block = block.to(0)
layers = get_layers(block)
block_device = get_device(block)
if isinstance(self.modules_in_block_to_quantize, list) and len(self.modules_in_block_to_quantize) > 0:
if self.true_sequential:
layers_name_list = self.modules_in_block_to_quantize
Expand Down Expand Up @@ -620,6 +615,10 @@ def tmp(_, input, output):
for j in range(len(dataset)):
# the args are already on the gpu
# don't need to store the output
layer_inputs[j] = nested_move_to(layer_inputs[j], block_device)
for k, v in layer_input_kwargs[j].items():
layer_input_kwargs[j][k] = nested_move_to(v, block_device)

block(*layer_inputs[j], **layer_input_kwargs[j])
# remove hook
for h in handles:
Expand Down
14 changes: 14 additions & 0 deletions optimum/gptq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,17 @@ def get_seqlen(model: nn.Module):
"We couldn't get the model sequence length. Setting it to 2048. You can overwrite this value by passing `model_seqlen` in` GPTQQuantizer`"
)
return 2048

def move_to(obj: torch.Tensor | nn.Module, device: torch.device):
if get_device(obj) != device:
obj = obj.to(device)
return obj


def nested_move_to(v, device):
if isinstance(v, torch.Tensor):
return move_to(v, device)
elif isinstance(v, (list, tuple)):
return type(v)([nested_move_to(e, device) for e in v])
else:
return v
2 changes: 1 addition & 1 deletion optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
TRANSFORMERS_MINIMUM_VERSION = version.parse("4.25.0")
DIFFUSERS_MINIMUM_VERSION = version.parse("0.22.0")
AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0
GPTQMODEL_MINIMUM_VERSION = version.parse("1.4.1") # Allows 1.4.0.dev0
GPTQMODEL_MINIMUM_VERSION = version.parse("1.4.2")


# This is the minimal required version to support some ONNX Runtime features
Expand Down
5 changes: 4 additions & 1 deletion tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,10 @@ def test_exllama_serialization(self):
# quantized models are more compatible with device map than
# device context managers (they're never used in transformers testing suite)
_ = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": self.device_for_inference})
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})
if is_gptqmodel_available():
_ = GPTQModel.load(tmpdirname, device_map={"": self.device_for_inference})
else:
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})


class GPTQTestNoBlockCaching(GPTQTestCUDA):
Expand Down