Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions utils/check_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"test_modeling_mt5.py",
"test_modeling_pegasus.py",
"test_modeling_tf_camembert.py",
"test_modeling_tf_mt5.py",
"test_modeling_tf_xlm_roberta.py",
"test_modeling_xlm_prophetnet.py",
"test_modeling_xlm_roberta.py",
Expand All @@ -62,7 +63,6 @@
"T5Stack", # Building part of bigger (tested) model.
"TFDPREncoder", # Building part of bigger (documented) model.
"TFDPRSpanPredictor", # Building part of bigger (documented) model.
"TFElectraMainLayer", # Building part of bigger (documented) model (should it be a TFPreTrainedModel ?)
]

# Update this dict with any special correspondance model name (used in modeling_xxx.py) to doc file.
Expand Down Expand Up @@ -135,11 +135,15 @@ def get_model_modules():
"modeling_tf_transfo_xl_utilities",
]
modules = []
for attr_name in dir(transformers):
if attr_name.startswith("modeling") and attr_name not in _ignore_modules:
module = getattr(transformers, attr_name)
if inspect.ismodule(module):
modules.append(module)
for model in dir(transformers.models):
# There are some magic dunder attributes in the dir, we ignore them
if not model.startswith("__"):
model_module = getattr(transformers.models, model)
for submodule in dir(model_module):
if submodule.startswith("modeling") and submodule not in _ignore_modules:
modeling_module = getattr(model_module, submodule)
if inspect.ismodule(modeling_module):
modules.append(modeling_module)
return modules


Expand Down Expand Up @@ -244,7 +248,7 @@ def check_all_models_are_tested():
test_files = get_model_test_files()
failures = []
for module in modules:
test_file = f"test_{module.__name__.split('.')[1]}.py"
test_file = f"test_{module.__name__.split('.')[-1]}.py"
if test_file not in test_files:
failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.")
new_failures = check_models_are_tested(module, test_file)
Expand Down Expand Up @@ -279,9 +283,9 @@ def check_models_are_documented(module, doc_file):

def _get_model_name(module):
""" Get the model name for the module defining it."""
splits = module.__name__.split("_")
module_name = module.__name__.split(".")[-1]
splits = module_name.split("_")
splits = splits[(2 if splits[1] in ["flax", "tf"] else 1) :]

return "_".join(splits)


Expand Down