Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/transformers/models/tapas/tokenization_tapas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,7 @@ def prepare_for_model(
num_columns = self._get_num_columns(raw_table)
_, _, num_tokens = self._get_table_boundaries(tokenized_table)

if truncation != TapasTruncationStrategy.DO_NOT_TRUNCATE and max_length:
if truncation != TapasTruncationStrategy.DO_NOT_TRUNCATE:
num_rows, num_tokens = self._get_truncated_table_rows(query_tokens, tokenized_table, num_rows, num_columns,
max_length, truncation_strategy=truncation)
table_data = list(self._get_table_values(tokenized_table, num_columns, num_rows, num_tokens))
Expand Down Expand Up @@ -1280,6 +1280,9 @@ def _get_truncated_table_rows(
if not isinstance(truncation_strategy, TapasTruncationStrategy):
truncation_strategy = TapasTruncationStrategy(truncation_strategy)

if max_length is None:
max_length = self.model_max_length

if truncation_strategy == TapasTruncationStrategy.DROP_ROWS_TO_FIT:
while True:
num_tokens = self._get_max_num_tokens(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_tapas.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
else None
)
test_pruning = False
test_torchscript = True
test_torchscript = False
test_resize_embeddings = True
test_head_masking = False

Expand Down
8 changes: 8 additions & 0 deletions tests/test_tokenization_tapas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,14 @@ def test_tapas_truncation_integration_test(self):
# Ensure that the input IDs are less than the max length defined.
self.assertLessEqual(len(new_encoded_inputs), i)

tokenizer.model_max_length = 20
new_encoded_inputs = tokenizer.encode(table=table, query=queries[0], truncation=True)
dropped_encoded_inputs = tokenizer.encode(table=table, query=queries[0], truncation="drop_rows_to_fit")

# Ensure that the input IDs are still truncated when no max_length is specified
self.assertListEqual(new_encoded_inputs, dropped_encoded_inputs)
self.assertLessEqual(len(new_encoded_inputs), 20)

@is_pt_tf_cross_test
def test_batch_encode_plus_tensors(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
Expand Down