Skip to content

Commit

Permalink
Merge pull request #3593 from flairNLP/filter_relations
Browse files Browse the repository at this point in the history
Optimize RelationClassifier by adding the option to filter long sentences and truncate context
  • Loading branch information
alanakbik authored Feb 4, 2025
2 parents ae592bf + 863d903 commit 8acd698
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 37 deletions.
1 change: 0 additions & 1 deletion flair/models/regexp_tagger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
import typing
from dataclasses import dataclass, field
from typing import Union

Expand Down
86 changes: 80 additions & 6 deletions flair/models/relation_classifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,13 @@ def __init__(
entity_label_types: Union[str, Sequence[str], dict[str, Optional[set[str]]]],
entity_pair_labels: Optional[set[tuple[str, str]]] = None,
entity_threshold: Optional[float] = None,
max_allowed_tokens_between_entities: Optional[int] = 20,
max_surrounding_context_length: Optional[int] = 10,
cross_augmentation: bool = True,
encoding_strategy: EncodingStrategy = TypedEntityMarker(),
zero_tag_value: str = "O",
allow_unk_tag: bool = True,
**classifierargs,
**classifierargs: Any,
) -> None:
"""Initializes a `RelationClassifier`.
Expand All @@ -267,6 +269,8 @@ def __init__(
entity_label_types: A label type or sequence of label types of the required relation entities. You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. E.g. to use only 'PER' and 'ORG' labels from a NER-tagger: `{'ner': {'PER', 'ORG'}}`. To use all labels from 'ner', pass 'ner'.
entity_pair_labels: A set of valid relation entity pair combinations, used as relation candidates. Specify valid entity pairs in a set of tuples of labels (<HEAD>, <TAIL>). E.g. for the `born_in` relation, only relations from 'PER' to 'LOC' make sense. Here, relations from 'PER' to 'PER' are not meaningful, so it is advised to specify the `entity_pair_labels` as `{('PER', 'ORG')}`. This setting may help to reduce the number of relation candidates. Leaving this parameter as `None` (default) disables the relation-candidate-filter, i.e. the model classifies the relation for each entity pair in the cross product of *all* entity pairs (inefficient).
entity_threshold: Only pre-labelled entities above this threshold are taken into account by the model.
max_allowed_tokens_between_entities: The maximum allowed number of allowed tokens between entities. All other entity pairs are filtered from consideration. If `None`, the filter will be disabled.
max_surrounding_context_length: The maximum length of context around entity pairs that will be considered. The context, in between the entity pairs will always be included. If `None`, the filter will be disabled.
cross_augmentation: If `True`, use cross augmentation to transform `Sentence`s into `EncodedSentenece`s. When cross augmentation is enabled, the transformation functions, e.g. `transform_corpus`, generate an encoded sentence for each entity pair in the cross product of all entities in the original sentence. When disabling cross augmentation, the transform functions only generate encoded sentences for each gold relation annotation in the original sentence.
encoding_strategy: An instance of a class conforming the :class:`EncodingStrategy` protocol
zero_tag_value: The label to use for out-of-class relations
Expand Down Expand Up @@ -302,6 +306,8 @@ def __init__(
self.entity_pair_labels = entity_pair_labels

self.entity_threshold = entity_threshold
self.max_allowed_tokens_between_entities = max_allowed_tokens_between_entities
self.max_surrounding_context_length = max_surrounding_context_length
self.cross_augmentation = cross_augmentation
self.encoding_strategy = encoding_strategy

Expand Down Expand Up @@ -393,12 +399,41 @@ def _entity_pair_permutations(

yield head, tail, gold_label

@staticmethod
def _truncate_context_around_entities(
encoded_sentence_tokens: list[str],
head_idx: int,
tail_idx: int,
context_length: int,
) -> list[str]:
"""Truncates the encoded sentence to include the head and tail entity and their surrounding context.
The context, in between the entity pairs will always be included.
Args:
encoded_sentence_tokens: The list of tokens corresponding to the encoded sentence.
head_idx: The index of the head entity in the token list.
tail_idx: The index of the tail entity in the token list.
context_length: The maximum number of tokens to include as surrounding context around the head and tail entities.
Returns:
The tokens of the truncated sentence.
"""
begin_slice: int = min(head_idx, tail_idx)
end_slice: int = max(head_idx, tail_idx)

# Preserve context around the entities. Always include their in-between context.
begin_slice = max(begin_slice - context_length, 0)
end_slice = min(end_slice + context_length + 1, len(encoded_sentence_tokens))

return encoded_sentence_tokens[begin_slice:end_slice]

def _encode_sentence(
self,
head: _Entity,
tail: _Entity,
gold_label: Optional[str] = None,
) -> EncodedSentence:
) -> Optional[EncodedSentence]:
"""Returns a new Sentence object with masked/marked head and tail spans according to the encoding strategy.
If provided, the encoded sentence also has the corresponding gold label annotation from :attr:`~label_type`.
Expand All @@ -414,6 +449,12 @@ def _encode_sentence(
original_sentence: Sentence = head.span.sentence
assert original_sentence is tail.span.sentence, "The head and tail need to come from the same sentence."

# Sanity check: Do not create a labeled span if one entity contains the other
if head.span[0].idx <= tail.span[0].idx and head.span[-1].idx >= tail.span[-1].idx:
return None
if head.span[0].idx >= tail.span[0].idx and head.span[-1].idx <= tail.span[-1].idx:
return None

# Pre-compute non-leading head and tail tokens for entity masking
non_leading_head_tokens: list[Token] = head.span.tokens[1:]
non_leading_tail_tokens: list[Token] = tail.span.tokens[1:]
Expand All @@ -422,11 +463,15 @@ def _encode_sentence(
# since there may be multiple occurrences of the same entity mentioned in the sentence.
# Therefore, we use the span's position in the sentence.
encoded_sentence_tokens: list[str] = []
head_idx: Optional[int] = None
tail_idx: Optional[int] = None
for token in original_sentence:
if token is head.span[0]:
head_idx = len(encoded_sentence_tokens)
encoded_sentence_tokens.append(self.encoding_strategy.encode_head(head.span, head.label))

elif token is tail.span[0]:
tail_idx = len(encoded_sentence_tokens)
encoded_sentence_tokens.append(self.encoding_strategy.encode_tail(tail.span, tail.label))

elif all(
Expand All @@ -435,6 +480,27 @@ def _encode_sentence(
):
encoded_sentence_tokens.append(token.text)

msg: str
if head_idx is None:
msg = f"The head entity ({head!r}) is not located inside the original sentence ({original_sentence!r})."
raise AssertionError(msg)
if tail_idx is None:
msg = f"The tail entity ({tail!r}) is not located inside the original sentence ({original_sentence!r})."
raise AssertionError(msg)

# Filter cases in which the distance between the two entities is too large
if (
self.max_allowed_tokens_between_entities is not None
and abs(head_idx - tail_idx) > self.max_allowed_tokens_between_entities
):
return None

# Remove excess tokens left and right of entity pair to make encoded sentence shorter
if self.max_surrounding_context_length is not None:
encoded_sentence_tokens = self._truncate_context_around_entities(
encoded_sentence_tokens, head_idx, tail_idx, self.max_surrounding_context_length
)

# Create masked sentence
encoded_sentence: EncodedSentence = EncodedSentence(
" ".join(encoded_sentence_tokens), use_tokenizer=SpaceTokenizer()
Expand All @@ -445,6 +511,7 @@ def _encode_sentence(
# Using the sentence label instead of annotating a separate `Relation` object is easier to manage since,
# during prediction, the forward pass does not need any knowledge about the entities in the sentence.
encoded_sentence.add_label(typename=self.label_type, value=gold_label, score=1.0)

encoded_sentence.copy_context_from_sentence(original_sentence)
return encoded_sentence

Expand All @@ -469,13 +536,15 @@ def _encode_sentence_for_inference(
Returns: Encoded sentences annotated with their gold relation and the corresponding relation in the original sentence
"""
for head, tail, gold_label in self._entity_pair_permutations(sentence):
masked_sentence: EncodedSentence = self._encode_sentence(
masked_sentence: Optional[EncodedSentence] = self._encode_sentence(
head=head,
tail=tail,
gold_label=gold_label if gold_label is not None else self.zero_tag_value,
)
original_relation: Relation = Relation(first=head.span, second=tail.span)
yield masked_sentence, original_relation

if masked_sentence is not None:
yield masked_sentence, original_relation

def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedSentence]:
"""Create Encoded Sentences and Relation pairs for Training.
Expand All @@ -492,13 +561,14 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS
else:
continue # Skip generated data points that do not express an originally annotated relation

masked_sentence: EncodedSentence = self._encode_sentence(
masked_sentence: Optional[EncodedSentence] = self._encode_sentence(
head=head,
tail=tail,
gold_label=gold_label,
)

yield masked_sentence
if masked_sentence is not None:
yield masked_sentence

def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list[EncodedSentence]:
"""Transforms sentences into encoded sentences specific to the `RelationClassifier`.
Expand Down Expand Up @@ -702,6 +772,8 @@ def _get_state_dict(self) -> dict[str, Any]:
"entity_label_types": self.entity_label_types,
"entity_pair_labels": self.entity_pair_labels,
"entity_threshold": self.entity_threshold,
"max_allowed_tokens_between_entities": self.max_allowed_tokens_between_entities,
"max_surrounding_context_length": self.max_surrounding_context_length,
"cross_augmentation": self.cross_augmentation,
"encoding_strategy": self.encoding_strategy,
"zero_tag_value": self.zero_tag_value,
Expand All @@ -719,6 +791,8 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
entity_label_types=state["entity_label_types"],
entity_pair_labels=state["entity_pair_labels"],
entity_threshold=state["entity_threshold"],
max_allowed_tokens_between_entities=state.get("max_allowed_tokens_between_entities"),
max_surrounding_context_length=state.get("max_surrounding_context_length"),
cross_augmentation=state["cross_augmentation"],
encoding_strategy=state["encoding_strategy"],
zero_tag_value=state["zero_tag_value"],
Expand Down
20 changes: 12 additions & 8 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,14 +966,18 @@ def _initialize_model_card(self, **training_parameters):
except ImportError:
pass

# remember all parameters used in train() call
model_card["training_parameters"] = {
k: str(v) if isinstance(v, Path) else v for k, v in training_parameters.items()
}

model_card["training_parameters"] = {
k: f"{v.__module__}.{v.__name__}" if inspect.isclass(v) else v for k, v in training_parameters.items()
}
# remember the training parameters
model_card["training_parameters"] = {}
for k, v in training_parameters.items():

# special rule for Path variables to make sure models can be deserialized on other OS
if isinstance(v, Path):
v = str(v)
# classes are only serialized as names
if inspect.isclass(v):
v = f"{v.__module__}.{v.__name__}"

model_card["training_parameters"][k] = v

plugins = [plugin.get_state() for plugin in model_card["training_parameters"]["plugins"]]
model_card["training_parameters"]["plugins"] = plugins
Expand Down
Loading

0 comments on commit 8acd698

Please sign in to comment.