diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2122d7579c86..df1c3ee1f2bb 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -191,7 +191,7 @@ from .tokenization_retribert import RetriBertTokenizer from .tokenization_roberta import RobertaTokenizer from .tokenization_squeezebert import SqueezeBertTokenizer -from .tokenization_tapas import TapasTokenizer +from .tokenization_tapas import TapasTokenizer, TapasTruncationStrategy from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils_base import ( diff --git a/src/transformers/tokenization_tapas.py b/src/transformers/tokenization_tapas.py index 80f55a688820..5f2337b48a8f 100644 --- a/src/transformers/tokenization_tapas.py +++ b/src/transformers/tokenization_tapas.py @@ -24,12 +24,12 @@ import os import re import unicodedata -import warnings from dataclasses import dataclass from typing import Callable, Dict, Generator, List, Optional, Text, Tuple, Union import pandas as pd import torch +from transformers import add_end_docstrings from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace from .tokenization_utils_base import ( @@ -39,7 +39,7 @@ PreTokenizedInput, TensorType, TextInput, - TruncationStrategy, + ExplicitEnum, ENCODE_KWARGS_DOCSTRING, ) from .utils import logging @@ -65,6 +65,16 @@ } +class TapasTruncationStrategy(ExplicitEnum): + """ + Possible values for the ``truncation`` argument in :meth:`~transformers.TapasTokenizer.__call__`. Useful for + tab-completion in an IDE. + """ + + DROP_ROWS_TO_FIT = "drop_rows_to_fit" + DO_NOT_TRUNCATE = "do_not_truncate" + + TableValue = collections.namedtuple("TokenValue", ["token", "column_id", "row_id"]) @@ -112,6 +122,46 @@ def whitespace_tokenize(text): tokens = text.split() return tokens +TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" + add_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to encode the sequences with the special tokens relative to their model. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): + Activates and controls padding. Accepts the following values: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.TapasTruncationStrategy`, `optional`, defaults to :obj:`False`): + Activates and controls truncation. Accepts the following values: + + * :obj:`True` or :obj:`'drop_rows_to_fit'`: Truncate to a maximum length specified with the argument + :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not + provided. This will truncate row by row, removing rows from the table. + * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with + sequence lengths greater than the model maximum admissible input size). + max_length (:obj:`int`, `optional`): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum + length is required by one of the truncation/padding parameters. If the model has no specific maximum + input length (like XLNet) truncation/padding to a maximum length will be deactivated. + is_split_into_words (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the input is already pre-tokenized (e.g., split into words), in which case the tokenizer + will skip the pre-tokenization step. This is useful for NER or token classification. + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). + return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. +""" + class TapasTokenizer(PreTrainedTokenizer): r""" @@ -218,8 +268,16 @@ def __init__( strip_column_names: bool = False, update_answer_coordinates: bool = False, drop_rows_to_fit: bool = False, + model_max_length: int = 512, + additional_special_tokens: Optional[List[str]] = None, **kwargs ): + if additional_special_tokens is not None: + if empty_token not in additional_special_tokens: + additional_special_tokens.append(empty_token) + else: + additional_special_tokens = [empty_token] + super().__init__( do_lower_case=do_lower_case, do_basic_tokenize=do_basic_tokenize, @@ -229,6 +287,7 @@ def __init__( pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, + empty_token=empty_token, tokenize_chinese_chars=tokenize_chinese_chars, strip_accents=strip_accents, cell_trim_length=cell_trim_length, @@ -237,7 +296,8 @@ def __init__( strip_column_names=strip_column_names, update_answer_coordinates=update_answer_coordinates, drop_rows_to_fit=drop_rows_to_fit, - additional_special_tokens=[empty_token], + model_max_length=model_max_length, + additional_special_tokens=additional_special_tokens, **kwargs, ) @@ -327,23 +387,67 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = return (vocab_file,) def create_attention_mask_from_sequences(self, query_ids: List[int], table_values: List[TableValue]) -> List[int]: + """ + Creates the attention mask according to the query token IDs and a list of table values. + + Args: + query_ids (:obj:`List[int]`): list of token IDs corresponding to the ID. + table_values (:obj:`List[TableValue]`): lift of table values, which are named tuples containing the + token value, the column ID and the row ID of said token. + + Returns: + :obj:`List[int]`: List of ints containing the attention mask values. + """ return [1] * (1 + len(query_ids) + 1 + len(table_values)) def create_segment_token_type_ids_from_sequences( self, query_ids: List[int], table_values: List[TableValue] ) -> List[int]: + """ + Creates the segment token type IDs according to the query token IDs and a list of table values. + + Args: + query_ids (:obj:`List[int]`): list of token IDs corresponding to the ID. + table_values (:obj:`List[TableValue]`): lift of table values, which are named tuples containing the + token value, the column ID and the row ID of said token. + + Returns: + :obj:`List[int]`: List of ints containing the segment token type IDs values. + """ table_ids = list(zip(*table_values))[0] if table_values else [] return [0] * (1 + len(query_ids) + 1) + [1] * len(table_ids) def create_column_token_type_ids_from_sequences( self, query_ids: List[int], table_values: List[TableValue] ) -> List[int]: + """ + Creates the column token type IDs according to the query token IDs and a list of table values. + + Args: + query_ids (:obj:`List[int]`): list of token IDs corresponding to the ID. + table_values (:obj:`List[TableValue]`): lift of table values, which are named tuples containing the + token value, the column ID and the row ID of said token. + + Returns: + :obj:`List[int]`: List of ints containing the column token type IDs values. + """ table_column_ids = list(zip(*table_values))[1] if table_values else [] return [0] * (1 + len(query_ids) + 1) + list(table_column_ids) def create_row_token_type_ids_from_sequences( self, query_ids: List[int], table_values: List[TableValue] ) -> List[int]: + """ + Creates the row token type IDs according to the query token IDs and a list of table values. + + Args: + query_ids (:obj:`List[int]`): list of token IDs corresponding to the ID. + table_values (:obj:`List[TableValue]`): lift of table values, which are named tuples containing the + token value, the column ID and the row ID of said token. + + Returns: + :obj:`List[int]`: List of ints containing the row token type IDs values. + """ table_row_ids = list(zip(*table_values))[2] if table_values else [] return [0] * (1 + len(query_ids) + 1) + list(table_row_ids) @@ -397,6 +501,7 @@ def get_special_tokens_mask( return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) return [1] + ([0] * len(token_ids_0)) + [1] + @add_end_docstrings(TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) def __call__( self, table: pd.DataFrame, @@ -424,10 +529,8 @@ def __call__( ] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, @@ -446,25 +549,20 @@ def __call__( table (:obj:`pd.DataFrame`): Table containing tabular data. Note that all cell values must be text. Use `.astype(str)` on a Pandas dataframe to convert it to string. - queries (:obj:`str`, :obj:`List[str]`): - Question or batch of questions related to a table to be encoded. Each query can be a string or a list - of strings (pretokenized string). If the queries are provided as list of strings (pretokenized), you - must set :obj:`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). Note that + queries (:obj:`str` or :obj:`List[str]`): + Question or batch of questions related to a table to be encoded. Note that in case of a batch, all questions must refer to the **same** table. - answer_coordinates (:obj:`List[Tuple]`, :obj:`List[List[Tuple]]`, `optional`): + answer_coordinates (:obj:`List[Tuple]` or :obj:`List[List[Tuple]]`, `optional`): Answer coordinates of each table-question pair in the batch. In case only a single table-question pair is provided, then the answer_coordinates must be a single list of one or more tuples. Each tuple must be a (row_index, column_index) pair. The first data row (not the column header row) has index 0. The first column has index 0. In case a batch of table-question pairs is provided, then the answer_coordinates must be a list of lists of tuples (each list corresponding to a single table-question pair). - answer_text (:obj:`List[str]`, :obj:`List[List[str]]`, `optional`): + answer_text (:obj:`List[str]` or :obj:`List[List[str]]`, `optional`): Answer text of each table-question pair in the batch. In case only a single table-question pair is provided, then the answer_text must be a single list of one or more strings. Each string must be the answer text of a corresponding answer coordinate. In case a batch of table-question pairs is provided, then the answer_coordinates must be a list of lists of strings (each list corresponding to a single table-question pair). - - For the other parameters, we refer to the documentation of :meth:`~transformers.PreTrainedTokenizer.__call__`. - """ assert isinstance(table, pd.DataFrame), "Table must be of type pd.DataFrame" @@ -502,8 +600,6 @@ def __call__( padding=padding, truncation=truncation, max_length=max_length, - stride=stride, - is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, @@ -525,8 +621,6 @@ def __call__( padding=padding, truncation=truncation, max_length=max_length, - stride=stride, - is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, @@ -539,6 +633,7 @@ def __call__( **kwargs, ) + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) def batch_encode_plus( self, table: pd.DataFrame, @@ -553,10 +648,8 @@ def batch_encode_plus( answer_text: Optional[List[List[TextInput]]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, @@ -568,16 +661,29 @@ def batch_encode_plus( verbose: bool = True, **kwargs ) -> BatchEncoding: + """ + Prepare a table and a list of strings for the model. - padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( - padding=padding, - truncation=truncation, - max_length=max_length, - pad_to_multiple_of=pad_to_multiple_of, - verbose=verbose, - **kwargs, - ) + .. warning:: + This method is deprecated, ``__call__`` should be used instead. + Args: + table (:obj:`pd.DataFrame`): + Table containing tabular data. Note that all cell values must be text. Use `.astype(str)` on a Pandas dataframe to + convert it to string. + queries (:obj:`List[str]`): + Batch of questions related to a table to be encoded. Note that all questions must refer to + the **same** table. + answer_coordinates (:obj:`List[Tuple]` or :obj:`List[List[Tuple]]`, `optional`): + Answer coordinates of each table-question pair in the batch. Each tuple must be + a (row_index, column_index) pair. The first data row (not the column header row) has index 0. The first column + has index 0. The answer_coordinates must be a + list of lists of tuples (each list corresponding to a single table-question pair). + answer_text (:obj:`List[str]` or :obj:`List[List[str]]`, `optional`): + Answer text of each table-question pair in the batch. In case a batch of table-question pairs is provided, then + the answer_coordinates must be a list of lists of strings (each list corresponding to a single table-question pair). Each string must be + the answer text of a corresponding answer coordinate. + """ if return_token_type_ids is not None and not add_special_tokens: raise ValueError( "Asking to return token_type_ids while setting add_special_tokens to False " @@ -606,11 +712,9 @@ def batch_encode_plus( answer_coordinates=answer_coordinates, answer_text=answer_text, add_special_tokens=add_special_tokens, - padding_strategy=padding_strategy, - truncation_strategy=truncation_strategy, + padding=padding, + truncation=truncation, max_length=max_length, - stride=stride, - is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, @@ -634,11 +738,9 @@ def _batch_encode_plus( answer_coordinates: Optional[List[List[Tuple]]] = None, answer_text: Optional[List[List[TextInput]]] = None, add_special_tokens: bool = True, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = True, @@ -650,40 +752,24 @@ def _batch_encode_plus( verbose: bool = True, **kwargs ) -> BatchEncoding: - table_tokens = self._tokenize_table(table) queries_tokens = [] - queries_ids = [] for query in queries: query_tokens = self.tokenize(query) queries_tokens.append(query_tokens) - queries_ids.append(self.convert_tokens_to_ids(query_tokens)) - - num_rows = self._get_num_rows(table, self.drop_rows_to_fit) - num_columns = self._get_num_columns(table) - - _, _, num_tokens = self._get_table_boundaries(table_tokens) - - table_data = list(self._get_table_values(table_tokens, num_columns, num_rows, num_tokens)) - - table_ids = list(zip(*table_data))[0] if len(table_data) > 0 else list(zip(*table_data)) - table_ids = self.convert_tokens_to_ids(list(table_ids)) batch_outputs = self._batch_prepare_for_model( - table_ids, - queries_ids, table, queries, - table_data=table_data, + tokenized_table=table_tokens, queries_tokens=queries_tokens, answer_coordinates=answer_coordinates, + padding=padding, + truncation=truncation, answer_text=answer_text, add_special_tokens=add_special_tokens, - padding=padding_strategy.value, - truncation=truncation_strategy.value, max_length=max_length, - stride=stride, pad_to_multiple_of=pad_to_multiple_of, return_tensors=return_tensors, prepend_batch_axis=True, @@ -699,27 +785,24 @@ def _batch_encode_plus( def _batch_prepare_for_model( self, - table_ids: List[int], - queries_ids: List[List[int]], raw_table: pd.DataFrame, raw_queries: Union[ List[TextInput], List[PreTokenizedInput], List[EncodedInput], ], + tokenized_table: Optional[TokenizedTable] = None, + queries_tokens: Optional[List[List[str]]] = None, answer_coordinates: Optional[List[List[Tuple]]] = None, answer_text: Optional[List[List[TextInput]]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = True, return_attention_mask: Optional[bool] = True, - return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, @@ -727,39 +810,14 @@ def _batch_prepare_for_model( prepend_batch_axis: bool = False, **kwargs ) -> BatchEncoding: - """ - Prepares a sequence of strings (queries) related to a table so that it can be used by the model. It creates - input ids, adds special tokens, truncates the table if overflowing (if the drop_rows_to_fit parameter is set to - True) while taking into account the special tokens and manages a moving window (with user defined stride) for - overflowing tokens - - This function is based on prepare_for_model (but in Tapas, training examples depend on each other, so we - defined it at a batch level) - - Args: - table: Pandas dataframe - queries: List of Strings, containing questions related to the table - """ batch_outputs = {} - if "table_data" in kwargs and "queries_tokens" in kwargs: - table_data = kwargs["table_data"] - queries_tokens = kwargs["queries_tokens"] - else: - table_data = None - queries_tokens = [None] * len(queries_ids) - - for index, example in enumerate(zip( - queries_ids, raw_queries, queries_tokens, answer_coordinates, answer_text - ) - ): - query_ids, raw_query, query_tokens, answer_coords, answer_txt = example + for index, example in enumerate(zip(raw_queries, queries_tokens, answer_coordinates, answer_text)): + raw_query, query_tokens, answer_coords, answer_txt = example outputs = self.prepare_for_model( - table_ids, - query_ids, raw_table, raw_query, - table_data=table_data, + tokenized_table=tokenized_table, query_tokens=query_tokens, answer_coordinates=answer_coords, answer_text=answer_txt, @@ -767,11 +825,9 @@ def _batch_prepare_for_model( padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterwards truncation=truncation, max_length=max_length, - stride=stride, pad_to_multiple_of=None, # we pad in batch afterwards return_attention_mask=False, # we pad in batch afterwards return_token_type_ids=return_token_type_ids, - return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_length=return_length, return_tensors=None, # We convert the whole batch to tensors at the end @@ -798,6 +854,7 @@ def _batch_prepare_for_model( return batch_outputs + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING) def encode( self, table: pd.DataFrame, @@ -810,12 +867,23 @@ def encode( ] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, max_length: Optional[int] = None, - stride: int = 0, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs ) -> List[int]: + """ + Prepare a table and a string for the model. This method does not return token type IDs, attention masks, etc. + which are necessary for the model to work correctly. Use that method if you want to build your processing + on your own, otherwise refer to ``__call__``. + + Args: + table (:obj:`pd.DataFrame`): + Table containing tabular data. Note that all cell values must be text. Use `.astype(str)` on a Pandas dataframe to + convert it to string. + query (:obj:`str` or :obj:`List[str]`): + Question related to a table to be encoded. + """ encoded_inputs = self.encode_plus( table, query=query, @@ -823,13 +891,13 @@ def encode( padding=padding, truncation=truncation, max_length=max_length, - stride=stride, return_tensors=return_tensors, **kwargs, ) return encoded_inputs["input_ids"] + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) def encode_plus( self, table: pd.DataFrame, @@ -844,30 +912,37 @@ def encode_plus( answer_text: Optional[List[TextInput]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, **kwargs ) -> BatchEncoding: - padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( - padding=padding, - truncation=truncation, - max_length=max_length, - pad_to_multiple_of=pad_to_multiple_of, - verbose=verbose, - **kwargs, - ) + """ + Prepare a table and a string for the model. + Args: + table (:obj:`pd.DataFrame`): + Table containing tabular data. Note that all cell values must be text. Use `.astype(str)` on a Pandas dataframe to + convert it to string. + query (:obj:`str` or :obj:`List[str]`): + Question related to a table to be encoded. + answer_coordinates (:obj:`List[Tuple]` or :obj:`List[List[Tuple]]`, `optional`): + Answer coordinates of each table-question pair in the batch. The answer_coordinates must be a single + list of one or more tuples. Each tuple must be + a (row_index, column_index) pair. The first data row (not the column header row) has index 0. The first column + has index 0. + answer_text (:obj:`List[str]` or :obj:`List[List[str]]`, `optional`): + Answer text of each table-question pair in the batch. The answer_text must be a single list of one + or more strings. Each string must be + the answer text of a corresponding answer coordinate. + """ if return_token_type_ids is not None and not add_special_tokens: raise ValueError( "Asking to return token_type_ids while setting add_special_tokens to False " @@ -894,16 +969,13 @@ def encode_plus( answer_coordinates=answer_coordinates, answer_text=answer_text, add_special_tokens=add_special_tokens, - padding_strategy=padding_strategy, - truncation_strategy=truncation_strategy, + truncation=truncation, + padding=padding, max_length=max_length, - stride=stride, - is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, @@ -922,16 +994,13 @@ def _encode_plus( answer_coordinates: Optional[List[Tuple]] = None, answer_text: Optional[List[TextInput]] = None, add_special_tokens: bool = True, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = True, return_attention_mask: Optional[bool] = True, - return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, @@ -948,65 +1017,48 @@ def _encode_plus( table_tokens = self._tokenize_table(table) query_tokens = self.tokenize(query) - num_rows = self._get_num_rows(table, self.drop_rows_to_fit) - num_columns = self._get_num_columns(table) - - _, _, num_tokens = self._get_table_boundaries(table_tokens) - - table_data = list(self._get_table_values(table_tokens, num_columns, num_rows, num_tokens)) - - query_ids = self.convert_tokens_to_ids(query_tokens) - table_ids = list(zip(*table_data))[0] if len(table_data) > 0 else list(zip(*table_data)) - table_ids = self.convert_tokens_to_ids(list(table_ids)) - return self.prepare_for_model( - table_ids, - query_ids, table, query, - table_data=table_data, + tokenized_table=table_tokens, query_tokens=query_tokens, answer_coordinates=answer_coordinates, answer_text=answer_text, add_special_tokens=add_special_tokens, - padding=padding_strategy.value, - truncation=truncation_strategy.value, + truncation=truncation, + padding=padding, max_length=max_length, - stride=stride, pad_to_multiple_of=pad_to_multiple_of, return_tensors=return_tensors, prepend_batch_axis=True, return_attention_mask=return_attention_mask, return_token_type_ids=return_token_type_ids, - return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_length=return_length, verbose=verbose, ) + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPAS_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) def prepare_for_model( self, - table_ids: List[int], - query_ids: List[int], raw_table: pd.DataFrame, raw_query: Union[ TextInput, PreTokenizedInput, EncodedInput, ], + tokenized_table: Optional[TokenizedTable] = None, + query_tokens: Optional[TokenizedTable] = None, answer_coordinates: Optional[List[Tuple]] = None, answer_text: Optional[List[TextInput]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = False, + truncation: Union[bool, str, TapasTruncationStrategy] = False, max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = True, return_attention_mask: Optional[bool] = True, - return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, @@ -1014,16 +1066,44 @@ def prepare_for_model( prepend_batch_axis: bool = False, **kwargs ) -> BatchEncoding: + """ + Prepares a sequence of input id so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens. - # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' - padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( - padding=padding, - truncation=truncation, - max_length=max_length, - pad_to_multiple_of=pad_to_multiple_of, - verbose=verbose, - **kwargs, - ) + Args: + raw_table (:obj:`pd.DataFrame`): + The original table before any transformation (like tokenization) was applied to it. + raw_query (:obj:`TextInput` or :obj:`PreTokenizedInput` or :obj:`EncodedInput`): + The original query before any transformation (like tokenization) was applied to it. + tokenized_table (:obj:`TokenizedTable`): + The table after tokenization. + query_tokens (:obj:`List[str]`): + The query after tokenization. + answer_coordinates (:obj:`List[Tuple]` or :obj:`List[List[Tuple]]`, `optional`): + Answer coordinates of each table-question pair in the batch. The answer_coordinates must be a single + list of one or more tuples. Each tuple must be + a (row_index, column_index) pair. The first data row (not the column header row) has index 0. The first column + has index 0. + answer_text (:obj:`List[str]` or :obj:`List[List[str]]`, `optional`): + Answer text of each table-question pair in the batch. The answer_text must be a single list of one + or more strings. Each string must be + the answer text of a corresponding answer coordinate. + """ + if isinstance(padding, bool): + if padding and (max_length is not None or pad_to_multiple_of is not None): + padding = PaddingStrategy.MAX_LENGTH + else: + padding = PaddingStrategy.DO_NOT_PAD + elif not isinstance(padding, PaddingStrategy): + padding = PaddingStrategy(padding) + + if isinstance(truncation, bool): + if truncation: + truncation = TapasTruncationStrategy.DROP_ROWS_TO_FIT + else: + truncation = TapasTruncationStrategy.DO_NOT_TRUNCATE + elif not isinstance(truncation, TapasTruncationStrategy): + truncation = TapasTruncationStrategy(truncation) encoded_inputs = {} @@ -1033,50 +1113,34 @@ def prepare_for_model( is_part_of_batch = True prev_answer_coordinates = kwargs["prev_answer_coordinates"] prev_answer_text = kwargs["prev_answer_text"] - - # This can be retrieved from the encoding step, which prevents recomputing. - # We still need to handle recomputing as `prepare_for_model` should be callable on raw IDs/table/query as well. - if ( - "table_data" not in kwargs - or "query_tokens" not in kwargs - or ( - ("table_data" in kwargs and kwargs["table_data"] is None) - and ("query_tokens" in kwargs and kwargs["query_tokens"] is None) - ) - ): - table_tokens = self._tokenize_table(raw_table) - num_rows = self._get_num_rows(raw_table, self.drop_rows_to_fit) - num_columns = self._get_num_columns(raw_table) - _, _, num_tokens = self._get_table_boundaries(table_tokens) - table_data = list(self._get_table_values(table_tokens, num_columns, num_rows, num_tokens)) - query_tokens = self.tokenize(raw_query) - else: - table_data = kwargs["table_data"] - query_tokens = kwargs["query_tokens"] - total_len = ( - len(query_ids) + len(table_ids) + (self.num_special_tokens_to_add(pair=True) if add_special_tokens else 0) - ) + num_rows = self._get_num_rows(raw_table, self.drop_rows_to_fit) + num_columns = self._get_num_columns(raw_table) + _, _, num_tokens = self._get_table_boundaries(tokenized_table) - overflowing_tokens = [] - if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: - query_ids, table_ids, overflowing_tokens = self.truncate_sequences( - query_ids, - pair_ids=table_ids, - num_tokens_to_remove=total_len - max_length, - truncation_strategy=truncation_strategy, - stride=stride, - ) + if truncation != TapasTruncationStrategy.DO_NOT_TRUNCATE and max_length: + num_rows, num_tokens = self._get_truncated_table_rows(query_tokens, tokenized_table, num_rows, num_columns, + max_length, truncation_strategy=truncation) + table_data = list(self._get_table_values(tokenized_table, num_columns, num_rows, num_tokens)) - if return_overflowing_tokens: - encoded_inputs["overflowing_tokens"] = overflowing_tokens - encoded_inputs["num_truncated_tokens"] = total_len - max_length + query_ids = self.convert_tokens_to_ids(query_tokens) + table_ids = list(zip(*table_data))[0] if len(table_data) > 0 else list(zip(*table_data)) + table_ids = self.convert_tokens_to_ids(list(table_ids)) + + if "return_overflowing_tokens" in kwargs and kwargs["return_overflowing_tokens"]: + raise ValueError("TAPAS does not return overflowing tokens as it works on tables.") if add_special_tokens: input_ids = self.build_inputs_with_special_tokens(query_ids, table_ids) else: input_ids = query_ids + table_ids + if max_length is not None and len(input_ids) > max_length: + raise ValueError( + "Could not encode the query and table header given the maximum length. Encoding the query and table" + f"header results in a length of {len(input_ids)} which is higher than the max_length of {max_length}" + ) + encoded_inputs["input_ids"] = input_ids segment_ids = self.create_segment_token_type_ids_from_sequences(query_ids, table_data) @@ -1156,11 +1220,11 @@ def prepare_for_model( self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True # Padding - if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + if padding != PaddingStrategy.DO_NOT_PAD or return_attention_mask: encoded_inputs = self.pad( encoded_inputs, max_length=max_length, - padding=padding_strategy.value, + padding=padding.value, pad_to_multiple_of=pad_to_multiple_of, return_attention_mask=return_attention_mask, ) @@ -1174,6 +1238,64 @@ def prepare_for_model( return batch_outputs + def _get_truncated_table_rows( + self, + query_tokens: List[str], + tokenized_table: TokenizedTable, + num_rows: int, + num_columns: int, + max_length: int, + truncation_strategy: Union[str, TapasTruncationStrategy], + ) -> Tuple[int, int]: + """ + Truncates a sequence pair in-place following the strategy. + + Args: + query_tokens (:obj:`List[str]`): + List of strings corresponding to the tokenized query. + tokenized_table (:obj:`TokenizedTable`): + Tokenized table + num_rows (:obj:`int`): + Total number of table rows + num_columns (:obj:`int`): + Total number of table columns + max_length (:obj:`int`): + Total maximum length. + truncation_strategy (:obj:`str` or :obj:`~transformers.TapasTruncationStrategy`): + Truncation strategy to use. Seeing as this method should only be called when truncating, the only + available strategy is the "drop_rows_to_fit" strategy. + + Returns: + :obj:`Tuple(int, int)`: tuple containing the number of rows after truncation, and the number of tokens + available for each table element. + """ + if not isinstance(truncation_strategy, TapasTruncationStrategy): + truncation_strategy = TapasTruncationStrategy(truncation_strategy) + + if truncation_strategy == TapasTruncationStrategy.DROP_ROWS_TO_FIT: + while True: + num_tokens = self._get_max_num_tokens( + query_tokens, + tokenized_table, + num_rows=num_rows, + num_columns=num_columns, + max_length=max_length + ) + + if num_tokens is not None: + # We could fit the table. + break + + # Try to drop a row to fit the table. + num_rows -= 1 + + if num_rows < 1: + break + elif truncation_strategy != TapasTruncationStrategy.DO_NOT_TRUNCATE: + raise ValueError(f"Unknown truncation strategy {truncation_strategy}.") + + return num_rows, num_tokens or 1 + def _tokenize_table( self, table=None, @@ -1223,7 +1345,7 @@ def _question_encoding_cost(self, question_tokens): # Two extra spots of SEP and CLS. return len(question_tokens) + 2 - def _get_token_budget(self, question_tokens): + def _get_token_budget(self, question_tokens, max_length=None): """ Computes the number of tokens left for the table after tokenizing a question, taking into account the max sequence length of the model. @@ -1233,7 +1355,7 @@ def _get_token_budget(self, question_tokens): List of question tokens. Returns: :obj:`int`: the number of tokens left for the table, given the model max length. """ - return self.model_max_length - self._question_encoding_cost(question_tokens) + return (max_length if max_length is not None else self.model_max_length) - self._question_encoding_cost(question_tokens) def _get_table_values(self, table, num_columns, num_rows, num_tokens) -> Generator[TableValue, None, None]: """Iterates over partial table and returns token, column and row indexes.""" @@ -1276,9 +1398,10 @@ def _get_max_num_tokens( tokenized_table, num_columns, num_rows, + max_length ): """Computes max number of tokens that can be squeezed into the budget.""" - token_budget = self._get_token_budget(question_tokens) + token_budget = self._get_token_budget(question_tokens, max_length) _, _, max_num_tokens = self._get_table_boundaries(tokenized_table) if self.cell_trim_length >= 0 and max_num_tokens > self.cell_trim_length: max_num_tokens = self.cell_trim_length diff --git a/tests/test_modeling_tapas.py b/tests/test_modeling_tapas.py index a519117857c6..76412cf7377b 100644 --- a/tests/test_modeling_tapas.py +++ b/tests/test_modeling_tapas.py @@ -368,55 +368,22 @@ def test_for_sequence_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) - # @slow - # def test_lm_outputs_same_as_reference_model(self): - # """Write something that could help someone fixing this here.""" - # checkpoint_path = "XXX/bart-large" - # model = self.big_model - # tokenizer = AutoTokenizer.from_pretrained( - # checkpoint_path - # ) # same with AutoTokenizer (see tokenization_auto.py). This is not mandatory - # # MODIFY THIS DEPENDING ON YOUR MODELS RELEVANT TASK. - # batch = tokenizer(["I went to the yesterday"]).to(torch_device) - # desired_mask_result = tokenizer.decode("store") # update this - # logits = model(**batch).logits - # masked_index = (batch.input_ids == self.tokenizer.mask_token_id).nonzero() - # assert model.num_parameters() == 175e9 # a joke - # mask_entry_logits = logits[0, masked_index.item(), :] - # probs = mask_entry_logits.softmax(dim=0) - # _, predictions = probs.topk(1) - # self.assertEqual(tokenizer.decode(predictions), desired_mask_result) - - # @cached_property - # def big_model(self): - # """Cached property means this code will only be executed once.""" - # checkpoint_path = "XXX/bart-large" - # model = AutoModelForMaskedLM.from_pretrained(checkpoint_path).to( - # torch_device - # ) # test whether AutoModel can determine your model_class from checkpoint name - # if torch_device == "cuda": - # model.half() - - # optional: do more testing! This will save you time later! - # @slow - # def test_that_XXX_can_be_used_in_a_pipeline(self): - # """We can use self.big_model here without calling __init__ again.""" - # pass - - # def test_XXX_loss_doesnt_change_if_you_add_padding(self): - # pass - - # def test_XXX_bad_args(self): - # pass - - # def test_XXX_backward_pass_reduces_loss(self): - # """Test loss/gradients same as reference implementation, for example.""" - # pass - - # @require_torch_and_cuda - # def test_large_inputs_in_fp16_dont_cause_overflow(self): - # pass +class TapasModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_masked_lm(self): + model = TapasForQuestionAnswering.from_pretrained("google/tapas-xxx") + + input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) + output = model(input_ids)[0] + expected_shape = torch.Size((1, 11, 50265)) + self.assertEqual(output.shape, expected_shape) + # compare the actual values for a slice. + expected_slice = torch.tensor( + [[[33.8802, -4.3103, 22.7761], [4.6539, -2.8098, 13.6253], [1.8228, -3.6898, 8.8600]]] + ) + + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) # Below: tests for Tapas utilities, based on segmented_tensor_test.py of the original implementation. # These test the operations on segmented tensors. diff --git a/tests/test_tokenization_tapas.py b/tests/test_tokenization_tapas.py index baae26423406..5d8cba376515 100644 --- a/tests/test_tokenization_tapas.py +++ b/tests/test_tokenization_tapas.py @@ -18,6 +18,7 @@ import tempfile import unittest from typing import List, Tuple +import numpy as np import pandas as pd @@ -38,16 +39,15 @@ @require_tokenizers class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase): - tokenizer_class = TapasTokenizer test_rust_tokenizer = False space_between_special_tokens = True from_pretrained_filter = filter_non_english def get_table( - self, - tokenizer: TapasTokenizer, - length=5, + self, + tokenizer: TapasTokenizer, + length=5, ): toks = [tokenizer.decode([i], clean_up_tokenization_spaces=False) for i in range(len(tokenizer))] @@ -61,10 +61,9 @@ def get_table( return table def get_table_and_query( - self, - tokenizer: TapasTokenizer, - add_special_tokens: bool = True, - length=5, + self, + tokenizer: TapasTokenizer, + length=5, ): toks = [tokenizer.decode([i], clean_up_tokenization_spaces=False) for i in range(len(tokenizer))] table = self.get_table(tokenizer, length=length - 3) @@ -73,14 +72,14 @@ def get_table_and_query( return table, query def get_clean_sequence( - self, - tokenizer: TapasTokenizer, - with_prefix_space=False, - max_length=20, - min_length=5, - empty_table: bool = False, - add_special_tokens: bool = True, - return_table_and_query: bool = False, + self, + tokenizer: TapasTokenizer, + with_prefix_space=False, + max_length=20, + min_length=5, + empty_table: bool = False, + add_special_tokens: bool = True, + return_table_and_query: bool = False, ): toks = [tokenizer.decode([i], clean_up_tokenization_spaces=False) for i in range(len(tokenizer))] @@ -104,19 +103,6 @@ def get_clean_sequence( return output_txt, output_ids - # def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5) -> Tuple[str, list]: - # data = { - # 'Actors': ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], - # 'Age': ["56", "45", "59"], - # 'Number of movies': ["87", "53", "69"], - # 'Date of birth': ["18 december 1963", "11 november 1974", "6 may 1961"] - # } - # table = pd.DataFrame.from_dict(data) - # output_ids = tokenizer.encode(table, add_special_tokens=False, max_length=max_length) - # output_txt = tokenizer.decode(output_ids) - # - # return output_txt, output_ids - def setUp(self): super().setUp() @@ -302,14 +288,9 @@ def test_is_punctuation(self): def test_clean_text(self): tokenizer = self.get_tokenizer() - # rust_tokenizer = self.get_rust_tokenizer() # Example taken from the issue https://github.com/huggingface/tokenizers/issues/340 - self.assertListEqual([tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]]) - - # self.assertListEqual( - # [rust_tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]] - # ) + self.assertListEqual([tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], ["[EMPTY]"], ["[UNK]"]]) @slow def test_sequence_builders(self): @@ -376,159 +357,6 @@ def test_offsets_with_special_characters(self): ) self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"]) - def test_tapas_integration_test(self): - data = { - "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], - "Age": ["56", "45", "59"], - "Number of movies": ["87", "53", "69"], - "Date of birth": ["18 december 1963", "11 november 1974", "6 may 1961"], - } - queries = [ - "When was Brad Pitt born?", - "Which actor appeared in the least number of movies?", - "What is the average number of movies?", - ] - table = pd.DataFrame.from_dict(data) - - # TODO: Should update this in the future - tokenizer = TapasTokenizer.from_pretrained("lysandre/tapas-temporary-repo", model_max_length=512) - - expected_results = { - "input_ids": [ - 101, - 2043, - 2001, - 8226, - 15091, - 2141, - 1029, - 102, - 5889, - 2287, - 2193, - 1997, - 5691, - 3058, - 1997, - 4182, - 8226, - 15091, - 5179, - 6584, - 2324, - 2285, - 3699, - 14720, - 4487, - 6178, - 9488, - 3429, - 5187, - 2340, - 2281, - 3326, - 2577, - 18856, - 7828, - 3240, - 5354, - 6353, - 1020, - 2089, - 3777, - ], - "attention_mask": [ - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - ], - "token_type_ids": [ - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0], - [1, 2, 0, 0, 0, 0, 0], - [1, 3, 0, 0, 0, 0, 0], - [1, 3, 0, 0, 0, 0, 0], - [1, 3, 0, 0, 0, 0, 0], - [1, 4, 0, 0, 0, 0, 0], - [1, 4, 0, 0, 0, 0, 0], - [1, 4, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0], - [1, 2, 1, 0, 2, 2, 0], - [1, 3, 1, 0, 3, 1, 0], - [1, 4, 1, 0, 2, 2, 0], - [1, 4, 1, 0, 2, 2, 0], - [1, 4, 1, 0, 2, 2, 0], - [1, 1, 2, 0, 0, 0, 0], - [1, 1, 2, 0, 0, 0, 0], - [1, 1, 2, 0, 0, 0, 0], - [1, 1, 2, 0, 0, 0, 0], - [1, 2, 2, 0, 1, 3, 0], - [1, 3, 2, 0, 1, 3, 0], - [1, 4, 2, 0, 3, 1, 0], - [1, 4, 2, 0, 3, 1, 0], - [1, 4, 2, 0, 3, 1, 0], - [1, 1, 3, 0, 0, 0, 0], - [1, 1, 3, 0, 0, 0, 0], - [1, 1, 3, 0, 0, 0, 0], - [1, 1, 3, 0, 0, 0, 0], - [1, 2, 3, 0, 3, 1, 0], - [1, 3, 3, 0, 2, 2, 0], - [1, 4, 3, 0, 1, 3, 0], - [1, 4, 3, 0, 1, 3, 0], - [1, 4, 3, 0, 1, 3, 0], - ], - } - - new_encoded_inputs = tokenizer.encode_plus(table=table, query=queries[0], padding="max_length") - - self.assertDictEqual(new_encoded_inputs, expected_results) - def test_add_special_tokens(self): tokenizers: List[TapasTokenizer] = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: @@ -642,7 +470,7 @@ def test_encode_plus_with_padding(self): not_padded_sequence = tokenizer.encode_plus( table, sequence, - padding=True, + padding=False, return_special_tokens_mask=True, ) not_padded_input_ids = not_padded_sequence["input_ids"] @@ -711,7 +539,8 @@ def test_encode_plus_with_padding(self): right_padded_token_type_ids = right_padded_sequence["token_type_ids"] assert ( - token_type_ids + [[token_type_padding_idx] * 7] * padding_size == right_padded_token_type_ids + token_type_ids + [ + [token_type_padding_idx] * 7] * padding_size == right_padded_token_type_ids ) assert [[token_type_padding_idx] * 7] * padding_size + token_type_ids == left_padded_token_type_ids @@ -749,8 +578,8 @@ def test_mask_output(self): table, query = self.get_table_and_query(tokenizer) if ( - tokenizer.build_inputs_with_special_tokens.__qualname__.split(".")[0] != "PreTrainedTokenizer" - and "token_type_ids" in tokenizer.model_input_names + tokenizer.build_inputs_with_special_tokens.__qualname__.split(".")[0] != "PreTrainedTokenizer" + and "token_type_ids" in tokenizer.model_input_names ): information = tokenizer.encode_plus(table, query, add_special_tokens=True) sequences, mask = information["input_ids"], information["token_type_ids"] @@ -800,7 +629,7 @@ def test_padding_to_max_length(self): sequence_length = len(encoded_sequence) # FIXME: the next line should be padding(max_length) to avoid warning padded_sequence = tokenizer.encode( - table, sequence, max_length=sequence_length + padding_size, pad_to_max_length=True + table, sequence, max_length=sequence_length + padding_size, padding=True ) padded_sequence_length = len(padded_sequence) assert sequence_length + padding_size == padded_sequence_length @@ -937,18 +766,9 @@ def test_batch_encode_plus_batch_sequence_length(self): encoded_sequences_batch_padded_2[key], ) + @unittest.skip("batch_encode_plus does not handle overflowing tokens.") def test_batch_encode_plus_overflowing_tokens(self): - tokenizers = self.get_tokenizers(do_lower_case=False) - for tokenizer in tokenizers: - table = self.get_table(tokenizer, length=0) - string_sequences = ["Testing the prepare_for_model method.", "Test"] - - if tokenizer.pad_token is None: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - tokenizer.batch_encode_plus( - table, string_sequences, return_overflowing_tokens=True, truncation=True, padding=True, max_length=3 - ) + pass def test_batch_encode_plus_padding(self): # Test that padded sequences are equivalent between batch_encode_plus and encode_plus @@ -1031,18 +851,6 @@ def test_padding_to_multiple_of(self): for key, value in normal_tokens.items(): self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key)) - # truncation to something which is not a multiple of pad_to_multiple_of raises an error - self.assertRaises( - ValueError, - tokenizer.__call__, - table, - "This", - padding=True, - truncation=True, - max_length=12, - pad_to_multiple_of=8, - ) - @unittest.skip("TAPAS cannot handle `prepare_for_model` without passing by `encode_plus` or `batch_encode_plus`") def test_prepare_for_model(self): pass @@ -1193,3 +1001,2287 @@ def test_right_and_left_padding(self): @unittest.skip("TAPAS doesn't handle pre-tokenized inputs.") def test_pretokenized_inputs(self): pass + + # TODO SET TO SLOW + def test_tapas_truncation_integration_test(self): + data = { + "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], + "Age": ["56", "45", "59"], + "Number of movies": ["87", "53", "69"], + "Date of birth": ["18 december 1963", "11 november 1974", "6 may 1961"], + } + queries = [ + "When was Brad Pitt born?", + "Which actor appeared in the least number of movies?", + "What is the average number of movies?", + ] + table = pd.DataFrame.from_dict(data) + + # TODO: Should update this in the future + tokenizer = TapasTokenizer.from_pretrained("lysandre/tapas-temporary-repo", model_max_length=512) + + for i in range(12): + # The table cannot even encode the headers, so raise an error + with self.assertRaises(ValueError): + tokenizer.encode(table=table, query=queries[0], max_length=i, truncation="drop_rows_to_fit") + + for i in range(12, 512): + new_encoded_inputs = tokenizer.encode(table=table, query=queries[0], max_length=i, truncation="drop_rows_to_fit") + + # Ensure that the input IDs are less than the max length defined. + self.assertLessEqual(len(new_encoded_inputs), i) + + # TODO SET TO SLOW + def test_tapas_integration_test(self): + data = { + "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], + "Age": ["56", "45", "59"], + "Number of movies": ["87", "53", "69"], + "Date of birth": ["18 december 1963", "11 november 1974", "6 may 1961"], + } + queries = [ + "When was Brad Pitt born?", + "Which actor appeared in the least number of movies?", + "What is the average number of movies?", + ] + table = pd.DataFrame.from_dict(data) + + # TODO: Should update this in the future + tokenizer = TapasTokenizer.from_pretrained("lysandre/tapas-temporary-repo", model_max_length=512) + + expected_results = { + "input_ids": [ + 101, + 2043, + 2001, + 8226, + 15091, + 2141, + 1029, + 102, + 5889, + 2287, + 2193, + 1997, + 5691, + 3058, + 1997, + 4182, + 8226, + 15091, + 5179, + 6584, + 2324, + 2285, + 3699, + 14720, + 4487, + 6178, + 9488, + 3429, + 5187, + 2340, + 2281, + 3326, + 2577, + 18856, + 7828, + 3240, + 5354, + 6353, + 1020, + 2089, + 3777, + ], + "attention_mask": [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ], + "token_type_ids": [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 2, 0, 0, 0, 0, 0], + [1, 3, 0, 0, 0, 0, 0], + [1, 3, 0, 0, 0, 0, 0], + [1, 3, 0, 0, 0, 0, 0], + [1, 4, 0, 0, 0, 0, 0], + [1, 4, 0, 0, 0, 0, 0], + [1, 4, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [1, 2, 1, 0, 2, 2, 0], + [1, 3, 1, 0, 3, 1, 0], + [1, 4, 1, 0, 2, 2, 0], + [1, 4, 1, 0, 2, 2, 0], + [1, 4, 1, 0, 2, 2, 0], + [1, 1, 2, 0, 0, 0, 0], + [1, 1, 2, 0, 0, 0, 0], + [1, 1, 2, 0, 0, 0, 0], + [1, 1, 2, 0, 0, 0, 0], + [1, 2, 2, 0, 1, 3, 0], + [1, 3, 2, 0, 1, 3, 0], + [1, 4, 2, 0, 3, 1, 0], + [1, 4, 2, 0, 3, 1, 0], + [1, 4, 2, 0, 3, 1, 0], + [1, 1, 3, 0, 0, 0, 0], + [1, 1, 3, 0, 0, 0, 0], + [1, 1, 3, 0, 0, 0, 0], + [1, 1, 3, 0, 0, 0, 0], + [1, 2, 3, 0, 3, 1, 0], + [1, 3, 3, 0, 2, 2, 0], + [1, 4, 3, 0, 1, 3, 0], + [1, 4, 3, 0, 1, 3, 0], + [1, 4, 3, 0, 1, 3, 0], + ], + } + + new_encoded_inputs = tokenizer.encode_plus(table=table, query=queries[0]) + + self.assertDictEqual(dict(new_encoded_inputs), expected_results) + + # TODO SET TO SLOW + def test_full_tokenizer(self): + data = [ + ["Pos", "No", "Driver", "Team", "Laps", "Time/Retired", "Grid", "Points"], + ["1", "32", "Patrick Carpentier", "Team Player's", "87", "1:48:11.023", "1", "22"], + ["2", "1", "Bruno Junqueira", "Newman/Haas Racing", "87", "+0.8 secs", "2", "17"], + ["3", "3", "Paul Tracy", "Team Player's", "87", "+28.6 secs", "3", "14"], + ["4", "9", "Michel Jourdain, Jr.", "Team Rahal", "87", "+40.8 secs", "13", "12"], + ["5", "34", "Mario Haberfeld", "Mi-Jack Conquest Racing", "87", "+42.1 secs", "6", "10"], + ["6", "20", "Oriol Servia", "Patrick Racing", "87", "+1:00.2", "10", "8"], + ["7", "51", "Adrian Fernandez", "Fernandez Racing", "87", "+1:01.4", "5", "6"], + ["8", "12", "Jimmy Vasser", "American Spirit Team Johansson", "87", "+1:01.8", "8", "5"], + ["9", "7", "Tiago Monteiro", "Fittipaldi-Dingman Racing", "86", "+ 1 Lap", "15", "4"], + ["10", "55", "Mario Dominguez", "Herdez Competition", "86", "+ 1 Lap", "11", "3"], + ["11", "27", "Bryan Herta", "PK Racing", "86", "+ 1 Lap", "12", "2"], + ["12", "31", "Ryan Hunter-Reay", "American Spirit Team Johansson", "86", "+ 1 Lap", "17", "1"], + ["13", "19", "Joel Camathias", "Dale Coyne Racing", "85", "+ 2 Laps", "18", "0"], + ["14", "33", "Alex Tagliani", "Rocketsports Racing", "85", "+ 2 Laps", "14", "0"], + ["15", "4", "Roberto Moreno", "Herdez Competition", "85", "+ 2 Laps", "9", "0"], + ["16", "11", "Geoff Boss", "Dale Coyne Racing", "83", "Mechanical", "19", "0"], + ["17", "2", "Sebastien Bourdais", "Newman/Haas Racing", "77", "Mechanical", "4", "0"], + ["18", "15", "Darren Manning", "Walker Racing", "12", "Mechanical", "7", "0"], + ["19", "5", "Rodolfo Lavin", "Walker Racing", "10", "Mechanical", "16", "0"], + ] + query = "what were the drivers names?" + table = pd.DataFrame.from_records(data[1:], columns=data[0]) + + # TODO: Should update this in the future + tokenizer = TapasTokenizer.from_pretrained("lysandre/tapas-temporary-repo", model_max_length=512) + model_inputs = tokenizer(table, query, padding="max_length") + + input_ids = model_inputs["input_ids"] + token_type_ids = np.array(model_inputs["token_type_ids"]) + segment_ids = token_type_ids[:, 0] + column_ids = token_type_ids[:, 1] + row_ids = token_type_ids[:, 2] + + expected_results = { + "input_ids": [ + 101, + 2054, + 2020, + 1996, + 6853, + 3415, + 1029, + 102, + 13433, + 2015, + 2053, + 4062, + 2136, + 10876, + 2051, + 1013, + 3394, + 8370, + 2685, + 1015, + 3590, + 4754, + 29267, + 4765, + 3771, + 2136, + 2447, + 1005, + 1055, + 6584, + 1015, + 1024, + 4466, + 1024, + 2340, + 1012, + 6185, + 2509, + 1015, + 2570, + 1016, + 1015, + 10391, + 12022, + 4226, + 7895, + 10625, + 1013, + 22996, + 3868, + 6584, + 1009, + 1014, + 1012, + 1022, + 10819, + 2015, + 1016, + 2459, + 1017, + 1017, + 2703, + 10555, + 2136, + 2447, + 1005, + 1055, + 6584, + 1009, + 2654, + 1012, + 1020, + 10819, + 2015, + 1017, + 2403, + 1018, + 1023, + 8709, + 8183, + 3126, + 21351, + 2078, + 1010, + 3781, + 1012, + 2136, + 10958, + 8865, + 6584, + 1009, + 2871, + 1012, + 1022, + 10819, + 2015, + 2410, + 2260, + 1019, + 4090, + 7986, + 5292, + 5677, + 8151, + 2771, + 1011, + 2990, + 9187, + 3868, + 6584, + 1009, + 4413, + 1012, + 1015, + 10819, + 2015, + 1020, + 2184, + 1020, + 2322, + 2030, + 20282, + 14262, + 9035, + 4754, + 3868, + 6584, + 1009, + 1015, + 1024, + 4002, + 1012, + 1016, + 2184, + 1022, + 1021, + 4868, + 7918, + 12023, + 12023, + 3868, + 6584, + 1009, + 1015, + 1024, + 5890, + 1012, + 1018, + 1019, + 1020, + 1022, + 2260, + 5261, + 12436, + 18116, + 2137, + 4382, + 2136, + 26447, + 6584, + 1009, + 1015, + 1024, + 5890, + 1012, + 1022, + 1022, + 1019, + 1023, + 1021, + 27339, + 3995, + 10125, + 9711, + 4906, + 25101, + 24657, + 1011, + 22033, + 2386, + 3868, + 6564, + 1009, + 1015, + 5001, + 2321, + 1018, + 2184, + 4583, + 7986, + 14383, + 2075, + 29488, + 14906, + 9351, + 2971, + 6564, + 1009, + 1015, + 5001, + 2340, + 1017, + 2340, + 2676, + 8527, + 2014, + 2696, + 1052, + 2243, + 3868, + 6564, + 1009, + 1015, + 5001, + 2260, + 1016, + 2260, + 2861, + 4575, + 4477, + 1011, + 2128, + 4710, + 2137, + 4382, + 2136, + 26447, + 6564, + 1009, + 1015, + 5001, + 2459, + 1015, + 2410, + 2539, + 8963, + 11503, + 25457, + 3022, + 8512, + 2522, + 9654, + 3868, + 5594, + 1009, + 1016, + 10876, + 2324, + 1014, + 2403, + 3943, + 4074, + 6415, + 15204, + 2072, + 12496, + 25378, + 3868, + 5594, + 1009, + 1016, + 10876, + 2403, + 1014, + 2321, + 1018, + 10704, + 17921, + 14906, + 9351, + 2971, + 5594, + 1009, + 1016, + 10876, + 1023, + 1014, + 2385, + 2340, + 14915, + 5795, + 8512, + 2522, + 9654, + 3868, + 6640, + 6228, + 2539, + 1014, + 2459, + 1016, + 28328, + 8945, + 3126, + 21351, + 2015, + 10625, + 1013, + 22996, + 3868, + 6255, + 6228, + 1018, + 1014, + 2324, + 2321, + 12270, + 11956, + 5232, + 3868, + 2260, + 6228, + 1021, + 1014, + 2539, + 1019, + 8473, + 28027, + 2080, + 2474, + 6371, + 5232, + 3868, + 2184, + 6228, + 2385, + 1014, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + "column_ids": [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 2, + 3, + 4, + 5, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 4, + 4, + 4, + 4, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 4, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 3, + 3, + 4, + 4, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 4, + 4, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 5, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 5, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 3, + 4, + 4, + 4, + 5, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 5, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 5, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 5, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 4, + 4, + 4, + 5, + 6, + 6, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 4, + 4, + 4, + 4, + 5, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 5, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 4, + 4, + 5, + 6, + 7, + 8, + 1, + 2, + 3, + 3, + 3, + 3, + 3, + 4, + 4, + 5, + 6, + 7, + 8, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + "row_ids": [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 3, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 10, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 11, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 12, + 13, + 13, + 13, + 13, + 13, + 13, + 13, + 13, + 13, + 13, + 13, + 13, + 13, + 13, + 13, + 13, + 14, + 14, + 14, + 14, + 14, + 14, + 14, + 14, + 14, + 14, + 14, + 14, + 14, + 14, + 14, + 15, + 15, + 15, + 15, + 15, + 15, + 15, + 15, + 15, + 15, + 15, + 15, + 15, + 16, + 16, + 16, + 16, + 16, + 16, + 16, + 16, + 16, + 16, + 16, + 16, + 17, + 17, + 17, + 17, + 17, + 17, + 17, + 17, + 17, + 17, + 17, + 17, + 17, + 17, + 17, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 18, + 19, + 19, + 19, + 19, + 19, + 19, + 19, + 19, + 19, + 19, + 19, + 19, + 19, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + "segment_ids": [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + } + + self.assertListEqual(input_ids, expected_results["input_ids"]) + self.assertListEqual(segment_ids.tolist(), expected_results["segment_ids"]) + self.assertListEqual(column_ids.tolist(), expected_results["column_ids"]) + self.assertListEqual(row_ids.tolist(), expected_results["row_ids"])