Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility to control the truncation in bert score #2776

Merged
merged 10 commits into from
Oct 11, 2024
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added multi-output support for MAE metric ([#2605](https://github.com/Lightning-AI/torchmetrics/pull/2605))


- Added `truncation` argument to `BERTScore` ([#2776](https://github.com/Lightning-AI/torchmetrics/pull/2776))


### Changed

- Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649))
Expand Down
5 changes: 4 additions & 1 deletion src/torchmetrics/functional/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def bert_score(
rescale_with_baseline: bool = False,
baseline_path: Optional[str] = None,
baseline_url: Optional[str] = None,
truncation: bool = False,
) -> Dict[str, Union[Tensor, List[float], str]]:
"""`Bert_score Evaluating Text Generation`_ for text similirity matching.

Expand Down Expand Up @@ -323,6 +324,7 @@ def bert_score(
of the files from `BERT_score`_
baseline_path: A path to the user's own local csv/tsv file with the baseline scale.
baseline_url: A url path to the user's own csv/tsv file with the baseline scale.
truncation: An indication of whether the input sequences should be truncated to the maximum length.

Returns:
Python dictionary containing the keys ``precision``, ``recall`` and ``f1`` with corresponding values.
Expand Down Expand Up @@ -417,13 +419,14 @@ def bert_score(

# We ignore mypy typing below as the proper typing is ensured by conditions above, only mypy cannot infer that.
if _are_valid_lists:
target_dataset = TextDataset(target, tokenizer, max_length, idf=idf) # type: ignore
target_dataset = TextDataset(target, tokenizer, max_length, idf=idf, truncation=truncation) # type: ignore
preds_dataset = TextDataset(
preds, # type: ignore
tokenizer,
max_length,
idf=idf,
tokens_idf=target_dataset.tokens_idf,
truncation=truncation,
)
elif _are_valid_tensors:
target_dataset = TokenizedDataset(**target, idf=idf) # type: ignore
Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/functional/text/helper_embedding_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,11 @@ def __init__(
tokenizer: Any,
max_length: int = 512,
preprocess_text_fn: Callable[
[List[str], Any, int], Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], Optional[Tensor]]]
[List[str], Any, int.bool], Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], Optional[Tensor]]]
Borda marked this conversation as resolved.
Show resolved Hide resolved
] = _preprocess_text,
idf: bool = False,
tokens_idf: Optional[Dict[int, float]] = None,
truncation: bool = False,
) -> None:
"""Initialize text dataset class.

Expand All @@ -209,9 +210,10 @@ def __init__(
preprocess_text_fn: A function used for processing the input sentences.
idf: An indication of whether calculate token inverse document frequencies to weight the model embeddings.
tokens_idf: Inverse document frequencies (these should be calculated on reference sentences).
truncation: An indication of whether tokenized sequences should be padded only to the length of the longest

"""
_text = preprocess_text_fn(text, tokenizer, max_length)
_text = preprocess_text_fn(text, tokenizer, max_length, truncation)
if isinstance(_text, tuple):
self.text, self.sorting_indices = _text
else:
Expand Down
7 changes: 5 additions & 2 deletions src/torchmetrics/text/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class BERTScore(Metric):
of the files from `BERT_score`_.
baseline_path: A path to the user's own local csv/tsv file with the baseline scale.
baseline_url: A url path to the user's own csv/tsv file with the baseline scale.
truncation: An indication of whether the input sequences should be truncated to the ``max_length``.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
Expand Down Expand Up @@ -150,6 +151,7 @@ def __init__(
rescale_with_baseline: bool = False,
baseline_path: Optional[str] = None,
baseline_url: Optional[str] = None,
truncation: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -169,6 +171,7 @@ def __init__(
self.rescale_with_baseline = rescale_with_baseline
self.baseline_path = baseline_path
self.baseline_url = baseline_url
self.truncation = truncation

if user_tokenizer:
self.tokenizer = user_tokenizer
Expand Down Expand Up @@ -210,15 +213,15 @@ def update(self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[s
preds,
self.tokenizer,
self.max_length,
truncation=False,
truncation=self.truncation,
sort_according_length=False,
own_tokenizer=self.user_tokenizer,
)
target_dict, _ = _preprocess_text(
target,
self.tokenizer,
self.max_length,
truncation=False,
truncation=self.truncation,
sort_according_length=False,
own_tokenizer=self.user_tokenizer,
)
Expand Down
17 changes: 17 additions & 0 deletions tests/unittests/text/test_bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,20 @@ def test_bertscore_sorting(idf: bool):

# First index should be the self-comparison - sorting by length should not shuffle this
assert score["f1"][0] > score["f1"][1]


@skip_on_connection_issues()
@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4")
@pytest.mark.parametrize("truncation", [True, False])
def test_bertscore_truncation(truncation: bool):
"""Test that BERTScore truncation works as expected."""
pred = ["abc " * 2000]
gt = ["def " * 2000]
bert_score = BERTScore(truncation=truncation)

if truncation:
res = bert_score(pred, gt)
assert res["f1"] > 0.0
else:
with pytest.raises(RuntimeError, match="The expanded size of the tensor.*must match.*"):
bert_score(pred, gt)
Loading