Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/transformers/models/tapas/tokenization_tapas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1901,9 +1901,9 @@ def convert_logits_to_predictions(self, data, logits, logits_agg=None, cell_clas
data (:obj:`dict`):
Dictionary mapping features to actual values. Should be created using
:class:`~transformers.TapasTokenizer`.
logits (:obj:`torch.FloatTensor` of shape ``(batch_size, sequence_length)``):
logits (:obj:`np.ndarray` of shape ``(batch_size, sequence_length)``):
Tensor containing the logits at the token level.
logits_agg (:obj:`torch.FloatTensor` of shape ``(batch_size, num_aggregation_labels)``, `optional`):
logits_agg (:obj:`np.ndarray` of shape ``(batch_size, num_aggregation_labels)``, `optional`):
Tensor containing the aggregation logits.
cell_classification_threshold (:obj:`float`, `optional`, defaults to 0.5):
Threshold to be used for cell selection. All table cells for which their probability is larger than
Expand All @@ -1917,8 +1917,13 @@ def convert_logits_to_predictions(self, data, logits, logits_agg=None, cell_clas
predicted_aggregation_indices (`optional`, returned when ``logits_aggregation`` is provided) ``List[int]``
of length ``batch_size``: Predicted aggregation operator indices of the aggregation head.
"""
# compute probabilities from token logits
# DO sigmoid here
# input data is of type float32
# np.log(np.finfo(np.float32).max) = 88.72284
# Any value over 88.72284 will overflow when passed through the exponential, sending a warning
# We disable this warning by truncating the logits.
logits[logits < -88.7] = -88.7

# Compute probabilities from token logits
probabilities = 1 / (1 + np.exp(-logits)) * data["attention_mask"]
token_types = [
"segment_ids",
Expand Down