Skip to content

Commit

Permalink
Merge branch 'master' into filter_relations
Browse files Browse the repository at this point in the history
  • Loading branch information
dobbersc authored Jan 22, 2025
2 parents de8b7f4 + 30974f2 commit 6789a6a
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 27 deletions.
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# -- Project information -----------------------------------------------------
from sphinx_github_style import get_linkcode_resolve
from torch.nn import Module

version = "0.15.0"
release = "0.15.0"
Expand Down
7 changes: 2 additions & 5 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,8 +1389,7 @@ def __init__(
sample_missing_splits: Union[bool, str] = True,
random_seed: Optional[int] = None,
) -> None:
"""
Constructor method to initialize a :class:`Corpus`. You can define the train, dev and test split
"""Constructor method to initialize a :class:`Corpus`. You can define the train, dev and test split
by passing the corresponding Dataset object to the constructor. At least one split should be defined.
If the option `sample_missing_splits` is set to True, missing splits will be randomly sampled from the
train split.
Expand Down Expand Up @@ -1484,7 +1483,6 @@ def downsample(
Returns:
A pointer to itself for optional use in method chaining.
"""

if downsample_train and self._train is not None:
self._train = self._downsample_to_proportion(self._train, percentage, random_seed)

Expand All @@ -1511,8 +1509,7 @@ def filter_empty_sentences(self):
log.info(self)

def filter_long_sentences(self, max_charlength: int):
"""
A method that filters all sentences for which the plain text is longer than a specified number of characters.
"""A method that filters all sentences for which the plain text is longer than a specified number of characters.
This is an in-place operation that directly modifies the Corpus object itself by removing these sentences.
Expand Down
4 changes: 2 additions & 2 deletions flair/datasets/document_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ def __init__(
if not rebalance_corpus and dataset == "test":
data_file = test_data_file

with open(data_file, "at") as f_p:
with open(data_file, "a") as f_p:
current_path = data_path / "aclImdb" / dataset / label
for file_name in current_path.iterdir():
if file_name.is_file() and file_name.name.endswith(".txt"):
Expand Down Expand Up @@ -891,7 +891,7 @@ def __init__(
data_path / "original",
members=[m for m in f_in.getmembers() if f"{dataset}/{label}" in m.name],
)
with open(f"{data_path}/{dataset}.txt", "at", encoding="utf-8") as f_p:
with open(f"{data_path}/{dataset}.txt", "a", encoding="utf-8") as f_p:
current_path = data_path / "original" / dataset / label
for file_name in current_path.iterdir():
if file_name.is_file():
Expand Down
3 changes: 2 additions & 1 deletion flair/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def aggregate(value, aggregation_fn=np.mean):

def validate_corpus_same_each_process(corpus: Corpus) -> None:
"""Catches most cases in which a corpus is not the same on each process. However, there is no guarantee for two
reasons: 1) It uses a sample for speed 2) It compares strings to avoid requiring the datasets to be serializable"""
reasons: 1) It uses a sample for speed 2) It compares strings to avoid requiring the datasets to be serializable
"""
for dataset in [corpus.train, corpus.dev, corpus.test]:
if dataset is not None:
_validate_dataset_same_each_process(dataset)
Expand Down
32 changes: 21 additions & 11 deletions flair/models/regexp_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def get_token_span(self, span: tuple[int, int]) -> Span:


class RegexpTagger:
def __init__(self, mapping: Union[list[tuple[str, str]], tuple[str, str]]) -> None:
def __init__(
self, mapping: Union[list[Union[tuple[str, str], tuple[str, str, int]]], tuple[str, str], tuple[str, str, int]]
) -> None:
r"""This tagger is capable of tagging sentence objects with given regexp -> label mappings.
I.e: The tuple (r'(["\'])(?:(?=(\\?))\2.)*?\1', 'QUOTE') maps every match of the regexp to
Expand All @@ -58,24 +60,33 @@ def __init__(self, mapping: Union[list[tuple[str, str]], tuple[str, str]]) -> No
Args:
mapping: A list of tuples or a single tuple representing a mapping as regexp -> label
"""
self._regexp_mapping: dict[str, typing.Pattern] = {}
self._regexp_mapping: list = []
self.register_labels(mapping=mapping)

def label_type(self):
for regexp, label, group in self._regexp_mapping:
return label

@property
def registered_labels(self):
return self._regexp_mapping

def register_labels(self, mapping: Union[list[tuple[str, str]], tuple[str, str]]):
def register_labels(self, mapping):
"""Register a regexp -> label mapping.
Args:
mapping: A list of tuples or a single tuple representing a mapping as regexp -> label
"""
mapping = self._listify(mapping)

for regexp, label in mapping:
for entry in mapping:
regexp = entry[0]
label = entry[1]
group = entry[2] if len(entry) > 2 else 0
try:
self._regexp_mapping[label] = re.compile(regexp)
pattern = re.compile(regexp)
self._regexp_mapping.append((pattern, label, group))

except re.error as err:
raise re.error(
f"Couldn't compile regexp '{regexp}' for label '{label}'. Aborted with error: '{err.msg}'"
Expand All @@ -89,10 +100,7 @@ def remove_labels(self, labels: Union[list[str], str]):
"""
labels = self._listify(labels)

for label in labels:
if not self._regexp_mapping.get(label):
continue
self._regexp_mapping.pop(label)
self._regexp_mapping = [mapping for mapping in self._regexp_mapping if mapping[1] not in labels]

@staticmethod
def _listify(element: object) -> list:
Expand Down Expand Up @@ -120,9 +128,11 @@ def _label(self, sentence: Sentence):
"""
collection = TokenCollection(sentence)

for label, pattern in self._regexp_mapping.items():
for pattern, label, group in self._regexp_mapping:
for match in pattern.finditer(sentence.to_original_text()):
span: tuple[int, int] = match.span()
# print(match)
span: tuple[int, int] = match.span(group)
# print(span)
try:
token_span = collection.get_token_span(span)
except ValueError:
Expand Down
1 change: 0 additions & 1 deletion flair/models/sequence_tagger_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def forward_loss(self, sentences: list[Sentence]) -> tuple[torch.Tensor, int]:
A tuple consisting of the loss tensor and the number of tokens in the batch.
"""

# if there are no sentences, there is no loss
if len(sentences) == 0:
return torch.tensor(0.0, dtype=torch.float, device=flair.device, requires_grad=True), 0
Expand Down
3 changes: 1 addition & 2 deletions flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model":
return model

def print_model_card(self):
"""
This method produces a log message that includes all recorded parameters the model was trained with.
"""This method produces a log message that includes all recorded parameters the model was trained with.
The model card includes information such as the Flair, PyTorch and Transformers versions used during training,
and the training parameters.
Expand Down
3 changes: 1 addition & 2 deletions flair/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ class SentenceSplitter(ABC):
"""

def split(self, text: str, link_sentences: bool = True) -> list[Sentence]:
"""
Takes as input a text as a plain string and outputs a list of :class:`flair.data.Sentence` objects.
"""Takes as input a text as a plain string and outputs a list of :class:`flair.data.Sentence` objects.
If link_sentences is set (by default, it is). The :class:`flair.data.Sentence` objects will include pointers
to the preceding and following sentences in the original text. This way, the original sequence information will
Expand Down
2 changes: 0 additions & 2 deletions flair/trainers/plugins/functional/checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging
from typing import Any

import torch

from flair.trainers.plugins.base import TrainerPlugin

log = logging.getLogger("flair")
Expand Down
16 changes: 16 additions & 0 deletions tests/models/test_regexp_tagger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from flair.data import Sentence
from flair.models import RegexpTagger


def test_regexp_tagger():

sentence = Sentence('Der sagte: "das ist durchaus interessant"')

tagger = RegexpTagger(
mapping=[(r'["„»]((?:(?=(\\?))\2.)*?)[”"“«]', "quote_part", 1), (r'["„»]((?:(?=(\\?))\2.)*?)[”"“«]', "quote")]
)

tagger.predict(sentence)

assert sentence.get_label("quote_part").data_point.text == "das ist durchaus interessant"
assert sentence.get_label("quote").data_point.text == '"das ist durchaus interessant"'

0 comments on commit 6789a6a

Please sign in to comment.