Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/transformers/quantizers/quantizer_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def __init__(self, quantization_config: CompressedTensorsConfig, **kwargs):
self.compressor = ModelCompressor.from_compression_config(quantization_config)
self.run_compressed = quantization_config.run_compressed
self.quantization_config = quantization_config
self.weight_norm = None
# temporarily workaround to skip weight_norm key fixing
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
self.weight_norm = torch.nn.utils.parametrizations.weight_norm
delattr(torch.nn.utils.parametrizations, "weight_norm")

def validate_environment(self, *args, **kwargs):
if not is_compressed_tensors_available():
Expand Down Expand Up @@ -82,6 +87,9 @@ def _process_model_before_weight_loading(self, model, **kwargs):

def _process_model_after_weight_loading(self, model, **kwargs):
"""Decompress loaded model if necessary - need for qat"""
# enable post loading
if self.weight_norm:
setattr(torch.nn.utils.parametrizations, "weight_norm", self.weight_norm)

if (self.is_quantization_compressed and not self.run_compressed) or self.is_sparsification_compressed:
config = kwargs.get("config", None)
Expand Down