Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@

XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning."


if is_accelerate_available():
Expand Down Expand Up @@ -689,21 +688,41 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
return shared_tensors, identical


def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
def _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers=False, pretrained_model_name_or_path=None
):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
renamed_keys = {}
found_gamma = False
found_beta = False
warning_msg = "This model "
if pretrained_model_name_or_path is not None:
warning_msg += f"(`{pretrained_model_name_or_path}`) "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we need to include this in the error message - I'm generally against adding args to functions just for auxillary logic like logging

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on this we should not need to modify the function for logging purposes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can do type(model_to_load) instead here and be fine

warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for key in state_dict.keys():
new_key = None
if "gamma" in key:
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight"))
# We add only the first key as an example
if not found_gamma:
renamed_keys[key] = key.replace("gamma", "weight")
found_gamma = True
new_key = key.replace("gamma", "weight")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two nits:

  • We can use new_key rather than replace twice.
  • We don't need to keep track of both "found" and replaced keys

e.g. something like:

def _load_state_dict_into_model(
    model_to_load, state_dict, start_prefix, assign_to_params_buffers=False, pretrained_model_name_or_path=None
):
    # Convert old format to new format if needed from a PyTorch state_dict
    old_keys = []
    new_keys = []
    renamed_gamma = {}
    renamed_beta = {}
    warning_msg = "This model "
    if pretrained_model_name_or_path is not None:
        warning_msg += f"(`{pretrained_model_name_or_path}`) "
    warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
    for key in state_dict.keys():
        new_key = None
        if "gamma" in key:
            new_key = key.replace("gamma", "weight")
            # We add only the first key as an example
            renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
        if "beta" in key:
            new_key = key.replace("beta", "bias")
            # We add only the first key as an example
            renamed_beta[key] = new_key if not renamed_beta else renamed_beta
        if new_key:
            old_keys.append(key)
            new_keys.append(new_key)
    

    renamed_keys = renamed_gamma | renamed_beta
    if renamed_keys:
        warning_msg = "Model contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
        for old_key, new_key in renamed_keys.items():
            warning_msg += f"* `{old_key}` -> `{new_key}`\n"
        warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
        logger.info_once(warning_msg)
    
    for old_key, new_key in zip(old_keys, new_keys):
        state_dict[new_key] = state_dict.pop(old_key)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note we want this to be and not or, because it's important that the user knows if both gamma or beta are present. Otherwise went with this solution!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because it's important that the user knows if both gamma or beta are present.

Right, but I think in this case we'd still want to use or. If and is used, we only have when key contained both gamma and beta, but not if there were two separate values -- one key which contained gamma and one key which contained beta

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, I misspoke what I actually wound up doing here. In the end it's renamed_keys = {**renamed_gamma, **renamed_beta} to catch everything

if "beta" in key:
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias"))
# We add only the first key as an example
if not found_beta:
renamed_keys[key] = key.replace("beta", "bias")
found_beta = True
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
if renamed_keys:
for old_key, new_key in renamed_keys.items():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
Copy link
Contributor Author

@muellerzr muellerzr Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this logic can be refactored too. I can do so in a follow-up (we repeat it for meta model)

for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

Expand Down Expand Up @@ -819,6 +838,7 @@ def _load_state_dict_into_meta_model(
is_safetensors=False,
keep_in_fp32_modules=None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
Expand All @@ -841,18 +861,34 @@ def _load_state_dict_into_meta_model(

old_keys = []
new_keys = []
renamed_keys = {}
found_gamma = False
found_beta = False
warning_msg = "This model "
if pretrained_model_name_or_path is not None:
warning_msg += f"(`{pretrained_model_name_or_path}`) "
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
is_quantized = hf_quantizer is not None
for key in state_dict.keys():
new_key = None
if "gamma" in key:
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight"))
if not found_gamma:
renamed_keys[key] = key.replace("gamma", "weight")
found_gamma = True
new_key = key.replace("gamma", "weight")
if "beta" in key:
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias"))
if not found_beta:
renamed_keys[key] = key.replace("beta", "bias")
found_beta = True
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
if renamed_keys:
for old_key, new_key in renamed_keys.items():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

Expand Down Expand Up @@ -4352,14 +4388,15 @@ def _find_mismatched_keys(
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
pretrained_model_name_or_path=pretrained_model_name_or_path,
)
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
assign_to_params_buffers = check_support_param_buffer_assignment(
model_to_load, state_dict, start_prefix
)
error_msgs = _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers
model_to_load, state_dict, start_prefix, assign_to_params_buffers, pretrained_model_name_or_path
)

else:
Expand Down Expand Up @@ -4429,6 +4466,7 @@ def _find_mismatched_keys(
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
pretrained_model_name_or_path=pretrained_model_name_or_path,
)
error_msgs += new_error_msgs
else:
Expand Down Expand Up @@ -4541,7 +4579,12 @@ def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=Fal

@staticmethod
def _load_pretrained_model_low_mem(
model, loaded_state_dict_keys, resolved_archive_file, start_prefix="", hf_quantizer=None
model,
loaded_state_dict_keys,
resolved_archive_file,
start_prefix="",
hf_quantizer=None,
pretrained_model_name_or_path=None,
):
"""
This is an experimental function that loads the model using ~1.x model size CPU memory
Expand Down Expand Up @@ -4571,6 +4614,7 @@ def _load_pretrained_model_low_mem(
start_prefix,
expected_keys=expected_keys,
hf_quantizer=hf_quantizer,
pretrained_model_name_or_path=pretrained_model_name_or_path,
)
return error_msgs

Expand Down
15 changes: 15 additions & 0 deletions src/transformers/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,21 @@ def warning_once(self, *args, **kwargs):
logging.Logger.warning_once = warning_once


@functools.lru_cache(None)
def info_once(self, *args, **kwargs):
"""
This method is identical to `logger.info()`, but will emit the info with the same message only once

Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
another type of cache that includes the caller frame information in the hashing function.
"""
self.info(*args, **kwargs)


logging.Logger.info_once = info_once


class EmptyTqdm:
"""Dummy tqdm which doesn't do anything."""

Expand Down
8 changes: 4 additions & 4 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,12 +1640,12 @@ def forward(self):

logger = logging.get_logger("transformers.modeling_utils")
config = PretrainedConfig()
warning_msg_gamma = "A parameter name that contains `gamma` will be renamed internally"
warning_msg_gamma = "`gamma_param` -> `weight_param`"
model = TestModelGamma(config)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.WARNING):
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as cl1:
_, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True)

Expand All @@ -1664,12 +1664,12 @@ def __init__(self, config):
def forward(self):
return self.beta_param.sum()

warning_msg_beta = "A parameter name that contains `beta` will be renamed internally"
warning_msg_beta = "`beta_param` -> `bias_param`"
model = TestModelBeta(config)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.WARNING):
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as cl2:
_, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True)

Expand Down