Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4015,8 +4015,11 @@ def save_pretrained(
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)

metadata = {}
if hf_quantizer is not None:
state_dict = hf_quantizer.get_state_dict(self)
state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self, safe_serialization)
metadata["format"] = "pt"

# Only save the model itself if we are using distributed training
model_to_save = unwrap_model(self)
# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
Expand Down Expand Up @@ -4294,7 +4297,7 @@ def save_pretrained(
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
else:
save_function(shard, os.path.join(save_directory, shard_file))

Expand Down
6 changes: 3 additions & 3 deletions src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,9 @@ def is_compileable(self) -> bool:
"""Flag indicating whether the quantized model can be compiled"""
return False

def get_state_dict(self, model):
"""Get state dict. Useful when we need to modify a bit the state dict due to quantization"""
return None
def get_state_dict_and_metadata(self, model, safe_serialization=False):
"""Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization"""
return None, {}

@abstractmethod
def _process_model_before_weight_loading(self, model, **kwargs): ...
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/quantizers/quantizer_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def update_param_name(self, param_name: str) -> str:
return param_name.replace("down_proj", "down_proj_blocks")
return param_name

def get_state_dict(self, model):
def get_state_dict_and_metadata(self, model):
from ..integrations import Mxfp4GptOssExperts

state_dict = model.state_dict()
Expand Down Expand Up @@ -411,7 +411,8 @@ def get_state_dict(self, model):
).transpose(-1, -2)
)

return state_dict
metadata = {}
return state_dict, metadata

def is_serializable(self, safe_serialization=None):
return True
Expand Down