Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@

XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning."


if is_accelerate_available():
Expand Down Expand Up @@ -693,17 +692,30 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
renamed_keys = {}
renamed_gamma = {}
renamed_beta = {}
warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` "
for key in state_dict.keys():
new_key = None
if "gamma" in key:
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight"))
# We add only the first key as an example
new_key = key.replace("gamma", "weight")
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
if "beta" in key:
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias"))
# We add only the first key as an example
new_key = key.replace("beta", "bias")
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
if new_key:
old_keys.append(key)
new_keys.append(new_key)
renamed_keys = {**renamed_gamma, **renamed_beta}
if renamed_keys:
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.items():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

Expand Down Expand Up @@ -819,6 +831,7 @@ def _load_state_dict_into_meta_model(
is_safetensors=False,
keep_in_fp32_modules=None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
Expand All @@ -841,18 +854,30 @@ def _load_state_dict_into_meta_model(

old_keys = []
new_keys = []
renamed_gamma = {}
renamed_beta = {}
is_quantized = hf_quantizer is not None
warning_msg = f"This model {type(model)}"
for key in state_dict.keys():
new_key = None
if "gamma" in key:
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight"))
# We add only the first key as an example
new_key = key.replace("gamma", "weight")
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
if "beta" in key:
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias"))
# We add only the first key as an example
new_key = key.replace("beta", "bias")
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
if new_key:
old_keys.append(key)
new_keys.append(new_key)
renamed_keys = {**renamed_gamma, **renamed_beta}
if renamed_keys:
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.items():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

Expand Down Expand Up @@ -4541,7 +4566,12 @@ def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=Fal

@staticmethod
def _load_pretrained_model_low_mem(
model, loaded_state_dict_keys, resolved_archive_file, start_prefix="", hf_quantizer=None
model,
loaded_state_dict_keys,
resolved_archive_file,
start_prefix="",
hf_quantizer=None,
pretrained_model_name_or_path=None,
):
"""
This is an experimental function that loads the model using ~1.x model size CPU memory
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,21 @@ def warning_once(self, *args, **kwargs):
logging.Logger.warning_once = warning_once


@functools.lru_cache(None)
def info_once(self, *args, **kwargs):
"""
This method is identical to `logger.info()`, but will emit the info with the same message only once

Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
another type of cache that includes the caller frame information in the hashing function.
"""
self.info(*args, **kwargs)


logging.Logger.info_once = info_once


class EmptyTqdm:
"""Dummy tqdm which doesn't do anything."""

Expand Down
10 changes: 6 additions & 4 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,17 +1640,18 @@ def forward(self):

logger = logging.get_logger("transformers.modeling_utils")
config = PretrainedConfig()
warning_msg_gamma = "A parameter name that contains `gamma` will be renamed internally"
warning_msg_gamma = "`gamma_param` -> `weight_param`"
model = TestModelGamma(config)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.WARNING):
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as cl1:
_, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True)

missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn("`TestModelGamma`", cl1.out)
self.assertIn(warning_msg_gamma, cl1.out)
self.assertIn("gamma_param", missing_keys)
self.assertIn("weight_param", unexpected_keys)
Expand All @@ -1664,17 +1665,18 @@ def __init__(self, config):
def forward(self):
return self.beta_param.sum()

warning_msg_beta = "A parameter name that contains `beta` will be renamed internally"
warning_msg_beta = "`beta_param` -> `bias_param`"
model = TestModelBeta(config)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.WARNING):
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as cl2:
_, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True)

missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn("`TestModelBeta`", cl2.out)
self.assertIn(warning_msg_beta, cl2.out)
self.assertIn("beta_param", missing_keys)
self.assertIn("bias_param", unexpected_keys)
Expand Down