diff --git a/tests/test_modeling_luke.py b/tests/test_modeling_luke.py index 488e18df3979..36793945fdd1 100644 --- a/tests/test_modeling_luke.py +++ b/tests/test_modeling_luke.py @@ -31,6 +31,7 @@ LukeConfig, LukeModel, LukeEntityAwareAttentionModel, + LukeTokenizer ) from transformers.models.luke.modeling_luke import ( LUKE_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -184,12 +185,12 @@ def test_model_from_pretrained(self): self.assertIsNotNone(model) -def prepare_luke_batch_inputs(): +def prepare_luke_batch_inputs(tokenizer): # Taken from Open Entity dev set text = """Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon .""" span = (39,42) - ENTITY_TOKEN = '[ENT]' + ENTITY_TOKEN = '' max_mention_length = 30 conv_tables = ( @@ -206,10 +207,7 @@ def preprocess_and_tokenize(text, start, end=None): for a, b in conv_tables: target_text = target_text.replace(a, b) - if isinstance(tokenizer, RobertaTokenizer): - return tokenizer.tokenize(target_text, add_prefix_space=True) - else: - return tokenizer.tokenize(target_text) + return tokenizer.tokenize(target_text.strip(), add_prefix_space=True) tokens = [tokenizer.cls_token] tokens += preprocess_and_tokenize(text, 0, span[0]) @@ -244,7 +242,8 @@ class LukeModelIntegrationTests(unittest.TestCase): def test_inference_no_head(self): model = LukeEntityAwareAttentionModel.from_pretrained("nielsr/luke-large").to(torch_device) - encoding = prepare_luke_batch_inputs() + tokenizer = LukeTokenizer.from_pretrained("nielsr/luke-large") + encoding = prepare_luke_batch_inputs(tokenizer) # convert all values to PyTorch tensors for key, value in encoding.items(): encoding[key] = torch.as_tensor(encoding[key]).unsqueeze(0).to(torch_device) @@ -263,9 +262,9 @@ def test_inference_no_head(self): # Verify entity hidden states expected_shape = torch.Size((1, 2, 1024)) - self.assertEqual(outputs.entity_last_hidden_state.shape == expected_shape) + self.assertTrue(outputs.entity_last_hidden_state.shape == expected_shape) expected_slice = torch.tensor([[ 0.3251, 0.3981, -0.0689], [-0.0098, 0.1215, 0.3544]]) - - self.assertTrue(torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) \ No newline at end of file + + self.assertTrue(torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))