Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions hqq/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Written by Dr. Hicham Badri @Mobius Labs GmbH - 2023
#####################################################
import os
import json
import torch
from torch import nn
from torch import float16
Expand Down Expand Up @@ -235,6 +236,17 @@ def cache_model(cls, model, save_dir: str):
@classmethod
def get_config_file(cls, save_dir: str) -> str:
return pjoin(save_dir, "config.json")

@classmethod
def get_patch_params_file(cls, save_dir: str) -> str:
return pjoin(save_dir, "quantize_config.json")

@classmethod
def load_patch_params(cls, save_dir: str) -> dict:
if not os.path.isfile(cls.get_patch_params_file(save_dir)):
return None
with open(cls.get_patch_params_file(save_dir)) as f:
return json.load(f)

@classmethod
def get_weight_file(cls, save_dir: str) -> str:
Expand Down Expand Up @@ -392,8 +404,17 @@ def _patch_other(layer):

model.hqq_quantized = True

model.patch_params = patch_params

return model

# Save model patch params
@classmethod
def save_patch_params(cls, model, save_dir: str):

with open(cls.get_patch_params_file(save_dir), 'w') as f:
json.dump(model.patch_params, f, indent=2)

# Prepares model weights by iterating through modules. It might some parameters that are NOT modules like model.param1
@classmethod
def serialize_weights(cls, model, verbose: bool = False) -> dict:
Expand Down Expand Up @@ -421,6 +442,9 @@ def save_quantized(cls, model, save_dir: str, verbose: bool = False):
# Save config
cls.cache_model(model, save_dir)

# Save patch params
cls.save_patch_params(model, save_dir)

# Serialization
weights = cls.serialize_weights(model, verbose=verbose)

Expand All @@ -445,6 +469,8 @@ def try_snapshot_download(
raise Exception("Weight file missing. Check your cache directory.")
if not os.path.exists(cls.get_config_file(save_dir)):
raise Exception("Config file missing. Check your cache directory.")
if not os.path.exists(cls.get_patch_params_file(save_dir)):
raise Exception("Quantize config file missing. Check your cache directory.")

return save_dir

Expand Down Expand Up @@ -476,6 +502,9 @@ def from_quantized(
# Name the layers
cls.setup_model(model)

# Load patch params
patch_params = cls.load_patch_params(save_dir)

# Load weights
try:
weights = cls.load_weights(save_dir)
Expand All @@ -485,15 +514,15 @@ def from_quantized(

# load_state_dict() doesn't work with modules initialized with init_empty_weights(), so we need to do this manually
@torch.no_grad()
def _load_module(module, params=None):
def _load_module(module, patch_params=None):
if module.name not in weights:
return module.to(device=device, dtype=compute_dtype, non_blocking=True)

state_dict = weights[module.name]
if "W_q" in state_dict:
module = HQQLinear(
linear_layer=None,
quant_config=None,
quant_config=patch_params,
compute_dtype=compute_dtype,
device=device,
)
Expand All @@ -515,14 +544,16 @@ def _load_module(module, params=None):

# Load modules
cls.patch_model(
model, _load_module, _load_module, {k: None for k in model.linear_tags}
model, _load_module, _load_module, {k: patch_params.get(k, None) if patch_params else None for k in model.linear_tags}
)

# Load other weights that are not part of any module
cls.post_module_load(model, weights)

model.hqq_quantized = True

model.patch_params = patch_params

# Set base class
model.base_class = cls

Expand Down