-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[WIP] Hard error when ignoring tensors. #27484
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ae5cadb
92d1715
88571f3
d8e1ed1
7ee5bd0
6a47030
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,7 +29,7 @@ | |
| from contextlib import contextmanager | ||
| from dataclasses import dataclass | ||
| from functools import partial, wraps | ||
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||
| from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union | ||
| from zipfile import is_zipfile | ||
|
|
||
| import torch | ||
|
|
@@ -564,6 +564,65 @@ def set_initialized_submodules(model, state_dict_keys): | |
| return not_initialized_submodules | ||
|
|
||
|
|
||
| def _end_ptr(tensor: torch.Tensor) -> int: | ||
| # extract the end of the pointer if the tensor is a slice of a bigger tensor | ||
| if tensor.nelement(): | ||
| stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size() | ||
| else: | ||
| stop = tensor.data_ptr() | ||
| return stop | ||
|
|
||
|
|
||
| def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]: | ||
| filtered_tensors = [] | ||
| for shared in tensors: | ||
| if len(shared) < 2: | ||
| filtered_tensors.append(shared) | ||
| continue | ||
|
|
||
| areas = [] | ||
| for name in shared: | ||
| tensor = state_dict[name] | ||
| areas.append((tensor.data_ptr(), _end_ptr(tensor), name)) | ||
| areas.sort() | ||
|
|
||
| _, last_stop, last_name = areas[0] | ||
| filtered_tensors.append({last_name}) | ||
| for start, stop, name in areas[1:]: | ||
| if start >= last_stop: | ||
| filtered_tensors.append({name}) | ||
| else: | ||
| filtered_tensors[-1].add(name) | ||
|
Comment on lines
+591
to
+595
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here we merge the tensors that have the same ending in the filtered tensors
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. merged the tensors ? We add the name to a previous set, so that the list |
||
| last_stop = stop | ||
| disjoint_tensors = [] | ||
| shared_tensors = [] | ||
| for tensors in filtered_tensors: | ||
| if len(tensors) == 1: | ||
| disjoint_tensors.append(tensors.pop()) | ||
| else: | ||
| shared_tensors.append(tensors) | ||
| return shared_tensors, disjoint_tensors | ||
|
|
||
|
|
||
| def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. with this tensors that are not loaded on the same device would have a different entry in the set dict right?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not seem tested by the small dummy test (but is most probably tested by some models that do have this?)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, because Nothing in transformers could trigger this, only
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Identical was already previously tested by other tests. (Since it's a regular flow for shared tensors with The dummy test only tests the new path for disjoint tensors (which cannot be triggered by regular transformers code) |
||
| shared_tensors = [] | ||
| identical = [] | ||
| for shared in tensors: | ||
| if len(shared) < 2: | ||
| continue | ||
|
|
||
| areas = collections.defaultdict(set) | ||
| for name in shared: | ||
| tensor = state_dict[name] | ||
| area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor)) | ||
| areas[area].add(name) | ||
| if len(areas) == 1: | ||
| identical.append(shared) | ||
| else: | ||
| shared_tensors.append(shared) | ||
| return shared_tensors, identical | ||
|
|
||
|
|
||
| def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): | ||
| # Convert old format to new format if needed from a PyTorch state_dict | ||
| old_keys = [] | ||
|
|
@@ -2354,6 +2413,8 @@ def save_pretrained( | |
| # These are all the pointers of shared tensors. | ||
| shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} | ||
| warn_names = set() | ||
| error_names = set() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. error name is only updated once with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's just a sort of code comment, it's mostly that some names will trigger warn, some errors. Not strongly feeling here, just feels it's a bit clearer that there are different names. |
||
| to_delete_names = set() | ||
| for names in shared_ptrs.values(): | ||
| # Removing the keys which are declared as known duplicates on | ||
| # load. This allows to make sure the name which is kept is consistent. | ||
|
|
@@ -2364,25 +2425,42 @@ def save_pretrained( | |
| if matches_pattern and name in state_dict: | ||
| found += 1 | ||
| if found < len(names): | ||
| del state_dict[name] | ||
|
|
||
| # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. | ||
| # If the link between tensors was done at runtime then `from_pretrained` will not get | ||
| # the key back leading to random tensor. A proper warning will be shown | ||
| # during reload (if applicable), but since the file is not necessarily compatible with | ||
| # the config, better show a proper warning. | ||
| found = 0 | ||
| for name in names: | ||
| if name in state_dict: | ||
| found += 1 | ||
| if found > 1: | ||
| del state_dict[name] | ||
| warn_names.add(name) | ||
| to_delete_names.add(name) | ||
| # We are entering a place where the weights and the transformers configuration do NOT match. | ||
| shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict) | ||
| # Those are actually tensor sharing but disjoint from each other, we can safely clone them | ||
| # Reloaded won't have the same property, but it shouldn't matter in any meaningful way. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What property means in this context? |
||
| for name in disjoint_names: | ||
| state_dict[name] = state_dict[name].clone() | ||
|
|
||
| # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. | ||
| # If the link between tensors was done at runtime then `from_pretrained` will not get | ||
| # the key back leading to random tensor. A proper warning will be shown | ||
| # during reload (if applicable), but since the file is not necessarily compatible with | ||
| # the config, better show a proper warning. | ||
| shared_names, identical_names = _find_identical(shared_names, state_dict) | ||
| # delete tensors that have identical storage | ||
| for inames in identical_names: | ||
Narsil marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| known = inames.intersection(to_delete_names) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bit strange to me that we re-compute the intersection and difference with the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes every bit less entangled. Set calculations should be ridiculously fast compared to the actual saving/loading part. |
||
| for name in known: | ||
| del state_dict[name] | ||
| unknown = sorted(inames.difference(to_delete_names)) | ||
| for name in unknown[1:]: | ||
| del state_dict[name] | ||
| warn_names.add(name) | ||
|
|
||
| error_names.update(shared_names) | ||
|
|
||
| if len(warn_names) > 0: | ||
| logger.warning_once( | ||
| f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", | ||
| ) | ||
|
|
||
| if len(error_names) > 0: | ||
| raise RuntimeError( | ||
| f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.", | ||
| ) | ||
|
|
||
| # Shard the model if it is too big. | ||
| if not _hf_peft_config_loaded: | ||
| weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.