-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Reduce the error log when using core models that need their weights renamed, and provide a step forward #32656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -105,7 +105,6 @@ | |
|
|
||
| XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() | ||
| XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() | ||
| PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning." | ||
|
|
||
|
|
||
| if is_accelerate_available(): | ||
|
|
@@ -689,21 +688,41 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor] | |
| return shared_tensors, identical | ||
|
|
||
|
|
||
| def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False): | ||
| def _load_state_dict_into_model( | ||
| model_to_load, state_dict, start_prefix, assign_to_params_buffers=False, pretrained_model_name_or_path=None | ||
| ): | ||
| # Convert old format to new format if needed from a PyTorch state_dict | ||
| old_keys = [] | ||
| new_keys = [] | ||
| renamed_keys = {} | ||
| found_gamma = False | ||
| found_beta = False | ||
| warning_msg = "This model " | ||
| if pretrained_model_name_or_path is not None: | ||
| warning_msg += f"(`{pretrained_model_name_or_path}`) " | ||
| warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" | ||
| for key in state_dict.keys(): | ||
| new_key = None | ||
| if "gamma" in key: | ||
| logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight")) | ||
| # We add only the first key as an example | ||
| if not found_gamma: | ||
| renamed_keys[key] = key.replace("gamma", "weight") | ||
| found_gamma = True | ||
| new_key = key.replace("gamma", "weight") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two nits:
e.g. something like: def _load_state_dict_into_model(
model_to_load, state_dict, start_prefix, assign_to_params_buffers=False, pretrained_model_name_or_path=None
):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
renamed_gamma = {}
renamed_beta = {}
warning_msg = "This model "
if pretrained_model_name_or_path is not None:
warning_msg += f"(`{pretrained_model_name_or_path}`) "
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
# We add only the first key as an example
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
if "beta" in key:
new_key = key.replace("beta", "bias")
# We add only the first key as an example
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
if new_key:
old_keys.append(key)
new_keys.append(new_key)
renamed_keys = renamed_gamma | renamed_beta
if renamed_keys:
warning_msg = "Model contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.items():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note we want this to be and not or, because it's important that the user knows if both gamma or beta are present. Otherwise went with this solution!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Right, but I think in this case we'd still want to use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct, I misspoke what I actually wound up doing here. In the end it's |
||
| if "beta" in key: | ||
| logger.warning(PARAM_RENAME_WARNING.format("beta", "bias")) | ||
| # We add only the first key as an example | ||
| if not found_beta: | ||
| renamed_keys[key] = key.replace("beta", "bias") | ||
| found_beta = True | ||
| new_key = key.replace("beta", "bias") | ||
| if new_key: | ||
| old_keys.append(key) | ||
| new_keys.append(new_key) | ||
| if renamed_keys: | ||
| for old_key, new_key in renamed_keys.items(): | ||
| warning_msg += f"* `{old_key}` -> `{new_key}`\n" | ||
| warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." | ||
| logger.info_once(warning_msg) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this logic can be refactored too. I can do so in a follow-up (we repeat it for meta model) |
||
| for old_key, new_key in zip(old_keys, new_keys): | ||
| state_dict[new_key] = state_dict.pop(old_key) | ||
|
|
||
|
|
@@ -819,6 +838,7 @@ def _load_state_dict_into_meta_model( | |
| is_safetensors=False, | ||
| keep_in_fp32_modules=None, | ||
| unexpected_keys=None, # passing `unexpected` for cleanup from quantization items | ||
| pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys | ||
| ): | ||
| """ | ||
| This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its | ||
|
|
@@ -841,18 +861,34 @@ def _load_state_dict_into_meta_model( | |
|
|
||
| old_keys = [] | ||
| new_keys = [] | ||
| renamed_keys = {} | ||
| found_gamma = False | ||
| found_beta = False | ||
| warning_msg = "This model " | ||
| if pretrained_model_name_or_path is not None: | ||
| warning_msg += f"(`{pretrained_model_name_or_path}`) " | ||
| warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" | ||
| is_quantized = hf_quantizer is not None | ||
| for key in state_dict.keys(): | ||
| new_key = None | ||
| if "gamma" in key: | ||
| logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight")) | ||
| if not found_gamma: | ||
| renamed_keys[key] = key.replace("gamma", "weight") | ||
| found_gamma = True | ||
| new_key = key.replace("gamma", "weight") | ||
| if "beta" in key: | ||
| logger.warning(PARAM_RENAME_WARNING.format("beta", "bias")) | ||
| if not found_beta: | ||
| renamed_keys[key] = key.replace("beta", "bias") | ||
| found_beta = True | ||
| new_key = key.replace("beta", "bias") | ||
| if new_key: | ||
| old_keys.append(key) | ||
| new_keys.append(new_key) | ||
| if renamed_keys: | ||
| for old_key, new_key in renamed_keys.items(): | ||
| warning_msg += f"* `{old_key}` -> `{new_key}`\n" | ||
| warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." | ||
| logger.info_once(warning_msg) | ||
| for old_key, new_key in zip(old_keys, new_keys): | ||
| state_dict[new_key] = state_dict.pop(old_key) | ||
|
|
||
|
|
@@ -4352,14 +4388,15 @@ def _find_mismatched_keys( | |
| is_safetensors=is_safetensors, | ||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||
| unexpected_keys=unexpected_keys, | ||
| pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
| ) | ||
| else: | ||
| # Sharded checkpoint or whole but low_cpu_mem_usage==True | ||
| assign_to_params_buffers = check_support_param_buffer_assignment( | ||
| model_to_load, state_dict, start_prefix | ||
| ) | ||
| error_msgs = _load_state_dict_into_model( | ||
| model_to_load, state_dict, start_prefix, assign_to_params_buffers | ||
| model_to_load, state_dict, start_prefix, assign_to_params_buffers, pretrained_model_name_or_path | ||
| ) | ||
|
|
||
| else: | ||
|
|
@@ -4429,6 +4466,7 @@ def _find_mismatched_keys( | |
| is_safetensors=is_safetensors, | ||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||
| unexpected_keys=unexpected_keys, | ||
| pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
| ) | ||
| error_msgs += new_error_msgs | ||
| else: | ||
|
|
@@ -4541,7 +4579,12 @@ def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=Fal | |
|
|
||
| @staticmethod | ||
| def _load_pretrained_model_low_mem( | ||
| model, loaded_state_dict_keys, resolved_archive_file, start_prefix="", hf_quantizer=None | ||
| model, | ||
| loaded_state_dict_keys, | ||
| resolved_archive_file, | ||
| start_prefix="", | ||
| hf_quantizer=None, | ||
| pretrained_model_name_or_path=None, | ||
| ): | ||
| """ | ||
| This is an experimental function that loads the model using ~1.x model size CPU memory | ||
|
|
@@ -4571,6 +4614,7 @@ def _load_pretrained_model_low_mem( | |
| start_prefix, | ||
| expected_keys=expected_keys, | ||
| hf_quantizer=hf_quantizer, | ||
| pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
| ) | ||
| return error_msgs | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we need to include this in the error message - I'm generally against adding args to functions just for auxillary logic like logging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on this we should not need to modify the function for logging purposes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can do
type(model_to_load)instead here and be fine