Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 29 additions & 29 deletions src/transformers/models/tapas/modeling_tapas.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class TableQuestionAnsweringOutput(ModelOutput):
Output type of :class:`~transformers.TapasForQuestionAnswering`.

Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label_ids` (and possibly :obj:`answer`, :obj:`aggregation_labels`, :obj:`numeric_values` and :obj:`numeric_values_scale` are provided)):
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` (and possibly :obj:`answer`, :obj:`aggregation_labels`, :obj:`numeric_values` and :obj:`numeric_values_scale` are provided)):
Total loss as the sum of the hierarchical cell selection log-likelihood loss and (optionally) the
semi-supervised regression loss and (optionally) supervised loss for aggregations.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`):
Expand Down Expand Up @@ -1018,7 +1018,7 @@ def forward(self, features, **kwargs):
TAPAS_START_DOCSTRING,
)
class TapasForQuestionAnswering(TapasPreTrainedModel):
def __init__(self, config):
def __init__(self, config: TapasConfig):
super().__init__(config)

# base model
Expand All @@ -1036,11 +1036,11 @@ def __init__(self, config):
else:
self.output_weights = nn.Parameter(torch.empty(config.hidden_size))
nn.init.normal_(
self.output_weights, std=0.02
self.output_weights, std=config.initializer_range
) # here, a truncated normal is used in the original implementation
self.column_output_weights = nn.Parameter(torch.empty(config.hidden_size))
nn.init.normal_(
self.column_output_weights, std=0.02
self.column_output_weights, std=config.initializer_range
) # here, a truncated normal is used in the original implementation
self.output_bias = nn.Parameter(torch.zeros([]))
self.column_output_bias = nn.Parameter(torch.zeros([]))
Expand All @@ -1062,7 +1062,7 @@ def forward(
head_mask=None,
inputs_embeds=None,
table_mask=None,
label_ids=None,
labels=None,
aggregation_labels=None,
float_answer=None,
numeric_values=None,
Expand All @@ -1075,7 +1075,7 @@ def forward(
table_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, seq_length)`, `optional`):
Mask for the table. Indicates which tokens belong to the table (1). Question tokens, table headers and
padding are 0.
label_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, seq_length)`, `optional`):
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, seq_length)`, `optional`):
Labels per token for computing the hierarchical cell selection loss. This encodes the positions of the
answer appearing in the table. Can be obtained using :class:`~transformers.TapasTokenizer`.

Expand Down Expand Up @@ -1156,7 +1156,7 @@ def forward(
"segment_ids",
"column_ids",
"row_ids",
"prev_label_ids",
"prev_labels",
"column_ranks",
"inv_column_ranks",
"numeric_relations",
Expand Down Expand Up @@ -1214,7 +1214,7 @@ def forward(
# Total loss calculation
total_loss = 0.0
calculate_loss = False
if label_ids is not None:
if labels is not None:
calculate_loss = True
is_supervised = not self.config.num_aggregation_labels > 0 or not self.config.use_answer_as_supervision

Expand All @@ -1226,18 +1226,18 @@ def forward(
# some ambiguous cases, see utils._calculate_aggregate_mask for more info.
# `aggregate_mask` is 1 for examples where we chose to aggregate and 0
# for examples where we chose to select the answer directly.
# `label_ids` encodes the positions of the answer appearing in the table.
# `labels` encodes the positions of the answer appearing in the table.
if is_supervised:
aggregate_mask = None
else:
if float_answer is not None:
assert label_ids.shape[0] == float_answer.shape[0], "Make sure the answers are a FloatTensor of shape (batch_size,)"
assert labels.shape[0] == float_answer.shape[0], "Make sure the answers are a FloatTensor of shape (batch_size,)"
# <float32>[batch_size]
aggregate_mask = _calculate_aggregate_mask(
float_answer,
pooled_output,
self.config.cell_selection_preference,
label_ids,
labels,
self.aggregation_classifier,
)
else:
Expand All @@ -1255,17 +1255,17 @@ def forward(
selection_loss_per_example = None
if not self.config.select_one_column:
weight = torch.where(
label_ids == 0,
torch.ones_like(label_ids, dtype=torch.float32),
self.config.positive_label_weight * torch.ones_like(label_ids, dtype=torch.float32),
labels == 0,
torch.ones_like(labels, dtype=torch.float32),
self.config.positive_label_weight * torch.ones_like(labels, dtype=torch.float32),
)
selection_loss_per_token = -dist_per_token.log_prob(label_ids) * weight
selection_loss_per_token = -dist_per_token.log_prob(labels) * weight
selection_loss_per_example = torch.sum(selection_loss_per_token * input_mask_float, dim=1) / (
torch.sum(input_mask_float, dim=1) + EPSILON_ZERO_DIVISION
)
else:
selection_loss_per_example, logits = _single_column_cell_selection_loss(
logits, column_logits, label_ids, cell_index, col_index, cell_mask
logits, column_logits, labels, cell_index, col_index, cell_mask
)
dist_per_token = torch.distributions.Bernoulli(logits=logits)

Expand All @@ -1285,7 +1285,7 @@ def forward(
if is_supervised:
# Note that `aggregate_mask` is None if the setting is supervised.
if aggregation_labels is not None:
assert label_ids.shape[0] == aggregation_labels.shape[0], "Make sure the aggregation labels are a LongTensor of shape (batch_size,)"
assert labels.shape[0] == aggregation_labels.shape[0], "Make sure the aggregation labels are a LongTensor of shape (batch_size,)"
per_example_additional_loss = _calculate_aggregation_loss(
logits_aggregation, aggregate_mask, aggregation_labels,
self.config.use_answer_as_supervision, self.config.num_aggregation_labels,
Expand All @@ -1297,7 +1297,7 @@ def forward(
)
else:
# Set aggregation labels to zeros
aggregation_labels = torch.zeros(label_ids.shape[0], dtype=torch.long, device=label_ids.device)
aggregation_labels = torch.zeros(labels.shape[0], dtype=torch.long, device=labels.device)
per_example_additional_loss = _calculate_aggregation_loss(
logits_aggregation, aggregate_mask, aggregation_labels,
self.config.use_answer_as_supervision, self.config.num_aggregation_labels,
Expand Down Expand Up @@ -1330,16 +1330,16 @@ def forward(

else:
# if no label ids are provided, set them to zeros in order to properly compute logits
label_ids = torch.zeros_like(logits)
labels = torch.zeros_like(logits)
_, logits = _single_column_cell_selection_loss(
logits, column_logits, label_ids, cell_index, col_index, cell_mask
logits, column_logits, labels, cell_index, col_index, cell_mask
)
if not return_dict:
output = (logits, logits_aggregation) + outputs[2:]
return ((total_loss,) + output) if calculate_loss else output

return TableQuestionAnsweringOutput(
loss=total_loss,
loss=total_loss if calculate_loss else None,
logits=logits,
logits_aggregation=logits_aggregation,
hidden_states=outputs.hidden_states,
Expand Down Expand Up @@ -1854,7 +1854,7 @@ def compute_column_logits(
return column_logits


def _single_column_cell_selection_loss(token_logits, column_logits, label_ids, cell_index, col_index, cell_mask):
def _single_column_cell_selection_loss(token_logits, column_logits, labels, cell_index, col_index, cell_mask):
"""
Computes the loss for cell selection constrained to a single column. The loss is a hierarchical log-likelihood. The
model first predicts a column and then selects cells within that column (conditioned on the column). Cells outside
Expand All @@ -1865,7 +1865,7 @@ def _single_column_cell_selection_loss(token_logits, column_logits, label_ids, c
Tensor containing the logits per token.
column_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, max_num_cols)`):
Tensor containing the logits per column.
label_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Labels per token.
cell_index (:obj:`ProductIndexMap`):
Index that groups tokens into cells.
Expand All @@ -1885,7 +1885,7 @@ def _single_column_cell_selection_loss(token_logits, column_logits, label_ids, c
# First find the column we should select. We use the column with maximum
# number of selected cells.
labels_per_column, _ = reduce_sum(
torch.as_tensor(label_ids, dtype=torch.float32, device=label_ids.device), col_index
torch.as_tensor(labels, dtype=torch.float32, device=labels.device), col_index
)
# shape of labels_per_column is (batch_size, max_num_cols). It contains the number of label ids for every column, for every example
column_label = torch.argmax(labels_per_column, dim=-1) # shape (batch_size,)
Expand All @@ -1894,7 +1894,7 @@ def _single_column_cell_selection_loss(token_logits, column_logits, label_ids, c
no_cell_selected = torch.eq(
torch.max(labels_per_column, dim=-1)[0], 0
) # no_cell_selected is of shape (batch_size,) and equals True
# if an example of the batch has no cells selected (i.e. if there are no label_ids set to 1 for that example)
# if an example of the batch has no cells selected (i.e. if there are no labels set to 1 for that example)
column_label = torch.where(
no_cell_selected.view(column_label.size()), torch.zeros_like(column_label), column_label
)
Expand All @@ -1909,7 +1909,7 @@ def _single_column_cell_selection_loss(token_logits, column_logits, label_ids, c
logits_per_cell, _ = reduce_mean(token_logits, cell_index)
# labels_per_cell: shape (batch_size, 64*32), indicating whether each cell should be selected (1) or not (0)
labels_per_cell, labels_index = reduce_max(
torch.as_tensor(label_ids, dtype=torch.long, device=label_ids.device), cell_index
torch.as_tensor(labels, dtype=torch.long, device=labels.device), cell_index
)

# Mask for the selected column.
Expand Down Expand Up @@ -1986,7 +1986,7 @@ def compute_token_logits(sequence_output, temperature, output_weights, output_bi
return logits


def _calculate_aggregate_mask(answer, pooled_output, cell_selection_preference, label_ids, aggregation_classifier):
def _calculate_aggregate_mask(answer, pooled_output, cell_selection_preference, labels, aggregation_classifier):
"""
Finds examples where the model should select cells with no aggregation.

Expand All @@ -2004,7 +2004,7 @@ def _calculate_aggregate_mask(answer, pooled_output, cell_selection_preference,
Output of the pooler (BertPooler) on top of the encoder layer.
cell_selection_preference (:obj:`float`):
Preference for cell selection in ambiguous cases.
label_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Labels per token. aggregation_classifier (:obj:`torch.nn.Linear`): Aggregation head

Returns:
Expand All @@ -2022,7 +2022,7 @@ def _calculate_aggregate_mask(answer, pooled_output, cell_selection_preference,
is_pred_cell_selection = aggregation_ops_total_mass <= cell_selection_preference

# Examples with non-empty cell selection supervision.
is_cell_supervision_available = torch.sum(label_ids, dim=1) > 0
is_cell_supervision_available = torch.sum(labels, dim=1) > 0

# torch.where is not equivalent to tf.where (in tensorflow 1)
# hence the added .view on the condition to match the shape of the first tensor
Expand Down
33 changes: 13 additions & 20 deletions src/transformers/models/tapas/tokenization_tapas.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@
}


PRETRAINED_INIT_CONFIGURATION = {
"nielsr/tapas-base-finetuned-sqa": {"do_lower_case": True},
"nielsr/tapas-base-finetuned-wtq": {"do_lower_case": True},
"nielsr/tapas-base-finetuned-wikisql-supervised": {"do_lower_case": True},
}


class TapasTruncationStrategy(ExplicitEnum):
"""
Expand Down Expand Up @@ -178,15 +172,15 @@ class TapasTokenizer(PreTrainedTokenizer):
Users should refer to this superclass for more information regarding those methods.
:class:`~transformers.TapasTokenizer` creates several token type ids to encode tabular structure. To be more
precise, it adds 7 token type ids, in the following order: :obj:`segment_ids`, :obj:`column_ids`, :obj:`row_ids`,
:obj:`prev_label_ids`, :obj:`column_ranks`, :obj:`inv_column_ranks` and :obj:`numeric_relations`:
:obj:`prev_labels`, :obj:`column_ranks`, :obj:`inv_column_ranks` and :obj:`numeric_relations`:

- segment_ids: indicate whether a token belongs to the question (0) or the table (1). 0 for special tokens and
padding.
- column_ids: indicate to which column of the table a token belongs (starting from 1). Is 0 for all question
tokens, special tokens and padding.
- row_ids: indicate to which row of the table a token belongs (starting from 1). Is 0 for all question tokens,
special tokens and padding. Tokens of column headers are also 0.
- prev_label_ids: indicate whether a token was (part of) an answer to the previous question (1) or not (0). Useful
- prev_labels: indicate whether a token was (part of) an answer to the previous question (1) or not (0). Useful
in a conversational setup (such as SQA).
- column_ranks: indicate the rank of a table token relative to a column, if applicable. For example, if you have a
column "number of movies" with values 87, 53 and 69, then the column ranks of these tokens are 3, 1 and 2 respectively.
Expand Down Expand Up @@ -252,7 +246,6 @@ class TapasTokenizer(PreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION

def __init__(
self,
Expand Down Expand Up @@ -1153,10 +1146,10 @@ def prepare_for_model(
column_ids = self.create_column_token_type_ids_from_sequences(query_ids, table_data)
row_ids = self.create_row_token_type_ids_from_sequences(query_ids, table_data)
if not is_part_of_batch or (prev_answer_coordinates is None and prev_answer_text is None):
# simply set the prev_label_ids to zeros
prev_label_ids = [0] * len(row_ids)
# simply set the prev_labels to zeros
prev_labels = [0] * len(row_ids)
else:
prev_label_ids = self.get_answer_ids(
prev_labels = self.get_answer_ids(
column_ids, row_ids, table_data, prev_answer_text, prev_answer_coordinates
)

Expand Down Expand Up @@ -1185,13 +1178,13 @@ def prepare_for_model(
encoded_inputs["attention_mask"] = attention_mask

if answer_coordinates is not None and answer_text is not None:
label_ids = self.get_answer_ids(
labels = self.get_answer_ids(
column_ids, row_ids, table_data, answer_text, answer_coordinates
)
numeric_values = self._get_numeric_values(raw_table, column_ids, row_ids)
numeric_values_scale = self._get_numeric_values_scale(raw_table, column_ids, row_ids)

encoded_inputs["label_ids"] = label_ids
encoded_inputs["labels"] = labels
encoded_inputs["numeric_values"] = numeric_values
encoded_inputs["numeric_values_scale"] = numeric_values_scale

Expand All @@ -1200,7 +1193,7 @@ def prepare_for_model(
segment_ids,
column_ids,
row_ids,
prev_label_ids,
prev_labels,
column_ranks,
inv_column_ranks,
numeric_relations,
Expand Down Expand Up @@ -1829,8 +1822,8 @@ def _pad(
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] + [[self.pad_token_type_id] * 7] * difference
)
if "label_ids" in encoded_inputs:
encoded_inputs["label_ids"] = encoded_inputs["label_ids"] + [0] * difference
if "labels" in encoded_inputs:
encoded_inputs["labels"] = encoded_inputs["labels"] + [0] * difference
if "numeric_values" in encoded_inputs:
encoded_inputs["numeric_values"] = encoded_inputs["numeric_values"] + [float("nan")] * difference
if "numeric_values_scale" in encoded_inputs:
Expand All @@ -1845,8 +1838,8 @@ def _pad(
encoded_inputs["token_type_ids"] = [[self.pad_token_type_id] * 7] * difference + encoded_inputs[
"token_type_ids"
]
if "label_ids" in encoded_inputs:
encoded_inputs["label_ids"] = [0] * difference + encoded_inputs["label_ids"]
if "labels" in encoded_inputs:
encoded_inputs["labels"] = [0] * difference + encoded_inputs["labels"]
if "numeric_values" in encoded_inputs:
encoded_inputs["numeric_values"] = [float("nan")] * difference + encoded_inputs["numeric_values"]
if "numeric_values_scale" in encoded_inputs:
Expand Down Expand Up @@ -1918,7 +1911,7 @@ def convert_logits_to_predictions(
"segment_ids",
"column_ids",
"row_ids",
"prev_label_ids",
"prev_labels",
"column_ranks",
"inv_column_ranks",
"numeric_relations",
Expand Down
Loading