-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Disentangle auto modules from other modeling files #13023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6703745
ec43ac4
7c98c1d
7730772
19ec65b
da3ed1b
8084d5d
4b5139d
a562088
c3c2919
c7d387b
d678ce9
f525d61
d459b56
8a765b6
8bf8119
d026e5c
cb92f16
f99ad9e
9ecb54c
76b2665
133ab69
7162b1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,6 +37,7 @@ | |
| cpm, | ||
| ctrl, | ||
| deberta, | ||
| deberta_v2, | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lots of modules were missing here. |
||
| deit, | ||
| detr, | ||
| dialogpt, | ||
|
|
@@ -50,6 +51,8 @@ | |
| gpt2, | ||
| gpt_neo, | ||
| herbert, | ||
| hubert, | ||
| ibert, | ||
| layoutlm, | ||
| led, | ||
| longformer, | ||
|
|
@@ -58,6 +61,7 @@ | |
| m2m_100, | ||
| marian, | ||
| mbart, | ||
| mbart50, | ||
| megatron_bert, | ||
| mmbt, | ||
| mobilebert, | ||
|
|
@@ -82,6 +86,7 @@ | |
| vit, | ||
| wav2vec2, | ||
| xlm, | ||
| xlm_prophetnet, | ||
| xlm_roberta, | ||
| xlnet, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,11 +13,13 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Factory function to build auto-model classes.""" | ||
| import importlib | ||
| from collections import OrderedDict | ||
|
|
||
| from ...configuration_utils import PretrainedConfig | ||
| from ...file_utils import copy_func | ||
| from ...utils import logging | ||
| from .configuration_auto import AutoConfig, replace_list_option_in_docstrings | ||
| from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
@@ -415,7 +417,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc="" | |
| from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name) | ||
| from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example) | ||
| from_config.__doc__ = from_config_docstring | ||
| from_config = replace_list_option_in_docstrings(model_mapping, use_model_types=False)(from_config) | ||
| from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The internal attribute |
||
| cls.from_config = classmethod(from_config) | ||
|
|
||
| if name.startswith("TF"): | ||
|
|
@@ -431,7 +433,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc="" | |
| shortcut = checkpoint_for_example.split("/")[-1].split("-")[0] | ||
| from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut) | ||
| from_pretrained.__doc__ = from_pretrained_docstring | ||
| from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained) | ||
| from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained) | ||
| cls.from_pretrained = classmethod(from_pretrained) | ||
| return cls | ||
|
|
||
|
|
@@ -445,3 +447,79 @@ def get_values(model_mapping): | |
| result.append(model) | ||
|
|
||
| return result | ||
|
|
||
|
|
||
| def getattribute_from_module(module, attr): | ||
| if attr is None: | ||
| return None | ||
| if isinstance(attr, tuple): | ||
| return tuple(getattribute_from_module(module, a) for a in attr) | ||
| if hasattr(module, attr): | ||
| return getattr(module, attr) | ||
| # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the | ||
| # object at the top level. | ||
| transformers_module = importlib.import_module("transformers") | ||
| return getattribute_from_module(transformers_module, attr) | ||
|
Comment on lines
+459
to
+462
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part is mainly there to support use-cases like ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast")) |
||
|
|
||
|
|
||
| class _LazyAutoMapping(OrderedDict): | ||
| """ | ||
| " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed. | ||
|
|
||
| Args: | ||
|
|
||
| - config_mapping: The map model type to config class | ||
| - model_mapping: The map model type to model (or tokenizer) class | ||
| """ | ||
|
|
||
| def __init__(self, config_mapping, model_mapping): | ||
| self._config_mapping = config_mapping | ||
| self._reverse_config_mapping = {v: k for k, v in config_mapping.items()} | ||
| self._model_mapping = model_mapping | ||
| self._modules = {} | ||
|
|
||
| def __getitem__(self, key): | ||
| model_type = self._reverse_config_mapping[key.__name__] | ||
| if model_type not in self._model_mapping: | ||
| raise KeyError(key) | ||
| model_name = self._model_mapping[model_type] | ||
| return self._load_attr_from_module(model_type, model_name) | ||
|
|
||
| def _load_attr_from_module(self, model_type, attr): | ||
| module_name = model_type_to_module_name(model_type) | ||
| if module_name not in self._modules: | ||
| self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models") | ||
| return getattribute_from_module(self._modules[module_name], attr) | ||
|
|
||
| def keys(self): | ||
| return [ | ||
| self._load_attr_from_module(key, name) | ||
| for key, name in self._config_mapping.items() | ||
| if key in self._model_mapping.keys() | ||
| ] | ||
|
|
||
| def values(self): | ||
| return [ | ||
| self._load_attr_from_module(key, name) | ||
| for key, name in self._model_mapping.items() | ||
| if key in self._config_mapping.keys() | ||
| ] | ||
|
|
||
| def items(self): | ||
| return [ | ||
| ( | ||
| self._load_attr_from_module(key, self._config_mapping[key]), | ||
| self._load_attr_from_module(key, self._model_mapping[key]), | ||
| ) | ||
| for key in self._model_mapping.keys() | ||
| if key in self._config_mapping.keys() | ||
| ] | ||
|
|
||
| def __iter__(self): | ||
| return iter(self._mapping.keys()) | ||
|
|
||
| def __contains__(self, item): | ||
| if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: | ||
| return False | ||
| model_type = self._reverse_config_mapping[item.__name__] | ||
| return model_type in self._model_mapping | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is cleaner to have the mBART-50 tokenizers in their own folder.