Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/transformers/models/luke/configuration_luke.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ class LukeConfig(RobertaConfig):
entity_emb_size (:obj:`int`, `optional`, defaults to 256):
The number of dimensions of the entity embedding.
use_entity_aware_attention (:obj:`bool`, defaults to :obj:`True`):
Whether or not the model should use entity-aware self-attention mechanism proposed in the original paper
<https://arxiv.org/abs/2010.01057>.
Whether or not the model should use the entity-aware self-attention mechanism proposed in
`LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention (Yamada et al.)
<https://arxiv.org/abs/2010.01057>`__.

Examples::
>>> from transformers import LukeConfig, LukeModel
Expand Down
66 changes: 33 additions & 33 deletions src/transformers/models/luke/modeling_luke.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,15 @@ class BaseLukeModelOutputWithPooling(BaseModelOutputWithPooling):
Sequence of entity hidden-states at the output of the last layer of the model.
pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
prediction (classification) objective during pretraining.
Linear layer and a Tanh activation function.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of
each layer plus the initial embedding outputs.
entity_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output
of each layer plus the initial embedding outputs.
of each layer plus the initial entity embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length + entity_length, sequence_length + entity_length)`. Attentions weights after the attention
Expand Down Expand Up @@ -102,7 +101,7 @@ class BaseLukeModelOutput(BaseModelOutput):
entity_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output
of each layer plus the initial embedding outputs.
of each layer plus the initial entity embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Expand Down Expand Up @@ -134,7 +133,7 @@ class EntityClassificationOutput(ModelOutput):
entity_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output
of each layer plus the initial embedding outputs.
of each layer plus the initial entity embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
Expand Down Expand Up @@ -167,7 +166,7 @@ class EntityPairClassificationOutput(ModelOutput):
entity_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output
of each layer plus the initial embedding outputs.
of each layer plus the initial entity embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
Expand Down Expand Up @@ -200,7 +199,7 @@ class EntitySpanClassificationOutput(ModelOutput):
entity_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output
of each layer plus the initial embedding outputs.
of each layer plus the initial entity embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
Expand Down Expand Up @@ -701,18 +700,21 @@ class LukePreTrainedModel(PreTrainedModel):
base_model_prefix = "luke"

def _init_weights(self, module: nn.Module):
""" Initialize the weights """
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
if module.embedding_dim == 1: # embedding for bias parameters
module.weight.data.zero_()
else:
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()


LUKE_START_DOCSTRING = r"""
Expand Down Expand Up @@ -771,22 +773,22 @@ def _init_weights(self, module: nn.Module):

`What are position IDs? <../glossary.html#position-ids>`_

entity_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
entity_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, entity_length)`):
Indices of entity tokens in the entity vocabulary.

Indices can be obtained using :class:`~transformers.LukeTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.

entity_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
entity_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, entity_length)`, `optional`):
Mask to avoid performing attention on padding entity token indices. Mask values selected in ``[0, 1]``:



- 1 for entity tokens that are **not masked**,
- 0 for entity tokens that are **masked**.

entity_token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
entity_token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, entity_length)`, `optional`):
Segment token indices to indicate first and second portions of the entity token inputs. Indices are
selected in ``[0, 1]``:

Expand Down Expand Up @@ -888,23 +890,22 @@ def forward(

>>> from transformers import LukeTokenizer, LukeModel

>>> tokenizer = LukeTokenizer.from_pretrained('studio-ousia/luke-base')
>>> model = LukeModel.from_pretrained('studio-ousia/luke-base')
>>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base")
>>> model = LukeModel.from_pretrained("studio-ousia/luke-base")

# Compute the contextualized entity representation for "Beyoncé" from the <mask> entity token.
# Compute the contextualized entity representation corresponding to the entity mention "Beyoncé"
>>> text = "Beyoncé lives in New York."
>>> entities = ["<mask>"]
>>> entity_spans = [(0, 7)]
>>> entity_spans = [(0, 7)] # character-based entity span corresponding to "Beyoncé"

>>> encoding = tokenizer(text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
>>> encoding = tokenizer(text, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
>>> outputs = model(**encoding)
>>> word_last_hidden_state = outputs.word_last_hidden_state
>>> entity_last_hidden_state = outputs.entity_last_hidden_state

# Input Wikipedia entities to enrich the contextualized representation.
# Input Wikipedia entities to obtain enriched contextualized representations.
>>> text = "Beyoncé lives in New York."
>>> entities = ["Beyoncé", "New York City"]
>>> entity_spans = [(0, 7), (17, 25)]
>>> entities = ["Beyoncé", "New York City"] # Wikipedia entity titles corresponding to the entity mentions "Beyoncé" and "New York"
>>> entity_spans = [(0, 7), (17, 25)] # character-based entity spans corresponding to "Beyoncé" and "New York"

>>> encoding = tokenizer(text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
>>> outputs = model(**encoding)
Expand Down Expand Up @@ -1034,7 +1035,7 @@ def create_position_ids_from_input_ids(input_ids, padding_idx):

@add_start_docstrings(
"""
The LUKE model with a classification head on top (a linear layer on top of the hidden state of the <mask> entity
The LUKE model with a classification head on top (a linear layer on top of the hidden state of the first entity
token) for entity classification tasks, such as Open Entity.
""",
LUKE_START_DOCSTRING,
Expand Down Expand Up @@ -1084,11 +1085,11 @@ def forward(

>>> from transformers import LukeTokenizer, LukeForEntityClassification

>>> tokenizer = LukeTokenizer.from_pretrained('studio-ousia/luke-base', task="entity_classification")
>>> model = LukeForEntityClassification.from_pretrained('studio-ousia/luke-base')
>>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_classification")
>>> model = LukeForEntityClassification.from_pretrained("studio-ousia/luke-base")

>>> text = "Beyoncé lives in New York."
>>> entity_spans = [(0, 7)]
>>> entity_spans = [(0, 7)] # character-based entity span corresponding to "Beyoncé"
>>> inputs = tokenizer(text, entity_spans=entity_spans, task="entity_classification", return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
Expand Down Expand Up @@ -1142,8 +1143,8 @@ def forward(

@add_start_docstrings(
"""
The LUKE Model with a classification head on top (a linear layer on top of the hidden states of the two <mask>
entity tokens) for entity pair classification tasks, such as TACRED.
The LUKE Model with a classification head on top (a linear layer on top of the hidden states of the two entity
tokens) for entity pair classification tasks, such as TACRED.
""",
LUKE_START_DOCSTRING,
)
Expand Down Expand Up @@ -1193,7 +1194,7 @@ def forward(
>>> from transformers import LukeTokenizer, LukeForEntityPairClassification

>>> text = "Beyoncé lives in New York."
>>> entity_spans = [(0, 7), (17, 25)]
>>> entity_spans = [(0, 7), (17, 25)] # character-based entity spans corresponding to "Beyoncé" and "New York"
>>> inputs = tokenizer(text, entity_spans=entity_spans, task="entity_pair_classification", return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
Expand Down Expand Up @@ -1306,13 +1307,12 @@ def forward(

>>> from transformers import LukeTokenizer, LukeForEntitySpanClassification

>>> tokenizer = LukeTokenizer.from_pretrained('studio-ousia/luke-base')
>>> model = LukeForEntitySpanClassification.from_pretrained('studio-ousia/luke-base')
>>> tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base")
>>> model = LukeForEntitySpanClassification.from_pretrained("studio-ousia/luke-base")

>>> text = "Beyoncé lives in New York."
>>> entities = ["<mask>", "<mask>"]
>>> entity_spans = [(0, 7), (17, 25)]
>>> inputs = tokenizer(text, entities=entities, entity_spans=entity_spans, return_tensors="pt")
>>> entity_spans = [(0, 7), (17, 25)] # character-based entity spans corresponding to "Beyoncé" and "New York"
>>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
"""
Expand Down
Loading