From 62604a35d008506aa0744636509ae4bb44fc149a Mon Sep 17 00:00:00 2001 From: jirka Date: Fri, 2 Aug 2024 13:07:36 +0200 Subject: [PATCH 1/9] text: temp drop `Chrf` implementation --- docs/source/conf.py | 2 - docs/source/links.rst | 2 - docs/source/text/chrf_score.rst | 21 - src/torchmetrics/functional/text/__init__.py | 1 - .../functional/text/_deprecated.py | 22 +- src/torchmetrics/functional/text/chrf.py | 649 ------------------ src/torchmetrics/text/__init__.py | 1 - src/torchmetrics/text/_deprecated.py | 34 - src/torchmetrics/text/chrf.py | 249 ------- .../deprecations/root_class_imports.py | 2 - tests/unittests/text/test_chrf.py | 163 ----- 11 files changed, 2 insertions(+), 1144 deletions(-) delete mode 100644 docs/source/text/chrf_score.rst delete mode 100644 src/torchmetrics/functional/text/chrf.py delete mode 100644 src/torchmetrics/text/chrf.py delete mode 100644 tests/unittests/text/test_chrf.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 9484761ba7a..9ef621f90df 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -438,8 +438,6 @@ def linkcode_resolve(domain, info) -> Optional[str]: # noqa: ANN001 "https://ieeexplore.ieee.org/abstract/document/4317530", # Robust parameter estimation with a small bias against heavy contamination "https://www.sciencedirect.com/science/article/pii/S0047259X08000456", - # chrF++: words helping character n-grams - "https://aclanthology.org/W17-4770", # A wavelet transform method to merge Landsat TM and SPOT panchromatic data "https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013", ] diff --git a/docs/source/links.rst b/docs/source/links.rst index 1fc3ab5755d..d7038032e1e 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -77,8 +77,6 @@ .. _Scikit_Learn-Ranking.py: https: //github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py .. _Verified Uncertainty Calibration: https://arxiv.org/abs/1909.10155 .. _SQuAD Metric: https://arxiv.org/abs/1606.05250 -.. _chrF score: https://aclanthology.org/W15-3049 -.. _chrF++ score: https://aclanthology.org/W17-4770 .. _TER: https://aclanthology.org/2006.amta-papers.25 .. _ExtendedEditDistance: https://aclanthology.org/W19-5359 .. _MultiScaleSSIM: https://ece.uwaterloo.ca/~z70wang/publications/msssim diff --git a/docs/source/text/chrf_score.rst b/docs/source/text/chrf_score.rst deleted file mode 100644 index c541014ae61..00000000000 --- a/docs/source/text/chrf_score.rst +++ /dev/null @@ -1,21 +0,0 @@ -.. customcarditem:: - :header: ChrF Score - :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/summarization.svg - :tags: Text - -.. include:: ../links.rst - -########## -ChrF Score -########## - -Module Interface -________________ - -.. autoclass:: torchmetrics.text.CHRFScore - :exclude-members: update, compute - -Functional Interface -____________________ - -.. autofunction:: torchmetrics.functional.text.chrf_score diff --git a/src/torchmetrics/functional/text/__init__.py b/src/torchmetrics/functional/text/__init__.py index 9282be6fbae..14ba75ce5a7 100644 --- a/src/torchmetrics/functional/text/__init__.py +++ b/src/torchmetrics/functional/text/__init__.py @@ -14,7 +14,6 @@ from torchmetrics.functional.text.bleu import bleu_score from torchmetrics.functional.text.cer import char_error_rate -from torchmetrics.functional.text.chrf import chrf_score from torchmetrics.functional.text.edit import edit_distance from torchmetrics.functional.text.eed import extended_edit_distance from torchmetrics.functional.text.mer import match_error_rate diff --git a/src/torchmetrics/functional/text/_deprecated.py b/src/torchmetrics/functional/text/_deprecated.py index fabfca2c0eb..73fc2623928 100644 --- a/src/torchmetrics/functional/text/_deprecated.py +++ b/src/torchmetrics/functional/text/_deprecated.py @@ -8,7 +8,6 @@ from torchmetrics.functional.text.bert import bert_score from torchmetrics.functional.text.bleu import bleu_score from torchmetrics.functional.text.cer import char_error_rate -from torchmetrics.functional.text.chrf import chrf_score from torchmetrics.functional.text.eed import extended_edit_distance from torchmetrics.functional.text.infolm import ( _ALLOWED_INFORMATION_MEASURE_LITERAL as _INFOLM_ALLOWED_INFORMATION_MEASURE_LITERAL, @@ -134,25 +133,8 @@ def _chrf_score( whitespace: bool = False, return_sentence_level_score: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: - """Wrapper for deprecated import. - - >>> preds = ['the cat is on the mat'] - >>> target = [['there is a cat on the mat', 'a cat is on the mat']] - >>> _chrf_score(preds, target) - tensor(0.8640) - - """ - _deprecated_root_import_func("chrf_score", "text") - return chrf_score( - preds=preds, - target=target, - n_char_order=n_char_order, - n_word_order=n_word_order, - beta=beta, - lowercase=lowercase, - whitespace=whitespace, - return_sentence_level_score=return_sentence_level_score, - ) + """Wrapper for deprecated import.""" + raise NotImplementedError("Chrf was temporarily removed.") def _extended_edit_distance( diff --git a/src/torchmetrics/functional/text/chrf.py b/src/torchmetrics/functional/text/chrf.py deleted file mode 100644 index 375355b85cb..00000000000 --- a/src/torchmetrics/functional/text/chrf.py +++ /dev/null @@ -1,649 +0,0 @@ -# Copyright The Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# referenced from -# Library Name: torchtext -# Authors: torchtext authors -# Date: 2021-11-25 -# Link: - -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -# Copyright 2017 Maja Popovic - -# The program is distributed under the terms -# of the GNU General Public Licence (GPL) - -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. - -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - -from collections import defaultdict -from itertools import chain -from typing import Dict, List, Optional, Sequence, Tuple, Union - -import torch -from torch import Tensor, tensor - -from torchmetrics.functional.text.helper import _validate_inputs - -_EPS_SMOOTHING = tensor(1e-16) -# Taken from https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py -_PUNCTUATIONS = set("!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~") - - -def _prepare_n_grams_dicts( - n_char_order: int, n_word_order: int -) -> Tuple[ - Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor] -]: - """Prepare dictionaries with default zero values for total ref, hypothesis and matching character and word n-grams. - - Args: - n_char_order: A character n-gram order. - n_word_order: A word n-gram order. - - Return: - Dictionaries with default zero values for total reference, hypothesis and matching character and word - n-grams. - - """ - total_preds_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} - total_preds_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} - total_target_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} - total_target_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} - total_matching_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} - total_matching_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} - - return ( - total_preds_char_n_grams, - total_preds_word_n_grams, - total_target_char_n_grams, - total_target_word_n_grams, - total_matching_char_n_grams, - total_matching_word_n_grams, - ) - - -def _get_characters(sentence: str, whitespace: bool) -> List[str]: - """Split sentence into individual characters. - - Args: - sentence: An input sentence to split. - whitespace: An indication whether to keep whitespaces during character n-gram extraction. - - Return: - A list of separated characters. - - """ - if whitespace: - return list(sentence) - return list(sentence.strip().replace(" ", "")) - - -def _separate_word_and_punctuation(word: str) -> List[str]: - """Separates out punctuations from beginning and end of words for chrF. - - Adapted from https://github.com/m-popovic/chrF and - https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. - - Args: - word: An input word to be separated from a punctuation if present. - - Return: - A list of a single word or a separated word and punctuation. - - """ - if len(word) == 1: - return [word] - - if word[-1] in _PUNCTUATIONS: - return [word[:-1], word[-1]] - if word[0] in _PUNCTUATIONS: - return [word[0], word[1:]] - return [word] - - -def _get_words_and_punctuation(sentence: str) -> List[str]: - """Separates out punctuations from beginning and end of words for chrF for all words in the sentence. - - Args: - sentence: An input sentence to split - - Return: - An aggregated list of separated words and punctuations. - - """ - return list(chain.from_iterable(_separate_word_and_punctuation(word) for word in sentence.strip().split())) - - -def _ngram_counts(char_or_word_list: List[str], n_gram_order: int) -> Dict[int, Dict[Tuple[str, ...], Tensor]]: - """Calculate n-gram counts. - - Args: - char_or_word_list: A list of characters of words - n_gram_order: The largest number of n-gram. - - Return: - A dictionary of dictionaries with a counts of given n-grams. - - """ - ngrams: Dict[int, Dict[Tuple[str, ...], Tensor]] = defaultdict(lambda: defaultdict(lambda: tensor(0.0))) - for n in range(1, n_gram_order + 1): - for ngram in (tuple(char_or_word_list[i : i + n]) for i in range(len(char_or_word_list) - n + 1)): - ngrams[n][ngram] += tensor(1) - return ngrams - - -def _get_n_grams_counts_and_total_ngrams( - sentence: str, n_char_order: int, n_word_order: int, lowercase: bool, whitespace: bool -) -> Tuple[ - Dict[int, Dict[Tuple[str, ...], Tensor]], - Dict[int, Dict[Tuple[str, ...], Tensor]], - Dict[int, Tensor], - Dict[int, Tensor], -]: - """Get n-grams and total n-grams. - - Args: - sentence: An input sentence - n_char_order: A character n-gram order. - n_word_order: A word n-gram order. - lowercase: An indication whether to enable case-insensitivity. - whitespace: An indication whether to keep whitespaces during character n-gram extraction. - - Return: - char_n_grams_counts: A dictionary of dictionaries with sentence character n-grams. - word_n_grams_counts: A dictionary of dictionaries with sentence word n-grams. - total_char_n_grams: A dictionary containing a total number of sentence character n-grams. - total_word_n_grams: A dictionary containing a total number of sentence word n-grams. - - """ - - def _char_and_word_ngrams_counts( - sentence: str, n_char_order: int, n_word_order: int, lowercase: bool - ) -> Tuple[Dict[int, Dict[Tuple[str, ...], Tensor]], Dict[int, Dict[Tuple[str, ...], Tensor]]]: - """Get a dictionary of dictionaries with a counts of given n-grams.""" - if lowercase: - sentence = sentence.lower() - char_n_grams_counts = _ngram_counts(_get_characters(sentence, whitespace), n_char_order) - word_n_grams_counts = _ngram_counts(_get_words_and_punctuation(sentence), n_word_order) - return char_n_grams_counts, word_n_grams_counts - - def _get_total_ngrams(n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]]) -> Dict[int, Tensor]: - """Get total sum of n-grams over n-grams w.r.t n.""" - total_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) - for n in n_grams_counts: - total_n_grams[n] = sum(n_grams_counts[n].values()).detach().clone() # type: ignore - return total_n_grams - - char_n_grams_counts, word_n_grams_counts = _char_and_word_ngrams_counts( - sentence, n_char_order, n_word_order, lowercase - ) - total_char_n_grams = _get_total_ngrams(char_n_grams_counts) - total_word_n_grams = _get_total_ngrams(word_n_grams_counts) - - return char_n_grams_counts, word_n_grams_counts, total_char_n_grams, total_word_n_grams - - -def _get_ngram_matches( - hyp_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], - ref_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], -) -> Dict[int, Tensor]: - """Get a number of n-gram matches between reference and hypothesis n-grams. - - Args: - hyp_n_grams_counts: n-grams counts for hypothesis - ref_n_grams_counts: n-grams counts for reference - - Return: - matching_n_grams - - """ - matching_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) - for n in hyp_n_grams_counts: - min_n_grams = [ - torch.min(ref_n_grams_counts[n][n_gram], hyp_n_grams_counts[n][n_gram]) for n_gram in hyp_n_grams_counts[n] - ] - matching_n_grams[n] = sum(min_n_grams).detach().clone() # type: ignore - return matching_n_grams - - -def _sum_over_dicts(total_n_grams: Dict[int, Tensor], n_grams: Dict[int, Tensor]) -> Dict[int, Tensor]: - """Aggregate total n-grams to keep corpus-level statistics. - - Args: - total_n_grams: A dictionary containing a total corpus-level number of n-grams. - n_grams: A dictionary containing a sentence-level number of n-grams. - - Return: - A dictionary containing a total corpus-level number of n-grams. - - """ - for n in n_grams: - total_n_grams[n] += n_grams[n] - return total_n_grams - - -def _calculate_fscore( - matching_char_n_grams: Dict[int, Tensor], - matching_word_n_grams: Dict[int, Tensor], - hyp_char_n_grams: Dict[int, Tensor], - hyp_word_n_grams: Dict[int, Tensor], - ref_char_n_grams: Dict[int, Tensor], - ref_word_n_grams: Dict[int, Tensor], - n_order: float, - beta: float, -) -> Tensor: - """Calculate sentence-level chrF/chrF++ score. - - For given hypothesis and reference statistics (either sentence-level or corpus-level) - the chrF/chrF++ score is returned. - - Args: - matching_char_n_grams: - A total number of matching character n-grams between the best matching reference and hypothesis. - matching_word_n_grams: - A total number of matching word n-grams between the best matching reference and hypothesis. - hyp_char_n_grams: A total number of hypothesis character n-grams. - hyp_word_n_grams: A total number of hypothesis word n-grams. - ref_char_n_grams: A total number of reference character n-grams. - ref_word_n_grams: A total number of reference word n-grams. - n_order: A sum of character and word n-gram order. - beta: A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal. - - Return: - A chrF/chrF++ score. This function is universal both for sentence-level and corpus-level calculation. - - """ - - def _get_n_gram_fscore( - matching_n_grams: Dict[int, Tensor], ref_n_grams: Dict[int, Tensor], hyp_n_grams: Dict[int, Tensor], beta: float - ) -> Dict[int, Tensor]: - """Get n-gram level f-score.""" - precision: Dict[int, Tensor] = { - n: matching_n_grams[n] / hyp_n_grams[n] if hyp_n_grams[n] > 0 else tensor(0.0) for n in matching_n_grams - } - recall: Dict[int, Tensor] = { - n: matching_n_grams[n] / ref_n_grams[n] if ref_n_grams[n] > 0 else tensor(0.0) for n in matching_n_grams - } - denominator: Dict[int, Tensor] = { - n: torch.max(beta**2 * precision[n] + recall[n], _EPS_SMOOTHING) for n in matching_n_grams - } - f_score: Dict[int, Tensor] = { - n: (1 + beta**2) * precision[n] * recall[n] / denominator[n] for n in matching_n_grams - } - - return f_score - - char_n_gram_f_score = _get_n_gram_fscore(matching_char_n_grams, ref_char_n_grams, hyp_char_n_grams, beta) - word_n_gram_f_score = _get_n_gram_fscore(matching_word_n_grams, ref_word_n_grams, hyp_word_n_grams, beta) - - return (sum(char_n_gram_f_score.values()) + sum(word_n_gram_f_score.values())) / tensor(n_order) - - -def _calculate_sentence_level_chrf_score( - targets: List[str], - pred_char_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], - pred_word_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], - pred_char_n_grams: Dict[int, Tensor], - pred_word_n_grams: Dict[int, Tensor], - n_char_order: int, - n_word_order: int, - n_order: float, - beta: float, - lowercase: bool, - whitespace: bool, -) -> Tuple[Tensor, Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor]]: - """Calculate the best sentence-level chrF/chrF++ score. - - For a given pre-processed hypothesis, all references are evaluated and score and statistics - for the best matching reference is returned. - - Args: - targets: An iterable of references. - pred_char_n_grams_counts: A dictionary of dictionaries with hypothesis character n-grams. - pred_word_n_grams_counts: A dictionary of dictionaries with hypothesis word n-grams. - pred_char_n_grams: A total number of hypothesis character n-grams. - pred_word_n_grams: A total number of hypothesis word n-grams. - n_char_order: A character n-gram order. - n_word_order: A word n-gram order. - n_order: A sum of character and word n-gram order. - beta: A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal. - lowercase: An indication whether to enable case-insensitivity. - whitespace: An indication whether to keep whitespaces during character n-gram extraction. - - Return: - Return chrF/chrF++ score and statistics for the best matching hypothesis and reference. - - f_score: A sentence-level chrF/chrF++ score. - matching_char_n_grams: - A total number of matching character n-grams between the best matching reference and hypothesis. - matching_word_n_grams: - A total number of matching word n-grams between the best matching reference and hypothesis. - target_char_n_grams: A total number of reference character n-grams. - target_word_n_grams: A total number of reference word n-grams. - - """ - best_f_score = tensor(0.0) - best_matching_char_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) - best_matching_word_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) - best_target_char_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) - best_target_word_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) - - for target in targets: - ( - target_char_n_grams_counts, - target_word_n_grams_counts, - target_char_n_grams, - target_word_n_grams, - ) = _get_n_grams_counts_and_total_ngrams(target, n_char_order, n_word_order, lowercase, whitespace) - matching_char_n_grams = _get_ngram_matches(target_char_n_grams_counts, pred_char_n_grams_counts) - matching_word_n_grams = _get_ngram_matches(target_word_n_grams_counts, pred_word_n_grams_counts) - - f_score = _calculate_fscore( - matching_char_n_grams, - matching_word_n_grams, - pred_char_n_grams, - pred_word_n_grams, - target_char_n_grams, - target_word_n_grams, - n_order, - beta, - ) - - if f_score > best_f_score: - best_f_score = f_score - best_matching_char_n_grams = matching_char_n_grams - best_matching_word_n_grams = matching_word_n_grams - best_target_char_n_grams = target_char_n_grams - best_target_word_n_grams = target_word_n_grams - - return ( - best_f_score, - best_matching_char_n_grams, - best_matching_word_n_grams, - best_target_char_n_grams, - best_target_word_n_grams, - ) - - -def _chrf_score_update( - preds: Union[str, Sequence[str]], - target: Union[Sequence[str], Sequence[Sequence[str]]], - total_preds_char_n_grams: Dict[int, Tensor], - total_preds_word_n_grams: Dict[int, Tensor], - total_target_char_n_grams: Dict[int, Tensor], - total_target_word_n_grams: Dict[int, Tensor], - total_matching_char_n_grams: Dict[int, Tensor], - total_matching_word_n_grams: Dict[int, Tensor], - n_char_order: int, - n_word_order: int, - n_order: float, - beta: float, - lowercase: bool, - whitespace: bool, - sentence_chrf_score: Optional[List[Tensor]] = None, -) -> Tuple[ - Dict[int, Tensor], - Dict[int, Tensor], - Dict[int, Tensor], - Dict[int, Tensor], - Dict[int, Tensor], - Dict[int, Tensor], - Optional[List[Tensor]], -]: - """Update function for chrf score. - - Args: - preds: An iterable of hypothesis corpus. - target: An iterable of iterables of reference corpus. - total_preds_char_n_grams: A dictionary containing a total number of hypothesis character n-grams. - total_preds_word_n_grams: A dictionary containing a total number of hypothesis word n-grams. - total_target_char_n_grams: A dictionary containing a total number of reference character n-grams. - total_target_word_n_grams: A dictionary containing a total number of reference word n-grams. - total_matching_char_n_grams: - A dictionary containing a total number of matching character n-grams between references and hypotheses. - total_matching_word_n_grams: - A dictionary containing a total number of total matching word n-grams between references and hypotheses. - n_char_order: A character n-gram order. - n_word_order: A word n-gram order. - n_order: Sum of character and word n-gram order. - beta: A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal. - lowercase: An indication whether to enable case-insensitivity. - whitespace: An indication whether to keep whitespaces during character n-gram extraction. - sentence_chrf_score: A list of sentence-level chrF/chrF++ scores. - - Return: - total_target_char_n_grams: number of reference character n-grams. - total_target_word_n_grams: number of reference word n-grams. - total_preds_char_n_grams: number of hypothesis character n-grams. - total_preds_word_n_grams: number of hypothesis word n-grams. - total_matching_char_n_grams: number of matching character n-grams between references and hypotheses. - total_matching_word_n_grams: number of total matching word n-grams between references and hypotheses. - sentence_chrf_score: A list of sentence-level chrF/chrF++ scores. - - Raises: - ValueError: - If length of ``preds`` and ``target`` differs. - - """ - target_corpus, preds = _validate_inputs(target, preds) - - for pred, targets in zip(preds, target_corpus): - ( - pred_char_n_grams_counts, - pred_word_n_grams_counts, - pred_char_n_grams, - pred_word_n_grams, - ) = _get_n_grams_counts_and_total_ngrams(pred, n_char_order, n_word_order, lowercase, whitespace) - total_preds_char_n_grams = _sum_over_dicts(total_preds_char_n_grams, pred_char_n_grams) - total_preds_word_n_grams = _sum_over_dicts(total_preds_word_n_grams, pred_word_n_grams) - - ( - sentence_level_f_score, - matching_char_n_grams, - matching_word_n_grams, - target_char_n_grams, - target_word_n_grams, - ) = _calculate_sentence_level_chrf_score( - targets, # type: ignore - pred_char_n_grams_counts, - pred_word_n_grams_counts, - pred_char_n_grams, - pred_word_n_grams, - n_char_order, - n_word_order, - n_order, - beta, - lowercase, - whitespace, - ) - - if sentence_chrf_score is not None: - sentence_chrf_score.append(sentence_level_f_score.unsqueeze(0)) - - total_target_char_n_grams = _sum_over_dicts(total_target_char_n_grams, target_char_n_grams) - total_target_word_n_grams = _sum_over_dicts(total_target_word_n_grams, target_word_n_grams) - total_matching_char_n_grams = _sum_over_dicts(total_matching_char_n_grams, matching_char_n_grams) - total_matching_word_n_grams = _sum_over_dicts(total_matching_word_n_grams, matching_word_n_grams) - - return ( - total_preds_char_n_grams, - total_preds_word_n_grams, - total_target_char_n_grams, - total_target_word_n_grams, - total_matching_char_n_grams, - total_matching_word_n_grams, - sentence_chrf_score, - ) - - -def _chrf_score_compute( - total_preds_char_n_grams: Dict[int, Tensor], - total_preds_word_n_grams: Dict[int, Tensor], - total_target_char_n_grams: Dict[int, Tensor], - total_target_word_n_grams: Dict[int, Tensor], - total_matching_char_n_grams: Dict[int, Tensor], - total_matching_word_n_grams: Dict[int, Tensor], - n_order: float, - beta: float, -) -> Tensor: - """Compute chrF/chrF++ score based on pre-computed target, prediction and matching character and word n-grams. - - Args: - total_preds_char_n_grams: number of hypothesis character n-grams. - total_preds_word_n_grams: number of hypothesis word n-grams. - total_target_char_n_grams: number of reference character n-grams. - total_target_word_n_grams: number of reference word n-grams. - total_matching_char_n_grams: number of matching character n-grams between references and hypotheses. - total_matching_word_n_grams: number of total matching word n-grams between references and hypotheses. - n_order: A sum of character and word n-gram order. - beta: - A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal. - - Return: - A corpus-level chrF/chrF++ score. - - """ - return _calculate_fscore( - total_matching_char_n_grams, - total_matching_word_n_grams, - total_preds_char_n_grams, - total_preds_word_n_grams, - total_target_char_n_grams, - total_target_word_n_grams, - n_order, - beta, - ) - - -def chrf_score( - preds: Union[str, Sequence[str]], - target: Sequence[Union[str, Sequence[str]]], - n_char_order: int = 6, - n_word_order: int = 2, - beta: float = 2.0, - lowercase: bool = False, - whitespace: bool = False, - return_sentence_level_score: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor]]: - """Calculate `chrF score`_ of machine translated text with one or more references. - - This implementation supports both chrF score computation introduced in [1] and chrF++ score introduced in - `chrF++ score`_. This implementation follows the implementations from https://github.com/m-popovic/chrF and - https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. - - Args: - preds: An iterable of hypothesis corpus. - target: An iterable of iterables of reference corpus. - n_char_order: - A character n-gram order. If `n_char_order=6`, the metrics refers to the official chrF/chrF++. - n_word_order: - A word n-gram order. If `n_word_order=2`, the metric refers to the official chrF++. If `n_word_order=0`, the - metric is equivalent to the original chrF. - beta: - A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal. - lowercase: An indication whether to enable case-insensitivity. - whitespace: An indication whether to keep whitespaces during character n-gram extraction. - return_sentence_level_score: An indication whether a sentence-level chrF/chrF++ score to be returned. - - Return: - A corpus-level chrF/chrF++ score. - (Optionally) A list of sentence-level chrF/chrF++ scores if `return_sentence_level_score=True`. - - Raises: - ValueError: - If ``n_char_order`` is not an integer greater than or equal to 1. - ValueError: - If ``n_word_order`` is not an integer greater than or equal to 0. - ValueError: - If ``beta`` is smaller than 0. - - Example: - >>> from torchmetrics.functional.text import chrf_score - >>> preds = ['the cat is on the mat'] - >>> target = [['there is a cat on the mat', 'a cat is on the mat']] - >>> chrf_score(preds, target) - tensor(0.8640) - - References: - [1] chrF: character n-gram F-score for automatic MT evaluation by Maja Popović `chrF score`_ - - [2] chrF++: words helping character n-grams by Maja Popović `chrF++ score`_ - - """ - if not isinstance(n_char_order, int) or n_char_order < 1: - raise ValueError("Expected argument `n_char_order` to be an integer greater than or equal to 1.") - if not isinstance(n_word_order, int) or n_word_order < 0: - raise ValueError("Expected argument `n_word_order` to be an integer greater than or equal to 0.") - if beta < 0: - raise ValueError("Expected argument `beta` to be greater than 0.") - - n_order = float(n_char_order + n_word_order) - - ( - total_preds_char_n_grams, - total_preds_word_n_grams, - total_target_char_n_grams, - total_target_word_n_grams, - total_matching_char_n_grams, - total_matching_word_n_grams, - ) = _prepare_n_grams_dicts(n_char_order, n_word_order) - - sentence_chrf_score: Optional[List[Tensor]] = [] if return_sentence_level_score else None - - ( - total_preds_char_n_grams, - total_preds_word_n_grams, - total_target_char_n_grams, - total_target_word_n_grams, - total_matching_char_n_grams, - total_matching_word_n_grams, - sentence_chrf_score, - ) = _chrf_score_update( - preds, - target, - total_preds_char_n_grams, - total_preds_word_n_grams, - total_target_char_n_grams, - total_target_word_n_grams, - total_matching_char_n_grams, - total_matching_word_n_grams, - n_char_order, - n_word_order, - n_order, - beta, - lowercase, - whitespace, - sentence_chrf_score, - ) - - chrf_f_score = _chrf_score_compute( - total_preds_char_n_grams, - total_preds_word_n_grams, - total_target_char_n_grams, - total_target_word_n_grams, - total_matching_char_n_grams, - total_matching_word_n_grams, - n_order, - beta, - ) - - if sentence_chrf_score: - return chrf_f_score, torch.cat(sentence_chrf_score) - return chrf_f_score diff --git a/src/torchmetrics/text/__init__.py b/src/torchmetrics/text/__init__.py index 48807a98fc4..23ff3c1e037 100644 --- a/src/torchmetrics/text/__init__.py +++ b/src/torchmetrics/text/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. from torchmetrics.text.bleu import BLEUScore from torchmetrics.text.cer import CharErrorRate -from torchmetrics.text.chrf import CHRFScore from torchmetrics.text.edit import EditDistance from torchmetrics.text.eed import ExtendedEditDistance from torchmetrics.text.mer import MatchErrorRate diff --git a/src/torchmetrics/text/_deprecated.py b/src/torchmetrics/text/_deprecated.py index d3ba1c4010e..166d48a37e9 100644 --- a/src/torchmetrics/text/_deprecated.py +++ b/src/torchmetrics/text/_deprecated.py @@ -2,7 +2,6 @@ from torchmetrics.text.bleu import BLEUScore from torchmetrics.text.cer import CharErrorRate -from torchmetrics.text.chrf import CHRFScore from torchmetrics.text.eed import ExtendedEditDistance from torchmetrics.text.mer import MatchErrorRate from torchmetrics.text.perplexity import Perplexity @@ -56,39 +55,6 @@ def __init__( super().__init__(**kwargs) -class _CHRFScore(CHRFScore): - """Wrapper for deprecated import. - - >>> preds = ['the cat is on the mat'] - >>> target = [['there is a cat on the mat', 'a cat is on the mat']] - >>> chrf = _CHRFScore() - >>> chrf(preds, target) - tensor(0.8640) - - """ - - def __init__( - self, - n_char_order: int = 6, - n_word_order: int = 2, - beta: float = 2.0, - lowercase: bool = False, - whitespace: bool = False, - return_sentence_level_score: bool = False, - **kwargs: Any, - ) -> None: - _deprecated_root_import_class("CHRFScore", "text") - super().__init__( - n_char_order=n_char_order, - n_word_order=n_word_order, - beta=beta, - lowercase=lowercase, - whitespace=whitespace, - return_sentence_level_score=return_sentence_level_score, - **kwargs, - ) - - class _ExtendedEditDistance(ExtendedEditDistance): """Wrapper for deprecated import. diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py deleted file mode 100644 index 1ff412ab1a4..00000000000 --- a/src/torchmetrics/text/chrf.py +++ /dev/null @@ -1,249 +0,0 @@ -# Copyright The Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# referenced from -# Library Name: torchtext -# Authors: torchtext authors and @sluks -# Date: 2021-11-25 -# Link: - -import itertools -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union - -import torch -from torch import Tensor, tensor - -from torchmetrics import Metric -from torchmetrics.functional.text.chrf import _chrf_score_compute, _chrf_score_update, _prepare_n_grams_dicts -from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE -from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE - -if not _MATPLOTLIB_AVAILABLE: - __doctest_skip__ = ["CHRFScore.plot"] - - -_N_GRAM_LEVELS = ("char", "word") -_TEXT_LEVELS = ("preds", "target", "matching") - -_DICT_STATES_NAMES = ( - "total_preds_char_n_grams", - "total_preds_word_n_grams", - "total_target_char_n_grams", - "total_target_word_n_grams", - "total_matching_char_n_grams", - "total_matching_word_n_grams", -) - -_DICT_STATES_TYPES = Tuple[ - Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor] -] - - -class CHRFScore(Metric): - """Calculate `chrf score`_ of machine translated text with one or more references. - - This implementation supports both ChrF score computation introduced in `chrF score`_ and `chrF++ score`_ introduced - in `chrF++ score`_. This implementation follows the implementations from https://github.com/m-popovic/chrF and - https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. - - As input to ``forward`` and ``update`` the metric accepts the following input: - - - ``preds`` (:class:`~Sequence`): An iterable of hypothesis corpus - - ``target`` (:class:`~Sequence`): An iterable of iterables of reference corpus - - As output of ``forward`` and ``compute`` the metric returns the following output: - - - ``chrf`` (:class:`~torch.Tensor`): If `return_sentence_level_score=True` return a list of sentence-level - chrF/chrF++ scores, else return a corpus-level chrF/chrF++ score - - Args: - n_char_order: A character n-gram order. If ``n_char_order=6``, the metrics refers to the official chrF/chrF++. - n_word_order: A word n-gram order. If ``n_word_order=2``, the metric refers to the official chrF++. - If ``n_word_order=0``, the metric is equivalent to the original ChrF. - beta: parameter determining an importance of recall w.r.t. precision. If ``beta=1``, their importance is equal. - lowercase: An indication whether to enable case-insensitivity. - whitespace: An indication whether keep whitespaces during n-gram extraction. - return_sentence_level_score: An indication whether a sentence-level chrF/chrF++ score to be returned. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - - Raises: - ValueError: - If ``n_char_order`` is not an integer greater than or equal to 1. - ValueError: - If ``n_word_order`` is not an integer greater than or equal to 0. - ValueError: - If ``beta`` is smaller than 0. - - Example: - >>> from torchmetrics.text import CHRFScore - >>> preds = ['the cat is on the mat'] - >>> target = [['there is a cat on the mat', 'a cat is on the mat']] - >>> chrf = CHRFScore() - >>> chrf(preds, target) - tensor(0.8640) - - """ - - is_differentiable: bool = False - higher_is_better: bool = True - full_state_update: bool = True - plot_lower_bound: float = 0.0 - plot_upper_bound: float = 1.0 - - sentence_chrf_score: Optional[List[Tensor]] = None - - def __init__( - self, - n_char_order: int = 6, - n_word_order: int = 2, - beta: float = 2.0, - lowercase: bool = False, - whitespace: bool = False, - return_sentence_level_score: bool = False, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - if not isinstance(n_char_order, int) or n_char_order < 1: - raise ValueError("Expected argument `n_char_order` to be an integer greater than or equal to 1.") - self.n_char_order = n_char_order - if not isinstance(n_word_order, int) or n_word_order < 0: - raise ValueError("Expected argument `n_word_order` to be an integer greater than or equal to 0.") - self.n_word_order = n_word_order - if beta < 0: - raise ValueError("Expected argument `beta` to be greater than 0.") - self.beta = beta - self.lowercase = lowercase - self.whitespace = whitespace - self.return_sentence_level_score = return_sentence_level_score - - self.n_order = float(n_char_order + n_word_order) - - # Adding state dynamically - for (n_gram_level, n_gram_order), text in self._get_text_n_gram_iterator(): - for n in range(1, n_gram_order + 1): - state_name = self._get_state_name(text, n_gram_level, n) - self.add_state(state_name, tensor(0.0), dist_reduce_fx="sum") - - if self.return_sentence_level_score: - self.add_state("sentence_chrf_score", [], dist_reduce_fx="cat") - - def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: - """Update state with predictions and targets.""" - n_grams_dicts_tuple = _chrf_score_update( - preds, - target, - *self._convert_states_to_dicts(), - self.n_char_order, - self.n_word_order, - self.n_order, - self.beta, - self.lowercase, - self.whitespace, - self.sentence_chrf_score if self.return_sentence_level_score else None, - ) - self._update_states_from_dicts(n_grams_dicts_tuple[:-1]) - if self.sentence_chrf_score is not None: - self.sentence_chrf_score = n_grams_dicts_tuple[-1] - - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: - """Calculate chrF/chrF++ score.""" - if self.sentence_chrf_score is not None: - return ( - _chrf_score_compute(*self._convert_states_to_dicts(), self.n_order, self.beta), - torch.cat(self.sentence_chrf_score), - ) - return _chrf_score_compute(*self._convert_states_to_dicts(), self.n_order, self.beta) - - def _convert_states_to_dicts(self) -> _DICT_STATES_TYPES: - """Convert global metric states to the n-gram dictionaries to be passed in ``_chrf_score_update``.""" - n_grams_dicts: Dict[str, Dict[int, Tensor]] = dict( - zip(_DICT_STATES_NAMES, _prepare_n_grams_dicts(self.n_char_order, self.n_word_order)) - ) - - for (n_gram_level, n_gram_order), text in self._get_text_n_gram_iterator(): - for n in range(1, n_gram_order + 1): - dict_name = self._get_dict_name(text, n_gram_level) - state_name = self._get_state_name(text, n_gram_level, n) - - n_grams_dicts[dict_name][n] = getattr(self, state_name) - - return tuple(n_grams_dicts.values()) # type: ignore - - def _update_states_from_dicts(self, n_grams_dicts_tuple: _DICT_STATES_TYPES) -> None: - """Update global metric states based on the n-gram dictionaries calculated on the current batch.""" - n_grams_dicts = dict(zip(_DICT_STATES_NAMES, n_grams_dicts_tuple)) - for (n_gram_level, n_gram_order), text in self._get_text_n_gram_iterator(): - for n in range(1, n_gram_order + 1): - dict_name = self._get_dict_name(text, n_gram_level) - state_name = self._get_state_name(text, n_gram_level, n) - - setattr(self, state_name, n_grams_dicts[dict_name][n]) - - @staticmethod - def _get_dict_name(text: str, n_gram_level: str) -> str: - """Return a dictionary name w.r.t input args.""" - return f"total_{text}_{n_gram_level}_n_grams" - - @staticmethod - def _get_state_name(text: str, n_gram_level: str, n: int) -> str: - """Return a metric state name w.r.t input args.""" - return f"total_{text}_{n_gram_level}_{n}_grams" - - def _get_text_n_gram_iterator(self) -> Iterator[Tuple[Tuple[str, int], str]]: - """Get iterator over char/word and reference/hypothesis/matching n-gram level.""" - return itertools.product(zip(_N_GRAM_LEVELS, [self.n_char_order, self.n_word_order]), _TEXT_LEVELS) - - def plot( - self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None - ) -> _PLOT_OUT_TYPE: - """Plot a single or multiple values from the metric. - - Args: - val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. - If no value is provided, will automatically call `metric.compute` and plot that result. - ax: An matplotlib axis object. If provided will add plot to that axis - - Returns: - Figure and Axes object - - Raises: - ModuleNotFoundError: - If `matplotlib` is not installed - - .. plot:: - :scale: 75 - - >>> # Example plotting a single value - >>> from torchmetrics.text import CHRFScore - >>> metric = CHRFScore() - >>> preds = ['the cat is on the mat'] - >>> target = [['there is a cat on the mat', 'a cat is on the mat']] - >>> metric.update(preds, target) - >>> fig_, ax_ = metric.plot() - - .. plot:: - :scale: 75 - - >>> # Example plotting multiple values - >>> from torchmetrics.text import CHRFScore - >>> metric = CHRFScore() - >>> preds = ['the cat is on the mat'] - >>> target = [['there is a cat on the mat', 'a cat is on the mat']] - >>> values = [ ] - >>> for _ in range(10): - ... values.append(metric(preds, target)) - >>> fig_, ax_ = metric.plot(values) - - """ - return self._plot(val, ax) diff --git a/tests/unittests/deprecations/root_class_imports.py b/tests/unittests/deprecations/root_class_imports.py index 5c4aa1a7155..2ae8f5f20a2 100644 --- a/tests/unittests/deprecations/root_class_imports.py +++ b/tests/unittests/deprecations/root_class_imports.py @@ -6,7 +6,6 @@ from torchmetrics import ( BLEUScore, CharErrorRate, - CHRFScore, ErrorRelativeGlobalDimensionlessSynthesis, ExtendedEditDistance, MatchErrorRate, @@ -86,7 +85,6 @@ # Text BLEUScore, CharErrorRate, - CHRFScore, ExtendedEditDistance, MatchErrorRate, Perplexity, diff --git a/tests/unittests/text/test_chrf.py b/tests/unittests/text/test_chrf.py deleted file mode 100644 index 233c9451381..00000000000 --- a/tests/unittests/text/test_chrf.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright The Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from functools import partial -from typing import Sequence - -import pytest -from torch import Tensor, tensor -from torchmetrics.functional.text.chrf import chrf_score -from torchmetrics.text.chrf import CHRFScore - -from unittests.text._helpers import TextTester -from unittests.text._inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references - - -def _reference_sacrebleu_chrf( - preds: Sequence[str], - targets: Sequence[Sequence[str]], - char_order: int, - word_order: int, - lowercase: bool, - whitespace: bool, -) -> Tensor: - try: - from sacrebleu import CHRF - except ImportError: - pytest.skip("test requires sacrebleu package to be installed") - - sacrebleu_chrf = CHRF( - char_order=char_order, word_order=word_order, lowercase=lowercase, whitespace=whitespace, eps_smoothing=True - ) - # Sacrebleu CHRF expects different format of input - targets = [[target[i] for target in targets] for i in range(len(targets[0]))] - sacrebleu_chrf = sacrebleu_chrf.corpus_score(preds, targets).score / 100 - return tensor(sacrebleu_chrf) - - -@pytest.mark.parametrize( - ["char_order", "word_order", "lowercase", "whitespace"], - [ - (6, 2, False, False), - (6, 2, False, True), - (4, 2, True, False), - (6, 0, True, False), - (6, 0, True, True), - (4, 0, False, True), - ], -) -@pytest.mark.parametrize( - ["preds", "targets"], - [(_inputs_multiple_references.preds, _inputs_multiple_references.target)], -) -class TestCHRFScore(TextTester): - """Test class for `CHRFScore` metric.""" - - @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_chrf_score_class(self, ddp, preds, targets, char_order, word_order, lowercase, whitespace): - """Test class implementation of metric.""" - metric_args = { - "n_char_order": char_order, - "n_word_order": word_order, - "lowercase": lowercase, - "whitespace": whitespace, - } - nltk_metric = partial( - _reference_sacrebleu_chrf, - char_order=char_order, - word_order=word_order, - lowercase=lowercase, - whitespace=whitespace, - ) - - self.run_class_metric_test( - ddp=ddp, - preds=preds, - targets=targets, - metric_class=CHRFScore, - reference_metric=nltk_metric, - metric_args=metric_args, - ) - - def test_chrf_score_functional(self, preds, targets, char_order, word_order, lowercase, whitespace): - """Test functional implementation of metric.""" - metric_args = { - "n_char_order": char_order, - "n_word_order": word_order, - "lowercase": lowercase, - "whitespace": whitespace, - } - nltk_metric = partial( - _reference_sacrebleu_chrf, - char_order=char_order, - word_order=word_order, - lowercase=lowercase, - whitespace=whitespace, - ) - - self.run_functional_metric_test( - preds, - targets, - metric_functional=chrf_score, - reference_metric=nltk_metric, - metric_args=metric_args, - ) - - def test_chrf_score_differentiability(self, preds, targets, char_order, word_order, lowercase, whitespace): - """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" - metric_args = { - "n_char_order": char_order, - "n_word_order": word_order, - "lowercase": lowercase, - "whitespace": whitespace, - } - - self.run_differentiability_test( - preds=preds, - targets=targets, - metric_module=CHRFScore, - metric_functional=chrf_score, - metric_args=metric_args, - ) - - -def test_chrf_empty_functional(): - """Test that eed returns 0 when no input is provided.""" - preds = [] - targets = [[]] - assert chrf_score(preds, targets) == tensor(0.0) - - -def test_chrf_empty_class(): - """Test that eed returns 0 when no input is provided.""" - chrf = CHRFScore() - preds = [] - targets = [[]] - assert chrf(preds, targets) == tensor(0.0) - - -def test_chrf_return_sentence_level_score_functional(): - """Test that chrf can return sentence level scores.""" - preds = _inputs_single_sentence_multiple_references.preds - targets = _inputs_single_sentence_multiple_references.target - _, chrf_sentence_score = chrf_score(preds, targets, return_sentence_level_score=True) - isinstance(chrf_sentence_score, Tensor) - - -def test_chrf_return_sentence_level_class(): - """Test that chrf can return sentence level scores.""" - chrf = CHRFScore(return_sentence_level_score=True) - preds = _inputs_single_sentence_multiple_references.preds - targets = _inputs_single_sentence_multiple_references.target - _, chrf_sentence_score = chrf(preds, targets) - isinstance(chrf_sentence_score, Tensor) From 86bafc67e845fc5e48c420b2f0d372cd5990748c Mon Sep 17 00:00:00 2001 From: jirka Date: Fri, 2 Aug 2024 13:15:30 +0200 Subject: [PATCH 2/9] ex --- src/torchmetrics/text/_deprecated.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/torchmetrics/text/_deprecated.py b/src/torchmetrics/text/_deprecated.py index 166d48a37e9..c75acafde37 100644 --- a/src/torchmetrics/text/_deprecated.py +++ b/src/torchmetrics/text/_deprecated.py @@ -55,6 +55,22 @@ def __init__( super().__init__(**kwargs) +class _CHRFScore: + """Wrapper for deprecated import.""" + + def __init__( + self, + n_char_order: int = 6, + n_word_order: int = 2, + beta: float = 2.0, + lowercase: bool = False, + whitespace: bool = False, + return_sentence_level_score: bool = False, + **kwargs: Any, + ) -> None: + raise NotImplementedError("Chrf was temporarily removed.") + + class _ExtendedEditDistance(ExtendedEditDistance): """Wrapper for deprecated import. From cf2a49651f5bea9882ca23edd76a55546530ae33 Mon Sep 17 00:00:00 2001 From: jirka Date: Fri, 2 Aug 2024 14:03:59 +0200 Subject: [PATCH 3/9] simplify --- docs/source/conf.py | 2 + docs/source/links.rst | 2 + docs/source/text/chrf_score.rst | 21 ++ src/torchmetrics/functional/text/__init__.py | 1 + .../functional/text/_deprecated.py | 13 +- src/torchmetrics/functional/text/chrf.py | 58 +++++ src/torchmetrics/text/__init__.py | 1 + src/torchmetrics/text/_deprecated.py | 14 +- src/torchmetrics/text/chrf.py | 206 ++++++++++++++++++ .../deprecations/root_class_imports.py | 2 + 10 files changed, 317 insertions(+), 3 deletions(-) create mode 100644 docs/source/text/chrf_score.rst create mode 100644 src/torchmetrics/functional/text/chrf.py create mode 100644 src/torchmetrics/text/chrf.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 9ef621f90df..9484761ba7a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -438,6 +438,8 @@ def linkcode_resolve(domain, info) -> Optional[str]: # noqa: ANN001 "https://ieeexplore.ieee.org/abstract/document/4317530", # Robust parameter estimation with a small bias against heavy contamination "https://www.sciencedirect.com/science/article/pii/S0047259X08000456", + # chrF++: words helping character n-grams + "https://aclanthology.org/W17-4770", # A wavelet transform method to merge Landsat TM and SPOT panchromatic data "https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013", ] diff --git a/docs/source/links.rst b/docs/source/links.rst index d7038032e1e..1fc3ab5755d 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -77,6 +77,8 @@ .. _Scikit_Learn-Ranking.py: https: //github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py .. _Verified Uncertainty Calibration: https://arxiv.org/abs/1909.10155 .. _SQuAD Metric: https://arxiv.org/abs/1606.05250 +.. _chrF score: https://aclanthology.org/W15-3049 +.. _chrF++ score: https://aclanthology.org/W17-4770 .. _TER: https://aclanthology.org/2006.amta-papers.25 .. _ExtendedEditDistance: https://aclanthology.org/W19-5359 .. _MultiScaleSSIM: https://ece.uwaterloo.ca/~z70wang/publications/msssim diff --git a/docs/source/text/chrf_score.rst b/docs/source/text/chrf_score.rst new file mode 100644 index 00000000000..c541014ae61 --- /dev/null +++ b/docs/source/text/chrf_score.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: ChrF Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/summarization.svg + :tags: Text + +.. include:: ../links.rst + +########## +ChrF Score +########## + +Module Interface +________________ + +.. autoclass:: torchmetrics.text.CHRFScore + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.text.chrf_score diff --git a/src/torchmetrics/functional/text/__init__.py b/src/torchmetrics/functional/text/__init__.py index 14ba75ce5a7..9282be6fbae 100644 --- a/src/torchmetrics/functional/text/__init__.py +++ b/src/torchmetrics/functional/text/__init__.py @@ -14,6 +14,7 @@ from torchmetrics.functional.text.bleu import bleu_score from torchmetrics.functional.text.cer import char_error_rate +from torchmetrics.functional.text.chrf import chrf_score from torchmetrics.functional.text.edit import edit_distance from torchmetrics.functional.text.eed import extended_edit_distance from torchmetrics.functional.text.mer import match_error_rate diff --git a/src/torchmetrics/functional/text/_deprecated.py b/src/torchmetrics/functional/text/_deprecated.py index 73fc2623928..62ea5645048 100644 --- a/src/torchmetrics/functional/text/_deprecated.py +++ b/src/torchmetrics/functional/text/_deprecated.py @@ -8,6 +8,7 @@ from torchmetrics.functional.text.bert import bert_score from torchmetrics.functional.text.bleu import bleu_score from torchmetrics.functional.text.cer import char_error_rate +from torchmetrics.functional.text.chrf import chrf_score from torchmetrics.functional.text.eed import extended_edit_distance from torchmetrics.functional.text.infolm import ( _ALLOWED_INFORMATION_MEASURE_LITERAL as _INFOLM_ALLOWED_INFORMATION_MEASURE_LITERAL, @@ -134,7 +135,17 @@ def _chrf_score( return_sentence_level_score: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Wrapper for deprecated import.""" - raise NotImplementedError("Chrf was temporarily removed.") + _deprecated_root_import_func("chrf_score", "text") + return chrf_score( + preds=preds, + target=target, + n_char_order=n_char_order, + n_word_order=n_word_order, + beta=beta, + lowercase=lowercase, + whitespace=whitespace, + return_sentence_level_score=return_sentence_level_score, + ) def _extended_edit_distance( diff --git a/src/torchmetrics/functional/text/chrf.py b/src/torchmetrics/functional/text/chrf.py new file mode 100644 index 00000000000..256b616af0e --- /dev/null +++ b/src/torchmetrics/functional/text/chrf.py @@ -0,0 +1,58 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Sequence, Tuple, Union + +from torch import Tensor + + +def chrf_score( + preds: Union[str, Sequence[str]], + target: Sequence[Union[str, Sequence[str]]], + n_char_order: int = 6, + n_word_order: int = 2, + beta: float = 2.0, + lowercase: bool = False, + whitespace: bool = False, + return_sentence_level_score: bool = False, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Calculate `chrF score`_ of machine translated text with one or more references. + + This implementation supports both chrF score computation introduced in [1] and chrF++ score introduced in + `chrF++ score`_. This implementation follows the implementations from https://github.com/m-popovic/chrF and + https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. + + .. attention:: + CHRF has been temporarily removed from the TorchMetrics package. + + Args: + preds: An iterable of hypothesis corpus. + target: An iterable of iterables of reference corpus. + n_char_order: + A character n-gram order. If `n_char_order=6`, the metrics refers to the official chrF/chrF++. + n_word_order: + A word n-gram order. If `n_word_order=2`, the metric refers to the official chrF++. If `n_word_order=0`, the + metric is equivalent to the original chrF. + beta: + A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal. + lowercase: An indication whether to enable case-insensitivity. + whitespace: An indication whether to keep whitespaces during character n-gram extraction. + return_sentence_level_score: An indication whether a sentence-level chrF/chrF++ score to be returned. + + References: + [1] chrF: character n-gram F-score for automatic MT evaluation by Maja Popović `chrF score`_ + + [2] chrF++: words helping character n-grams by Maja Popović `chrF++ score`_ + + """ + raise NotImplementedError("ChrF has been temporarily removed from the TorchMetrics package.") diff --git a/src/torchmetrics/text/__init__.py b/src/torchmetrics/text/__init__.py index 23ff3c1e037..48807a98fc4 100644 --- a/src/torchmetrics/text/__init__.py +++ b/src/torchmetrics/text/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from torchmetrics.text.bleu import BLEUScore from torchmetrics.text.cer import CharErrorRate +from torchmetrics.text.chrf import CHRFScore from torchmetrics.text.edit import EditDistance from torchmetrics.text.eed import ExtendedEditDistance from torchmetrics.text.mer import MatchErrorRate diff --git a/src/torchmetrics/text/_deprecated.py b/src/torchmetrics/text/_deprecated.py index c75acafde37..77e32730711 100644 --- a/src/torchmetrics/text/_deprecated.py +++ b/src/torchmetrics/text/_deprecated.py @@ -2,6 +2,7 @@ from torchmetrics.text.bleu import BLEUScore from torchmetrics.text.cer import CharErrorRate +from torchmetrics.text.chrf import CHRFScore from torchmetrics.text.eed import ExtendedEditDistance from torchmetrics.text.mer import MatchErrorRate from torchmetrics.text.perplexity import Perplexity @@ -55,7 +56,7 @@ def __init__( super().__init__(**kwargs) -class _CHRFScore: +class _CHRFScore(CHRFScore): """Wrapper for deprecated import.""" def __init__( @@ -68,7 +69,16 @@ def __init__( return_sentence_level_score: bool = False, **kwargs: Any, ) -> None: - raise NotImplementedError("Chrf was temporarily removed.") + _deprecated_root_import_class("CHRFScore", "text") + super().__init__( + n_char_order=n_char_order, + n_word_order=n_word_order, + beta=beta, + lowercase=lowercase, + whitespace=whitespace, + return_sentence_level_score=return_sentence_level_score, + **kwargs, + ) class _ExtendedEditDistance(ExtendedEditDistance): diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py new file mode 100644 index 00000000000..e27f04f37ee --- /dev/null +++ b/src/torchmetrics/text/chrf.py @@ -0,0 +1,206 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union + +from torch import Tensor + +from torchmetrics import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["CHRFScore.plot"] + + +_N_GRAM_LEVELS = ("char", "word") +_TEXT_LEVELS = ("preds", "target", "matching") + +_DICT_STATES_NAMES = ( + "total_preds_char_n_grams", + "total_preds_word_n_grams", + "total_target_char_n_grams", + "total_target_word_n_grams", + "total_matching_char_n_grams", + "total_matching_word_n_grams", +) + +_DICT_STATES_TYPES = Tuple[ + Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor], Dict[int, Tensor] +] + + +class CHRFScore(Metric): + """Calculate `chrf score`_ of machine translated text with one or more references. + + This implementation supports both ChrF score computation introduced in `chrF score`_ and `chrF++ score`_ introduced + in `chrF++ score`_. This implementation follows the implementations from https://github.com/m-popovic/chrF and + https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. + + .. attention:: + CHRF has been temporarily removed from the TorchMetrics package. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~Sequence`): An iterable of hypothesis corpus + - ``target`` (:class:`~Sequence`): An iterable of iterables of reference corpus + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``chrf`` (:class:`~torch.Tensor`): If `return_sentence_level_score=True` return a list of sentence-level + chrF/chrF++ scores, else return a corpus-level chrF/chrF++ score + + Args: + n_char_order: A character n-gram order. If ``n_char_order=6``, the metrics refers to the official chrF/chrF++. + n_word_order: A word n-gram order. If ``n_word_order=2``, the metric refers to the official chrF++. + If ``n_word_order=0``, the metric is equivalent to the original ChrF. + beta: parameter determining an importance of recall w.r.t. precision. If ``beta=1``, their importance is equal. + lowercase: An indication whether to enable case-insensitivity. + whitespace: An indication whether keep whitespaces during n-gram extraction. + return_sentence_level_score: An indication whether a sentence-level chrF/chrF++ score to be returned. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + """ + + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = True + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + sentence_chrf_score: Optional[List[Tensor]] = None + + def __init__( + self, + n_char_order: int = 6, + n_word_order: int = 2, + beta: float = 2.0, + lowercase: bool = False, + whitespace: bool = False, + return_sentence_level_score: bool = False, + **kwargs: Any, + ) -> None: + # super().__init__(**kwargs) + # + # if not isinstance(n_char_order, int) or n_char_order < 1: + # raise ValueError("Expected argument `n_char_order` to be an integer greater than or equal to 1.") + # self.n_char_order = n_char_order + # if not isinstance(n_word_order, int) or n_word_order < 0: + # raise ValueError("Expected argument `n_word_order` to be an integer greater than or equal to 0.") + # self.n_word_order = n_word_order + # if beta < 0: + # raise ValueError("Expected argument `beta` to be greater than 0.") + # self.beta = beta + # self.lowercase = lowercase + # self.whitespace = whitespace + # self.return_sentence_level_score = return_sentence_level_score + # + # self.n_order = float(n_char_order + n_word_order) + # + # # Adding state dynamically + # for (n_gram_level, n_gram_order), text in self._get_text_n_gram_iterator(): + # for n in range(1, n_gram_order + 1): + # state_name = self._get_state_name(text, n_gram_level, n) + # self.add_state(state_name, tensor(0.0), dist_reduce_fx="sum") + # + # if self.return_sentence_level_score: + # self.add_state("sentence_chrf_score", [], dist_reduce_fx="cat") + raise NotImplementedError("ChrF has been temporarily removed from the TorchMetrics package.") + + def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: + """Update state with predictions and targets.""" + # n_grams_dicts_tuple = _chrf_score_update( + # preds, + # target, + # *self._convert_states_to_dicts(), + # self.n_char_order, + # self.n_word_order, + # self.n_order, + # self.beta, + # self.lowercase, + # self.whitespace, + # self.sentence_chrf_score if self.return_sentence_level_score else None, + # ) + # self._update_states_from_dicts(n_grams_dicts_tuple[:-1]) + # if self.sentence_chrf_score is not None: + # self.sentence_chrf_score = n_grams_dicts_tuple[-1] + + def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Calculate chrF/chrF++ score.""" + # if self.sentence_chrf_score is not None: + # return ( + # _chrf_score_compute(*self._convert_states_to_dicts(), self.n_order, self.beta), + # torch.cat(self.sentence_chrf_score), + # ) + # return _chrf_score_compute(*self._convert_states_to_dicts(), self.n_order, self.beta) + + def _convert_states_to_dicts(self) -> _DICT_STATES_TYPES: + """Convert global metric states to the n-gram dictionaries to be passed in ``_chrf_score_update``.""" + # n_grams_dicts: Dict[str, Dict[int, Tensor]] = dict( + # zip(_DICT_STATES_NAMES, _prepare_n_grams_dicts(self.n_char_order, self.n_word_order)) + # ) + # + # for (n_gram_level, n_gram_order), text in self._get_text_n_gram_iterator(): + # for n in range(1, n_gram_order + 1): + # dict_name = self._get_dict_name(text, n_gram_level) + # state_name = self._get_state_name(text, n_gram_level, n) + # + # n_grams_dicts[dict_name][n] = getattr(self, state_name) + # + # return tuple(n_grams_dicts.values()) # type: ignore + + def _update_states_from_dicts(self, n_grams_dicts_tuple: _DICT_STATES_TYPES) -> None: + """Update global metric states based on the n-gram dictionaries calculated on the current batch.""" + n_grams_dicts = dict(zip(_DICT_STATES_NAMES, n_grams_dicts_tuple)) + for (n_gram_level, n_gram_order), text in self._get_text_n_gram_iterator(): + for n in range(1, n_gram_order + 1): + dict_name = self._get_dict_name(text, n_gram_level) + state_name = self._get_state_name(text, n_gram_level, n) + + setattr(self, state_name, n_grams_dicts[dict_name][n]) + + @staticmethod + def _get_dict_name(text: str, n_gram_level: str) -> str: + """Return a dictionary name w.r.t input args.""" + return f"total_{text}_{n_gram_level}_n_grams" + + @staticmethod + def _get_state_name(text: str, n_gram_level: str, n: int) -> str: + """Return a metric state name w.r.t input args.""" + return f"total_{text}_{n_gram_level}_{n}_grams" + + def _get_text_n_gram_iterator(self) -> Iterator[Tuple[Tuple[str, int], str]]: + """Get iterator over char/word and reference/hypothesis/matching n-gram level.""" + return itertools.product(zip(_N_GRAM_LEVELS, [self.n_char_order, self.n_word_order]), _TEXT_LEVELS) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + """ + return self._plot(val, ax) diff --git a/tests/unittests/deprecations/root_class_imports.py b/tests/unittests/deprecations/root_class_imports.py index 2ae8f5f20a2..5c4aa1a7155 100644 --- a/tests/unittests/deprecations/root_class_imports.py +++ b/tests/unittests/deprecations/root_class_imports.py @@ -6,6 +6,7 @@ from torchmetrics import ( BLEUScore, CharErrorRate, + CHRFScore, ErrorRelativeGlobalDimensionlessSynthesis, ExtendedEditDistance, MatchErrorRate, @@ -85,6 +86,7 @@ # Text BLEUScore, CharErrorRate, + CHRFScore, ExtendedEditDistance, MatchErrorRate, Perplexity, From 6e421dd3c803024c19a2341d9b292190f791684b Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 2 Aug 2024 14:10:19 +0200 Subject: [PATCH 4/9] Apply suggestions from code review --- src/torchmetrics/functional/text/chrf.py | 2 +- src/torchmetrics/text/chrf.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/text/chrf.py b/src/torchmetrics/functional/text/chrf.py index 256b616af0e..cc85f66d7ba 100644 --- a/src/torchmetrics/functional/text/chrf.py +++ b/src/torchmetrics/functional/text/chrf.py @@ -33,7 +33,7 @@ def chrf_score( https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. .. attention:: - CHRF has been temporarily removed from the TorchMetrics package. + ChrF has been temporarily removed from the TorchMetrics package. Args: preds: An iterable of hypothesis corpus. diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py index e27f04f37ee..97beb167f38 100644 --- a/src/torchmetrics/text/chrf.py +++ b/src/torchmetrics/text/chrf.py @@ -50,7 +50,7 @@ class CHRFScore(Metric): https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. .. attention:: - CHRF has been temporarily removed from the TorchMetrics package. + ChrF has been temporarily removed from the TorchMetrics package. As input to ``forward`` and ``update`` the metric accepts the following input: From fd1099cf9c1b63ad6527f4ed28b903da9be140e6 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 2 Aug 2024 14:51:30 +0200 Subject: [PATCH 5/9] Update src/torchmetrics/text/chrf.py Co-authored-by: Luca Antiga --- src/torchmetrics/text/chrf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py index 97beb167f38..f73c230f077 100644 --- a/src/torchmetrics/text/chrf.py +++ b/src/torchmetrics/text/chrf.py @@ -50,7 +50,7 @@ class CHRFScore(Metric): https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. .. attention:: - ChrF has been temporarily removed from the TorchMetrics package. + ChrF has been temporarily removed from the TorchMetrics package due to licensing issues with the upstream package. As input to ``forward`` and ``update`` the metric accepts the following input: From 017f3746e67086855d432c4b267f95de35cb95f3 Mon Sep 17 00:00:00 2001 From: jirka Date: Fri, 2 Aug 2024 14:52:12 +0200 Subject: [PATCH 6/9] docs --- src/torchmetrics/functional/text/chrf.py | 4 ++-- src/torchmetrics/text/chrf.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/functional/text/chrf.py b/src/torchmetrics/functional/text/chrf.py index cc85f66d7ba..48c597cfd46 100644 --- a/src/torchmetrics/functional/text/chrf.py +++ b/src/torchmetrics/functional/text/chrf.py @@ -33,7 +33,7 @@ def chrf_score( https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. .. attention:: - ChrF has been temporarily removed from the TorchMetrics package. + ChrF has been temporarily removed from the TorchMetrics package due to licensing issues with the upstream package. Args: preds: An iterable of hypothesis corpus. @@ -55,4 +55,4 @@ def chrf_score( [2] chrF++: words helping character n-grams by Maja Popović `chrF++ score`_ """ - raise NotImplementedError("ChrF has been temporarily removed from the TorchMetrics package.") + raise NotImplementedError("ChrF has been temporarily removed from the TorchMetrics package due to licensing issues with the upstream package.") diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py index f73c230f077..e3018014655 100644 --- a/src/torchmetrics/text/chrf.py +++ b/src/torchmetrics/text/chrf.py @@ -117,7 +117,7 @@ def __init__( # # if self.return_sentence_level_score: # self.add_state("sentence_chrf_score", [], dist_reduce_fx="cat") - raise NotImplementedError("ChrF has been temporarily removed from the TorchMetrics package.") + raise NotImplementedError("ChrF has been temporarily removed from the TorchMetrics package due to licensing issues with the upstream package.") def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: """Update state with predictions and targets.""" From 13b45452c460293277f96801d366f071065bce12 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 2 Aug 2024 12:57:03 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/text/chrf.py | 4 +++- src/torchmetrics/text/chrf.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/text/chrf.py b/src/torchmetrics/functional/text/chrf.py index 48c597cfd46..add1372b1fb 100644 --- a/src/torchmetrics/functional/text/chrf.py +++ b/src/torchmetrics/functional/text/chrf.py @@ -55,4 +55,6 @@ def chrf_score( [2] chrF++: words helping character n-grams by Maja Popović `chrF++ score`_ """ - raise NotImplementedError("ChrF has been temporarily removed from the TorchMetrics package due to licensing issues with the upstream package.") + raise NotImplementedError( + "ChrF has been temporarily removed from the TorchMetrics package due to licensing issues with the upstream package." + ) diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py index e3018014655..e7fdc2022c9 100644 --- a/src/torchmetrics/text/chrf.py +++ b/src/torchmetrics/text/chrf.py @@ -117,7 +117,9 @@ def __init__( # # if self.return_sentence_level_score: # self.add_state("sentence_chrf_score", [], dist_reduce_fx="cat") - raise NotImplementedError("ChrF has been temporarily removed from the TorchMetrics package due to licensing issues with the upstream package.") + raise NotImplementedError( + "ChrF has been temporarily removed from the TorchMetrics package due to licensing issues with the upstream package." + ) def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: """Update state with predictions and targets.""" From bf4b36434cd9464575797591315b7b80ee935ffc Mon Sep 17 00:00:00 2001 From: jirka Date: Fri, 2 Aug 2024 15:01:44 +0200 Subject: [PATCH 8/9] types --- src/torchmetrics/functional/text/chrf.py | 8 ++++++-- src/torchmetrics/text/chrf.py | 14 +++++++++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/torchmetrics/functional/text/chrf.py b/src/torchmetrics/functional/text/chrf.py index 48c597cfd46..397777f7310 100644 --- a/src/torchmetrics/functional/text/chrf.py +++ b/src/torchmetrics/functional/text/chrf.py @@ -33,7 +33,8 @@ def chrf_score( https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. .. attention:: - ChrF has been temporarily removed from the TorchMetrics package due to licensing issues with the upstream package. + ChrF has been temporarily removed from the TorchMetrics package + due to licensing issues with the upstream package. Args: preds: An iterable of hypothesis corpus. @@ -55,4 +56,7 @@ def chrf_score( [2] chrF++: words helping character n-grams by Maja Popović `chrF++ score`_ """ - raise NotImplementedError("ChrF has been temporarily removed from the TorchMetrics package due to licensing issues with the upstream package.") + raise NotImplementedError( + "ChrF has been temporarily removed from the TorchMetrics package" + " due to licensing issues with the upstream package." + ) diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py index e3018014655..c7274d0b993 100644 --- a/src/torchmetrics/text/chrf.py +++ b/src/torchmetrics/text/chrf.py @@ -50,7 +50,8 @@ class CHRFScore(Metric): https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. .. attention:: - ChrF has been temporarily removed from the TorchMetrics package due to licensing issues with the upstream package. + ChrF has been temporarily removed from the TorchMetrics package + due to licensing issues with the upstream package. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -117,7 +118,10 @@ def __init__( # # if self.return_sentence_level_score: # self.add_state("sentence_chrf_score", [], dist_reduce_fx="cat") - raise NotImplementedError("ChrF has been temporarily removed from the TorchMetrics package due to licensing issues with the upstream package.") + raise NotImplementedError( + "ChrF has been temporarily removed from the TorchMetrics package" + " due to licensing issues with the upstream package." + ) def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: """Update state with predictions and targets.""" @@ -137,7 +141,7 @@ def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # if self.sentence_chrf_score is not None: # self.sentence_chrf_score = n_grams_dicts_tuple[-1] - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: # type: ignore[empty-body] """Calculate chrF/chrF++ score.""" # if self.sentence_chrf_score is not None: # return ( @@ -146,7 +150,7 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: # ) # return _chrf_score_compute(*self._convert_states_to_dicts(), self.n_order, self.beta) - def _convert_states_to_dicts(self) -> _DICT_STATES_TYPES: + def _convert_states_to_dicts(self) -> _DICT_STATES_TYPES: # type: ignore[empty-body] """Convert global metric states to the n-gram dictionaries to be passed in ``_chrf_score_update``.""" # n_grams_dicts: Dict[str, Dict[int, Tensor]] = dict( # zip(_DICT_STATES_NAMES, _prepare_n_grams_dicts(self.n_char_order, self.n_word_order)) @@ -183,7 +187,7 @@ def _get_state_name(text: str, n_gram_level: str, n: int) -> str: def _get_text_n_gram_iterator(self) -> Iterator[Tuple[Tuple[str, int], str]]: """Get iterator over char/word and reference/hypothesis/matching n-gram level.""" - return itertools.product(zip(_N_GRAM_LEVELS, [self.n_char_order, self.n_word_order]), _TEXT_LEVELS) + return itertools.product(zip(_N_GRAM_LEVELS, [self.n_char_order, self.n_word_order]), _TEXT_LEVELS) # type: ignore[return-value] def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None From 3960f552df006ea3c2efbcc886d06ba30fe3c880 Mon Sep 17 00:00:00 2001 From: jirka Date: Fri, 2 Aug 2024 15:41:57 +0200 Subject: [PATCH 9/9] types --- src/torchmetrics/text/chrf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/text/chrf.py b/src/torchmetrics/text/chrf.py index c7274d0b993..742388d06ff 100644 --- a/src/torchmetrics/text/chrf.py +++ b/src/torchmetrics/text/chrf.py @@ -187,7 +187,7 @@ def _get_state_name(text: str, n_gram_level: str, n: int) -> str: def _get_text_n_gram_iterator(self) -> Iterator[Tuple[Tuple[str, int], str]]: """Get iterator over char/word and reference/hypothesis/matching n-gram level.""" - return itertools.product(zip(_N_GRAM_LEVELS, [self.n_char_order, self.n_word_order]), _TEXT_LEVELS) # type: ignore[return-value] + return itertools.product(zip(_N_GRAM_LEVELS, [self.n_char_order, self.n_word_order]), _TEXT_LEVELS) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None