Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ff40f1a
HQQ model serialization attempt
mobicham Jul 18, 2024
fa8a9f5
Merge branch 'huggingface:main' into main
mobicham Jul 18, 2024
75dfe0a
fix hqq dispatch and unexpected keys
SunMarc Aug 1, 2024
f2ea032
Merge remote-tracking branch 'upstream/main' into hqq_serialization
SunMarc Aug 1, 2024
bc9cb55
Merge remote-tracking branch 'upstream/main' into hqq_serialization
SunMarc Aug 1, 2024
a8704d2
style
SunMarc Aug 1, 2024
5cb7d81
remove check_old_param
mobicham Aug 27, 2024
7f1b85d
revert to check HQQLinear in quantizer_hqq.py
mobicham Aug 28, 2024
71cccd4
revert to check HQQLinear in quantizer_hqq.py
mobicham Aug 28, 2024
ff982b3
update HqqConfig default params
mobicham Aug 28, 2024
d35ea7c
make ci happy
mobicham Aug 28, 2024
cbe219f
make ci happy
mobicham Aug 28, 2024
2bb974c
revert to HQQLinear check in quantizer_hqq.py
mobicham Aug 28, 2024
9f7c235
check hqq_min version 0.2.0
mobicham Aug 28, 2024
7f15b49
Merge branch 'main' into hqq_serialization
mobicham Aug 28, 2024
383e028
set axis=1 as default in quantization_config.py
mobicham Aug 28, 2024
cf5a05c
Merge branch 'hqq_serialization' of https://github.com/mobiusml/trans…
mobicham Aug 28, 2024
4682a72
validate_env with hqq>=0.2.0 version message
mobicham Aug 28, 2024
813ed62
deprecated hqq kwargs message
mobicham Aug 28, 2024
d0c594c
make ci happy
mobicham Aug 28, 2024
7e019b3
remove run_expected_keys_check hack + bump to 0.2.1 min hqq version
mobicham Aug 29, 2024
0dd1152
fix unexpected_keys hqq update
mobicham Aug 29, 2024
9053ad5
add pre_quantized check
mobicham Aug 29, 2024
2b6e7df
add update_expected_keys to base quantizerr
mobicham Aug 29, 2024
e68110a
ci base.py fix?
mobicham Aug 29, 2024
433c3a0
ci base.py fix?
mobicham Aug 29, 2024
4db1991
fix "quantization typo" src/transformers/utils/quantization_config.py
mobicham Sep 30, 2024
3b56533
Merge branch 'main' into hqq_serialization
mobicham Sep 30, 2024
a8843cf
fix post merge
mobicham Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/en/quantization/hqq.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ To quantize a model, you need to create an [`HqqConfig`]. There are two ways of
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig

# Method 1: all linear layers will use the same quantization config
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default
quant_config = HqqConfig(nbits=8, group_size=64)
```

``` Python
# Method 2: each linear layer with the same tag will use a dedicated quantization config
q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False}
q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False}
q4_config = {'nbits':4, 'group_size':64}
q3_config = {'nbits':3, 'group_size':32}
quant_config = HqqConfig(dynamic_config={
'self_attn.q_proj':q4_config,
'self_attn.k_proj':q4_config,
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/integrations/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_

has_been_replaced = True

# Add these fake parameters to avoid loading fail
for att in ["W_q", "meta"]:
setattr(module, att, None)

if len(list(module.children())) > 0:
_, has_been_replaced = _prepare_for_hqq_linear(
module,
Expand Down Expand Up @@ -97,7 +101,7 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve

# Convert quantization_config to layer-wise config
skip_modules = quantization_config.skip_modules
quant_config = quantization_config.to_dict()
quant_config = quantization_config.quant_config
linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))

if any(key in linear_tags for key in quant_config.keys()):
Expand All @@ -113,7 +117,11 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
)

# We store quantization config as linear_tag -> hqq quant config
model.config.quantization_config = patch_params
model.config.quantization_config = {
"quant_config": quant_config,
"quant_method": quantization_config.quant_method,
"skip_modules": skip_modules,
}

if not has_been_replaced:
logger.warning("No linear modules were found in your model for quantization.")
Expand Down
16 changes: 14 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
prune_linear_layer,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizer_hqq import HqqHfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
from .safetensors_conversion import auto_conversion
from .utils import (
Expand Down Expand Up @@ -857,9 +858,14 @@ def _load_state_dict_into_meta_model(

is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")

# We add this because HQQLinear dict has a very large state_dict (19 params/per module), which makes loading extremely slow
run_expected_keys_check = True
if isinstance(hf_quantizer, HqqHfQuantizer):
run_expected_keys_check = False

Comment thread
mobicham marked this conversation as resolved.
Outdated
for param_name, param in state_dict.items():
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
if param_name not in loaded_state_dict_keys or ((param_name not in expected_keys) and run_expected_keys_check):
continue

if param_name.startswith(start_prefix):
Expand Down Expand Up @@ -891,12 +897,17 @@ def _load_state_dict_into_meta_model(
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29

old_param = model
splits = param_name.split(".")
for split in splits:
old_param = getattr(old_param, split)
# Not all the attributes of a module are Parameters/Tensor
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for potentially ints ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, some parameters are strings (packing format, etc.), booleans or integers. They are necessary meta-data to dequantize

if old_param is None:
break

if old_param is not None:
if dtype is None:
param = param.to(old_param.dtype)
Expand Down Expand Up @@ -3725,6 +3736,7 @@ def from_pretrained(
from_pt = not (from_tf | from_flax)

# load pt weights early so that we know which dtype to init the model under

if from_pt:
if not is_sharded and state_dict is None:
# Time to load the checkpoint
Expand Down Expand Up @@ -4181,7 +4193,7 @@ def _fix_key(key):
value = torch.empty(*param.size(), dtype=target_dtype)
if (
not is_quantized
or getattr(hf_quantizer, "requires_parameters_quantization", False)
or (getattr(hf_quantizer, "requires_parameters_quantization", False))
or not hf_quantizer.check_quantized_param(
model, param_value=value, param_name=key, state_dict={}
)
Expand Down
60 changes: 53 additions & 7 deletions src/transformers/quantizers/quantizer_hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ def validate_environment(self, *args, **kwargs):
else:
self.using_multi_gpu = len(set(device_map.values())) > 1

def update_missing_keys(
self, model: "PreTrainedModel", missing_keys: List[str], prefix: str, **kwargs
) -> List[str]:
if self.pre_quantized:
return [key for key in missing_keys if ("weight" not in key)]
else:
return missing_keys

def check_quantized_param(
self,
model: "PreTrainedModel",
Expand All @@ -99,9 +107,18 @@ def check_quantized_param(
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
if is_hqq_available():
from hqq.core.quantize import HQQLinear
module, tensor_name = get_module_from_name(model, param_name)

return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
if self.pre_quantized:
return (
(isinstance(module, torch.nn.Linear) or isinstance(module, HQQLinear))
and tensor_name != "weight"
and tensor_name != "bias"
)
else:
return isinstance(module, torch.nn.Linear) and tensor_name == "weight"

def create_quantized_param(
self,
Expand All @@ -122,21 +139,50 @@ def create_quantized_param(
from hqq.core.quantize import HQQLinear

module, tensor_name = get_module_from_name(model, param_name)

layer_name = param_name.replace(".weight", "").replace(".bias", "")
layer_name = ".".join(param_name.split(".")[:-1])
parent_module = find_parent(model, layer_name)
node = layer_name.split(".")[-1]

# Step 0: set module state_dict
module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key}
# set module state_dict
module_state_dict = {}
for k, v in state_dict.items():
if layer_name + "." in k:
module_state_dict[k.split(".")[-1]] = v
if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k)

if self.pre_quantized:
if isinstance(module, (torch.nn.Linear, HQQLinear)):
hqq_layer = HQQLinear(
linear_layer=None,
quant_config=None,
compute_dtype=self.torch_dtype,
device=target_device,
)

hqq_layer.axis = None
hqq_layer.channel_wise = None
hqq_layer.load_state_dict(module_state_dict)
Comment thread
mobicham marked this conversation as resolved.
Outdated

if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)

if self.using_multi_gpu:
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)

setattr(parent_module, node, hqq_layer)

# cleanup
del module.__dict__, module
torch.cuda.empty_cache()
return

# Step 1: populate module with weight/bias from module state dict
for key in module_state_dict:
setattr(module, key, torch.nn.Parameter(module_state_dict[key]))

# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
# directly doesn't work.

if hasattr(module, "quant_config"):
hqq_layer = HQQLinear(
module,
Expand Down Expand Up @@ -193,7 +239,7 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs

@property
def is_serializable(self):
return False
return True

@property
def is_trainable(self) -> bool:
Expand Down
30 changes: 16 additions & 14 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,6 @@ class HqqConfig(QuantizationConfigMixin):
Number of bits. Supported values are (8, 4, 3, 2, 1).
group_size (`int`, *optional*, defaults to 64):
Group-size value. Supported values are any value that is divisble by weight.shape[axis]).
quant_zero (`bool`, *optional*, defaults to `True`):
Quantize the zero-point if set to `True`.
quant_scale (`bool`, *optional*, defaults to `False`):
Quantize the scaling if set to `True`.
offload_meta (`bool`, *optional*, defaults to `False`):
Offload the meta-data to the CPU if set to `True`.
view_as_float (`bool`, *optional*, defaults to `False`):
View the quantized weight as float (used in distributed training) if set to `True`.
axis (`int`, *optional*, defaults to 0):
Expand All @@ -215,11 +209,8 @@ def __init__(
self,
nbits: int = 4,
group_size: int = 64,
quant_zero: bool = True,
quant_scale: bool = False,
offload_meta: bool = False,
Comment thread
SunMarc marked this conversation as resolved.
view_as_float: bool = False,
axis: int = 0,
Comment thread
mobicham marked this conversation as resolved.
axis: int = 1,
Comment thread
mobicham marked this conversation as resolved.
Outdated
dynamic_config: Optional[dict] = None,
skip_modules: List[str] = ["lm_head"],
**kwargs,
Expand All @@ -239,9 +230,6 @@ def __init__(
**{
"nbits": nbits,
"group_size": group_size,
"quant_zero": quant_zero,
"quant_scale": quant_scale,
"offload_meta": offload_meta,
"view_as_float": view_as_float,
"axis": axis,
}
Expand All @@ -258,12 +246,26 @@ def post_init(self):
"""
pass

@classmethod
def from_dict(cls, config: Dict[str, Any]):
"""
Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py
"""
instance = cls()
instance.quant_config = config["quant_config"]
instance.skip_modules = config["skip_modules"]
return instance

def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
return self.quant_config
return {
"quant_config": self.quant_config,
"quant_method": self.quant_method,
"skip_modules": self.skip_modules,
}

def __repr__(self):
config_dict = self.to_dict()
Expand Down
74 changes: 46 additions & 28 deletions tests/quantization/hqq/test_hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ def test_to_dict(self):
quantization_config = HqqConfig()
hqq_orig_config = quantization_config.to_dict()

for key in hqq_orig_config:
self.assertEqual(quantization_config.quant_config[key], hqq_orig_config[key])
self.assertEqual(quantization_config.quant_config, hqq_orig_config['quant_config'])


@slow
Expand All @@ -109,7 +108,7 @@ def test_fp16_quantized_model(self):
"""
Simple LLM model testing fp16
"""
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
quant_config = HqqConfig(nbits=8, group_size=64)

hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
Expand All @@ -118,26 +117,24 @@ def test_fp16_quantized_model(self):
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)

def test_f16_quantized_model_with_offloading(self):

@slow
@require_torch_gpu
@require_torch_multi_gpu
@require_accelerate
class HQQTestMultiGPU(unittest.TestCase):
def tearDown(self):
cleanup()

def test_fp16_quantized_model_multipgpu(self):
"""
Simple LLM model testing bfp16 with meta-data offloading
Simple LLM model testing fp16 with multi-gpu
"""
q4_config = {"nbits": 4, "group_size": 64, "quant_zero": False, "quant_scale": False}
q3_config = {"nbits": 3, "group_size": 32, "quant_zero": False, "quant_scale": False, "offload_meta": True}
quant_config = HqqConfig(
dynamic_config={
"self_attn.q_proj": q4_config,
"self_attn.k_proj": q4_config,
"self_attn.v_proj": q4_config,
"self_attn.o_proj": q4_config,
"mlp.gate_proj": q3_config,
"mlp.up_proj": q3_config,
"mlp.down_proj": q3_config,
}
)

quant_config = HqqConfig(nbits=8, group_size=64)

hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device="auto"
)

check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
Expand All @@ -146,22 +143,43 @@ def test_f16_quantized_model_with_offloading(self):

@slow
@require_torch_gpu
@require_torch_multi_gpu
@require_accelerate
class HQQTestMultiGPU(unittest.TestCase):
class HQQSerializationTest(unittest.TestCase):
def tearDown(self):
cleanup()

def test_fp16_quantized_model_multipgpu(self):
def test_model_serialization(self):
"""
Simple LLM model testing fp16 with multi-gpu
Simple HQQ LLM save/load test
"""

quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
quant_config = HqqConfig(nbits=4, group_size=64)

hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device="auto"
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)

check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
Comment thread
mobicham marked this conversation as resolved.
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device)

with torch.no_grad():
logits_ref = hqq_runner.model.forward(input_tensor).logits

#Save
saved_model_id = 'quant_model'
hqq_runner.model.save_pretrained(saved_model_id)

#Remove old model
del hqq_runner.model
torch.cuda.empty_cache()

#Load and check if the logits match
model_loaded = AutoModelForCausalLM.from_pretrained(
'quant_model',
torch_dtype=torch.float16,
device_map=torch_device,
low_cpu_mem_usage=True
)

with torch.no_grad():
logits_loaded = model_loaded.forward(input_tensor).logits

self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)