diff --git a/hqq/models/base.py b/hqq/models/base.py index 31bb625..a66aaf3 100755 --- a/hqq/models/base.py +++ b/hqq/models/base.py @@ -1,6 +1,7 @@ # Written by Dr. Hicham Badri @Mobius Labs GmbH - 2023 ##################################################### import os +import json import torch from torch import nn from torch import float16 @@ -235,6 +236,17 @@ def cache_model(cls, model, save_dir: str): @classmethod def get_config_file(cls, save_dir: str) -> str: return pjoin(save_dir, "config.json") + + @classmethod + def get_patch_params_file(cls, save_dir: str) -> str: + return pjoin(save_dir, "quantize_config.json") + + @classmethod + def load_patch_params(cls, save_dir: str) -> dict: + if not os.path.isfile(cls.get_patch_params_file(save_dir)): + return None + with open(cls.get_patch_params_file(save_dir)) as f: + return json.load(f) @classmethod def get_weight_file(cls, save_dir: str) -> str: @@ -392,8 +404,17 @@ def _patch_other(layer): model.hqq_quantized = True + model.patch_params = patch_params + return model + # Save model patch params + @classmethod + def save_patch_params(cls, model, save_dir: str): + + with open(cls.get_patch_params_file(save_dir), 'w') as f: + json.dump(model.patch_params, f, indent=2) + # Prepares model weights by iterating through modules. It might some parameters that are NOT modules like model.param1 @classmethod def serialize_weights(cls, model, verbose: bool = False) -> dict: @@ -421,6 +442,9 @@ def save_quantized(cls, model, save_dir: str, verbose: bool = False): # Save config cls.cache_model(model, save_dir) + # Save patch params + cls.save_patch_params(model, save_dir) + # Serialization weights = cls.serialize_weights(model, verbose=verbose) @@ -445,6 +469,8 @@ def try_snapshot_download( raise Exception("Weight file missing. Check your cache directory.") if not os.path.exists(cls.get_config_file(save_dir)): raise Exception("Config file missing. Check your cache directory.") + if not os.path.exists(cls.get_patch_params_file(save_dir)): + raise Exception("Quantize config file missing. Check your cache directory.") return save_dir @@ -476,6 +502,9 @@ def from_quantized( # Name the layers cls.setup_model(model) + # Load patch params + patch_params = cls.load_patch_params(save_dir) + # Load weights try: weights = cls.load_weights(save_dir) @@ -485,7 +514,7 @@ def from_quantized( # load_state_dict() doesn't work with modules initialized with init_empty_weights(), so we need to do this manually @torch.no_grad() - def _load_module(module, params=None): + def _load_module(module, patch_params=None): if module.name not in weights: return module.to(device=device, dtype=compute_dtype, non_blocking=True) @@ -493,7 +522,7 @@ def _load_module(module, params=None): if "W_q" in state_dict: module = HQQLinear( linear_layer=None, - quant_config=None, + quant_config=patch_params, compute_dtype=compute_dtype, device=device, ) @@ -515,7 +544,7 @@ def _load_module(module, params=None): # Load modules cls.patch_model( - model, _load_module, _load_module, {k: None for k in model.linear_tags} + model, _load_module, _load_module, {k: patch_params.get(k, None) if patch_params else None for k in model.linear_tags} ) # Load other weights that are not part of any module @@ -523,6 +552,8 @@ def _load_module(module, params=None): model.hqq_quantized = True + model.patch_params = patch_params + # Set base class model.base_class = cls