Add way to save quantize config and can be loaded again#93
Add way to save quantize config and can be loaded again#93fahadh4ilyas wants to merge 4 commits intodropbox:masterfrom
Conversation
|
Thanks a lot for the effort @fahadh4ilyas ! That is correct, as a temporary solution, there's this patching functions that adds a quant_config: https://github.com/mobiusml/hqq/blob/master/hqq/utils/patching.py#L29 There's an easy way to do this, without needing a separate json:
However, this is going in a different direction, will explain below. Current DirectionI am currently refactoring the whole serialization logic to make it compatible with safetensors. The goal is to be able to directly save/load HQQ-quantized nodels with HF transformers. For the moment, I added support for The way how it works right now is that import torch
compute_dtype = torch.float16
device = 'cuda:0'
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-8B', torch_dtype=compute_dtype, cache_dir='/nas/hicham/tmp/')
#Quantize
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear
from hqq.models.hf.base import AutoHQQHFModel
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1, offload_meta=False)
model = AutoHQQHFModel.quantize_model(model, quant_config=quant_config, compute_dtype=compute_dtype, device=device)
##########################################
#Safetensors save/load layer check
from safetensors import safe_open
from safetensors.torch import save_file
_state_dict = model.model.layers[0].self_attn.q_proj.state_dict()
save_file(_state_dict, "layer.safetensors")
_state_dict_loaded= {}
with safe_open("layer.safetensors", framework="pt") as f:
for key in f.keys():
_state_dict_loaded[key] = f.get_tensor(key)
#######################################
#Model save/load check (with hqq lib)
AutoHQQHFModel.save_quantized(model, 'llama3-hqq')
model_loaded = AutoHQQHFModel.from_quantized("llama3-hqq")
#quant_config loaded
print(model_loaded.model.layers[0].self_attn.q_proj.quant_config)Next step is to use this logic to save/load HQQ-quantized model with HF transformers. Then we can get back to supporting quantized scale/zero. Happy to hear suggestions from you regarding this! |
|
Doesn't safetensors support metadata? How about the meta and quant_config is put inside the metadata? |
|
Yeah I thought about it, but it will make things even more complicated, since it will require more work on the |
What do you mean by "will require more work on the By using metadata, we only split current |
|
hqq's I am trying to see what is the right way of doing this with @SunMarc For now, the logic is working just fine with |
|
I also tried loading a model saved with the previous version (https://huggingface.co/mobiuslabsgmbh/Llama-2-7b-chat-hf_4bitnogs_hqq) and it worked without any issue, which is good news for backward compatibility. |
|
Draft pull request here: huggingface/transformers#32056 |
|
Closing this since we are very close to full transformers serialization support: huggingface/transformers#33141 |
Because
quant_configis gone when you load model usingfrom_quantized. I tried to re-add thequant_confighere so then when we callprepare_for_inferencefor loaded quantized model, it will not crash becausequant_confignot found.