Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions utils/create_dummy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import argparse
import collections.abc
import copy
import importlib
import inspect
import json
import os
Expand Down Expand Up @@ -67,6 +66,9 @@
if not is_tf_available():
raise ValueError("Please install TensorFlow.")

from get_test_info import get_model_to_tester_mapping, get_tester_classes_for_model # noqa E402

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't this be with the other imports above?
Note: can't comment on the sys.path(".") above but the current folder is always in the path, so it's not necessary.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger In a few files under utilis, if there are statements like importlib.import_module("tests/models/..."), it requires adding

# This is required to make the module import works (when the python process is running from the root of the repo)
sys.path.append(".")

Otherwise, I get ModuleNotFoundError: No module named 'tests'.

The file utils/create_dummy_models.py had this importlib.import_module(...) before this PR, so it required sys.path.append("."). I can remove it from this PR.

(I think I should try to figure out why we need it in the case mentioned above - but if you happen to know, please teach me a lesson 🙏 )

@ydshieh ydshieh Mar 16, 2023

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't this be with the other imports above?

Yes, it works! Thanks a lot

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, that's why it's there then. Could you make sure the comment is added so we now why this sys.path.apoend is there?

@ydshieh ydshieh Mar 16, 2023

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's no longer needed for this file create_dummy_models.py (as I re-used the method in another file), and I removed it.

In another file (get_test_info.py), I have a comment

# This is required to make the module import works (when the python process is running from the root of the repo)



FRAMEWORKS = ["pytorch", "tensorflow"]
INVALID_ARCH = []
TARGET_VOCAB_SIZE = 1024
Expand Down Expand Up @@ -94,8 +96,12 @@
"TFCamembertModel",
"TFCamembertForCausalLM",
"DecisionTransformerModel",
"GraphormerModel",
"InformerModel",
Comment thread
ydshieh marked this conversation as resolved.
"JukeboxModel",
"MarianForCausalLM",
"MaskFormerSwinModel",
"MaskFormerSwinBackbone",
Comment thread
ydshieh marked this conversation as resolved.
"MT5Model",
"MT5ForConditionalGeneration",
"TFMT5ForConditionalGeneration",
Expand Down Expand Up @@ -126,6 +132,7 @@
"XLMRobertaForQuestionAnswering",
"TFXLMRobertaForSequenceClassification",
"TFXLMRobertaForMaskedLM",
"TFXLMRobertaForCausalLM",

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot this one is a previous PR

"TFXLMRobertaForQuestionAnswering",
"TFXLMRobertaModel",
"TFXLMRobertaForMultipleChoice",
Expand Down Expand Up @@ -355,7 +362,7 @@ def build_processor(config_class, processor_class, allow_no_checkpoint=False):
return processor


def get_tiny_config(config_class, **model_tester_kwargs):
def get_tiny_config(config_class, model_class=None, **model_tester_kwargs):
"""Retrieve a tiny configuration from `config_class` using each model's `ModelTester`.

Args:
Expand All @@ -378,9 +385,18 @@ def get_tiny_config(config_class, **model_tester_kwargs):
module_name = model_type_to_module_name(model_type)
if not modeling_name.startswith(module_name):
raise ValueError(f"{modeling_name} doesn't start with {module_name}!")
module = importlib.import_module(f".models.{module_name}.test_modeling_{modeling_name}", package="tests")
camel_case_model_name = config_class.__name__.split("Config")[0]
model_tester_class = getattr(module, f"{camel_case_model_name}ModelTester", None)
Comment thread
ydshieh marked this conversation as resolved.
test_file = os.path.join("tests", "models", module_name, f"test_modeling_{modeling_name}.py")
models_to_model_testers = get_model_to_tester_mapping(test_file)
# Find the model tester class
model_tester_class = None
tester_classes = []
if model_class is not None:
tester_classes = get_tester_classes_for_model(test_file, model_class)
else:
for _tester_classes in models_to_model_testers.values():
tester_classes.extend(_tester_classes)
if len(tester_classes) > 0:
model_tester_class = sorted(tester_classes, key=lambda x: x.__name__)[0]
Comment thread
ydshieh marked this conversation as resolved.
except ModuleNotFoundError:
error = f"Tiny config not created for {model_type} - cannot find the testing module from the model name."
raise ValueError(error)
Expand Down