Skip to content

Add HQQ model serialization support#32056

Closed
mobicham wants to merge 2 commits intohuggingface:mainfrom
mobiusml:main
Closed

Add HQQ model serialization support#32056
mobicham wants to merge 2 commits intohuggingface:mainfrom
mobiusml:main

Conversation

@mobicham
Copy link
Contributor

What does this PR do?

The goal of this PR is to add support for saving/loading transformers models quantized via HQQ.
Currently, it's not possible to directly save/load HQQ-quantized models via model.save_pretrained. The workaround is to use the hqq lib which offers serialization but without safetensors support.

The first thing to do was making the HQQLinear.state_dict compatible with safetensors: dropbox/hqq@74bbe01
Since safetensors only support torch.Tensor dtype and no nested dictionaries are allowed, I had to implement an encoding/decoding logic for the parameters, which works fine for the moment.

The second step was integrating the new HQQLinear state_dict loading with the transformers lib, which I found a bit tricky because of the following issues:

  • Some state_dict parameters are split across 2 safetensors files. So loading fails for that specific layer since I can only get the parameters from the first file. For now, I put a fake quantized layers so that loading doesn't break but you'll see a message.
  • dispatch_model breaks with HQQLinear module, I am not sure what's the issue here. Before it reaches that function, the layer is already loaded from the state_dict. If I skip the dispatching function, I a get a model with meta device for all the layers except HQQLinear

Since this is a draft pull-request, I am commenting out some stuff while debugging, will fix that later (you'll see there's a TODO @mobicham comment so I can get back to it).

Will need to drop support for quantized scale/zero, so I will need to update the HqqConfig and the documentation, as well as add tests later.

Any help is highly appreciated , thank you!

Who can review?

@SunMarc

Example

import torch, gc
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig

quant_config  = HqqConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, axis=1)

model = AutoModelForCausalLM.from_pretrained(
    'meta-llama/Meta-Llama-3-8B', 
    torch_dtype=torch.float16, 
    cache_dir='/nas/hicham/tmp/',
    device_map="cuda:0", 
    quantization_config=quant_config
)
 
#Test 
with torch.no_grad():
	out = model.forward(torch.zeros((1, 8), dtype=torch.int32, device='cuda:0'))


# Save
model.save_pretrained("llama3-hqq")
del model
torch.cuda.empty_cache(); gc.collect();

#Load
model_loaded = AutoModelForCausalLM.from_pretrained(
    'llama3-hqq', 
    torch_dtype=torch.float16, 
    cache_dir='.',
    device_map="cuda:0")

## Does't work here because of `dispatch_model()`

with torch.no_grad():
	out = model_loaded.forward(torch.zeros((1, 8), dtype=torch.int32, device='cuda:0'))

@amyeroberts
Copy link
Contributor

Gentle ping @SunMarc

@mobiusml mobiusml closed this by deleting the head repository Aug 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants