diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index b8c6637db3..26fb9add7a 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -222,6 +222,58 @@ def document_max_pooling(sentence_hidden_states: torch.Tensor, sentence_lengths: return result +def map_tokens_to_subtokens(subtoken_offsets, token_offsets, verbose: bool = False, subtokens=None, tokens=None): + + mapping: list[Optional[int]] = [] + for subtoken_id, subtoken in enumerate(subtoken_offsets): + + # subtokens of length 0 should not be mapped to anything + if subtoken[0] == subtoken[1]: + mapping.append(None) + continue + + mapping_found = False + + if verbose and subtokens: + print(f"trying to match {subtokens[subtoken_id]} ({subtoken})") + + # check if the subtoken is wholly contained within a token. If so, it should be mapped to this token + for token_id, token in enumerate(token_offsets): + + if verbose and tokens: + print(f" ... does {tokens[token_id]} (#{token_id}, {token}) match?") + + if token[0] - 1 <= subtoken[0] and token[1] >= subtoken[1]: + if verbose: + print(" ... yes!") + mapping.append(token_id) + mapping_found = True + break + + if mapping_found: + continue + + # if the subtoken is not wholly contained within a token, it may be partially contained + # in this case, take the first token in which it is partially contained + for token_id, token in enumerate(token_offsets): + if verbose and tokens: + print(f" ... does {tokens[token_id]} (#{token_id}, {token}) partially match?") + if token[0] >= subtoken[0]: + if verbose: + print(" ... yes!") + mapping.append(token_id) + mapping_found = True + break + + if mapping_found: + continue + + # if a subtoken cannot be mapped, the mapping is None + mapping.append(None) + + return mapping + + def _legacy_reconstruct_word_ids( embedding: "TransformerBaseEmbeddings", flair_tokens: list[list[str]] ) -> list[list[Optional[int]]]: @@ -354,6 +406,8 @@ def __init__( feature_extractor: Optional[FeatureExtractionMixin] = None, needs_manual_ocr: Optional[bool] = None, use_context_separator: bool = True, + use_raw_text_as_input: bool = False, + **kwargs, ) -> None: self.name = name super().__init__() @@ -374,6 +428,7 @@ def __init__( self.feature_extractor = feature_extractor self.use_context_separator = use_context_separator self.cls_pooling = cls_pooling + self.use_raw_text_as_input = use_raw_text_as_input tokenizer_params = list(inspect.signature(self.tokenizer.__call__).parameters.keys()) self.tokenizer_needs_ocr_boxes = "boxes" in tokenizer_params @@ -417,6 +472,7 @@ def to_args(self): "feature_extractor": self.feature_extractor, "use_context_separator": self.use_context_separator, "cls_pooling": self.cls_pooling, + "use_raw_text_as_input": self.use_raw_text_as_input, } if hasattr(self, "needs_manual_ocr"): args["needs_manual_ocr"] = self.needs_manual_ocr @@ -568,6 +624,17 @@ def __build_transformer_model_inputs( tokenizer_kwargs["is_split_into_words"] = True tokenizer_kwargs["text"] = [[t.text for t in tokens] for tokens in flair_tokens] + # if we use raw text as input #TODO: explain + if self.use_raw_text_as_input: + tokenizer_kwargs["is_split_into_words"] = False + tokenizer_kwargs["return_offsets_mapping"] = True + + # reconstruct text of sentences and preserve whitespace_after information + tokenizer_kwargs["text"] = [ + "".join([t.text if t.whitespace_after == 0 else t.text + " " * t.whitespace_after for t in tokens]) + for tokens in flair_tokens + ] + batch_encoding = self.tokenizer( **tokenizer_kwargs, stride=self.stride, @@ -627,12 +694,35 @@ def __build_transformer_model_inputs( if "bbox" in batch_encoding: model_kwargs["bbox"] = batch_encoding["bbox"].to(device, non_blocking=True) + # If we need a token-level embedding, we need to derive mappings between subtokens and flair tokens if self.token_embedding or self.needs_manual_ocr: assert sentence_lengths is not None # for type checking model_kwargs["token_lengths"] = torch.tensor(sentence_lengths, device=device) if self.tokenizer.is_fast: - word_ids_list = [batch_encoding.word_ids(i) for i in range(input_ids.size()[0])] + + if self.use_raw_text_as_input: + word_ids_list = [] + assert flair_tokens # assert that this is not None for mypy type checking + for sentence_no, sentence_tokens in enumerate(flair_tokens): + + subtoken_offsets = batch_encoding["offset_mapping"][sentence_no] + + offset = 0 + token_offsets = [] + for token in sentence_tokens: + token_offsets.append((offset, offset + len(token.text))) + offset += len(token.text) + token.whitespace_after + + mapping = map_tokens_to_subtokens( + subtoken_offsets=subtoken_offsets, + token_offsets=token_offsets, + ) + + word_ids_list.append(mapping) + + else: + word_ids_list = [batch_encoding.word_ids(i) for i in range(input_ids.size()[0])] else: word_ids_list = _legacy_reconstruct_word_ids( self, @@ -1053,6 +1143,7 @@ def __init__( transformers_model_kwargs: dict[str, Any] = {}, peft_config=None, peft_gradient_checkpointing_kwargs: Optional[dict[str, Any]] = {}, + use_raw_text_as_input: bool = False, **kwargs, ) -> None: """Instantiate transformers embeddings. @@ -1099,6 +1190,7 @@ def __init__( logging.set_verbosity_error() self.tokenizer: PreTrainedTokenizer + self.use_raw_text_as_input = use_raw_text_as_input self.feature_extractor: Optional[FeatureExtractionMixin] if tokenizer_data is None: diff --git a/tests/embeddings/test_transformer_word_embeddings.py b/tests/embeddings/test_transformer_word_embeddings.py index a2ca3716a5..604d957de3 100644 --- a/tests/embeddings/test_transformer_word_embeddings.py +++ b/tests/embeddings/test_transformer_word_embeddings.py @@ -4,10 +4,12 @@ import pytest import torch from PIL import Image +from torch import tensor from transformers.utils import is_detectron2_available from flair.data import BoundingBox, Dictionary, Sentence from flair.embeddings import TransformerJitWordEmbeddings, TransformerWordEmbeddings +from flair.embeddings.transformer import map_tokens_to_subtokens from flair.models import SequenceTagger from tests.embedding_test_utils import BaseEmbeddingsTest @@ -323,3 +325,77 @@ def test_onnx_export_works(self, results_base_path): for sent_a, sent_b in zip(normal_sentences, onnx_sentences): for token_a, token_b in zip(sent_a, sent_b): assert torch.isclose(token_a.get_embedding(), token_b.get_embedding(), atol=1e-6).all() + + def test_token_subtoken_mapping(self): + ### Test Case 1: Normal text + # text = "BEST DENTIST EVER -" + + # Token and subtoken offsets + # tokens = ["[FLERT]", "BEST", "DENTIST", "EVER", "-", "[FLERT]"] + token_offsets = [(0, 7), (8, 12), (13, 20), (21, 25), (26, 27), (27, 34)] + + # subtokens = ["[CLS]", "[FLERT]", "▁BEST", "▁D", "ENT", "IST", "▁EVER", "▁-", "[FLERT]", "[SEP]", ] + subtoken_offsets = tensor( + [[0, 0], [0, 7], [8, 12], [12, 14], [14, 17], [17, 20], [20, 25], [25, 27], [27, 34], [0, 0]] + ) + + mapping = map_tokens_to_subtokens(subtoken_offsets=subtoken_offsets, token_offsets=token_offsets) + + assert [None, 0, 1, 2, 2, 2, 3, 4, 5, None] == mapping + + ### Test Case 2: Differing tokenizations + # text = "So don't be afraid" + + # Token and subtoken offsets + # tokens = ["[FLERT]", "So", "do", "n't", "be", "afraid", "[FLERT]"] + token_offsets = [(0, 7), (8, 10), (11, 13), (13, 16), (17, 19), (20, 26), (26, 33)] + + # subtokens = ["[CLS]", "[FLERT]", "▁So", "▁don", "'", "t", "▁be", "▁afraid", "[FLERT]", "[SEP]"] + subtoken_offsets = tensor( + [[0, 0], [0, 7], [8, 10], [10, 14], [14, 15], [15, 16], [16, 19], [19, 26], [26, 33], [0, 0]] + ) + + mapping = map_tokens_to_subtokens(subtoken_offsets=subtoken_offsets, token_offsets=token_offsets) + + assert [None, 0, 1, 2, 3, 3, 4, 5, 6, None] == mapping + + ### Test Case 3: Text with punctuation and no whitespaces + # text = "this and/or that," + + # Token and subtoken offsets + # tokens = ["[FLERT]", "this", "and", "/", "or", "that", ",", "[FLERT]"] + token_offsets = [(0, 7), (8, 12), (13, 16), (16, 17), (17, 19), (20, 24), (24, 25), (25, 32)] + + # subtokens = ["[CLS]", "[FLERT]", "▁this", "▁and", "/", "or", "▁that", ",", "[FLERT]", "[SEP]"] + subtoken_offsets = tensor( + [[0, 0], [0, 7], [8, 12], [12, 16], [16, 17], [17, 19], [19, 24], [24, 25], [25, 32], [0, 0]] + ) + + mapping = map_tokens_to_subtokens(subtoken_offsets=subtoken_offsets, token_offsets=token_offsets) + + assert [None, 0, 1, 2, 3, 4, 5, 6, 7, None] == mapping + + ### Test Case 4: Suboptimal tokenization caused by limited vocabulary without whitespace + # text = "number of public-diplomacy officers" + + # Token and subtoken offsets + # tokens = ['number', 'of', 'public', '-', 'diplomacy', 'officers'] + token_offsets = [(0, 6), (7, 9), (10, 16), (16, 17), (17, 26), (27, 35)] + + # new_subtokens = ['[CLS]', '▁number', '▁of', '▁public', '-', 'diploma', 'cy', '▁officers', '[SEP]'] + # old_subtokens = ['[CLS]', '▁number', '▁of', '▁public', '▁-', '▁diplomacy', '▁officers', '[SEP]'] + subtoken_offsets = tensor([[0, 0], [0, 6], [6, 9], [9, 16], [16, 17], [17, 24], [24, 26], [26, 35], [0, 0]]) + + assert [None, 0, 1, 2, 3, 4, 5, 6, 7, None] == mapping + + ### Test Case 5: Suboptimal tokenization in which two tokenizer words become one subtoken ("wan" "na" -> "wanna") + # text = "I gotta have it" + + # Token and subtoken offsets + # tokens = ['I', 'got', 'ta', 'have', 'it'] + token_offsets = [(0, 1), (2, 5), (5, 7), (8, 12), (13, 15)] + + # new subtokens = ['[CLS]', '▁I', '▁gotta', '▁have', '▁it', '[SEP]'] + # old subtokens = ['[CLS]', '▁I', '▁got', '▁ta', '▁have', '▁it', '[SEP]'] + subtoken_offsets = tensor([[0, 0], [0, 1], [1, 7], [7, 12], [12, 15], [0, 0]]) + assert [None, 0, 1, 2, 3, 4, 5, 6, 7, None] == mapping