Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions tests/models/retribert/test_tokenization_retribert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
_is_punctuation,
_is_whitespace,
)
from transformers.testing_utils import require_tokenizers, slow
from transformers.testing_utils import require_tokenizers, require_torch, slow

from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english, merge_model_tokenizer_mappings


# Copied from transformers.tests.bert.test_modeling_bert.py with Bert->RetriBert
Expand Down Expand Up @@ -338,3 +338,47 @@ def test_change_tokenize_chinese_chars(self):
]
self.assertListEqual(tokens_without_spe_char_p, expected_tokens)
self.assertListEqual(tokens_without_spe_char_r, expected_tokens)

# RetriBertModel doesn't define `get_input_embeddings` and it's forward method doesn't take only the output of the tokenizer as input
@require_torch
@slow
def test_torch_encode_plus_sent_to_model(self):
import torch

from transformers import MODEL_MAPPING, TOKENIZER_MAPPING

MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)

tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):

if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
return

config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
config = config_class()

if config.is_encoder_decoder or config.pad_token_id is None:
return

model = model_class(config)

# The following test is different from the common's one
self.assertGreaterEqual(model.bert_query.get_input_embeddings().weight.shape[0], len(tokenizer))
Comment on lines +367 to +368
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the common tests it was:

                # Make sure the model contains at least the full vocabulary size in its embedding matrix
                is_using_common_embeddings = hasattr(model.get_input_embeddings(), "weight")
                if is_using_common_embeddings:
                    self.assertGreaterEqual(model.get_input_embeddings().weight.shape[0], len(tokenizer))


# Build sequence
first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
sequence = " ".join(first_ten_tokens)
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")

# Ensure that the BatchEncoding.to() method works.
encoded_sequence.to(model.device)

batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
# This should not fail

with torch.no_grad(): # saves some time
# The following lines are different from the common's ones
model.embed_questions(**encoded_sequence)
model.embed_questions(**batch_encoded_sequence)
Comment on lines +381 to +384
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the common test it was:

                with torch.no_grad():  # saves some time
                    model(**encoded_sequence)
                    model(**batch_encoded_sequence)