diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a11f7743ed8e..33ecceacb17a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4015,8 +4015,11 @@ def save_pretrained( repo_id = self._create_repo(repo_id, **kwargs) files_timestamps = self._get_files_timestamps(save_directory) + metadata = {} if hf_quantizer is not None: - state_dict = hf_quantizer.get_state_dict(self) + state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self, safe_serialization) + metadata["format"] = "pt" + # Only save the model itself if we are using distributed training model_to_save = unwrap_model(self) # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" @@ -4294,7 +4297,7 @@ def save_pretrained( if safe_serialization: # At some point we will need to deal better with save_function (used for TPU and other distributed # joyfulness), but for now this enough. - safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}) + safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) else: save_function(shard, os.path.join(save_directory, shard_file)) diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 653953abec0a..323faa9c17e2 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -338,9 +338,9 @@ def is_compileable(self) -> bool: """Flag indicating whether the quantized model can be compiled""" return False - def get_state_dict(self, model): - """Get state dict. Useful when we need to modify a bit the state dict due to quantization""" - return None + def get_state_dict_and_metadata(self, model, safe_serialization=False): + """Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization""" + return None, {} @abstractmethod def _process_model_before_weight_loading(self, model, **kwargs): ... diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index b9076007d38d..d0d370a11df6 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -379,7 +379,7 @@ def update_param_name(self, param_name: str) -> str: return param_name.replace("down_proj", "down_proj_blocks") return param_name - def get_state_dict(self, model): + def get_state_dict_and_metadata(self, model): from ..integrations import Mxfp4GptOssExperts state_dict = model.state_dict() @@ -411,7 +411,8 @@ def get_state_dict(self, model): ).transpose(-1, -2) ) - return state_dict + metadata = {} + return state_dict, metadata def is_serializable(self, safe_serialization=None): return True