Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,17 +1409,42 @@ def assert_screenout(out, what):


def set_model_tester_for_less_flaky_test(test_case):
if hasattr(test_case.model_tester, "num_hidden_layers"):
test_case.model_tester.num_hidden_layers = 1
target_num_hidden_layers = 1
# TODO (if possible): Avoid exceptional cases
exceptional_classes = [
"ZambaModelTester",
"RwkvModelTester",
"AriaVisionText2TextModelTester",
"GPTNeoModelTester",
"DPTModelTester",
]
if test_case.model_tester.__class__.__name__ in exceptional_classes:
target_num_hidden_layers = None
if hasattr(test_case.model_tester, "out_features") or hasattr(test_case.model_tester, "out_indices"):
target_num_hidden_layers = None
Comment on lines +1423 to +1424
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some vision models is hard to adjust the number of layers as sometimes other parameters have to be changed at the same time


if hasattr(test_case.model_tester, "num_hidden_layers") and target_num_hidden_layers is not None:
test_case.model_tester.num_hidden_layers = target_num_hidden_layers
if (
hasattr(test_case.model_tester, "vision_config")
and "num_hidden_layers" in test_case.model_tester.vision_config
and target_num_hidden_layers is not None
):
test_case.model_tester.vision_config = copy.deepcopy(test_case.model_tester.vision_config)
test_case.model_tester.vision_config["num_hidden_layers"] = 1
if hasattr(test_case.model_tester, "text_config") and "num_hidden_layers" in test_case.model_tester.text_config:
test_case.model_tester.vision_config["num_hidden_layers"] = target_num_hidden_layers
if (
hasattr(test_case.model_tester, "text_config")
and "num_hidden_layers" in test_case.model_tester.text_config
and target_num_hidden_layers is not None
):
test_case.model_tester.text_config = copy.deepcopy(test_case.model_tester.text_config)
test_case.model_tester.text_config["num_hidden_layers"] = 1
test_case.model_tester.text_config["num_hidden_layers"] = target_num_hidden_layers

# A few model class specific handling

# For Albert
if hasattr(test_case.model_tester, "num_hidden_groups"):
test_case.model_tester.num_hidden_groups = test_case.model_tester.num_hidden_layers


def set_config_for_less_flaky_test(config):
Expand Down
1 change: 0 additions & 1 deletion tests/models/upernet/test_modeling_upernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def __init__(
self.out_features = out_features
self.num_labels = num_labels
self.scope = scope
self.num_hidden_layers = num_stages

def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
Expand Down
4 changes: 4 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,10 @@ def recursive_check(batched_object, single_row_object, model_name, key):
),
)

set_model_tester_for_less_flaky_test(self)

config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()
set_config_for_less_flaky_test(config)
equivalence = get_tensor_equivalence_function(batched_input)

for model_class in self.all_model_classes:
Expand All @@ -827,6 +830,7 @@ def recursive_check(batched_object, single_row_object, model_name, key):
config, batched_input = self.model_tester.prepare_config_and_inputs_for_model_class(model_class)
batched_input_prepared = self._prepare_for_class(batched_input, model_class)
model = model_class(config).to(torch_device).eval()
set_model_for_less_flaky_test(model)

batch_size = self.model_tester.batch_size
single_row_input = {}
Expand Down
Loading