diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a0c07b526cb5..df6945cb12dc 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -373,7 +373,6 @@ "LUKE_PRETRAINED_MODEL_ARCHIVE_LIST", "LukeLayer", "LukeModel", - "LukeEntityAwareAttentionModel", "LukePreTrainedModel", "LukeForEntityClassification", "LukeForEntityPairClassification", @@ -1734,7 +1733,6 @@ ) from .models.luke import ( LUKE_PRETRAINED_MODEL_ARCHIVE_LIST, - LukeEntityAwareAttentionModel, LukeForEntityTyping, LukeLayer, LukeModel, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 9bf422035b43..7f03ce4ba5c1 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -149,7 +149,7 @@ ) # Add modeling imports here -from ..luke.modeling_luke import LukeEntityAwareAttentionModel, LukeModel +from ..luke.modeling_luke import LukeModel from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel from ..marian.modeling_marian import MarianForCausalLM, MarianModel, MarianMTModel from ..mbart.modeling_mbart import ( diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 3151c0e971d1..04aa18304d1a 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -40,6 +40,7 @@ from ..layoutlm.tokenization_layoutlm import LayoutLMTokenizer from ..led.tokenization_led import LEDTokenizer from ..longformer.tokenization_longformer import LongformerTokenizer +from ..luke.tokenization_luke import LukeTokenizer from ..lxmert.tokenization_lxmert import LxmertTokenizer from ..mobilebert.tokenization_mobilebert import MobileBertTokenizer from ..mpnet.tokenization_mpnet import MPNetTokenizer @@ -78,6 +79,7 @@ LayoutLMConfig, LEDConfig, LongformerConfig, + LukeConfig, LxmertConfig, MarianConfig, MBartConfig, @@ -201,6 +203,7 @@ TOKENIZER_MAPPING = OrderedDict( [ + (LukeConfig, (LukeTokenizer, None)), (RetriBertConfig, (RetriBertTokenizer, RetriBertTokenizerFast)), (T5Config, (T5Tokenizer, T5TokenizerFast)), (MT5Config, (MT5Tokenizer, MT5TokenizerFast)), diff --git a/src/transformers/models/luke/__init__.py b/src/transformers/models/luke/__init__.py index 17a9772a389a..c54f8b086912 100644 --- a/src/transformers/models/luke/__init__.py +++ b/src/transformers/models/luke/__init__.py @@ -30,7 +30,6 @@ _import_structure["modeling_luke"] = [ "LUKE_PRETRAINED_MODEL_ARCHIVE_LIST", "LukeModel", - "LukeEntityAwareAttentionModel", "LukeForEntityClassification", "LukeForEntityPairClassification", "LukeForEntitySpanClassification", @@ -41,7 +40,7 @@ from .configuration_luke import LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, LukeConfig if is_torch_available(): - from .modeling_luke import LUKE_PRETRAINED_MODEL_ARCHIVE_LIST, LukeEntityAwareAttentionModel, LukeModel + from .modeling_luke import LUKE_PRETRAINED_MODEL_ARCHIVE_LIST, LukeModel else: import importlib diff --git a/src/transformers/models/luke/configuration_luke.py b/src/transformers/models/luke/configuration_luke.py index 7b353c30f700..50ac7fc1ead8 100644 --- a/src/transformers/models/luke/configuration_luke.py +++ b/src/transformers/models/luke/configuration_luke.py @@ -46,9 +46,17 @@ class LukeConfig(RobertaConfig): """ model_type = "luke" - def __init__(self, entity_vocab_size: int = 500000, entity_emb_size: int = 256, **kwargs): + def __init__( + self, + vocab_size: int = 50267, + entity_vocab_size: int = 500000, + entity_emb_size: int = 256, + use_entity_aware_attention=False, + **kwargs + ): """Constructs LukeConfig.""" - super(LukeConfig, self).__init__(**kwargs) + super(LukeConfig, self).__init__(vocab_size=vocab_size, **kwargs) self.entity_vocab_size = entity_vocab_size self.entity_emb_size = entity_emb_size + self.use_entity_aware_attention = use_entity_aware_attention diff --git a/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py index 3916cef29332..34d042a06a74 100644 --- a/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py @@ -20,7 +20,7 @@ import torch -from transformers import LukeConfig, LukeEntityAwareAttentionModel, LukeTokenizer, RobertaTokenizer +from transformers import LukeConfig, LukeModel, LukeTokenizer, RobertaTokenizer from transformers.tokenization_utils_base import AddedToken @@ -29,7 +29,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p # Load configuration defined in the metadata file with open(metadata_path) as metadata_file: metadata = json.load(metadata_file) - config = LukeConfig(**metadata["model_config"]) + config = LukeConfig(use_entity_aware_attention=True, **metadata["model_config"]) # Load in the weights from the checkpoint_path state_dict = torch.load(checkpoint_path, map_location="cpu") @@ -70,7 +70,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p entity_emb = state_dict["entity_embeddings.entity_embeddings.weight"] entity_emb[entity_vocab["[MASK2]"]] = entity_emb[entity_vocab["[MASK]"]] - model = LukeEntityAwareAttentionModel(config=config).eval() + model = LukeModel(config=config).eval() missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) assert len(missing_keys) == 1 and missing_keys[0] == "embeddings.position_ids" @@ -81,9 +81,7 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p text = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ." span = (39, 42) - encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True) - for key, value in encoding.items(): - encoding[key] = torch.as_tensor(encoding[key]).unsqueeze(0) + encoding = tokenizer(text, entity_spans=[span], add_prefix_space=True, return_tensors="pt") outputs = model(**encoding) @@ -97,8 +95,8 @@ def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, p expected_shape = torch.Size((1, 42, 768)) expected_slice = torch.tensor([[0.0037, 0.1368, -0.0091], [0.1099, 0.3329, -0.1095], [0.0765, 0.5335, 0.1179]]) - assert outputs.last_hidden_state.shape == expected_shape - assert torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) + assert outputs.word_last_hidden_state.shape == expected_shape + assert torch.allclose(outputs.word_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4) # Verify entity hidden states if model_size == "large": diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index df2a0ddb8aa6..a938c41f6883 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -22,9 +22,8 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from torch.nn import CrossEntropyLoss, MSELoss -from ...activations import ACT2FN, gelu +from ...activations import ACT2FN from ...file_utils import ( ModelOutput, add_code_sample_docstrings, @@ -35,13 +34,10 @@ from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, - MaskedLMOutput, ) from ...modeling_utils import ( PreTrainedModel, apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, ) from ...utils import logging from .configuration_luke import LukeConfig @@ -83,23 +79,37 @@ class BaseLukeModelOutputWithPooling(BaseModelOutputWithPooling): weighted average in the self-attention heads. """ + word_last_hidden_state: torch.FloatTensor = None entity_last_hidden_state: torch.FloatTensor = None + word_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None @dataclass -class BaseLukeEntityAwareAttentionModelOutputWithPooling(ModelOutput): +class BaseLukeModelOutput(BaseModelOutput): """ - Base class for entity-aware model's outputs with entity-aware self-attention mechanism. + Base class for model's outputs, with potential hidden states and attentions. Args: - last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length - max_entity_length, hidden_size)`): + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. - entity_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, max_entity_length, hidden_size)`): - Sequence of entity hidden-states at the output of the last layer of the model. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. """ - last_hidden_state: torch.FloatTensor = None + word_last_hidden_state: torch.FloatTensor = None entity_last_hidden_state: torch.FloatTensor = None + word_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None @dataclass @@ -129,6 +139,7 @@ class EntityClassificationOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None + word_hidden_states: Optional[Tuple[torch.FloatTensor]] = None entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -160,6 +171,7 @@ class EntityPairClassificationOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None + word_hidden_states: Optional[Tuple[torch.FloatTensor]] = None entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -191,6 +203,7 @@ class EntitySpanClassificationOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None + word_hidden_states: Optional[Tuple[torch.FloatTensor]] = None entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -215,7 +228,6 @@ def __init__(self, config): # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") # End copy self.padding_idx = config.pad_token_id @@ -224,13 +236,16 @@ def __init__(self, config): ) def forward( - self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, + self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, ): if position_ids is None: if input_ids is not None: # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = create_position_ids_from_input_ids( - input_ids, self.padding_idx).to(input_ids.device) + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device) else: position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) @@ -244,12 +259,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) + + position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + embeddings = inputs_embeds + position_embeddings + token_type_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -272,9 +286,9 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds): return position_ids.unsqueeze(0).expand(input_shape) -class EntityEmbeddings(nn.Module): +class LukeEntityEmbeddings(nn.Module): def __init__(self, config: LukeConfig): - super(EntityEmbeddings, self).__init__() + super(LukeEntityEmbeddings, self).__init__() self.config = config self.entity_embeddings = nn.Embedding(config.entity_vocab_size, config.entity_emb_size, padding_idx=0) @@ -289,7 +303,7 @@ def __init__(self, config: LukeConfig): def forward( self, entity_ids: torch.LongTensor, position_ids: torch.LongTensor, token_type_ids: torch.LongTensor = None - ): + ): if token_type_ids is None: token_type_ids = torch.zeros_like(entity_ids) @@ -325,16 +339,18 @@ def __init__(self, config): self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size + self.use_entity_aware_attention = config.use_entity_aware_attention self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) + if self.use_entity_aware_attention: + self.w2e_query = nn.Linear(config.hidden_size, self.all_head_size) + self.e2w_query = nn.Linear(config.hidden_size, self.all_head_size) + self.e2e_query = nn.Linear(config.hidden_size, self.all_head_size) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -343,42 +359,46 @@ def transpose_for_scores(self, x): def forward( self, - hidden_states, + word_hidden_states, + entity_hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, output_attentions=False, ): - mixed_query_layer = self.query(hidden_states) + word_size = word_hidden_states.size(1) - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None + if entity_hidden_states is None: + concat_hidden_states = word_hidden_states + else: + concat_hidden_states = torch.cat([word_hidden_states, entity_hidden_states], dim=1) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.transpose_for_scores(self.key(concat_hidden_states)) + value_layer = self.transpose_for_scores(self.value(concat_hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) + if self.use_entity_aware_attention and entity_hidden_states is not None: + w2w_query_layer = self.transpose_for_scores(self.query(word_hidden_states)) + w2e_query_layer = self.transpose_for_scores(self.w2e_query(word_hidden_states)) + e2w_query_layer = self.transpose_for_scores(self.e2w_query(entity_hidden_states)) + e2e_query_layer = self.transpose_for_scores(self.e2e_query(entity_hidden_states)) - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + w2w_key_layer = key_layer[:, :, :word_size, :] + e2w_key_layer = key_layer[:, :, :word_size, :] + w2e_key_layer = key_layer[:, :, word_size:, :] + e2e_key_layer = key_layer[:, :, word_size:, :] - if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) - position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + w2w_attention_scores = torch.matmul(w2w_query_layer, w2w_key_layer.transpose(-1, -2)) + w2e_attention_scores = torch.matmul(w2e_query_layer, w2e_key_layer.transpose(-1, -2)) + e2w_attention_scores = torch.matmul(e2w_query_layer, e2w_key_layer.transpose(-1, -2)) + e2e_attention_scores = torch.matmul(e2e_query_layer, e2e_key_layer.transpose(-1, -2)) - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) - relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) - attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + word_attention_scores = torch.cat([w2w_attention_scores, w2e_attention_scores], dim=3) + entity_attention_scores = torch.cat([e2w_attention_scores, e2e_attention_scores], dim=3) + attention_scores = torch.cat([word_attention_scores, entity_attention_scores], dim=2) + + else: + query_layer = self.transpose_for_scores(self.query(concat_hidden_states)) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: @@ -402,7 +422,16 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + output_word_hidden_states = context_layer[:, :word_size, :] + if entity_hidden_states is None: + output_entity_hidden_states = None + else: + output_entity_hidden_states = context_layer[:, word_size:, :] + + if output_attentions: + outputs = (output_word_hidden_states, output_entity_hidden_states, attention_probs) + else: + outputs = (output_word_hidden_states, output_entity_hidden_states) return outputs @@ -431,40 +460,45 @@ def __init__(self, config): self.pruned_heads = set() def prune_heads(self, heads): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices( - heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads - ) - - # Prune linear layers - self.self.query = prune_linear_layer(self.self.query, index) - self.self.key = prune_linear_layer(self.self.key, index) - self.self.value = prune_linear_layer(self.self.value, index) - self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - - # Update hyper params and store pruned heads - self.self.num_attention_heads = self.self.num_attention_heads - len(heads) - self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads - self.pruned_heads = self.pruned_heads.union(heads) + raise NotImplementedError("LUKE does not support the pruning of attention heads") def forward( self, - hidden_states, + word_hidden_states, + entity_hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, output_attentions=False, ): + word_size = word_hidden_states.size(1) self_outputs = self.self( - hidden_states, + word_hidden_states, + entity_hidden_states, attention_mask, head_mask, encoder_hidden_states, output_attentions, ) - attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + if entity_hidden_states is None: + concat_self_outputs = self_outputs[0] + concat_hidden_states = word_hidden_states + else: + concat_self_outputs = torch.cat(self_outputs[:2], dim=1) + concat_hidden_states = torch.cat([word_hidden_states, entity_hidden_states], dim=1) + + attention_output = self.output(concat_self_outputs, concat_hidden_states) + + word_attention_output = attention_output[:, :word_size, :] + if entity_hidden_states is None: + entity_attention_output = None + else: + entity_attention_output = attention_output[:, word_size:, :] + + outputs = (word_attention_output, entity_attention_output) + self_outputs[ + 2: + ] # add attentions if we output them + return outputs @@ -511,26 +545,39 @@ def __init__(self, config): def forward( self, - hidden_states, + word_hidden_states, + entity_hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, output_attentions=False, ): + word_size = word_hidden_states.size(1) + self_attention_outputs = self.attention( - hidden_states, + word_hidden_states, + entity_hidden_states, attention_mask, head_mask, output_attentions=output_attentions, ) - attention_output = self_attention_outputs[0] + if entity_hidden_states is None: + concat_attention_output = self_attention_outputs[0] + else: + concat_attention_output = torch.cat(self_attention_outputs[:2], dim=1) - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[2:] # add self attentions if we output attention weights layer_output = apply_chunking_to_forward( - self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, concat_attention_output ) - outputs = (layer_output,) + outputs + word_layer_output = layer_output[:, :word_size, :] + if entity_hidden_states is None: + entity_layer_output = None + else: + entity_layer_output = layer_output[:, word_size:, :] + + outputs = (word_layer_output, entity_layer_output) + outputs return outputs @@ -549,7 +596,8 @@ def __init__(self, config): def forward( self, - hidden_states, + word_hidden_states, + entity_hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, @@ -558,11 +606,20 @@ def forward( return_dict=True, ): all_hidden_states = () if output_hidden_states else None + all_word_hidden_states = () if output_hidden_states else None + all_entity_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None + if entity_hidden_states is None: + hidden_states = word_hidden_states + else: + hidden_states = torch.cat([word_hidden_states, entity_hidden_states], dim=1) + for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + all_word_hidden_states = all_word_hidden_states + (word_hidden_states,) + all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None if getattr(self.config, "gradient_checkpointing", False): @@ -575,26 +632,36 @@ def custom_forward(*inputs): layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), - hidden_states, + word_hidden_states, + entity_hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, ) else: layer_outputs = layer_module( - hidden_states, + word_hidden_states, + entity_hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, output_attentions, ) - hidden_states = layer_outputs[0] + word_hidden_states = layer_outputs[0] + if entity_hidden_states is None: + hidden_states = word_hidden_states + else: + entity_hidden_states = layer_outputs[1] + hidden_states = torch.cat([word_hidden_states, entity_hidden_states], dim=1) + if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_self_attentions = all_self_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + all_word_hidden_states = all_word_hidden_states + (word_hidden_states,) + all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,) if not return_dict: return tuple( @@ -603,13 +670,21 @@ def custom_forward(*inputs): hidden_states, all_hidden_states, all_self_attentions, + word_hidden_states, + entity_hidden_states, + all_word_hidden_states, + all_entity_hidden_states, ] if v is not None ) - return BaseModelOutput( + return BaseLukeModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, + word_last_hidden_state=word_hidden_states, + entity_last_hidden_state=entity_hidden_states, + word_hidden_states=all_word_hidden_states, + entity_hidden_states=all_entity_hidden_states, ) @@ -762,7 +837,7 @@ def __init__(self, config, add_pooling_layer=True): self.embeddings = LukeEmbeddings(config) self.embeddings.token_type_embeddings.requires_grad = False - self.entity_embeddings = EntityEmbeddings(config) + self.entity_embeddings = LukeEntityEmbeddings(config) self.encoder = LukeEncoder(config) self.pooler = LukePooler(config) if add_pooling_layer else None @@ -782,12 +857,7 @@ def set_entity_embeddings(self, value): self.entity_embeddings.entity_embeddings = value def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) + raise NotImplementedError("LUKE does not support the pruning of attention heads") @add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) # @add_code_sample_docstrings( @@ -839,17 +909,15 @@ def forward( device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length)), device=device) + attention_mask = torch.ones((batch_size, seq_length), device=device) if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - encoder_extended_attention_mask = None + if entity_ids is not None: + entity_seq_length = entity_ids.size(1) + if entity_attention_mask is None: + entity_attention_mask = torch.ones((batch_size, entity_seq_length), device=device) + if entity_token_type_ids is None: + entity_token_type_ids = torch.zeros((batch_size, entity_seq_length), dtype=torch.long, device=device) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -859,7 +927,7 @@ def forward( head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) # First, compute word embeddings - embedding_output = self.embeddings( + word_embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, @@ -868,14 +936,17 @@ def forward( # Second, compute extended attention mask extended_attention_mask = self._compute_extended_attention_mask(attention_mask, entity_attention_mask) + # Third, compute entity embeddings and concatenate with word embeddings - if entity_ids is not None: + if entity_ids is None: + entity_embedding_output = None + else: entity_embedding_output = self.entity_embeddings(entity_ids, entity_position_ids, entity_token_type_ids) - embedding_output = torch.cat([embedding_output, entity_embedding_output], dim=1) # Fourth, send embeddings through the model encoder_outputs = self.encoder( - embedding_output, + word_embedding_output, + entity_embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, @@ -889,38 +960,40 @@ def forward( # Sixth, we compute the pooled_output, word_sequence_output and entity_sequence_output based on the sequence_output pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - word_seq_len = input_ids.shape[1] - word_sequence_output = sequence_output[:, :word_seq_len, :] - entity_sequence_output = None - if entity_ids is not None: - entity_sequence_output = sequence_output[:, word_seq_len:, :] if not return_dict: return ( - word_sequence_output, - entity_sequence_output, + sequence_output, pooled_output, ) + encoder_outputs[1:] - return BaseLukeModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=word_sequence_output, - entity_last_hidden_state=entity_sequence_output, + return BaseLukeModelOutputWithPooling( + last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, + word_last_hidden_state=encoder_outputs.word_last_hidden_state, + entity_last_hidden_state=encoder_outputs.entity_last_hidden_state, + word_hidden_states=encoder_outputs.word_hidden_states, + entity_hidden_states=encoder_outputs.entity_hidden_states, ) def _compute_extended_attention_mask( - self, word_attention_mask: torch.LongTensor, entity_attention_mask: torch.LongTensor + self, word_attention_mask: torch.LongTensor, entity_attention_mask: Optional[torch.LongTensor] ): attention_mask = word_attention_mask if entity_attention_mask is not None: - attention_mask = torch.cat([attention_mask, entity_attention_mask], dim=1) - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + attention_mask = torch.cat([attention_mask, entity_attention_mask], dim=-1) + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError("Wrong shape for attention_mask (shape {})".format(attention_mask.shape)) + + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask @@ -940,158 +1013,6 @@ def create_position_ids_from_input_ids(input_ids, padding_idx): return incremental_indices.long() + padding_idx -@add_start_docstrings( - """ - The bare LUKE Model transformer with an entity aware self-attention mechanism outputting raw hidden-states for both - word tokens and entities without any specific head on top. - """, - LUKE_START_DOCSTRING, -) -class LukeEntityAwareAttentionModel(LukeModel): - def __init__(self, config): - super(LukeEntityAwareAttentionModel, self).__init__(config) - self.config = config - self.encoder = EntityAwareEncoder(config) - - def forward( - self, - input_ids, - token_type_ids, - attention_mask, - entity_ids, - entity_attention_mask, - entity_token_type_ids, - entity_position_ids, - return_dict=None, - ): - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - word_embeddings = self.embeddings(input_ids, token_type_ids) - entity_embeddings = self.entity_embeddings(entity_ids, entity_position_ids, entity_token_type_ids) - attention_mask = self._compute_extended_attention_mask(attention_mask, entity_attention_mask) - - word_hidden_states, entity_hidden_states = self.encoder(word_embeddings, entity_embeddings, attention_mask) - - if not return_dict: - return (word_hidden_states, entity_hidden_states) - - return BaseLukeEntityAwareAttentionModelOutputWithPooling( - last_hidden_state=word_hidden_states, - entity_last_hidden_state=entity_hidden_states, - ) - - -class EntityAwareSelfAttention(nn.Module): - def __init__(self, config): - super(EntityAwareSelfAttention, self).__init__() - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) - - self.w2e_query = nn.Linear(config.hidden_size, self.all_head_size) - self.e2w_query = nn.Linear(config.hidden_size, self.all_head_size) - self.e2e_query = nn.Linear(config.hidden_size, self.all_head_size) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - return x.view(*new_x_shape).permute(0, 2, 1, 3) - - def forward(self, word_hidden_states, entity_hidden_states, attention_mask): - word_size = word_hidden_states.size(1) - - w2w_query_layer = self.transpose_for_scores(self.query(word_hidden_states)) - w2e_query_layer = self.transpose_for_scores(self.w2e_query(word_hidden_states)) - e2w_query_layer = self.transpose_for_scores(self.e2w_query(entity_hidden_states)) - e2e_query_layer = self.transpose_for_scores(self.e2e_query(entity_hidden_states)) - - key_layer = self.transpose_for_scores(self.key(torch.cat([word_hidden_states, entity_hidden_states], dim=1))) - - w2w_key_layer = key_layer[:, :, :word_size, :] - e2w_key_layer = key_layer[:, :, :word_size, :] - w2e_key_layer = key_layer[:, :, word_size:, :] - e2e_key_layer = key_layer[:, :, word_size:, :] - - w2w_attention_scores = torch.matmul(w2w_query_layer, w2w_key_layer.transpose(-1, -2)) - w2e_attention_scores = torch.matmul(w2e_query_layer, w2e_key_layer.transpose(-1, -2)) - e2w_attention_scores = torch.matmul(e2w_query_layer, e2w_key_layer.transpose(-1, -2)) - e2e_attention_scores = torch.matmul(e2e_query_layer, e2e_key_layer.transpose(-1, -2)) - - word_attention_scores = torch.cat([w2w_attention_scores, w2e_attention_scores], dim=3) - entity_attention_scores = torch.cat([e2w_attention_scores, e2e_attention_scores], dim=3) - attention_scores = torch.cat([word_attention_scores, entity_attention_scores], dim=2) - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - attention_scores = attention_scores + attention_mask - - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.dropout(attention_probs) - - value_layer = self.transpose_for_scores( - self.value(torch.cat([word_hidden_states, entity_hidden_states], dim=1)) - ) - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer[:, :word_size, :], context_layer[:, word_size:, :] - - -class EntityAwareAttention(nn.Module): - def __init__(self, config): - super(EntityAwareAttention, self).__init__() - self.self = EntityAwareSelfAttention(config) - self.output = LukeSelfOutput(config) - - def forward(self, word_hidden_states, entity_hidden_states, attention_mask): - word_self_output, entity_self_output = self.self(word_hidden_states, entity_hidden_states, attention_mask) - hidden_states = torch.cat([word_hidden_states, entity_hidden_states], dim=1) - self_output = torch.cat([word_self_output, entity_self_output], dim=1) - output = self.output(self_output, hidden_states) - return output[:, : word_hidden_states.size(1), :], output[:, word_hidden_states.size(1) :, :] - - -class EntityAwareLayer(nn.Module): - def __init__(self, config): - super(EntityAwareLayer, self).__init__() - - self.attention = EntityAwareAttention(config) - self.intermediate = LukeIntermediate(config) - self.output = LukeOutput(config) - - def forward(self, word_hidden_states, entity_hidden_states, attention_mask): - word_attention_output, entity_attention_output = self.attention( - word_hidden_states, entity_hidden_states, attention_mask - ) - attention_output = torch.cat([word_attention_output, entity_attention_output], dim=1) - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - - return layer_output[:, : word_hidden_states.size(1), :], layer_output[:, word_hidden_states.size(1) :, :] - - -class EntityAwareEncoder(nn.Module): - def __init__(self, config): - super(EntityAwareEncoder, self).__init__() - self.layer = nn.ModuleList([EntityAwareLayer(config) for _ in range(config.num_hidden_layers)]) - - def forward(self, word_hidden_states, entity_hidden_states, attention_mask): - for idx, layer_module in enumerate(self.layer): - word_hidden_states, entity_hidden_states = layer_module( - word_hidden_states, entity_hidden_states, attention_mask - ) - - return word_hidden_states, entity_hidden_states - - @add_start_docstrings( """ The LUKE Model with a classification head on top (a linear layer on top of the hidden state of the entity @@ -1103,7 +1024,7 @@ class LukeForEntityClassification(nn.Module): def __init__(self, config): super().__init__(config) - self.luke = LukeEntityAwareAttentionModel(config) + self.luke = LukeModel(config) self.num_labels = config.num_labels self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -1114,12 +1035,12 @@ def __init__(self, config): def forward( self, input_ids, - attention_mask, - token_type_ids, entity_ids, - entity_attention_mask, - entity_token_type_ids, entity_position_ids, + attention_mask=None, + token_type_ids=None, + entity_attention_mask=None, + entity_token_type_ids=None, labels=None, return_dict=None, ): @@ -1141,7 +1062,7 @@ def forward( entity_attention_mask=entity_attention_mask, entity_token_type_ids=entity_token_type_ids, entity_position_ids=entity_position_ids, - return_dict=return_dict, + return_dict=True, ) feature_vector = outputs.entity_last_hidden_state[:, 0, :] @@ -1155,15 +1076,22 @@ def forward( loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) if not return_dict: - output = (logits,) + outputs[2:] + output = ( + logits, + outputs.hidden_states, + outputs.word_hidden_states, + outputs.entity_hidden_states, + outputs.attentions, + ) return ((loss,) + output) if loss is not None else output return EntityClassificationOutput( loss=loss, logits=logits, - hidden_states=None, # currently not supported - entity_hidden_states=None, # currently not supported - attentions=None, # currently not supported + hidden_states=outputs.hidden_states, + word_hidden_states=outputs.word_hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, ) @@ -1178,7 +1106,7 @@ class LukeForEntityPairClassification(nn.Module): def __init__(self, config): super().__init__(config) - self.luke = LukeEntityAwareAttentionModel(config) + self.luke = LukeModel(config) self.num_labels = config.num_labels self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -1189,12 +1117,12 @@ def __init__(self, config): def forward( self, input_ids, - attention_mask, - token_type_ids, entity_ids, - entity_attention_mask, - entity_token_type_ids, entity_position_ids, + attention_mask=None, + token_type_ids=None, + entity_attention_mask=None, + entity_token_type_ids=None, labels=None, return_dict=None, ): @@ -1216,11 +1144,12 @@ def forward( entity_attention_mask=entity_attention_mask, entity_token_type_ids=entity_token_type_ids, entity_position_ids=entity_position_ids, - return_dict=return_dict, + return_dict=True, ) - feature_vector = torch.cat([outputs.entity_last_hidden_state[:, 0, :], - outputs.entity_last_hidden_state[1][:, 1, :]], dim=1) + feature_vector = torch.cat( + [outputs.entity_last_hidden_state[:, 0, :], outputs.entity_last_hidden_state[1][:, 1, :]], dim=1 + ) feature_vector = self.dropout(feature_vector) logits = self.classifier(feature_vector) @@ -1231,15 +1160,22 @@ def forward( loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) if not return_dict: - output = (logits,) + outputs[2:] + output = ( + logits, + outputs.hidden_states, + outputs.word_hidden_states, + outputs.entity_hidden_states, + outputs.attentions, + ) return ((loss,) + output) if loss is not None else output return EntityPairClassificationOutput( loss=loss, logits=logits, - hidden_states=None, # currently not supported - entity_hidden_states=None, # currently not supported - attentions=None, # currently not supported + hidden_states=outputs.hidden_states, + word_hidden_states=outputs.word_hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, ) @@ -1254,7 +1190,7 @@ class LukeForEntitySpanClassification(nn.Module): def __init__(self, config): super().__init__(config) - self.luke = LukeEntityAwareAttentionModel(config) + self.luke = LukeModel(config) self.num_labels = config.num_labels self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -1265,14 +1201,14 @@ def __init__(self, config): def forward( self, input_ids, - attention_mask, - token_type_ids, entity_ids, - entity_attention_mask, - entity_token_type_ids, entity_position_ids, entity_start_positions, entity_end_positions, + attention_mask=None, + token_type_ids=None, + entity_attention_mask=None, + entity_token_type_ids=None, labels=None, return_dict=None, ): @@ -1298,7 +1234,7 @@ def forward( entity_attention_mask=entity_attention_mask, entity_token_type_ids=entity_token_type_ids, entity_position_ids=entity_position_ids, - return_dict=return_dict, + return_dict=True, ) hidden_size = outputs.last_hidden_state.size(-1) @@ -1318,13 +1254,20 @@ def forward( loss = F.binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) if not return_dict: - output = (logits,) + outputs[2:] + output = ( + logits, + outputs.hidden_states, + outputs.word_hidden_states, + outputs.entity_hidden_states, + outputs.attentions, + ) return ((loss,) + output) if loss is not None else output return EntitySpanClassificationOutput( loss=loss, logits=logits, - hidden_states=None, # currently not supported - entity_hidden_states=None, # currently not supported - attentions=None, # currently not supported + hidden_states=outputs.hidden_states, + word_hidden_states=outputs.word_hidden_states, + entity_hidden_states=outputs.entity_hidden_states, + attentions=outputs.attentions, ) diff --git a/src/transformers/models/luke/tokenization_luke.py b/src/transformers/models/luke/tokenization_luke.py index 0704181fea49..1fed50c162b1 100644 --- a/src/transformers/models/luke/tokenization_luke.py +++ b/src/transformers/models/luke/tokenization_luke.py @@ -15,6 +15,7 @@ """Tokenization classes for LUKE.""" import itertools import json +import os from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -95,9 +96,11 @@ class LukeTokenizer(RobertaTokenizer): max_mention_length (:obj:`int`, `optional`, defaults to 30): The maximum number of tokens inside an entity span. entity_token_1 (:obj:`str`, `optional`, defaults to :obj:``): - The special token representing an entity span. This token is only used when `task` is set to `entity_classification` or `entity_pair_classification`. + The special token representing an entity span. This token is only used when `task` is set to + `entity_classification` or `entity_pair_classification`. entity_token_2 (:obj:`str`, `optional`, defaults to :obj:``): - The special token representing an entity span. This token is only used when `task` is set to `entity_pair_classification`. + The special token representing an entity span. This token is only used when `task` is set to + `entity_pair_classification`. """ vocab_files_names = VOCAB_FILES_NAMES @@ -128,12 +131,18 @@ def __init__( if isinstance(entity_token_2, str) else entity_token_2 ) - additional_special_tokens = [entity_token_1, entity_token_2] + kwargs["additional_special_tokens"] = [entity_token_1, entity_token_2] + kwargs.get( + "additional_special_tokens", [] + ) super().__init__( vocab_file=vocab_file, merges_file=merges_file, - additional_special_tokens=additional_special_tokens, + task=task, + max_entity_length=32, + max_mention_length=30, + entity_token_1="", + entity_token_2="", **kwargs, ) @@ -170,8 +179,8 @@ def __call__( is_split_into_words: Optional[bool] = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, - return_token_type_ids: Optional[bool] = True, - return_attention_mask: Optional[bool] = True, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, @@ -185,22 +194,30 @@ def __call__( Args: text (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): - The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings - (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set - :obj:`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + The sequence or batch of sequences to be encoded. Each sequence must be a string. Note that this + tokenizer does not support tokenization based on pretokenized string. text_pair (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): - The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings - (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set - :obj:`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + The sequence or batch of sequences to be encoded. Each sequence must be a string. Note that this + tokenizer does not support tokenization based on pretokenized string. entities (:obj:`List[str]`, obj:`List[List[str]]`, `optional`):: The sequence or batch of sequences of entities to be encoded. Each sequence consists of strings representing entities, i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., New - York). This argument is ignored if you specify the `task` argument. + York). This argument is ignored if you specify the `task` argument in the constructor. + entities_pair (:obj:`List[str]`, obj:`List[List[str]]`, `optional`):: + The sequence or batch of sequences of entities to be encoded. Each sequence consists of strings + representing entities, i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., New + York). This argument is ignored if you specify the `task` argument in the constructor. entity_spans (:obj:`List[Tuple]`, obj:`List[List[Tuple]]`, `optional`):: - The sequence or batch of sequences of entity spans to be encoded. Each sequence consists of tuples with - two integers denoting start and end positions of entities. - If you specify `entity` or `entity_pair` as the `task` argument, the length of `entity_spans` must be 1 - or 2, respectively. If you specify `entities`, the length of spans must be equal to that of entities. + The sequence or batch of sequences of entity spans to be encoded. Each sequence consists of tuples each + with two integers denoting start and end positions of entities. If you specify `entity_classification` + or `entity_pair_classification` as the `task` argument in the constructor, the length of each sequence + must be 1 or 2, respectively. If you specify `entities`, the length of each sequence must be equal to + each sequence of `entities`. + entity_spans_pair (:obj:`List[Tuple]`, obj:`List[List[Tuple]]`, `optional`):: + The sequence or batch of sequences of entity spans to be encoded. Each sequence consists of tuples each + with two integers denoting start and end positions of entities. If you specify the `task` argument in + the constructor, this argument is ignored. If you specify `entities_pair`, the length of each sequence + must be equal to each sequence of `entities_pair`. max_entity_length (:obj:`int`, `optional`): The maximum length of :obj:`entity_ids`. """ @@ -314,14 +331,31 @@ def encode_plus( ``__call__`` should be used instead. Args: - text (:obj:`str`, :obj:`List[str]` or :obj:`List[int]` (the latter only for not-fast tokenizers)): - The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the - ``tokenize`` method) or a list of integers (tokenized string ids using the ``convert_tokens_to_ids`` - method). - text_pair (:obj:`str`, :obj:`List[str]` or :obj:`List[int]`, `optional`): - Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using - the ``tokenize`` method) or a list of integers (tokenized string ids using the - ``convert_tokens_to_ids`` method). + text (:obj:`str`): + The first sequence to be encoded. Each sequence must be a string. + text_pair (:obj:`str`): + The second sequence to be encoded. Each sequence must be a string. + entities (:obj:`List[str]` `optional`):: + The first sequence of entities to be encoded. The sequence consists of strings representing entities, + i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., New York). This argument is + ignored if you specify the `task` argument in the constructor. + entities_pair (:obj:`List[str]`, obj:`List[List[str]]`, `optional`):: + The second sequence of entities to be encoded. The sequence consists of strings representing entities, + i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., New York). This argument is + ignored if you specify the `task` argument in the constructor. + entity_spans (:obj:`List[Tuple]`, obj:`List[List[Tuple]]`, `optional`):: + The first sequence of entity spans to be encoded. The sequence consists of tuples each with two integers + denoting start and end positions of entities. If you specify `entity_classification` or + `entity_pair_classification` as the `task` argument in the constructor, the length of each sequence must + be 1 or 2, respectively. If you specify `entities`, the length of the sequence must be equal to the + sequence of `entities`. + entity_spans_pair (:obj:`List[Tuple]`, obj:`List[List[Tuple]]`, `optional`):: + The second sequence of entity spans to be encoded. The sequence consists of tuples each with two integers + denoting start and end positions of entities. If you specify the `task` argument in the constructor, + this argument is ignored. If you specify `entities_pair`, the length of the sequence must be equal to + the sequence of `entities_pair`. + max_entity_length (:obj:`int`, `optional`): + The maximum length of the entity sequence. """ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' @@ -413,7 +447,7 @@ def _encode_plus( entities_pair=entities_pair, entity_spans=entity_spans, entity_spans_pair=entity_spans_pair, - **kwargs + **kwargs, ) # prepare_for_model will create the attention_mask and token_type_ids @@ -453,6 +487,7 @@ def batch_encode_plus( padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = False, max_length: Optional[int] = None, + max_entity_length: Optional[int] = None, stride: int = 0, is_split_into_words: Optional[bool] = False, pad_to_multiple_of: Optional[int] = None, @@ -473,10 +508,18 @@ def batch_encode_plus( This method is deprecated, ``__call__`` should be used instead. Args: - batch_text_or_text_pairs (:obj:`List[str]`, :obj:`List[Tuple[str, str]]`, :obj:`List[List[str]]`, :obj:`List[Tuple[List[str], List[str]]]`, and for not-fast tokenizers, also :obj:`List[List[int]]`, :obj:`List[Tuple[List[int], List[int]]]`): - Batch of sequences or pair of sequences to be encoded. This can be a list of - string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see - details in ``encode_plus``). + batch_text_or_text_pairs (:obj:`List[str]`, :obj:`List[Tuple[str, str]]`): + Batch of sequences or pair of sequences to be encoded. This can be a list of string or a list of pair of + string (see details in ``encode_plus``). + batch_entities_or_entities_pairs (:obj:`List[List[str]]`, :obj:`List[Tuple[List[str], List[str]]]`, + `optional`): + Batch of entity sequences or pairs of entity sequences to be encoded (see details in ``encode_plus``). + batch_entity_spans_or_entity_spans_pairs (:obj:`List[List[Tuple[int, int]]]`, + :obj:`List[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]`, `optional`):: + Batch of entity span sequences or pairs of entity span sequences to be encoded (see details in + ``encode_plus``). + max_entity_length (:obj:`int`, `optional`): + The maximum length of the entity sequence. """ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' @@ -491,12 +534,13 @@ def batch_encode_plus( return self._batch_encode_plus( batch_text_or_text_pairs=batch_text_or_text_pairs, - batch_entities_or_entitiies_pairs=batch_entities_or_entities_pairs, + batch_entities_or_entities_pairs=batch_entities_or_entities_pairs, batch_entity_spans_or_entity_spans_pairs=batch_entity_spans_or_entity_spans_pairs, add_special_tokens=add_special_tokens, padding_strategy=padding_strategy, truncation_strategy=truncation_strategy, max_length=max_length, + max_entity_length=max_entity_length, stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, @@ -588,7 +632,7 @@ def _batch_encode_plus( entities_pair=entities_pair, entity_spans=entity_spans, entity_spans_pair=entity_spans_pair, - **kwargs + **kwargs, ) input_ids.append((first_ids, second_ids)) entity_ids.append((first_entity_ids, second_entity_ids)) @@ -596,8 +640,8 @@ def _batch_encode_plus( batch_outputs = self._batch_prepare_for_model( input_ids, - batch_entity_id_pairs=entity_ids, - batch_entity_token_span_pairs=entity_token_spans, + batch_entity_ids_pairs=entity_ids, + batch_entity_token_spans_pairs=entity_token_spans, add_special_tokens=add_special_tokens, padding_strategy=padding_strategy, truncation_strategy=truncation_strategy, @@ -626,7 +670,6 @@ def _create_input_sequence( entity_spans_pair: Optional[List[Tuple[int, int]]] = None, **kwargs ) -> Tuple[list, list, list, list, list, list]: - def get_input_ids(text): tokens = self.tokenize(text, **kwargs) return self.convert_tokens_to_ids(tokens) @@ -757,7 +800,7 @@ def _batch_prepare_for_model( self, batch_ids_pairs: List[Tuple[List[int], None]], batch_entity_ids_pairs: List[Tuple[Optional[List[int]], Optional[List[int]]]], - batch_entity_token_span_pairs: List[Tuple[Optional[List[Tuple[int, int]]], Optional[List[Tuple[int, int]]]]], + batch_entity_token_spans_pairs: List[Tuple[Optional[List[Tuple[int, int]]], Optional[List[Tuple[int, int]]]]], add_special_tokens: bool = True, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, @@ -780,15 +823,18 @@ def _batch_prepare_for_model( Args: batch_ids_pairs: list of tokenized input ids or input ids pairs + batch_entity_ids_pairs: list of entity ids or entity ids pairs + batch_entity_token_spans_pairs: list of entity spans or entity spans pairs + max_entity_length: The maximum length of the entity sequence. """ batch_outputs = {} for input_ids, entity_ids, entity_token_span_pairs in zip( - batch_ids_pairs, batch_entity_ids_pairs, batch_entity_token_span_pairs + batch_ids_pairs, batch_entity_ids_pairs, batch_entity_token_spans_pairs ): first_ids, second_ids = input_ids first_entity_ids, second_entity_ids = entity_ids - first_entity_token_spans, second_entity_token_spans = batch_entity_token_span_pairs + first_entity_token_spans, second_entity_token_spans = entity_token_span_pairs outputs = self.prepare_for_model( first_ids, second_ids, @@ -858,19 +904,26 @@ def prepare_for_model( **kwargs ) -> BatchEncoding: """ - Prepares a sequence of input id, or a pair of sequences of inputs ids together with entity ids so that it can - be used by the model. It adds special tokens, truncates sequences if overflowing while taking into account the - special tokens and manages a moving window (with user defined stride) for overflowing tokens + Prepares a sequence of input id, entity id and entity span, or a pair of sequences of inputs ids, entity ids, + entity spans so that it can be used by the model. It adds special tokens, truncates sequences if overflowing + while taking into account the special tokens and manages a moving window (with user defined stride) for + overflowing tokens Args: ids (:obj:`List[int]`): - Tokenized input ids of the first sequence. Can be obtained from a string by chaining the ``tokenize`` - and ``convert_tokens_to_ids`` methods. + Tokenized input ids of the first sequence. pair_ids (:obj:`List[int]`, `optional`): - Tokenized input ids of the second sequence. Can be obtained from a string by chaining the ``tokenize`` - and ``convert_tokens_to_ids`` methods. + Tokenized input ids of the second sequence. entity_ids (:obj:`List[int]`, `optional`): - Tokenized entity ids. + Entity ids of the first sequence. + pair_entity_ids (:obj:`List[int]`, `optional`): + Entity ids of the second sequence. + entity_token_spans (:obj:`List[Tuple[int, int]]`, `optional`): + Entity spans of the first sequence. + pair_entity_token_spans (:obj:`List[Tuple[int, int]]`, `optional`): + Entity spans of the second sequence. + max_entity_length (:obj:`int`, `optional`): + The maximum length of the entity sequence. """ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' @@ -944,37 +997,35 @@ def prepare_for_model( else: encoded_inputs["special_tokens_mask"] = [0] * len(sequence) - if entity_ids is None: - encoded_inputs["entity_ids"] = [] - encoded_inputs["entity_position_ids"] = [] - - else: - # Set max entity length - if not max_entity_length: - max_entity_length = self.max_entity_length + # Set max entity length + if not max_entity_length: + max_entity_length = self.max_entity_length + if entity_ids is not None: total_entity_len = 0 num_invalid_entities = 0 - valid_entity_ids, valid_entity_token_spans = zip( - *[(ent_id, span) for ent_id, span in zip(entity_ids, entity_token_spans) if span[1] <= len(ids)] - ) + valid_entity_ids = [ent_id for ent_id, span in zip(entity_ids, entity_token_spans) if span[1] <= len(ids)] + valid_entity_token_spans = [span for span in entity_token_spans if span[1] <= len(ids)] + total_entity_len += len(valid_entity_ids) num_invalid_entities += len(entity_ids) - len(valid_entity_ids) valid_pair_entity_ids, valid_pair_entity_token_spans = None, None if pair_entity_ids is not None: - valid_pair_entity_ids, valid_pair_entity_token_spans = zip( - *[ - (ent_id, span) - for ent_id, span in zip(pair_entity_ids, pair_entity_token_spans) - if span[1] <= len(pair_ids) - ] - ) + valid_pair_entity_ids = [ + ent_id + for ent_id, span in zip(pair_entity_ids, pair_entity_token_spans) + if span[1] <= len(pair_ids) + ] + valid_pair_entity_token_spans = [span for span in pair_entity_token_spans if span[1] <= len(pair_ids)] total_entity_len += len(valid_pair_entity_ids) num_invalid_entities += len(pair_entity_ids) - len(valid_pair_entity_ids) if num_invalid_entities != 0: - logger.warning('%d entities are ignored because their entity spans are invalid due to the truncation of input tokens', num_invalid_entities) + logger.warning( + "%d entities are ignored because their entity spans are invalid due to the truncation of input tokens", + num_invalid_entities, + ) if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and total_entity_len > max_entity_length: # truncate entities up to max_entity_length @@ -1010,8 +1061,8 @@ def prepare_for_model( encoded_inputs["entity_position_ids"] = entity_position_ids - if return_token_type_ids: - encoded_inputs["entity_token_type_ids"] = [0] * len(encoded_inputs["entity_ids"]) + if return_token_type_ids: + encoded_inputs["entity_token_type_ids"] = [0] * len(encoded_inputs["entity_ids"]) # Check lengths self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) @@ -1022,6 +1073,7 @@ def prepare_for_model( encoded_inputs = self.pad( encoded_inputs, max_length=max_length, + max_entity_length=max_entity_length, padding=padding_strategy.value, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, @@ -1081,6 +1133,8 @@ def pad( different lengths). max_length (:obj:`int`, `optional`): Maximum length of the returned list and optionally padding length (see above). + max_entity_length (:obj:`int`, `optional`): + The maximum length of the entity sequence. pad_to_multiple_of (:obj:`int`, `optional`): If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability @@ -1172,7 +1226,9 @@ def pad( if padding_strategy == PaddingStrategy.LONGEST: max_length = max(len(inputs) for inputs in required_input) - max_entity_length = max(len(inputs) for inputs in encoded_inputs["entity_ids"]) + max_entity_length = ( + max(len(inputs) for inputs in encoded_inputs["entity_ids"]) if "entity_ids" in encoded_inputs else 0 + ) padding_strategy = PaddingStrategy.MAX_LENGTH batch_outputs = {} @@ -1210,6 +1266,7 @@ def _pad( encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). max_length: maximum length of the returned list and optionally padding length (see below). Will truncate by taking into account the special tokens. + max_entity_length: The maximum length of the entity sequence. padding_strategy: PaddingStrategy to use for padding. - PaddingStrategy.LONGEST Pad to the longest sequence in the batch @@ -1224,73 +1281,98 @@ def _pad( >= 7.5 (Volta). return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ + entities_provided = bool("entity_ids" in encoded_inputs) + # Load from model defaults if return_attention_mask is None: return_attention_mask = "attention_mask" in self.model_input_names if padding_strategy == PaddingStrategy.LONGEST: max_length = len(encoded_inputs["input_ids"]) - max_entity_length = len(encoded_inputs["entity_ids"]) + if entities_provided: + max_entity_length = len(encoded_inputs["entity_ids"]) if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of if ( - max_entity_length is not None + entities_provided + and max_entity_length is not None and pad_to_multiple_of is not None and (max_entity_length % pad_to_multiple_of != 0) ): max_entity_length = ((max_entity_length // pad_to_multiple_of) + 1) * pad_to_multiple_of needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and ( - len(encoded_inputs["input_ids"]) != max_length or len(encoded_inputs["entity_ids"]) != max_entity_length + len(encoded_inputs["input_ids"]) != max_length + or (entities_provided and len(encoded_inputs["entity_ids"]) != max_entity_length) ) if needs_to_be_padded: difference = max_length - len(encoded_inputs["input_ids"]) - entity_difference = max_entity_length - len(encoded_inputs["entity_ids"]) + if entities_provided: + entity_difference = max_entity_length - len(encoded_inputs["entity_ids"]) if self.padding_side == "right": if return_attention_mask: encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference - encoded_inputs["entity_attention_mask"] = [1] * len(encoded_inputs["entity_ids"]) + [ - 0 - ] * entity_difference + if entities_provided: + encoded_inputs["entity_attention_mask"] = [1] * len(encoded_inputs["entity_ids"]) + [ + 0 + ] * entity_difference if "token_type_ids" in encoded_inputs: encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"] + [0] * difference - encoded_inputs["entity_token_type_ids"] = ( - encoded_inputs["entity_token_type_ids"] + [0] * entity_difference - ) + if entities_provided: + encoded_inputs["entity_token_type_ids"] = ( + encoded_inputs["entity_token_type_ids"] + [0] * entity_difference + ) if "special_tokens_mask" in encoded_inputs: encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference - encoded_inputs["entity_ids"] = encoded_inputs["entity_ids"] + [0] * entity_difference - encoded_inputs["entity_position_ids"] = ( - encoded_inputs["entity_position_ids"] + [[-1] * self.max_mention_length] * entity_difference - ) + if entities_provided: + encoded_inputs["entity_ids"] = encoded_inputs["entity_ids"] + [0] * entity_difference + encoded_inputs["entity_position_ids"] = ( + encoded_inputs["entity_position_ids"] + [[-1] * self.max_mention_length] * entity_difference + ) elif self.padding_side == "left": if return_attention_mask: encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"]) - encoded_inputs["entity_attention_mask"] = [0] * entity_difference + [1] * len( - encoded_inputs["entity_ids"] - ) + if entities_provided: + encoded_inputs["entity_attention_mask"] = [0] * entity_difference + [1] * len( + encoded_inputs["entity_ids"] + ) if "token_type_ids" in encoded_inputs: encoded_inputs["token_type_ids"] = [0] * difference + encoded_inputs["token_type_ids"] - encoded_inputs["entity_token_type_ids"] = [0] * entity_difference + encoded_inputs[ - "entity_token_type_ids" - ] + if entities_provided: + encoded_inputs["entity_token_type_ids"] = [0] * entity_difference + encoded_inputs[ + "entity_token_type_ids" + ] if "special_tokens_mask" in encoded_inputs: encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"] - encoded_inputs["entity_ids"] = [0] * entity_difference + encoded_inputs["entity_ids"] - encoded_inputs["entity_position_ids"] = [ - [-1] * self.max_mention_length - ] * entity_difference + encoded_inputs["entity_position_ids"] + if entities_provided: + encoded_inputs["entity_ids"] = [0] * entity_difference + encoded_inputs["entity_ids"] + encoded_inputs["entity_position_ids"] = [ + [-1] * self.max_mention_length + ] * entity_difference + encoded_inputs["entity_position_ids"] else: raise ValueError("Invalid padding strategy:" + str(self.padding_side)) else: if return_attention_mask: encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) - encoded_inputs["entity_attention_mask"] = [1] * len(encoded_inputs["entity_ids"]) + if entities_provided: + encoded_inputs["entity_attention_mask"] = [1] * len(encoded_inputs["entity_ids"]) + + return encoded_inputs + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + vocab_file, merge_file = super(LukeTokenizer, self).save_vocabulary(save_directory, filename_prefix) + + entity_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["entity_vocab_file"] + ) + + with open(entity_vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.entity_vocab, ensure_ascii=False)) - return encoded_inputs \ No newline at end of file + return vocab_file, merge_file, entity_vocab_file diff --git a/tests/test_tokenization_luke.py b/tests/test_tokenization_luke.py index e16ae928d323..bdea64e8b30a 100644 --- a/tests/test_tokenization_luke.py +++ b/tests/test_tokenization_luke.py @@ -34,9 +34,9 @@ def setUp(self): self.special_tokens_map = {"entity_token_1": "", "entity_token_2": ""} - def get_tokenizer(self, **kwargs): + def get_tokenizer(self, task=None, **kwargs): kwargs.update(self.special_tokens_map) - return self.tokenizer_class.from_pretrained("studio-ousia/luke-large", **kwargs) + return self.tokenizer_class.from_pretrained("studio-ousia/luke-base", task=task, **kwargs) def get_input_output_texts(self, tokenizer): input_text = "lower newer" @@ -44,14 +44,14 @@ def get_input_output_texts(self, tokenizer): return input_text, output_text def test_full_tokenizer(self): - tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/luke-large") + tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/luke-base") text = "lower newer" - bpe_tokens = ["l", "o", "w", "er", "\u0120", "n", "e", "w", "er"] + bpe_tokens = ["lower", "\u0120newer"] tokens = tokenizer.tokenize(text) # , add_prefix_space=True) self.assertListEqual(tokens, bpe_tokens) input_tokens = tokens + [tokenizer.unk_token] - input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19] + input_bpe_tokens = [29668, 13964, 3] self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) def luke_dict_integration_testing(self): @@ -158,24 +158,364 @@ class LukeTokenizerIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() + def test_single_text_no_padding_or_truncation(self): + tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", return_token_type_ids=True) + sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck." + entities = ["Ana Ivanovic", "Thursday", "Dummy Entity"] + spans = [(9, 21), (30, 38), (39, 42)] + + encoding = tokenizer(sentence, entities=entities, entity_spans=spans, return_token_type_ids=True) + + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck.", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][3:6], spaces_between_special_tokens=False), " Ana Ivanovic" + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][8:9], spaces_between_special_tokens=False), " Thursday" + ) + self.assertEqual(tokenizer.decode(encoding["input_ids"][9:10], spaces_between_special_tokens=False), " she") + + self.assertEqual( + encoding["entity_ids"], + [ + tokenizer.entity_vocab["Ana Ivanovic"], + tokenizer.entity_vocab["Thursday"], + tokenizer.entity_vocab["[UNK]"], + ], + ) + self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1]) + self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0]) + self.assertEqual( + encoding["entity_position_ids"], + [ + [ + 3, + 4, + 5, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + ], + [ + 8, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + ], + [ + 9, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + ], + ], + ) + + def test_single_text_padding_pytorch_tensors(self): + tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", return_token_type_ids=True) + sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck." + entities = ["Ana Ivanovic", "Thursday", "Dummy Entity"] + spans = [(9, 21), (30, 38), (39, 42)] + + encoding = tokenizer( + sentence, + entities=entities, + entity_spans=spans, + return_token_type_ids=True, + padding="max_length", + max_length=30, + max_entity_length=16, + return_tensors="pt", + ) + + # test words + self.assertEqual(encoding["input_ids"].shape, (1, 30)) + self.assertEqual(encoding["attention_mask"].shape, (1, 30)) + self.assertEqual(encoding["token_type_ids"].shape, (1, 30)) + + # test entities + self.assertEqual(encoding["entity_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_attention_mask"].shape, (1, 16)) + self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_position_ids"].shape, (1, 16, tokenizer.max_mention_length)) + + def test_text_pair_no_padding_or_truncation(self): + tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", return_token_type_ids=True) + sentence = "Top seed Ana Ivanovic said on Thursday" + sentence_pair = "She could hardly believe her luck." + entities = ["Ana Ivanovic", "Thursday"] + entities_pair = ["Dummy Entity"] + spans = [(9, 21), (30, 38)] + spans_pair = [(0, 3)] + + encoding = tokenizer( + sentence, + sentence_pair, + entities=entities, + entities_pair=entities_pair, + entity_spans=spans, + entity_spans_pair=spans_pair, + return_token_type_ids=True, + ) + + self.assertEqual( + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + "Top seed Ana Ivanovic said on ThursdayShe could hardly believe her luck.", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][3:6], spaces_between_special_tokens=False), " Ana Ivanovic" + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][8:9], spaces_between_special_tokens=False), " Thursday" + ) + self.assertEqual(tokenizer.decode(encoding["input_ids"][11:12], spaces_between_special_tokens=False), "She") + + self.assertEqual( + encoding["entity_ids"], + [ + tokenizer.entity_vocab["Ana Ivanovic"], + tokenizer.entity_vocab["Thursday"], + tokenizer.entity_vocab["[UNK]"], + ], + ) + self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1]) + self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0]) + self.assertEqual( + encoding["entity_position_ids"], + [ + [ + 3, + 4, + 5, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + ], + [ + 8, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + ], + [ + 11, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + ], + ], + ) + + def test_text_pair_padding_pytorch_tensors(self): + tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", return_token_type_ids=True) + sentence = "Top seed Ana Ivanovic said on Thursday" + sentence_pair = "She could hardly believe her luck." + entities = ["Ana Ivanovic", "Thursday"] + entities_pair = ["Dummy Entity"] + spans = [(9, 21), (30, 38)] + spans_pair = [(0, 3)] + + encoding = tokenizer( + sentence, + sentence_pair, + entities=entities, + entities_pair=entities_pair, + entity_spans=spans, + entity_spans_pair=spans_pair, + return_token_type_ids=True, + padding="max_length", + max_length=30, + max_entity_length=16, + return_tensors="pt", + ) + + # test words + self.assertEqual(encoding["input_ids"].shape, (1, 30)) + self.assertEqual(encoding["attention_mask"].shape, (1, 30)) + self.assertEqual(encoding["token_type_ids"].shape, (1, 30)) + + # test entities + self.assertEqual(encoding["entity_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_attention_mask"].shape, (1, 16)) + self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 16)) + self.assertEqual(encoding["entity_position_ids"].shape, (1, 16, tokenizer.max_mention_length)) + def test_entity_classification_no_padding_or_truncation(self): tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_classification") sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ." span = (39, 42) - encoding = tokenizer(sentence, additional_info=[span]) + encoding = tokenizer(sentence, entity_spans=[span], return_token_type_ids=True) # test words self.assertEqual(len(encoding["input_ids"]), 42) self.assertEqual(len(encoding["attention_mask"]), 42) self.assertEqual(len(encoding["token_type_ids"]), 42) self.assertEqual( - tokenizer.decode(encoding["input_ids"]), - "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon.", + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon.", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][9:12], spaces_between_special_tokens=False), " she" ) # test entities - self.assertEqual(encoding["entity_ids"], [1]) + self.assertEqual(encoding["entity_ids"], [2]) self.assertEqual(encoding["entity_attention_mask"], [1]) self.assertEqual(encoding["entity_token_type_ids"], [0]) self.assertEqual( @@ -217,12 +557,16 @@ def test_entity_classification_no_padding_or_truncation(self): ) def test_entity_classification_padding_pytorch_tensors(self): - tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_classification") + tokenizer = LukeTokenizer.from_pretrained( + "studio-ousia/luke-base", task="entity_classification", return_token_type_ids=True + ) sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck as a fortuitous netcord helped the new world number one avoid a humiliating second- round exit at Wimbledon ." # entity information span = (39, 42) - encoding = tokenizer(sentence, entity_spans=[span], padding="max_length", return_tensors="pt") + encoding = tokenizer( + sentence, entity_spans=[span], return_token_type_ids=True, padding="max_length", return_tensors="pt" + ) # test words self.assertEqual(encoding["input_ids"].shape, (1, 512)) @@ -233,21 +577,33 @@ def test_entity_classification_padding_pytorch_tensors(self): self.assertEqual(encoding["entity_ids"].shape, (1, 1)) self.assertEqual(encoding["entity_attention_mask"].shape, (1, 1)) self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 1)) - self.assertEqual(encoding["entity_position_ids"].shape, (1, tokenizer.max_entity_length, tokenizer.max_mention_length)) + self.assertEqual( + encoding["entity_position_ids"].shape, (1, tokenizer.max_entity_length, tokenizer.max_mention_length) + ) def test_entity_pair_classification_no_padding_or_truncation(self): - tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_pair_classification") + tokenizer = LukeTokenizer.from_pretrained( + "studio-ousia/luke-base", task="entity_pair_classification", return_token_type_ids=True + ) sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck." # head and tail information spans = [(9, 21), (39, 42)] - encoding = tokenizer(sentence, entity_spans=spans) + encoding = tokenizer(sentence, entity_spans=spans, return_token_type_ids=True) self.assertEqual( - tokenizer.decode(encoding["input_ids"]), - "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck.", + tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False), + "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck.", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][3:8], spaces_between_special_tokens=False), + " Ana Ivanovic", + ) + self.assertEqual( + tokenizer.decode(encoding["input_ids"][11:14], spaces_between_special_tokens=False), " she" ) - self.assertEqual(encoding["entity_ids"], [1, 2]) + + self.assertEqual(encoding["entity_ids"], [2, 3]) self.assertEqual(encoding["entity_attention_mask"], [1, 1]) self.assertEqual(encoding["entity_token_type_ids"], [0, 0]) self.assertEqual( @@ -259,7 +615,7 @@ def test_entity_pair_classification_no_padding_or_truncation(self): 5, 6, 7, - 8, + -1, -1, -1, -1, @@ -286,9 +642,9 @@ def test_entity_pair_classification_no_padding_or_truncation(self): -1, ], [ + 11, + 12, 13, - 14, - 15, -1, -1, -1, @@ -321,12 +677,21 @@ def test_entity_pair_classification_no_padding_or_truncation(self): ) def test_entity_pair_classification_padding_pytorch_tensors(self): - tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_pair_classification") + tokenizer = LukeTokenizer.from_pretrained( + "studio-ousia/luke-base", task="entity_pair_classification", return_token_type_ids=True + ) sentence = "Top seed Ana Ivanovic said on Thursday she could hardly believe her luck." # head and tail information spans = [(9, 21), (39, 42)] - encoding = tokenizer(sentence, entity_spans=spans, padding="max_length", max_length=30, return_tensors="pt") + encoding = tokenizer( + sentence, + entity_spans=spans, + return_token_type_ids=True, + padding="max_length", + max_length=30, + return_tensors="pt", + ) # test words self.assertEqual(encoding["input_ids"].shape, (1, 30)) @@ -337,4 +702,6 @@ def test_entity_pair_classification_padding_pytorch_tensors(self): self.assertEqual(encoding["entity_ids"].shape, (1, 2)) self.assertEqual(encoding["entity_attention_mask"].shape, (1, 2)) self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 2)) - self.assertEqual(encoding["entity_position_ids"].shape, (1, tokenizer.max_entity_length, tokenizer.max_mention_length)) \ No newline at end of file + self.assertEqual( + encoding["entity_position_ids"].shape, (1, tokenizer.max_entity_length, tokenizer.max_mention_length) + )