diff --git a/utils/check_repo.py b/utils/check_repo.py index d05cbf8326e8..291101ec3e12 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -49,6 +49,7 @@ "test_modeling_mt5.py", "test_modeling_pegasus.py", "test_modeling_tf_camembert.py", + "test_modeling_tf_mt5.py", "test_modeling_tf_xlm_roberta.py", "test_modeling_xlm_prophetnet.py", "test_modeling_xlm_roberta.py", @@ -62,7 +63,6 @@ "T5Stack", # Building part of bigger (tested) model. "TFDPREncoder", # Building part of bigger (documented) model. "TFDPRSpanPredictor", # Building part of bigger (documented) model. - "TFElectraMainLayer", # Building part of bigger (documented) model (should it be a TFPreTrainedModel ?) ] # Update this dict with any special correspondance model name (used in modeling_xxx.py) to doc file. @@ -135,11 +135,15 @@ def get_model_modules(): "modeling_tf_transfo_xl_utilities", ] modules = [] - for attr_name in dir(transformers): - if attr_name.startswith("modeling") and attr_name not in _ignore_modules: - module = getattr(transformers, attr_name) - if inspect.ismodule(module): - modules.append(module) + for model in dir(transformers.models): + # There are some magic dunder attributes in the dir, we ignore them + if not model.startswith("__"): + model_module = getattr(transformers.models, model) + for submodule in dir(model_module): + if submodule.startswith("modeling") and submodule not in _ignore_modules: + modeling_module = getattr(model_module, submodule) + if inspect.ismodule(modeling_module): + modules.append(modeling_module) return modules @@ -244,7 +248,7 @@ def check_all_models_are_tested(): test_files = get_model_test_files() failures = [] for module in modules: - test_file = f"test_{module.__name__.split('.')[1]}.py" + test_file = f"test_{module.__name__.split('.')[-1]}.py" if test_file not in test_files: failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.") new_failures = check_models_are_tested(module, test_file) @@ -279,9 +283,9 @@ def check_models_are_documented(module, doc_file): def _get_model_name(module): """ Get the model name for the module defining it.""" - splits = module.__name__.split("_") + module_name = module.__name__.split(".")[-1] + splits = module_name.split("_") splits = splits[(2 if splits[1] in ["flax", "tf"] else 1) :] - return "_".join(splits)