diff --git a/CHANGELOG.md b/CHANGELOG.md index 584b59be6f7..50be7dd1f53 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added new metric `ProcrustesDistance` to new domain Shape ([#2723](https://github.com/Lightning-AI/torchmetrics/pull/2723) +- Added `truncation` argument to `BERTScore` ([#2776](https://github.com/Lightning-AI/torchmetrics/pull/2776)) + + ### Changed - Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649)) diff --git a/src/torchmetrics/functional/text/bert.py b/src/torchmetrics/functional/text/bert.py index c073ac6e74f..71bec857a72 100644 --- a/src/torchmetrics/functional/text/bert.py +++ b/src/torchmetrics/functional/text/bert.py @@ -276,6 +276,7 @@ def bert_score( rescale_with_baseline: bool = False, baseline_path: Optional[str] = None, baseline_url: Optional[str] = None, + truncation: bool = False, ) -> Dict[str, Union[Tensor, List[float], str]]: """`Bert_score Evaluating Text Generation`_ for text similirity matching. @@ -323,6 +324,7 @@ def bert_score( of the files from `BERT_score`_ baseline_path: A path to the user's own local csv/tsv file with the baseline scale. baseline_url: A url path to the user's own csv/tsv file with the baseline scale. + truncation: An indication of whether the input sequences should be truncated to the maximum length. Returns: Python dictionary containing the keys ``precision``, ``recall`` and ``f1`` with corresponding values. @@ -417,13 +419,14 @@ def bert_score( # We ignore mypy typing below as the proper typing is ensured by conditions above, only mypy cannot infer that. if _are_valid_lists: - target_dataset = TextDataset(target, tokenizer, max_length, idf=idf) # type: ignore + target_dataset = TextDataset(target, tokenizer, max_length, idf=idf, truncation=truncation) # type: ignore preds_dataset = TextDataset( preds, # type: ignore tokenizer, max_length, idf=idf, tokens_idf=target_dataset.tokens_idf, + truncation=truncation, ) elif _are_valid_tensors: target_dataset = TokenizedDataset(**target, idf=idf) # type: ignore diff --git a/src/torchmetrics/functional/text/helper_embedding_metric.py b/src/torchmetrics/functional/text/helper_embedding_metric.py index 1ab911b9395..f2b59126c7d 100644 --- a/src/torchmetrics/functional/text/helper_embedding_metric.py +++ b/src/torchmetrics/functional/text/helper_embedding_metric.py @@ -195,10 +195,11 @@ def __init__( tokenizer: Any, max_length: int = 512, preprocess_text_fn: Callable[ - [List[str], Any, int], Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], Optional[Tensor]]] + [List[str], Any, int, bool], Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], Optional[Tensor]]] ] = _preprocess_text, idf: bool = False, tokens_idf: Optional[Dict[int, float]] = None, + truncation: bool = False, ) -> None: """Initialize text dataset class. @@ -209,9 +210,10 @@ def __init__( preprocess_text_fn: A function used for processing the input sentences. idf: An indication of whether calculate token inverse document frequencies to weight the model embeddings. tokens_idf: Inverse document frequencies (these should be calculated on reference sentences). + truncation: An indication of whether tokenized sequences should be padded only to the length of the longest """ - _text = preprocess_text_fn(text, tokenizer, max_length) + _text = preprocess_text_fn(text, tokenizer, max_length, truncation) if isinstance(_text, tuple): self.text, self.sorting_indices = _text else: diff --git a/src/torchmetrics/text/bert.py b/src/torchmetrics/text/bert.py index cd7d4ce0aaf..6e1bab1b9bd 100644 --- a/src/torchmetrics/text/bert.py +++ b/src/torchmetrics/text/bert.py @@ -107,6 +107,7 @@ class BERTScore(Metric): of the files from `BERT_score`_. baseline_path: A path to the user's own local csv/tsv file with the baseline scale. baseline_url: A url path to the user's own csv/tsv file with the baseline scale. + truncation: An indication of whether the input sequences should be truncated to the ``max_length``. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: @@ -150,6 +151,7 @@ def __init__( rescale_with_baseline: bool = False, baseline_path: Optional[str] = None, baseline_url: Optional[str] = None, + truncation: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -169,6 +171,7 @@ def __init__( self.rescale_with_baseline = rescale_with_baseline self.baseline_path = baseline_path self.baseline_url = baseline_url + self.truncation = truncation if user_tokenizer: self.tokenizer = user_tokenizer @@ -210,7 +213,7 @@ def update(self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[s preds, self.tokenizer, self.max_length, - truncation=False, + truncation=self.truncation, sort_according_length=False, own_tokenizer=self.user_tokenizer, ) @@ -218,7 +221,7 @@ def update(self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[s target, self.tokenizer, self.max_length, - truncation=False, + truncation=self.truncation, sort_according_length=False, own_tokenizer=self.user_tokenizer, ) diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index 99605b69cc2..1d74f0c858d 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -188,3 +188,20 @@ def test_bertscore_sorting(idf: bool): # First index should be the self-comparison - sorting by length should not shuffle this assert score["f1"][0] > score["f1"][1] + + +@skip_on_connection_issues() +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +@pytest.mark.parametrize("truncation", [True, False]) +def test_bertscore_truncation(truncation: bool): + """Test that BERTScore truncation works as expected.""" + pred = ["abc " * 2000] + gt = ["def " * 2000] + bert_score = BERTScore(truncation=truncation) + + if truncation: + res = bert_score(pred, gt) + assert res["f1"] > 0.0 + else: + with pytest.raises(RuntimeError, match="The expanded size of the tensor.*must match.*"): + bert_score(pred, gt)