Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
from .dependency_versions_check import dep_version_check
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
Expand Down Expand Up @@ -124,6 +123,7 @@
from .utils import (
CONFIG_NAME,
WEIGHTS_NAME,
find_labels,
get_full_repo_name,
is_apex_available,
is_datasets_available,
Expand Down Expand Up @@ -495,11 +495,7 @@ def __init__(
self.current_flos = 0
self.hp_search_backend = None
self.use_tune_checkpoints = False
default_label_names = (
["start_positions", "end_positions"]
if type(self.model).__name__ in MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.values()
else ["labels"]
)
default_label_names = find_labels(self.model.__class__)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
PaddingStrategy,
TensorType,
cached_property,
find_labels,
is_tensor,
to_numpy,
to_py_obj,
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Generic utilities
"""

import inspect
from collections import OrderedDict, UserDict
from contextlib import ExitStack
from dataclasses import fields
Expand Down Expand Up @@ -289,3 +290,23 @@ def __enter__(self):

def __exit__(self, *args, **kwargs):
self.stack.__exit__(*args, **kwargs)


def find_labels(model_class):
"""
Find the labels used by a given model.

Args:
model_class (`type`): The class of the model.
"""
model_name = model_class.__name__
if model_name.startswith("TF"):
signature = inspect.signature(model_class.call)
elif model_name.startswith("Flax"):
signature = inspect.signature(model_class.__call__)
else:
signature = inspect.signature(model_class.forward)
if "QuestionAnswering" in model_name:
return [p for p in signature.parameters if "label" in p or p in ("start_positions", "end_positions")]
else:
return [p for p in signature.parameters if "label" in p]
39 changes: 35 additions & 4 deletions tests/utils/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@
RepositoryNotFoundError,
RevisionNotFoundError,
filename_to_url,
find_labels,
get_file_from_repo,
get_from_cache,
has_file,
hf_bucket_url,
is_flax_available,
is_tf_available,
is_torch_available,
)


Expand Down Expand Up @@ -158,24 +162,51 @@ def test_get_file_from_repo_local(self):
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))


class ContextManagerTests(unittest.TestCase):
class GenericUtilTests(unittest.TestCase):
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_no_context(self, mock_stdout):
def test_context_managers_no_context(self, mock_stdout):
with ContextManagers([]):
print("Transformers are awesome!")
# The print statement adds a new line at the end of the output
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")

@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_one_context(self, mock_stdout):
def test_context_managers_one_context(self, mock_stdout):
with ContextManagers([context_en()]):
print("Transformers are awesome!")
# The output should be wrapped with an English welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")

@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_two_context(self, mock_stdout):
def test_context_managers_two_context(self, mock_stdout):
with ContextManagers([context_fr(), context_en()]):
print("Transformers are awesome!")
# The output should be wrapped with an English and French welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")

def test_find_labels(self):
if is_torch_available():
from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification

self.assertEqual(find_labels(BertForSequenceClassification), ["labels"])
self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"])
self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"])

if is_tf_available():
from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification

self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"])
self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"])
self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"])

if is_flax_available():
# Flax models don't have labels
from transformers import (
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForSequenceClassification,
)

self.assertEqual(find_labels(FlaxBertForSequenceClassification), [])
self.assertEqual(find_labels(FlaxBertForPreTraining), [])
self.assertEqual(find_labels(FlaxBertForQuestionAnswering), [])