Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/model-templates.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
- name: Run style changes
run: |
git fetch origin master:master
make fixup
make style && make quality

- name: Failure short reports
if: ${{ always() }}
Expand Down
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ deps_table_check_updated:
# autogenerating code

autogenerate_code: deps_table_update
python utils/class_mapping_update.py

# Check that source code meets quality standards

Expand Down
8 changes: 5 additions & 3 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
"models.marian": ["MarianConfig"],
"models.mbart": ["MBartConfig"],
"models.mbart50": [],
"models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
"models.mmbt": ["MMBTConfig"],
"models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"],
Expand Down Expand Up @@ -315,7 +316,7 @@
_import_structure["models.m2m_100"].append("M2M100Tokenizer")
_import_structure["models.marian"].append("MarianTokenizer")
_import_structure["models.mbart"].append("MBartTokenizer")
_import_structure["models.mbart"].append("MBart50Tokenizer")
_import_structure["models.mbart50"].append("MBart50Tokenizer")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cleaner to have the mBART-50 tokenizers in their own folder.

_import_structure["models.mt5"].append("MT5Tokenizer")
_import_structure["models.pegasus"].append("PegasusTokenizer")
_import_structure["models.reformer"].append("ReformerTokenizer")
Expand Down Expand Up @@ -358,7 +359,7 @@
_import_structure["models.longformer"].append("LongformerTokenizerFast")
_import_structure["models.lxmert"].append("LxmertTokenizerFast")
_import_structure["models.mbart"].append("MBartTokenizerFast")
_import_structure["models.mbart"].append("MBart50TokenizerFast")
_import_structure["models.mbart50"].append("MBart50TokenizerFast")
_import_structure["models.mobilebert"].append("MobileBertTokenizerFast")
_import_structure["models.mpnet"].append("MPNetTokenizerFast")
_import_structure["models.mt5"].append("MT5TokenizerFast")
Expand Down Expand Up @@ -2021,7 +2022,8 @@
from .models.led import LEDTokenizerFast
from .models.longformer import LongformerTokenizerFast
from .models.lxmert import LxmertTokenizerFast
from .models.mbart import MBart50TokenizerFast, MBartTokenizerFast
from .models.mbart import MBartTokenizerFast
from .models.mbart50 import MBart50TokenizerFast
from .models.mobilebert import MobileBertTokenizerFast
from .models.mpnet import MPNetTokenizerFast
from .models.mt5 import MT5TokenizerFast
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/modelcard.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@
is_tokenizers_available,
is_torch_available,
)
from .training_args import ParallelMode
from .utils import logging
from .utils.modeling_auto_mapping import (
from .models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
Expand All @@ -54,6 +52,8 @@
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
)
from .training_args import ParallelMode
from .utils import logging


TASK_MAPPING = {
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
cpm,
ctrl,
deberta,
deberta_v2,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of modules were missing here.

deit,
detr,
dialogpt,
Expand All @@ -50,6 +51,8 @@
gpt2,
gpt_neo,
herbert,
hubert,
ibert,
layoutlm,
led,
longformer,
Expand All @@ -58,6 +61,7 @@
m2m_100,
marian,
mbart,
mbart50,
megatron_bert,
mmbt,
mobilebert,
Expand All @@ -82,6 +86,7 @@
vit,
wav2vec2,
xlm,
xlm_prophetnet,
xlm_roberta,
xlnet,
)
84 changes: 81 additions & 3 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory function to build auto-model classes."""
import importlib
from collections import OrderedDict

from ...configuration_utils import PretrainedConfig
from ...file_utils import copy_func
from ...utils import logging
from .configuration_auto import AutoConfig, replace_list_option_in_docstrings
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -415,7 +417,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""
from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name)
from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example)
from_config.__doc__ = from_config_docstring
from_config = replace_list_option_in_docstrings(model_mapping, use_model_types=False)(from_config)
from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The internal attribute _model_mapping contains the mapping model type to model class name. We use this to avoid importing all models when generating the docstring (which would defeat the purpose of this PR).

cls.from_config = classmethod(from_config)

if name.startswith("TF"):
Expand All @@ -431,7 +433,7 @@ def auto_class_update(cls, checkpoint_for_example="bert-base-cased", head_doc=""
shortcut = checkpoint_for_example.split("/")[-1].split("-")[0]
from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut)
from_pretrained.__doc__ = from_pretrained_docstring
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained)
cls.from_pretrained = classmethod(from_pretrained)
return cls

Expand All @@ -445,3 +447,79 @@ def get_values(model_mapping):
result.append(model)

return result


def getattribute_from_module(module, attr):
if attr is None:
return None
if isinstance(attr, tuple):
return tuple(getattribute_from_module(module, a) for a in attr)
if hasattr(module, attr):
return getattr(module, attr)
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
# object at the top level.
transformers_module = importlib.import_module("transformers")
return getattribute_from_module(transformers_module, attr)
Comment on lines +459 to +462
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is mainly there to support use-cases like ("ibert", ("RobertaTokenizer", "RobertaTokenizerFast"))



class _LazyAutoMapping(OrderedDict):
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.

Args:

- config_mapping: The map model type to config class
- model_mapping: The map model type to model (or tokenizer) class
"""

def __init__(self, config_mapping, model_mapping):
self._config_mapping = config_mapping
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
self._model_mapping = model_mapping
self._modules = {}

def __getitem__(self, key):
model_type = self._reverse_config_mapping[key.__name__]
if model_type not in self._model_mapping:
raise KeyError(key)
model_name = self._model_mapping[model_type]
return self._load_attr_from_module(model_type, model_name)

def _load_attr_from_module(self, model_type, attr):
module_name = model_type_to_module_name(model_type)
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
return getattribute_from_module(self._modules[module_name], attr)

def keys(self):
return [
self._load_attr_from_module(key, name)
for key, name in self._config_mapping.items()
if key in self._model_mapping.keys()
]

def values(self):
return [
self._load_attr_from_module(key, name)
for key, name in self._model_mapping.items()
if key in self._config_mapping.keys()
]

def items(self):
return [
(
self._load_attr_from_module(key, self._config_mapping[key]),
self._load_attr_from_module(key, self._model_mapping[key]),
)
for key in self._model_mapping.keys()
if key in self._config_mapping.keys()
]

def __iter__(self):
return iter(self._mapping.keys())

def __contains__(self, item):
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
return False
model_type = self._reverse_config_mapping[item.__name__]
return model_type in self._model_mapping
Loading