From e6f1736b2dc31d32af625e2b5694b22947482dd9 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Thu, 6 Jan 2022 15:16:30 +0100 Subject: [PATCH 01/22] [WIP] Unify some text metrics --- tests/text/test_bertscore.py | 10 +- torchmetrics/functional/text/bert.py | 66 ++++----- torchmetrics/functional/text/bleu.py | 92 ++++++------ torchmetrics/functional/text/chrf.py | 212 +++++++++++++-------------- torchmetrics/functional/text/wil.py | 67 ++++----- torchmetrics/functional/text/wip.py | 65 ++++---- torchmetrics/text/bert.py | 38 ++--- torchmetrics/text/bleu.py | 35 +++-- torchmetrics/text/chrf.py | 27 ++-- torchmetrics/text/sacre_bleu.py | 25 ++-- torchmetrics/text/wil.py | 25 ++-- torchmetrics/text/wip.py | 25 ++-- 12 files changed, 342 insertions(+), 345 deletions(-) diff --git a/tests/text/test_bertscore.py b/tests/text/test_bertscore.py index fe707fa8994..0def5e32fcf 100644 --- a/tests/text/test_bertscore.py +++ b/tests/text/test_bertscore.py @@ -205,7 +205,7 @@ def test_score(preds, refs): original_score = _parse_original_bert_score(original_score) Scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3) - Scorer.update(predictions=preds, references=refs) + Scorer.update(preds=preds, target=refs) metrics_score = Scorer.compute() for metric in _METRICS: @@ -223,7 +223,7 @@ def test_score_with_idf(preds, refs): original_score = _parse_original_bert_score(original_score) Scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=True, batch_size=3) - Scorer.update(predictions=preds, references=refs) + Scorer.update(preds=preds, target=refs) metrics_score = Scorer.compute() for metric in _METRICS: @@ -241,7 +241,7 @@ def test_score_all_layers(preds, refs): original_score = _parse_original_bert_score(original_score) Scorer = BERTScore(model_name_or_path=MODEL_NAME, all_layers=True, idf=False, batch_size=3) - Scorer.update(predictions=preds, references=refs) + Scorer.update(preds=preds, target=refs) metrics_score = Scorer.compute() for metric in _METRICS: @@ -259,7 +259,7 @@ def test_score_all_layers_with_idf(preds, refs): original_score = _parse_original_bert_score(original_score) Scorer = BERTScore(model_name_or_path=MODEL_NAME, all_layers=True, idf=True, batch_size=3) - Scorer.update(predictions=preds, references=refs) + Scorer.update(preds=preds, target=refs) metrics_score = Scorer.compute() for metric in _METRICS: @@ -280,7 +280,7 @@ def test_accumulation(preds, refs): Scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3) for p, r in zip(preds, refs): - Scorer.update(predictions=p, references=r) + Scorer.update(preds=p, target=r) metrics_score = Scorer.compute() for metric in _METRICS: diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index 9a80cf84362..9009049a24f 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -350,24 +350,24 @@ def _get_scaled_precision_or_recall(cos_sim: Tensor, metric: str, idf_scale: Ten def _get_precision_recall_f1( - pred_embeddings: Tensor, ref_embeddings: Tensor, pred_idf_scale: Tensor, ref_idf_scale: Tensor + preds_embeddings: Tensor, target_embeddings: Tensor, preds_idf_scale: Tensor, target_idf_scale: Tensor ) -> Tuple[Tensor, Tensor, Tensor]: """Calculate precision, recall and F1 score over candidate and reference sentences. Args: - pred_embeddings: Embeddings of candidate sentenecs. - ref_embeddings: Embeddings of reference sentences. - pred_idf_scale: An IDF scale factor for candidate sentences. - ref_idf_scale: An IDF scale factor for reference sentences. + preds_embeddings: Embeddings of candidate sentenecs. + target_embeddings: Embeddings of reference sentences. + preds_idf_scale: An IDF scale factor for candidate sentences. + target_idf_scale: An IDF scale factor for reference sentences. Return: Tensors containing precision, recall and F1 score, respectively. """ # Dimensions: b = batch_size, l = num_layers, p = predictions_seq_len, r = references_seq_len, d = bert_dim - cos_sim = torch.einsum("blpd, blrd -> blpr", pred_embeddings, ref_embeddings) + cos_sim = torch.einsum("blpd, blrd -> blpr", preds_embeddings, target_embeddings) # Final metrics shape = (batch_size * num_layers | batch_size) - precision = _get_scaled_precision_or_recall(cos_sim, "precision", pred_idf_scale) - recall = _get_scaled_precision_or_recall(cos_sim, "recall", ref_idf_scale) + precision = _get_scaled_precision_or_recall(cos_sim, "precision", preds_idf_scale) + recall = _get_scaled_precision_or_recall(cos_sim, "recall", target_idf_scale) f1_score = 2 * precision * recall / (precision + recall) f1_score = f1_score.masked_fill(torch.isnan(f1_score), 0.0) @@ -450,8 +450,8 @@ def _rescale_metrics_with_baseline( def bert_score( - predictions: Union[List[str], Dict[str, Tensor]], - references: Union[List[str], Dict[str, Tensor]], + preds: Union[List[str], Dict[str, Tensor]], + target: Union[List[str], Dict[str, Tensor]], model_name_or_path: Optional[str] = None, num_layers: Optional[int] = None, all_layers: bool = False, @@ -478,10 +478,10 @@ def bert_score( This implemenation follows the original implementation from `BERT_score`_ Args: - predictions: + preds: Either an iterable of predicted sentences or a `Dict[str, torch.Tensor]` containing `input_ids` and `attention_mask` `torch.Tensor`. - references: + target: Either an iterable of target sentences or a `Dict[str, torch.Tensor]` containing `input_ids` and `attention_mask` `torch.Tensor`. model_name_or_path: @@ -536,7 +536,7 @@ def bert_score( Raises: ValueError: - If `len(predictions) != len(references)`. + If `len(preds) != len(target)`. ValueError: If `tqdm` package is required and not installed. ValueError: @@ -548,14 +548,14 @@ def bert_score( Example: >>> from torchmetrics.functional.text.bert import bert_score - >>> predictions = ["hello there", "general kenobi"] - >>> references = ["hello there", "master kenobi"] - >>> bert_score(predictions=predictions, references=references, lang="en") # doctest: +SKIP + >>> preds = ["hello there", "general kenobi"] + >>> target = ["hello there", "master kenobi"] + >>> bert_score(preds=preds, target=target, lang="en") # doctest: +SKIP {'precision': [0.99..., 0.99...], 'recall': [0.99..., 0.99...], 'f1': [0.99..., 0.99...]} """ - if len(predictions) != len(references): + if len(preds) != len(target): raise ValueError("Number of predicted and reference sententes must be the same!") if verbose and (not _TQDM_AVAILABLE): @@ -585,12 +585,12 @@ def bert_score( except AttributeError: warnings.warn("It was not possible to retrieve the parameter `num_layers` from the model specification.") - _are_empty_lists = all(isinstance(text, list) and len(text) == 0 for text in (predictions, references)) + _are_empty_lists = all(isinstance(text, list) and len(text) == 0 for text in (preds, target)) _are_valid_lists = all( - isinstance(text, list) and len(text) > 0 and isinstance(text[0], str) for text in (predictions, references) + isinstance(text, list) and len(text) > 0 and isinstance(text[0], str) for text in (preds, target) ) _are_valid_tensors = all( - isinstance(text, dict) and isinstance(text["input_ids"], Tensor) for text in (predictions, references) + isinstance(text, dict) and isinstance(text["input_ids"], Tensor) for text in (preds, target) ) if _are_empty_lists: warnings.warn("Predictions and references are empty.") @@ -608,32 +608,32 @@ def bert_score( # We ignore mypy typing below as the proper typing is ensured by conditions above, only mypy cannot infer that. if _are_valid_lists: - ref_dataset = TextDataset(references, tokenizer, max_length, idf=idf) # type: ignore - pred_dataset = TextDataset( - predictions, # type: ignore + target_dataset = TextDataset(target, tokenizer, max_length, idf=idf) # type: ignore + preds_dataset = TextDataset( + preds, # type: ignore tokenizer, max_length, idf=idf, - tokens_idf=ref_dataset.tokens_idf, + tokens_idf=target_dataset.tokens_idf, ) elif _are_valid_tensors: - ref_dataset = TokenizedDataset(**references, idf=idf) # type: ignore - pred_dataset = TokenizedDataset(**predictions, idf=idf, tokens_idf=ref_dataset.tokens_idf) # type: ignore + target_dataset = TokenizedDataset(**target, idf=idf) # type: ignore + preds_dataset = TokenizedDataset(**preds, idf=idf, tokens_idf=target_dataset.tokens_idf) # type: ignore else: raise ValueError("Invalid input provided.") - ref_loader = DataLoader(ref_dataset, batch_size=batch_size, num_workers=num_threads) - pred_loader = DataLoader(pred_dataset, batch_size=batch_size, num_workers=num_threads) + target_loader = DataLoader(target_dataset, batch_size=batch_size, num_workers=num_threads) + preds_loader = DataLoader(preds_dataset, batch_size=batch_size, num_workers=num_threads) - ref_embeddings, ref_idf_scale = _get_embeddings_and_idf_scale( - ref_loader, ref_dataset.max_length, model, device, num_layers, all_layers, idf, verbose, user_forward_fn + target_embeddings, target_idf_scale = _get_embeddings_and_idf_scale( + target_loader, target_dataset.max_length, model, device, num_layers, all_layers, idf, verbose, user_forward_fn ) - pred_embeddings, pred_idf_scale = _get_embeddings_and_idf_scale( - pred_loader, pred_dataset.max_length, model, device, num_layers, all_layers, idf, verbose, user_forward_fn + preds_embeddings, preds_idf_scale = _get_embeddings_and_idf_scale( + preds_loader, preds_dataset.max_length, model, device, num_layers, all_layers, idf, verbose, user_forward_fn ) precision, recall, f1_score = _get_precision_recall_f1( - pred_embeddings, ref_embeddings, pred_idf_scale, ref_idf_scale + preds_embeddings, target_embeddings, preds_idf_scale, target_idf_scale ) if baseline is not None: diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index 19176f229fd..627799f6a71 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -58,57 +58,55 @@ def _tokenize_fn(sentence: str) -> Sequence[str]: def _bleu_score_update( - translate_corpus: Sequence[str], - reference_corpus: Sequence[Sequence[str]], + preds: Sequence[str], + target: Sequence[Sequence[str]], numerator: Tensor, denominator: Tensor, - trans_len: Tensor, - ref_len: Tensor, + preds_len: Tensor, + target_len: Tensor, n_gram: int = 4, tokenizer: Callable[[str], Sequence[str]] = _tokenize_fn, ) -> Tuple[Tensor, Tensor]: """Updates and returns variables required to compute the BLEU score. Args: - translate_corpus: An iterable of machine translated corpus - reference_corpus: An iterable of iterables of reference corpus + preds: An iterable of machine translated corpus + target: An iterable of iterables of reference corpus numerator: Numerator of precision score (true positives) denominator: Denominator of precision score (true positives + false positives) - trans_len: count of words in a candidate prediction - ref_len: count of words in a reference translation + preds_len: count of words in a candidate prediction + target: count of words in a reference translation n_gram: gram value ranged 1 to 4 tokenizer: A function that turns sentence into list of words """ - reference_corpus_: Sequence[Sequence[Sequence[str]]] = [ - [tokenizer(line) if line else [] for line in reference] for reference in reference_corpus - ] - translate_corpus_: Sequence[Sequence[str]] = [tokenizer(line) if line else [] for line in translate_corpus] + target_: Sequence[Sequence[Sequence[str]]] = [[tokenizer(line) if line else [] for line in t] for t in target] + preds_: Sequence[Sequence[str]] = [tokenizer(line) if line else [] for line in preds] - for (translation, references) in zip(translate_corpus_, reference_corpus_): - trans_len += len(translation) - ref_len_list = [len(ref) for ref in references] - ref_len_diff = [abs(len(translation) - x) for x in ref_len_list] - ref_len += ref_len_list[ref_len_diff.index(min(ref_len_diff))] - translation_counter: Counter = _count_ngram(translation, n_gram) - reference_counter: Counter = Counter() + for (pred, targets) in zip(preds_, target_): + preds_len += len(pred) + target_len_list = [len(ref) for ref in targets] + target_len_diff = [abs(len(pred) - x) for x in target_len_list] + target_len += target_len_list[target_len_diff.index(min(target_len_diff))] + preds_counter: Counter = _count_ngram(pred, n_gram) + target_counter: Counter = Counter() - for ref in references: - reference_counter |= _count_ngram(ref, n_gram) + for ref in targets: + target_counter |= _count_ngram(ref, n_gram) - ngram_counter_clip = translation_counter & reference_counter + ngram_counter_clip = preds_counter & target_counter for counter_clip in ngram_counter_clip: numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] - for counter in translation_counter: - denominator[len(counter) - 1] += translation_counter[counter] + for counter in preds_counter: + denominator[len(counter) - 1] += preds_counter[counter] - return trans_len, ref_len + return preds_len, target_len def _bleu_score_compute( - trans_len: Tensor, - ref_len: Tensor, + preds_len: Tensor, + target_len: Tensor, numerator: Tensor, denominator: Tensor, n_gram: int = 4, @@ -117,8 +115,8 @@ def _bleu_score_compute( """Computes the BLEU score. Args: - trans_len: count of words in a candidate translation - ref_len: count of words in a reference translation + preds_len: count of words in a candidate translation + target_len: count of words in a reference translation numerator: Numerator of precision score (true positives) denominator: Denominator of precision score (true positives + false positives) n_gram: gram value ranged 1 to 4 @@ -139,24 +137,24 @@ def _bleu_score_compute( log_precision_scores = tensor([1.0 / n_gram] * n_gram, device=device) * torch.log(precision_scores) geometric_mean = torch.exp(torch.sum(log_precision_scores)) - brevity_penalty = tensor(1.0, device=device) if trans_len > ref_len else torch.exp(1 - (ref_len / trans_len)) + brevity_penalty = tensor(1.0, device=device) if preds_len > target_len else torch.exp(1 - (target_len / preds_len)) bleu = brevity_penalty * geometric_mean return bleu def bleu_score( - translate_corpus: Union[str, Sequence[str]], - reference_corpus: Sequence[Union[str, Sequence[str]]], + preds: Union[str, Sequence[str]], + target: Sequence[Union[str, Sequence[str]]], n_gram: int = 4, smooth: bool = False, ) -> Tensor: """Calculate `BLEU score`_ of machine translated text with one or more references. Args: - translate_corpus: + preds: An iterable of machine translated corpus - reference_corpus: + target: An iterable of iterables of reference corpus n_gram: Gram value ranged from 1 to 4 (Default 4) @@ -168,9 +166,9 @@ def bleu_score( Example: >>> from torchmetrics.functional import bleu_score - >>> translate_corpus = ['the cat is on the mat'] - >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] - >>> bleu_score(translate_corpus, reference_corpus) + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> bleu_score(preds=preds, target=target) tensor(0.7598) References: @@ -184,21 +182,19 @@ def bleu_score( "Input order of targets and preds were changed to predictions firsts and targets second in v0.7." " Warning will be removed in v0.8." ) - translate_corpus_ = [translate_corpus] if isinstance(translate_corpus, str) else translate_corpus - reference_corpus_ = [ - [reference_text] if isinstance(reference_text, str) else reference_text for reference_text in reference_corpus - ] + preds_ = [preds] if isinstance(preds, str) else preds + target_ = [[target_text] if isinstance(target_text, str) else target_text for target_text in target] - if len(translate_corpus_) != len(reference_corpus_): - raise ValueError(f"Corpus has different size {len(translate_corpus_)} != {len(reference_corpus_)}") + if len(preds_) != len(target_): + raise ValueError(f"Corpus has different size {len(preds_)} != {len(target_)}") numerator = torch.zeros(n_gram) denominator = torch.zeros(n_gram) - trans_len = tensor(0, dtype=torch.float) - ref_len = tensor(0, dtype=torch.float) + preds_len = tensor(0.0) + target_len = tensor(0.0) - trans_len, ref_len = _bleu_score_update( - translate_corpus_, reference_corpus_, numerator, denominator, trans_len, ref_len, n_gram, _tokenize_fn + preds_len, target_len = _bleu_score_update( + preds_, target_, numerator, denominator, preds_len, target_len, n_gram, _tokenize_fn ) - return _bleu_score_compute(trans_len, ref_len, numerator, denominator, n_gram, smooth) + return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, smooth) diff --git a/torchmetrics/functional/text/chrf.py b/torchmetrics/functional/text/chrf.py index ba8b2214931..20891c3fa6d 100644 --- a/torchmetrics/functional/text/chrf.py +++ b/torchmetrics/functional/text/chrf.py @@ -311,11 +311,11 @@ def _get_n_gram_fscore( def _calculate_sentence_level_chrf_score( - references: List[str], - hyp_char_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], - hyp_word_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], - hyp_char_n_grams: Dict[int, Tensor], - hyp_word_n_grams: Dict[int, Tensor], + targets: List[str], + pred_char_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], + pred_word_n_grams_counts: Dict[int, Dict[Tuple[str, ...], Tensor]], + preds_char_n_grams: Dict[int, Tensor], + preds_word_n_grams: Dict[int, Tensor], n_char_order: int, n_word_order: int, n_order: float, @@ -327,15 +327,15 @@ def _calculate_sentence_level_chrf_score( are evaluated and score and statistics for the best matching reference is returned. Args: - references: + targets: An iterable of references. - hyp_char_n_grams_counts: + preds_char_n_grams_counts: A dictionary of dictionaries with hypothesis character n-grams. - hyp_word_n_grams_counts: + preds_word_n_grams_counts: A dictionary of dictionaries with hypothesis word n-grams. - hyp_char_n_grams: + pred_char_n_grams: A total number of hypothesis character n-grams. - hyp_word_n_grams: + pred_word_n_grams: A total number of hypothesis word n-grams. n_char_order: A character n-gram order. @@ -359,35 +359,35 @@ def _calculate_sentence_level_chrf_score( A total number of matching character n-grams between the best matching reference and hypothesis. matching_word_n_grams: A total number of matching word n-grams between the best matching reference and hypothesis. - ref_char_n_grams: + target_char_n_grams: A total number of reference character n-grams. - ref_word_n_grams: + target_word_n_grams: A total number of reference word n-grams. """ best_f_score = tensor(0.0) best_matching_char_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) best_matching_word_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) - best_ref_char_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) - best_ref_word_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + best_target_char_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) + best_target_word_n_grams: Dict[int, Tensor] = defaultdict(lambda: tensor(0.0)) - for reference in references: + for target in targets: ( - ref_char_n_grams_counts, - ref_word_n_grams_counts, - ref_char_n_grams, - ref_word_n_grams, - ) = _get_n_grams_counts_and_total_ngrams(reference, n_char_order, n_word_order, lowercase, whitespace) - matching_char_n_grams = _get_ngram_matches(ref_char_n_grams_counts, hyp_char_n_grams_counts) - matching_word_n_grams = _get_ngram_matches(ref_word_n_grams_counts, hyp_word_n_grams_counts) + target_char_n_grams_counts, + target_word_n_grams_counts, + target_char_n_grams, + target_word_n_grams, + ) = _get_n_grams_counts_and_total_ngrams(target, n_char_order, n_word_order, lowercase, whitespace) + matching_char_n_grams = _get_ngram_matches(target_char_n_grams_counts, pred_char_n_grams_counts) + matching_word_n_grams = _get_ngram_matches(target_word_n_grams_counts, pred_word_n_grams_counts) f_score = _calculate_fscore( matching_char_n_grams, matching_word_n_grams, - hyp_char_n_grams, - hyp_word_n_grams, - ref_char_n_grams, - ref_word_n_grams, + preds_char_n_grams, + preds_word_n_grams, + target_char_n_grams, + target_word_n_grams, n_order, beta, ) @@ -396,25 +396,25 @@ def _calculate_sentence_level_chrf_score( best_f_score = f_score best_matching_char_n_grams = matching_char_n_grams best_matching_word_n_grams = matching_word_n_grams - best_ref_char_n_grams = ref_char_n_grams - best_ref_word_n_grams = ref_word_n_grams + best_target_char_n_grams = target_char_n_grams + best_target_word_n_grams = target_word_n_grams return ( best_f_score, best_matching_char_n_grams, best_matching_word_n_grams, - best_ref_char_n_grams, - best_ref_word_n_grams, + best_target_char_n_grams, + best_target_word_n_grams, ) def _chrf_score_update( - hypothesis_corpus: Union[str, Sequence[str]], - reference_corpus: Union[Sequence[str], Sequence[Sequence[str]]], - total_hyp_char_n_grams: Dict[int, Tensor], - total_hyp_word_n_grams: Dict[int, Tensor], - total_ref_char_n_grams: Dict[int, Tensor], - total_ref_word_n_grams: Dict[int, Tensor], + preds: Union[str, Sequence[str]], + target: Union[Sequence[str], Sequence[Sequence[str]]], + total_preds_char_n_grams: Dict[int, Tensor], + total_preds_word_n_grams: Dict[int, Tensor], + total_target_char_n_grams: Dict[int, Tensor], + total_target_word_n_grams: Dict[int, Tensor], total_matching_char_n_grams: Dict[int, Tensor], total_matching_word_n_grams: Dict[int, Tensor], n_char_order: int, @@ -435,17 +435,17 @@ def _chrf_score_update( ]: """ Args: - hypothesis_corpus: + preds: An iterable of hypothesis corpus. - reference_corpus: + target: An iterable of iterables of reference corpus. - total_hyp_char_n_grams: + total_preds_char_n_grams: A dictionary containing a total number of hypothesis character n-grams. - total_hyp_word_n_grams: + total_preds_word_n_grams: A dictionary containing a total number of hypothesis word n-grams. - total_ref_char_n_grams: + total_target_char_n_grams: A dictionary containing a total number of reference character n-grams. - total_ref_word_n_grams: + total_target_word_n_grams: A dictionary containing a total number of reference word n-grams. total_matching_char_n_grams: A dictionary containing a total number of matching character n-grams between references and hypotheses. @@ -467,13 +467,13 @@ def _chrf_score_update( A list of sentence-level chrF/chrF++ scores. Return: - total_ref_char_n_grams: + total_target_char_n_grams: An updated dictionary containing a total number of reference character n-grams. - total_ref_word_n_grams: + total_target_word_n_grams: An updated dictionary containing a total number of reference word n-grams. - total_hyp_char_n_grams: + total_preds_char_n_grams: An updated dictionary containing a total number of hypothesis character n-grams. - total_hyp_word_n_grams: + total_preds_word_n_grams: An updated dictionary containing a total number of hypothesis word n-grams. total_matching_char_n_grams: An updated dictionary containing a total number of matching character n-grams between references and @@ -486,32 +486,32 @@ def _chrf_score_update( Raises: ValueError: - If length of `reference_corpus` and `hypothesis_corpus` differs. + If length of `preds` and `target` differs. """ - reference_corpus, hypothesis_corpus = _validate_inputs(reference_corpus, hypothesis_corpus) + target_corpus, preds = _validate_inputs(target, preds) - for (hypothesis, references) in zip(hypothesis_corpus, reference_corpus): + for (pred, targets) in zip(preds, target_corpus): ( - hyp_char_n_grams_counts, - hyp_word_n_grams_counts, - hyp_char_n_grams, - hyp_word_n_grams, - ) = _get_n_grams_counts_and_total_ngrams(hypothesis, n_char_order, n_word_order, lowercase, whitespace) - total_hyp_char_n_grams = _sum_over_dicts(total_hyp_char_n_grams, hyp_char_n_grams) - total_hyp_word_n_grams = _sum_over_dicts(total_hyp_word_n_grams, hyp_word_n_grams) + pred_char_n_grams_counts, + pred_word_n_grams_counts, + pred_char_n_grams, + pred_word_n_grams, + ) = _get_n_grams_counts_and_total_ngrams(pred, n_char_order, n_word_order, lowercase, whitespace) + total_preds_char_n_grams = _sum_over_dicts(total_preds_char_n_grams, pred_char_n_grams) + total_preds_word_n_grams = _sum_over_dicts(total_preds_word_n_grams, pred_word_n_grams) ( sentence_level_f_score, matching_char_n_grams, matching_word_n_grams, - ref_char_n_grams, - ref_word_n_grams, + target_char_n_grams, + target_word_n_grams, ) = _calculate_sentence_level_chrf_score( - references, # type: ignore - hyp_char_n_grams_counts, - hyp_word_n_grams_counts, - hyp_char_n_grams, - hyp_word_n_grams, + targets, # type: ignore + pred_char_n_grams_counts, + pred_word_n_grams_counts, + pred_char_n_grams, + pred_word_n_grams, n_char_order, n_word_order, n_order, @@ -523,16 +523,16 @@ def _chrf_score_update( if sentence_chrf_score is not None: sentence_chrf_score.append(sentence_level_f_score.unsqueeze(0)) - total_ref_char_n_grams = _sum_over_dicts(total_ref_char_n_grams, ref_char_n_grams) - total_ref_word_n_grams = _sum_over_dicts(total_ref_word_n_grams, ref_word_n_grams) + total_target_char_n_grams = _sum_over_dicts(total_target_char_n_grams, target_char_n_grams) + total_target_word_n_grams = _sum_over_dicts(total_target_word_n_grams, target_word_n_grams) total_matching_char_n_grams = _sum_over_dicts(total_matching_char_n_grams, matching_char_n_grams) total_matching_word_n_grams = _sum_over_dicts(total_matching_word_n_grams, matching_word_n_grams) return ( - total_hyp_char_n_grams, - total_hyp_word_n_grams, - total_ref_char_n_grams, - total_ref_word_n_grams, + total_preds_char_n_grams, + total_preds_word_n_grams, + total_target_char_n_grams, + total_target_word_n_grams, total_matching_char_n_grams, total_matching_word_n_grams, sentence_chrf_score, @@ -540,10 +540,10 @@ def _chrf_score_update( def _chrf_score_compute( - total_hyp_char_n_grams: Dict[int, Tensor], - total_hyp_word_n_grams: Dict[int, Tensor], - total_ref_char_n_grams: Dict[int, Tensor], - total_ref_word_n_grams: Dict[int, Tensor], + total_preds_char_n_grams: Dict[int, Tensor], + total_preds_word_n_grams: Dict[int, Tensor], + total_target_char_n_grams: Dict[int, Tensor], + total_target_word_n_grams: Dict[int, Tensor], total_matching_char_n_grams: Dict[int, Tensor], total_matching_word_n_grams: Dict[int, Tensor], n_order: float, @@ -552,13 +552,13 @@ def _chrf_score_compute( """Compute chrF/chrF++ score based on pre-computed target, prediction and matching character and word n-grams. Args: - total_hyp_char_n_grams: + total_preds_char_n_grams: A dictionary containing a total number of hypothesis character n-grams. - total_hyp_word_n_grams: + total_preds_word_n_grams: A dictionary containing a total number of hypothesis word n-grams. - total_ref_char_n_grams: + total_target_char_n_grams: A dictionary containing a total number of reference character n-grams. - total_ref_word_n_grams: + total_target_word_n_grams: A dictionary containing a total number of reference word n-grams. total_matching_char_n_grams: A dictionary containing a total number of matching character n-grams between references and hypotheses. @@ -575,10 +575,10 @@ def _chrf_score_compute( chrf_f_score = _calculate_fscore( total_matching_char_n_grams, total_matching_word_n_grams, - total_hyp_char_n_grams, - total_hyp_word_n_grams, - total_ref_char_n_grams, - total_ref_word_n_grams, + total_preds_char_n_grams, + total_preds_word_n_grams, + total_target_char_n_grams, + total_target_word_n_grams, n_order, beta, ) @@ -586,8 +586,8 @@ def _chrf_score_compute( def chrf_score( - hypothesis_corpus: Union[str, Sequence[str]], - reference_corpus: Union[Sequence[str], Sequence[Sequence[str]]], + preds: Union[str, Sequence[str]], + target: Union[Sequence[str], Sequence[Sequence[str]]], n_char_order: int = 6, n_word_order: int = 2, beta: float = 2.0, @@ -601,9 +601,9 @@ def chrf_score( https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/metrics/chrf.py. Args: - hypothesis_corpus: + preds: An iterable of hypothesis corpus. - reference_corpus: + target: An iterable of iterables of reference corpus. n_char_order: A character n-gram order. If `n_char_order=6`, the metrics refers to the official chrF/chrF++. @@ -633,9 +633,9 @@ def chrf_score( Example: >>> from torchmetrics.functional import chrf_score - >>> hypothesis_corpus = ['the cat is on the mat'] - >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] - >>> chrf_score(hypothesis_corpus, reference_corpus) + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> chrf_score(preds=preds, target=target) tensor(0.8640) References: @@ -652,10 +652,10 @@ def chrf_score( n_order = float(n_char_order + n_word_order) ( - total_hyp_char_n_grams, - total_hyp_word_n_grams, - total_ref_char_n_grams, - total_ref_word_n_grams, + total_preds_char_n_grams, + total_preds_word_n_grams, + total_target_char_n_grams, + total_target_word_n_grams, total_matching_char_n_grams, total_matching_word_n_grams, ) = _prepare_n_grams_dicts(n_char_order, n_word_order) @@ -663,20 +663,20 @@ def chrf_score( sentence_chrf_score: Optional[List[Tensor]] = [] if return_sentence_level_score else None ( - total_hyp_char_n_grams, - total_hyp_word_n_grams, - total_ref_char_n_grams, - total_ref_word_n_grams, + total_preds_char_n_grams, + total_preds_word_n_grams, + total_target_char_n_grams, + total_target_word_n_grams, total_matching_char_n_grams, total_matching_word_n_grams, sentence_chrf_score, ) = _chrf_score_update( - hypothesis_corpus, - reference_corpus, - total_hyp_char_n_grams, - total_hyp_word_n_grams, - total_ref_char_n_grams, - total_ref_word_n_grams, + preds, + target, + total_preds_char_n_grams, + total_preds_word_n_grams, + total_target_char_n_grams, + total_target_word_n_grams, total_matching_char_n_grams, total_matching_word_n_grams, n_char_order, @@ -689,10 +689,10 @@ def chrf_score( ) chrf_f_score = _chrf_score_compute( - total_hyp_char_n_grams, - total_hyp_word_n_grams, - total_ref_char_n_grams, - total_ref_word_n_grams, + total_preds_char_n_grams, + total_preds_word_n_grams, + total_target_char_n_grams, + total_target_word_n_grams, total_matching_char_n_grams, total_matching_word_n_grams, n_order, diff --git a/torchmetrics/functional/text/wil.py b/torchmetrics/functional/text/wil.py index 7dbd5982c9f..a431f737e84 100644 --- a/torchmetrics/functional/text/wil.py +++ b/torchmetrics/functional/text/wil.py @@ -21,68 +21,69 @@ def _wil_update( - predictions: Union[str, List[str]], - references: Union[str, List[str]], + preds: Union[str, List[str]], + target: Union[str, List[str]], ) -> Tuple[Tensor, Tensor, Tensor]: """Update the wil score with the current set of references and predictions. Args: - predictions: Transcription(s) to score as a string or list of strings - references: Reference(s) for each speech input as a string or list of strings + preds: Transcription(s) to score as a string or list of strings + target: Reference(s) for each speech input as a string or list of strings Returns: Number of edit operations to get from the reference to the prediction, summed over all samples Number of words overall references Number of words overall predictions """ - if isinstance(predictions, str): - predictions = [predictions] - if isinstance(references, str): - references = [references] - total = tensor(0, dtype=torch.float) - errors = tensor(0, dtype=torch.float) - reference_total = tensor(0, dtype=torch.float) - prediction_total = tensor(0, dtype=torch.float) - for prediction, reference in zip(predictions, references): - prediction_tokens = prediction.split() - reference_tokens = reference.split() - errors += _edit_distance(prediction_tokens, reference_tokens) - reference_total += len(reference_tokens) - prediction_total += len(prediction_tokens) - total += max(len(reference_tokens), len(prediction_tokens)) + if isinstance(preds, str): + preds = [preds] + if isinstance(target, str): + target = [target] + total = tensor(0.0) + errors = tensor(0.0) + target_total = tensor(0.0) + preds_total = tensor(0.0) + for pred, tar in zip(preds, target): + pred_tokens = pred.split() + target_tokens = tar.split() + errors += _edit_distance(pred_tokens, target_tokens) + target_total += len(target_tokens) + preds_total += len(pred_tokens) + total += max(len(target_tokens), len(pred_tokens)) - return errors - total, reference_total, prediction_total + return errors - total, target_total, preds_total -def _wil_compute(errors: Tensor, reference_total: Tensor, prediction_total: Tensor) -> Tensor: +def _wil_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> Tensor: """Compute the Word Information Lost. Args: errors: Number of edit operations to get from the reference to the prediction, summed over all samples - reference_total: Number of words overall references - prediction_total: Number of words overall prediction + target_total: Number of words overall references + preds_total: Number of words overall prediction Returns: Word Information Lost score """ - return 1 - ((errors / reference_total) * (errors / prediction_total)) + return 1 - ((errors / target_total) * (errors / preds_total)) def word_information_lost( - predictions: Union[str, List[str]], - references: Union[str, List[str]], + preds: Union[str, List[str]], + target: Union[str, List[str]], ) -> Tensor: """Word Information Lost rate is a metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a Word Information Lost rate of 0 being a perfect score. Args: - predictions: Transcription(s) to score as a string or list of strings - references: Reference(s) for each speech input as a string or list of strings + preds: Transcription(s) to score as a string or list of strings + target: Reference(s) for each speech input as a string or list of strings Returns: Word Information Lost rate Examples: - >>> predictions = ["this is the prediction", "there is an other sample"] - >>> references = ["this is the reference", "there is another one"] - >>> word_information_lost(predictions=predictions, references=references) + >>> from torchmetrics.functional import word_information_lost + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> word_information_lost(preds=preds, target=target) tensor(0.6528) """ - errors, reference_total, prediction_total = _wil_update(predictions, references) - return _wil_compute(errors, reference_total, prediction_total) + errors, target_total, preds_total = _wil_update(preds, target) + return _wil_compute(errors, target_total, preds_total) diff --git a/torchmetrics/functional/text/wip.py b/torchmetrics/functional/text/wip.py index 1cf1ca00c22..74f213581ad 100644 --- a/torchmetrics/functional/text/wip.py +++ b/torchmetrics/functional/text/wip.py @@ -21,68 +21,69 @@ def _wip_update( - predictions: Union[str, List[str]], - references: Union[str, List[str]], + preds: Union[str, List[str]], + target: Union[str, List[str]], ) -> Tuple[Tensor, Tensor, Tensor]: """Update the wip score with the current set of references and predictions. Args: - predictions: Transcription(s) to score as a string or list of strings - references: Reference(s) for each speech input as a string or list of strings + preds: Transcription(s) to score as a string or list of strings + target: Reference(s) for each speech input as a string or list of strings Returns: Number of edit operations to get from the reference to the prediction, summed over all samples Number of words overall references Number of words overall prediction """ - if isinstance(predictions, str): - predictions = [predictions] - if isinstance(references, str): - references = [references] - total = tensor(0, dtype=torch.float) - errors = tensor(0, dtype=torch.float) - reference_total = tensor(0, dtype=torch.float) - prediction_total = tensor(0, dtype=torch.float) - for prediction, reference in zip(predictions, references): - prediction_tokens = prediction.split() - reference_tokens = reference.split() - errors += _edit_distance(prediction_tokens, reference_tokens) - reference_total += len(reference_tokens) - prediction_total += len(prediction_tokens) - total += max(len(reference_tokens), len(prediction_tokens)) + if isinstance(preds, str): + preds = [preds] + if isinstance(target, str): + target = [target] + total = tensor(0.0) + errors = tensor(0.0) + target_total = tensor(0.0) + preds_total = tensor(0.0) + for pred, tar in zip(preds, target): + pred_tokens = pred.split() + target_tokens = tar.split() + errors += _edit_distance(pred_tokens, target_tokens) + target_total += len(target_tokens) + preds_total += len(pred_tokens) + total += max(len(target_tokens), len(pred_tokens)) - return errors - total, reference_total, prediction_total + return errors - total, target_total, preds_total -def _wip_compute(errors: Tensor, reference_total: Tensor, prediction_total: Tensor) -> Tensor: +def _wip_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> Tensor: """Compute the Word Information Perserved. Args: errors: Number of edit operations to get from the reference to the prediction, summed over all samples - reference_total: Number of words overall references - prediction_total: Number of words overall prediction + target_total: Number of words overall references + preds_total: Number of words overall prediction Returns: Word Information Perserved score """ - return (errors / reference_total) * (errors / prediction_total) + return (errors / target_total) * (errors / preds_total) def word_information_preserved( - predictions: Union[str, List[str]], - references: Union[str, List[str]], + preds: Union[str, List[str]], + target: Union[str, List[str]], ) -> Tensor: """Word Information Preserved rate is a metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a Word Information preserved rate of 0 being a perfect score. Args: - predictions: Transcription(s) to score as a string or list of strings - references: Reference(s) for each speech input as a string or list of strings + preds: Transcription(s) to score as a string or list of strings + total: Reference(s) for each speech input as a string or list of strings Returns: Word Information preserved rate Examples: - >>> predictions = ["this is the prediction", "there is an other sample"] - >>> references = ["this is the reference", "there is another one"] - >>> word_information_preserved(predictions=predictions, references=references) + >>> from torchmetrics.functional import word_information_preserved + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] + >>> word_information_preserved(preds=preds, target=target) tensor(0.3472) """ - errors, reference_total, prediction_total = _wip_update(predictions, references) + errors, reference_total, prediction_total = _wip_update(preds, target) return _wip_compute(errors, reference_total, prediction_total) diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index 6ebe6d13779..e687f4490c5 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -46,9 +46,9 @@ class BERTScore(Metric): This implemenation follows the original implementation from `BERT_score`_. Args: - predictions: + preds: An iterable of predicted sentences. - references: + target: An iterable of target sentences. model_type: A name or a model path used to load `transformers` pretrained model. @@ -111,11 +111,11 @@ class BERTScore(Metric): Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values. Example: - >>> from torchmetrics.text.bert import BERTScore - >>> predictions = ["hello there", "general kenobi"] - >>> references = ["hello there", "master kenobi"] + >>> from torchmetrics import BERTScore + >>> preds = ["hello there", "general kenobi"] + >>> target = ["hello there", "master kenobi"] >>> bertscore = BERTScore() - >>> bertscore.update(predictions=predictions,references=references) + >>> bertscore.update(preds=preds,target=target) >>> bertscore.compute() # doctest: +SKIP {'precision': [0.99..., 0.99...], 'recall': [0.99..., 0.99...], @@ -170,8 +170,8 @@ def __init__( self.rescale_with_baseline = rescale_with_baseline self.baseline_path = baseline_path self.baseline_url = baseline_url - self.predictions: Dict[str, List[torch.Tensor]] = {"input_ids": [], "attention_mask": []} - self.references: Dict[str, List[torch.Tensor]] = {"input_ids": [], "attention_mask": []} + self.preds: Dict[str, List[torch.Tensor]] = {"input_ids": [], "attention_mask": []} + self.target: Dict[str, List[torch.Tensor]] = {"input_ids": [], "attention_mask": []} if user_tokenizer: self.tokenizer = user_tokenizer @@ -192,7 +192,7 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.user_tokenizer = False - def update(self, predictions: List[str], references: List[str]) -> None: # type: ignore + def update(self, preds: List[str], target: List[str]) -> None: # type: ignore """Store predictions/references for computing BERT scores. It is necessary to store sentences in a tokenized form to ensure the DDP mode working. @@ -202,26 +202,26 @@ def update(self, predictions: List[str], references: List[str]) -> None: # type references: An iterable of predicted sentences. """ - predictions_dict = _preprocess_text( - predictions, + preds_dict = _preprocess_text( + preds, self.tokenizer, self.max_length, truncation=False, sort_according_length=False, own_tokenizer=self.user_tokenizer, ) - references_dict = _preprocess_text( - references, + target_dict = _preprocess_text( + target, self.tokenizer, self.max_length, truncation=False, sort_according_length=False, own_tokenizer=self.user_tokenizer, ) - self.predictions["input_ids"].append(predictions_dict["input_ids"]) - self.predictions["attention_mask"].append(predictions_dict["attention_mask"]) - self.references["input_ids"].append(references_dict["input_ids"]) - self.references["attention_mask"].append(references_dict["attention_mask"]) + self.preds["input_ids"].append(preds_dict["input_ids"]) + self.preds["attention_mask"].append(preds_dict["attention_mask"]) + self.target["input_ids"].append(target_dict["input_ids"]) + self.target["attention_mask"].append(target_dict["attention_mask"]) def compute(self) -> Dict[str, Union[List[float], str]]: """Calculate BERT scores. @@ -230,8 +230,8 @@ def compute(self) -> Dict[str, Union[List[float], str]]: Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values. """ return bert_score( - predictions=_concatenate(self.predictions), - references=_concatenate(self.references), + preds=_concatenate(self.preds), + test_average_precisiont=_concatenate(self.target), model_name_or_path=self.model_name_or_path, num_layers=self.num_layers, all_layers=self.all_layers, diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index 10c1f5ab5a1..090238b5873 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -46,10 +46,11 @@ class BLEUScore(Metric): will be used to perform the allgather. Example: - >>> translate_corpus = ['the cat is on the mat'] - >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] + >>> from torchmetrics import BLEUScore + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> metric = BLEUScore() - >>> metric(translate_corpus, reference_corpus) + >>> metric(preds=preds, target=target) tensor(0.7598) References: @@ -62,8 +63,8 @@ class BLEUScore(Metric): is_differentiable = False higher_is_better = True - trans_len: Tensor - ref_len: Tensor + preds_len: Tensor + target_len: Tensor numerator: Tensor denominator: Tensor @@ -89,28 +90,26 @@ def __init__( self.n_gram = n_gram self.smooth = smooth - self.add_state("trans_len", tensor(0, dtype=torch.float), dist_reduce_fx="sum") - self.add_state("ref_len", tensor(0, dtype=torch.float), dist_reduce_fx="sum") + self.add_state("preds_len", tensor(0.0), dist_reduce_fx="sum") + self.add_state("target_len", tensor(0.0), dist_reduce_fx="sum") self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum") self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum") - def update( # type: ignore - self, translate_corpus: Sequence[str], reference_corpus: Sequence[Sequence[str]] - ) -> None: + def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore """Compute Precision Scores. Args: - translate_corpus: An iterable of machine translated corpus - reference_corpus: An iterable of iterables of reference corpus + preds: An iterable of machine translated corpus + target: An iterable of iterables of reference corpus """ - self.trans_len, self.ref_len = _bleu_score_update( - translate_corpus, - reference_corpus, + self.preds_len, self.target_len = _bleu_score_update( + preds, + target, self.numerator, self.denominator, - self.trans_len, - self.ref_len, + self.preds_len, + self.target_len, self.n_gram, _tokenize_fn, ) @@ -122,5 +121,5 @@ def compute(self) -> Tensor: Tensor with BLEU Score """ return _bleu_score_compute( - self.trans_len, self.ref_len, self.numerator, self.denominator, self.n_gram, self.smooth + self.preds_len, self.target_len, self.numerator, self.denominator, self.n_gram, self.smooth ) diff --git a/torchmetrics/text/chrf.py b/torchmetrics/text/chrf.py index 1059f13cf37..0b5c0f26cb8 100644 --- a/torchmetrics/text/chrf.py +++ b/torchmetrics/text/chrf.py @@ -30,10 +30,10 @@ _TEXT_LEVELS = ("ref", "hyp", "matching") _DICT_STATES_NAMES = ( - "total_hyp_char_n_grams", - "total_hyp_word_n_grams", - "total_ref_char_n_grams", - "total_ref_word_n_grams", + "total_preds_char_n_grams", + "total_preds_word_n_grams", + "total_target_char_n_grams", + "total_target_word_n_grams", "total_matching_char_n_grams", "total_matching_word_n_grams", ) @@ -83,10 +83,11 @@ class CHRFScore(Metric): If ``beta`` is smaller than 0. Example: - >>> hypothesis_corpus = ['the cat is on the mat'] - >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] + >>> from torchmetrics import CHRFScore + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> metric = CHRFScore() - >>> metric(hypothesis_corpus, reference_corpus) + >>> metric(preds=preds, target=target) tensor(0.8640) References: @@ -142,20 +143,18 @@ def __init__( if self.return_sentence_level_score: self.add_state("sentence_chrf_score", [], dist_reduce_fx="cat") - def update( # type: ignore - self, hypothesis_corpus: Sequence[str], reference_corpus: Sequence[Sequence[str]] - ) -> None: + def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore """Compute Precision Scores. Args: - hypothesis_corpus: + preds: An iterable of hypothesis corpus. - reference_corpus: + target: An iterable of iterables of reference corpus. """ n_grams_dicts_tuple = _chrf_score_update( - hypothesis_corpus, - reference_corpus, + preds, + target, *self._convert_states_to_dicts(), self.n_char_order, self.n_word_order, diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index 97ae380a4a1..57eb023551e 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -66,10 +66,11 @@ class SacreBLEUScore(BLEUScore): Example: - >>> translate_corpus = ['the cat is on the mat'] - >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] + >>> from torchmetrics import SacreBLEUScore + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> metric = SacreBLEUScore() - >>> metric(translate_corpus, reference_corpus) + >>> metric(preds=preds, target=target) tensor(0.7598) References: @@ -115,22 +116,20 @@ def __init__( ) self.tokenizer = _SacreBLEUTokenizer(tokenize, lowercase) - def update( # type: ignore - self, translate_corpus: Sequence[str], reference_corpus: Sequence[Sequence[str]] - ) -> None: + def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore """Compute Precision Scores. Args: - translate_corpus: An iterable of machine translated corpus - reference_corpus: An iterable of iterables of reference corpus + preds: An iterable of machine translated corpus + target: An iterable of iterables of reference corpus """ - self.trans_len, self.ref_len = _bleu_score_update( - translate_corpus, - reference_corpus, + self.preds_len, self.target_len = _bleu_score_update( + preds, + target, self.numerator, self.denominator, - self.trans_len, - self.ref_len, + self.preds_len, + self.target_len, self.n_gram, self.tokenizer, ) diff --git a/torchmetrics/text/wil.py b/torchmetrics/text/wil.py index 9eb0031bf74..ca0b25730df 100644 --- a/torchmetrics/text/wil.py +++ b/torchmetrics/text/wil.py @@ -53,17 +53,18 @@ class WordInfoLost(Metric): Examples: - >>> predictions = ["this is the prediction", "there is an other sample"] - >>> references = ["this is the reference", "there is another one"] + >>> from torchmetrics import WordInfoLost + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] >>> metric = WordInfoLost() - >>> metric(predictions, references) + >>> metric(preds=preds, target=target) tensor(0.6528) """ is_differentiable = False higher_is_better = False errors: Tensor - reference_total: Tensor - prediction_total: Tensor + target_total: Tensor + preds_total: Tensor def __init__( self, @@ -82,17 +83,17 @@ def __init__( self.add_state("reference_total", tensor(0.0), dist_reduce_fx="sum") self.add_state("prediction_total", tensor(0.0), dist_reduce_fx="sum") - def update(self, predictions: Union[str, List[str]], references: Union[str, List[str]]) -> None: # type: ignore + def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store references/predictions for computing Word Information Lost scores. Args: - predictions: Transcription(s) to score as a string or list of strings - references: Reference(s) for each speech input as a string or list of strings + preds: Transcription(s) to score as a string or list of strings + target: Reference(s) for each speech input as a string or list of strings """ - errors, reference_total, prediction_total = _wil_update(predictions, references) + errors, target_total, preds_total = _wil_update(preds, target) self.errors += errors - self.reference_total += reference_total - self.prediction_total += prediction_total + self.target_total += target_total + self.preds_total += preds_total def compute(self) -> Tensor: """Calculate the Word Information Lost. @@ -100,4 +101,4 @@ def compute(self) -> Tensor: Returns: Word Information Lost score """ - return _wil_compute(self.errors, self.reference_total, self.prediction_total) + return _wil_compute(self.errors, self.target_total, self.preds_total) diff --git a/torchmetrics/text/wip.py b/torchmetrics/text/wip.py index 0ba45a506a0..fa91c898c60 100644 --- a/torchmetrics/text/wip.py +++ b/torchmetrics/text/wip.py @@ -53,17 +53,18 @@ class WordInfoPreserved(Metric): Examples: - >>> predictions = ["this is the prediction", "there is an other sample"] - >>> references = ["this is the reference", "there is another one"] + >>> from torchmetrics import WordInfoPreserved + >>> preds = ["this is the prediction", "there is an other sample"] + >>> target = ["this is the reference", "there is another one"] >>> metric = WordInfoPreserved() - >>> metric(predictions, references) + >>> metric(preds=preds, target=target) tensor(0.3472) """ is_differentiable = False higher_is_better = False errors: Tensor - reference_total: Tensor - prediction_total: Tensor + target_total: Tensor + preds_total: Tensor def __init__( self, @@ -82,17 +83,17 @@ def __init__( self.add_state("reference_total", tensor(0.0), dist_reduce_fx="sum") self.add_state("prediction_total", tensor(0.0), dist_reduce_fx="sum") - def update(self, predictions: Union[str, List[str]], references: Union[str, List[str]]) -> None: # type: ignore + def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store references/predictions for computing word Information Preserved scores. Args: - predictions: Transcription(s) to score as a string or list of strings - references: Reference(s) for each speech input as a string or list of strings + preds: Transcription(s) to score as a string or list of strings + target: Reference(s) for each speech input as a string or list of strings """ - errors, reference_total, prediction_total = _wip_update(predictions, references) + errors, target_total, preds_total = _wip_update(preds, target) self.errors += errors - self.reference_total += reference_total - self.prediction_total += prediction_total + self.target_total += target_total + self.preds_total += preds_total def compute(self) -> Tensor: """Calculate the word Information Preserved. @@ -100,4 +101,4 @@ def compute(self) -> Tensor: Returns: word Information Preserved score """ - return _wip_compute(self.errors, self.reference_total, self.prediction_total) + return _wip_compute(self.errors, self.target_total, self.preds_total) From 6e690be1b98a41da8e4c1a9aca300fc300e63b11 Mon Sep 17 00:00:00 2001 From: Scott Cronin Date: Wed, 5 Jan 2022 17:49:34 -0500 Subject: [PATCH 02/22] Removed input_order from text unit tests (#717) --- tests/text/helpers.py | 58 +++++------------------------------ tests/text/test_bleu.py | 5 +-- tests/text/test_cer.py | 5 +-- tests/text/test_chrf.py | 5 +-- tests/text/test_mer.py | 5 +-- tests/text/test_rouge.py | 4 +-- tests/text/test_sacre_bleu.py | 5 +-- tests/text/test_ter.py | 5 +-- tests/text/test_wer.py | 5 +-- tests/text/test_wil.py | 5 +-- tests/text/test_wip.py | 5 +-- 11 files changed, 17 insertions(+), 90 deletions(-) diff --git a/tests/text/helpers.py b/tests/text/helpers.py index 9d328e81362..2a3620cdf92 100644 --- a/tests/text/helpers.py +++ b/tests/text/helpers.py @@ -13,7 +13,6 @@ # limitations under the License. import pickle import sys -from enum import Enum, unique from functools import partial from typing import Any, Callable, Sequence, Union @@ -31,12 +30,6 @@ pass -@unique -class INPUT_ORDER(Enum): - PREDS_FIRST = 1 - TARGETS_FIRST = 2 - - TEXT_METRIC_INPUT = Union[Sequence[str], Sequence[Sequence[str]], Sequence[Sequence[Sequence[str]]]] NUM_BATCHES = 2 @@ -56,7 +49,6 @@ def _class_test( device: str = "cpu", fragment_kwargs: bool = False, check_scriptable: bool = True, - input_order: INPUT_ORDER = INPUT_ORDER.PREDS_FIRST, key: str = None, **kwargs_update: Any, ): @@ -78,7 +70,6 @@ def _class_test( calculated across devices for each batch (and not just at the end) device: determine which device to run on, either 'cuda' or 'cpu' fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `targets` among processes - input_order: Define the ordering for the preds and targets positional arguments. key: The key passed onto the `_assert_allclose` to compare the respective metric from the Dict output against the sk_metric. kwargs_update: Additional keyword arguments that will be passed with preds and @@ -106,11 +97,7 @@ def _class_test( for i in range(rank, NUM_BATCHES, worldsize): batch_kwargs_update = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} - - if input_order == INPUT_ORDER.PREDS_FIRST: - batch_result = metric(preds[i], targets[i], **batch_kwargs_update) - elif input_order == INPUT_ORDER.TARGETS_FIRST: - batch_result = metric(targets[i], preds[i], **batch_kwargs_update) + batch_result = metric(preds[i], targets[i], **batch_kwargs_update) if metric.dist_sync_on_step and check_dist_sync_on_step and rank == 0: # Concatenation of Sequence of strings @@ -124,10 +111,7 @@ def _class_test( for k, v in (kwargs_update if fragment_kwargs else batch_kwargs_update).items() } - if input_order == INPUT_ORDER.PREDS_FIRST: - sk_batch_result = sk_metric(ddp_preds, ddp_targets, **ddp_kwargs_upd) - elif input_order == INPUT_ORDER.TARGETS_FIRST: - sk_batch_result = sk_metric(ddp_targets, ddp_preds, **ddp_kwargs_upd) + sk_batch_result = sk_metric(ddp_preds, ddp_targets, **ddp_kwargs_upd) _assert_allclose(batch_result, sk_batch_result, atol=atol, key=key) elif check_batch and not metric.dist_sync_on_step: @@ -135,11 +119,7 @@ def _class_test( k: v.cpu() if isinstance(v, Tensor) else v for k, v in (batch_kwargs_update if fragment_kwargs else kwargs_update).items() } - if input_order == INPUT_ORDER.PREDS_FIRST: - sk_batch_result = sk_metric(preds[i], targets[i], **batch_kwargs_update) - elif input_order == INPUT_ORDER.TARGETS_FIRST: - sk_batch_result = sk_metric(targets[i], preds[i], **batch_kwargs_update) - + sk_batch_result = sk_metric(preds[i], targets[i], **batch_kwargs_update) _assert_allclose(batch_result, sk_batch_result, atol=atol, key=key) # check that metrics are hashable @@ -159,11 +139,7 @@ def _class_test( k: torch.cat([v[i] for i in range(NUM_BATCHES)]).cpu() if isinstance(v, Tensor) else v for k, v in kwargs_update.items() } - if input_order == INPUT_ORDER.PREDS_FIRST: - sk_result = sk_metric(total_preds, total_targets, **total_kwargs_update) - elif input_order == INPUT_ORDER.TARGETS_FIRST: - sk_result = sk_metric(total_targets, total_preds, **total_kwargs_update) - + sk_result = sk_metric(total_preds, total_targets, **total_kwargs_update) # assert after aggregation _assert_allclose(result, sk_result, atol=atol, key=key) @@ -177,7 +153,6 @@ def _functional_test( atol: float = 1e-8, device: str = "cpu", fragment_kwargs: bool = False, - input_order: INPUT_ORDER = INPUT_ORDER.PREDS_FIRST, key: str = None, **kwargs_update, ): @@ -191,7 +166,6 @@ def _functional_test( metric_args: dict with additional arguments used for class initialization device: determine which device to run on, either 'cuda' or 'cpu' fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `targets` among processes - input_order: Define the ordering for the preds and targets positional arguments. key: The key passed onto the `_assert_allclose` to compare the respective metric from the Dict output against the sk_metric. kwargs_update: Additional keyword arguments that will be passed with preds and @@ -207,19 +181,13 @@ def _functional_test( for i in range(NUM_BATCHES): extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} - if input_order == INPUT_ORDER.PREDS_FIRST: - lightning_result = metric(preds[i], targets[i], **extra_kwargs) - elif input_order == INPUT_ORDER.TARGETS_FIRST: - lightning_result = metric(targets[i], preds[i], **extra_kwargs) + lightning_result = metric(preds[i], targets[i], **extra_kwargs) extra_kwargs = { k: v.cpu() if isinstance(v, Tensor) else v for k, v in (extra_kwargs if fragment_kwargs else kwargs_update).items() } - if input_order == INPUT_ORDER.PREDS_FIRST: - sk_result = sk_metric(preds[i], targets[i], **extra_kwargs) - elif input_order == INPUT_ORDER.TARGETS_FIRST: - sk_result = sk_metric(targets[i], preds[i], **extra_kwargs) + sk_result = sk_metric(preds[i], targets[i], **extra_kwargs) # assert its the same _assert_allclose(lightning_result, sk_result, atol=atol, key=key) @@ -271,7 +239,6 @@ def run_functional_metric_test( sk_metric: Callable, metric_args: dict = None, fragment_kwargs: bool = False, - input_order: INPUT_ORDER = INPUT_ORDER.PREDS_FIRST, key: str = None, **kwargs_update, ): @@ -284,7 +251,6 @@ def run_functional_metric_test( sk_metric: callable function that is used for comparison metric_args: dict with additional arguments used for class initialization fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `targets` among processes - input_order: Define the ordering for the preds and targets positional arguments. key: The key passed onto the `_assert_allclose` to compare the respective metric from the Dict output against the sk_metric. kwargs_update: Additional keyword arguments that will be passed with preds and @@ -301,7 +267,6 @@ def run_functional_metric_test( atol=self.atol, device=device, fragment_kwargs=fragment_kwargs, - input_order=input_order, key=key, **kwargs_update, ) @@ -319,7 +284,6 @@ def run_class_metric_test( check_batch: bool = True, fragment_kwargs: bool = False, check_scriptable: bool = True, - input_order: INPUT_ORDER = INPUT_ORDER.PREDS_FIRST, key: str = None, **kwargs_update, ): @@ -339,7 +303,6 @@ def run_class_metric_test( check_batch: bool, if true will check if the metric is also correctly calculated across devices for each batch (and not just at the end) fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `targets` among processes - input_order: Define the ordering for the preds and targets positional arguments. key: The key passed onto the `_assert_allclose` to compare the respective metric from the Dict output against the sk_metric. kwargs_update: Additional keyword arguments that will be passed with preds and @@ -365,7 +328,6 @@ def run_class_metric_test( atol=self.atol, fragment_kwargs=fragment_kwargs, check_scriptable=check_scriptable, - input_order=input_order, key=key, **kwargs_update, ), @@ -389,7 +351,6 @@ def run_class_metric_test( device=device, fragment_kwargs=fragment_kwargs, check_scriptable=check_scriptable, - input_order=input_order, key=key, **kwargs_update, ) @@ -449,7 +410,6 @@ def run_differentiability_test( metric_module: Metric, metric_functional: Callable, metric_args: dict = None, - input_order: INPUT_ORDER = INPUT_ORDER.PREDS_FIRST, key: str = None, ): """Test if a metric is differentiable or not. @@ -459,17 +419,13 @@ def run_differentiability_test( targets: torch tensor with targets metric_module: the metric module to test metric_args: dict with additional arguments used for class initialization - input_order: Define the ordering for the preds and targets positional arguments. key: The key passed onto the `_assert_allclose` to compare the respective metric from the Dict output against the sk_metric. """ metric_args = metric_args or {} # only floating point tensors can require grad metric = metric_module(**metric_args) - if input_order == INPUT_ORDER.PREDS_FIRST: - out = metric(preds[0], targets[0]) - elif input_order == INPUT_ORDER.TARGETS_FIRST: - out = metric(targets[0], preds[0]) + out = metric(preds[0], targets[0]) # Check if requires_grad matches is_differentiable attribute _assert_requires_grad(metric, out, key=key) diff --git a/tests/text/test_bleu.py b/tests/text/test_bleu.py index 48bbb12633d..1de2a43b12e 100644 --- a/tests/text/test_bleu.py +++ b/tests/text/test_bleu.py @@ -18,7 +18,7 @@ from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu from torch import tensor -from tests.text.helpers import INPUT_ORDER, TextTester +from tests.text.helpers import TextTester from tests.text.inputs import _inputs_multiple_references from torchmetrics.functional.text.bleu import bleu_score from torchmetrics.text.bleu import BLEUScore @@ -67,7 +67,6 @@ def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, weights, sk_metric=compute_bleu_metric_nltk, dist_sync_on_step=dist_sync_on_step, metric_args=metric_args, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_bleu_score_functional(self, preds, targets, weights, n_gram, smooth_func, smooth): @@ -80,7 +79,6 @@ def test_bleu_score_functional(self, preds, targets, weights, n_gram, smooth_fun metric_functional=bleu_score, sk_metric=compute_bleu_metric_nltk, metric_args=metric_args, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_bleu_score_differentiability(self, preds, targets, weights, n_gram, smooth_func, smooth): @@ -92,7 +90,6 @@ def test_bleu_score_differentiability(self, preds, targets, weights, n_gram, smo metric_module=BLEUScore, metric_functional=bleu_score, metric_args=metric_args, - input_order=INPUT_ORDER.PREDS_FIRST, ) diff --git a/tests/text/test_cer.py b/tests/text/test_cer.py index 729ec27d9ce..0bdcd6a046f 100644 --- a/tests/text/test_cer.py +++ b/tests/text/test_cer.py @@ -2,7 +2,7 @@ import pytest -from tests.text.helpers import INPUT_ORDER, TextTester +from tests.text.helpers import TextTester from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from torchmetrics.functional.text.cer import char_error_rate from torchmetrics.text.cer import CharErrorRate @@ -41,7 +41,6 @@ def test_cer_class(self, ddp, dist_sync_on_step, preds, targets): metric_class=CharErrorRate, sk_metric=compare_fn, dist_sync_on_step=dist_sync_on_step, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_cer_functional(self, preds, targets): @@ -51,7 +50,6 @@ def test_cer_functional(self, preds, targets): targets, metric_functional=char_error_rate, sk_metric=compare_fn, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_cer_differentiability(self, preds, targets): @@ -61,5 +59,4 @@ def test_cer_differentiability(self, preds, targets): targets=targets, metric_module=CharErrorRate, metric_functional=char_error_rate, - input_order=INPUT_ORDER.PREDS_FIRST, ) diff --git a/tests/text/test_chrf.py b/tests/text/test_chrf.py index 4863d850dfc..54b689ee1a7 100644 --- a/tests/text/test_chrf.py +++ b/tests/text/test_chrf.py @@ -4,7 +4,7 @@ import pytest from torch import Tensor, tensor -from tests.text.helpers import INPUT_ORDER, TextTester +from tests.text.helpers import TextTester from tests.text.inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references from torchmetrics.functional.text.chrf import chrf_score from torchmetrics.text.chrf import CHRFScore @@ -71,7 +71,6 @@ def test_chrf_score_class( sk_metric=nltk_metric, dist_sync_on_step=dist_sync_on_step, metric_args=metric_args, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_chrf_score_functional(self, preds, targets, char_order, word_order, lowercase, whitespace): @@ -91,7 +90,6 @@ def test_chrf_score_functional(self, preds, targets, char_order, word_order, low metric_functional=chrf_score, sk_metric=nltk_metric, metric_args=metric_args, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_chrf_score_differentiability(self, preds, targets, char_order, word_order, lowercase, whitespace): @@ -108,7 +106,6 @@ def test_chrf_score_differentiability(self, preds, targets, char_order, word_ord metric_module=CHRFScore, metric_functional=chrf_score, metric_args=metric_args, - input_order=INPUT_ORDER.PREDS_FIRST, ) diff --git a/tests/text/test_mer.py b/tests/text/test_mer.py index 3a9f126648a..190936d0a7e 100644 --- a/tests/text/test_mer.py +++ b/tests/text/test_mer.py @@ -2,7 +2,7 @@ import pytest -from tests.text.helpers import INPUT_ORDER, TextTester +from tests.text.helpers import TextTester from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from torchmetrics.utilities.imports import _JIWER_AVAILABLE @@ -39,7 +39,6 @@ def test_mer_class(self, ddp, dist_sync_on_step, preds, targets): metric_class=MatchErrorRate, sk_metric=_compute_mer_metric_jiwer, dist_sync_on_step=dist_sync_on_step, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_mer_functional(self, preds, targets): @@ -49,7 +48,6 @@ def test_mer_functional(self, preds, targets): targets, metric_functional=match_error_rate, sk_metric=_compute_mer_metric_jiwer, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_mer_differentiability(self, preds, targets): @@ -59,5 +57,4 @@ def test_mer_differentiability(self, preds, targets): targets=targets, metric_module=MatchErrorRate, metric_functional=match_error_rate, - input_order=INPUT_ORDER.PREDS_FIRST, ) diff --git a/tests/text/test_rouge.py b/tests/text/test_rouge.py index 67b84cbc3bd..f7ea1e2199f 100644 --- a/tests/text/test_rouge.py +++ b/tests/text/test_rouge.py @@ -18,7 +18,7 @@ import pytest import torch -from tests.text.helpers import INPUT_ORDER, TextTester +from tests.text.helpers import TextTester from tests.text.inputs import _inputs_multiple_references, _inputs_single_sentence_single_reference from torchmetrics.functional.text.rouge import rouge_score from torchmetrics.text.rouge import ROUGEScore @@ -120,7 +120,6 @@ def test_rouge_score_class( sk_metric=rouge_metric, dist_sync_on_step=dist_sync_on_step, metric_args=metric_args, - input_order=INPUT_ORDER.PREDS_FIRST, key=pl_rouge_metric_key, ) @@ -137,7 +136,6 @@ def test_rouge_score_functional(self, preds, targets, pl_rouge_metric_key, use_s metric_functional=rouge_score, sk_metric=rouge_metric, metric_args=metric_args, - input_order=INPUT_ORDER.PREDS_FIRST, key=pl_rouge_metric_key, ) diff --git a/tests/text/test_sacre_bleu.py b/tests/text/test_sacre_bleu.py index 6cbe0aa8328..f49e5798c6c 100644 --- a/tests/text/test_sacre_bleu.py +++ b/tests/text/test_sacre_bleu.py @@ -18,7 +18,7 @@ import pytest from torch import Tensor, tensor -from tests.text.helpers import INPUT_ORDER, TextTester +from tests.text.helpers import TextTester from tests.text.inputs import _inputs_multiple_references from torchmetrics.functional.text.sacre_bleu import sacre_bleu_score from torchmetrics.text.sacre_bleu import SacreBLEUScore @@ -61,7 +61,6 @@ def test_bleu_score_class(self, ddp, dist_sync_on_step, preds, targets, tokenize sk_metric=original_sacrebleu, dist_sync_on_step=dist_sync_on_step, metric_args=metric_args, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_bleu_score_functional(self, preds, targets, tokenize, lowercase): @@ -74,7 +73,6 @@ def test_bleu_score_functional(self, preds, targets, tokenize, lowercase): metric_functional=sacre_bleu_score, sk_metric=original_sacrebleu, metric_args=metric_args, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_bleu_score_differentiability(self, preds, targets, tokenize, lowercase): @@ -86,5 +84,4 @@ def test_bleu_score_differentiability(self, preds, targets, tokenize, lowercase) metric_module=SacreBLEUScore, metric_functional=sacre_bleu_score, metric_args=metric_args, - input_order=INPUT_ORDER.PREDS_FIRST, ) diff --git a/tests/text/test_ter.py b/tests/text/test_ter.py index 50c38049031..ad75d0f1797 100644 --- a/tests/text/test_ter.py +++ b/tests/text/test_ter.py @@ -4,7 +4,7 @@ import pytest from torch import Tensor, tensor -from tests.text.helpers import INPUT_ORDER, TextTester +from tests.text.helpers import TextTester from tests.text.inputs import _inputs_multiple_references, _inputs_single_sentence_multiple_references from torchmetrics.functional.text.ter import ter from torchmetrics.text.ter import TER @@ -75,7 +75,6 @@ def test_chrf_score_class( sk_metric=nltk_metric, dist_sync_on_step=dist_sync_on_step, metric_args=metric_args, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_ter_score_functional(self, preds, targets, normalize, no_punctuation, asian_support, lowercase): @@ -99,7 +98,6 @@ def test_ter_score_functional(self, preds, targets, normalize, no_punctuation, a metric_functional=ter, sk_metric=nltk_metric, metric_args=metric_args, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_chrf_score_differentiability(self, preds, targets, normalize, no_punctuation, asian_support, lowercase): @@ -116,7 +114,6 @@ def test_chrf_score_differentiability(self, preds, targets, normalize, no_punctu metric_module=TER, metric_functional=ter, metric_args=metric_args, - input_order=INPUT_ORDER.PREDS_FIRST, ) diff --git a/tests/text/test_wer.py b/tests/text/test_wer.py index f9791594317..23fea84224a 100644 --- a/tests/text/test_wer.py +++ b/tests/text/test_wer.py @@ -2,7 +2,7 @@ import pytest -from tests.text.helpers import INPUT_ORDER, TextTester +from tests.text.helpers import TextTester from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from torchmetrics.utilities.imports import _JIWER_AVAILABLE @@ -39,7 +39,6 @@ def test_wer_class(self, ddp, dist_sync_on_step, preds, targets): metric_class=WER, sk_metric=_compute_wer_metric_jiwer, dist_sync_on_step=dist_sync_on_step, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_wer_functional(self, preds, targets): @@ -49,7 +48,6 @@ def test_wer_functional(self, preds, targets): targets, metric_functional=wer, sk_metric=_compute_wer_metric_jiwer, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_wer_differentiability(self, preds, targets): @@ -59,5 +57,4 @@ def test_wer_differentiability(self, preds, targets): targets=targets, metric_module=WER, metric_functional=wer, - input_order=INPUT_ORDER.PREDS_FIRST, ) diff --git a/tests/text/test_wil.py b/tests/text/test_wil.py index 08b42159d11..c901759d7e9 100644 --- a/tests/text/test_wil.py +++ b/tests/text/test_wil.py @@ -3,7 +3,7 @@ import pytest from jiwer import wil -from tests.text.helpers import INPUT_ORDER, TextTester +from tests.text.helpers import TextTester from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from torchmetrics.functional.text.wil import word_information_lost from torchmetrics.text.wil import WordInfoLost @@ -34,7 +34,6 @@ def test_wil_class(self, ddp, dist_sync_on_step, preds, targets): metric_class=WordInfoLost, sk_metric=_compute_wil_metric_jiwer, dist_sync_on_step=dist_sync_on_step, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_wil_functional(self, preds, targets): @@ -44,7 +43,6 @@ def test_wil_functional(self, preds, targets): targets, metric_functional=word_information_lost, sk_metric=_compute_wil_metric_jiwer, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_wil_differentiability(self, preds, targets): @@ -54,5 +52,4 @@ def test_wil_differentiability(self, preds, targets): targets=targets, metric_module=WordInfoLost, metric_functional=word_information_lost, - input_order=INPUT_ORDER.PREDS_FIRST, ) diff --git a/tests/text/test_wip.py b/tests/text/test_wip.py index 0d655823232..6d67459bd4e 100644 --- a/tests/text/test_wip.py +++ b/tests/text/test_wip.py @@ -3,7 +3,7 @@ import pytest from jiwer import wip -from tests.text.helpers import INPUT_ORDER, TextTester +from tests.text.helpers import TextTester from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 from torchmetrics.functional.text.wip import word_information_preserved from torchmetrics.text.wip import WordInfoPreserved @@ -34,7 +34,6 @@ def test_wip_class(self, ddp, dist_sync_on_step, preds, targets): metric_class=WordInfoPreserved, sk_metric=_compute_wip_metric_jiwer, dist_sync_on_step=dist_sync_on_step, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_wip_functional(self, preds, targets): @@ -44,7 +43,6 @@ def test_wip_functional(self, preds, targets): targets, metric_functional=word_information_preserved, sk_metric=_compute_wip_metric_jiwer, - input_order=INPUT_ORDER.PREDS_FIRST, ) def test_wip_differentiability(self, preds, targets): @@ -54,5 +52,4 @@ def test_wip_differentiability(self, preds, targets): targets=targets, metric_module=WordInfoPreserved, metric_functional=word_information_preserved, - input_order=INPUT_ORDER.PREDS_FIRST, ) From ad72e6867ab72cc16f7e33f91adc2e11d1c6cf67 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Thu, 6 Jan 2022 15:32:29 +0100 Subject: [PATCH 03/22] Fix some unchanged names --- torchmetrics/functional/text/chrf.py | 16 ++++++++-------- torchmetrics/text/chrf.py | 2 +- torchmetrics/text/wil.py | 4 ++-- torchmetrics/text/wip.py | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/torchmetrics/functional/text/chrf.py b/torchmetrics/functional/text/chrf.py index 20891c3fa6d..555f0b8f8d8 100644 --- a/torchmetrics/functional/text/chrf.py +++ b/torchmetrics/functional/text/chrf.py @@ -64,18 +64,18 @@ def _prepare_n_grams_dicts( Dictionaries with default zero values for total reference, hypothesis and matching character and word n-grams. """ - total_ref_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} - total_ref_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} - total_hyp_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} - total_hyp_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} + total_preds_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} + total_preds_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} + total_target_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} + total_target_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} total_matching_char_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_char_order)} total_matching_word_n_grams: Dict[int, Tensor] = {n + 1: tensor(0.0) for n in range(n_word_order)} return ( - total_hyp_char_n_grams, - total_hyp_word_n_grams, - total_ref_char_n_grams, - total_ref_word_n_grams, + total_preds_char_n_grams, + total_preds_word_n_grams, + total_target_char_n_grams, + total_target_word_n_grams, total_matching_char_n_grams, total_matching_word_n_grams, ) diff --git a/torchmetrics/text/chrf.py b/torchmetrics/text/chrf.py index 0b5c0f26cb8..a0d8e4f4933 100644 --- a/torchmetrics/text/chrf.py +++ b/torchmetrics/text/chrf.py @@ -27,7 +27,7 @@ from torchmetrics.functional.text.chrf import _chrf_score_compute, _chrf_score_update, _prepare_n_grams_dicts _N_GRAM_LEVELS = ("char", "word") -_TEXT_LEVELS = ("ref", "hyp", "matching") +_TEXT_LEVELS = ("preds", "target", "matching") _DICT_STATES_NAMES = ( "total_preds_char_n_grams", diff --git a/torchmetrics/text/wil.py b/torchmetrics/text/wil.py index ca0b25730df..3082ae9d6ec 100644 --- a/torchmetrics/text/wil.py +++ b/torchmetrics/text/wil.py @@ -80,8 +80,8 @@ def __init__( dist_sync_fn=dist_sync_fn, ) self.add_state("errors", tensor(0.0), dist_reduce_fx="sum") - self.add_state("reference_total", tensor(0.0), dist_reduce_fx="sum") - self.add_state("prediction_total", tensor(0.0), dist_reduce_fx="sum") + self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") + self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store references/predictions for computing Word Information Lost scores. diff --git a/torchmetrics/text/wip.py b/torchmetrics/text/wip.py index fa91c898c60..3fca37ca0ff 100644 --- a/torchmetrics/text/wip.py +++ b/torchmetrics/text/wip.py @@ -80,8 +80,8 @@ def __init__( dist_sync_fn=dist_sync_fn, ) self.add_state("errors", tensor(0.0), dist_reduce_fx="sum") - self.add_state("reference_total", tensor(0.0), dist_reduce_fx="sum") - self.add_state("prediction_total", tensor(0.0), dist_reduce_fx="sum") + self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") + self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store references/predictions for computing word Information Preserved scores. From b3fd97b4327c8095c248646e3120efd99d1ef3ac Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Thu, 6 Jan 2022 16:30:08 +0100 Subject: [PATCH 04/22] Unify some other stuff --- tests/text/test_bertscore.py | 140 ++++++++++++++++++----------------- tests/text/test_bleu.py | 30 ++++---- tests/text/test_chrf.py | 24 +++--- tests/text/test_wil.py | 4 +- tests/text/test_wip.py | 4 +- torchmetrics/text/bert.py | 7 +- torchmetrics/text/wil.py | 2 +- torchmetrics/text/wip.py | 4 +- 8 files changed, 107 insertions(+), 108 deletions(-) diff --git a/tests/text/test_bertscore.py b/tests/text/test_bertscore.py index 0def5e32fcf..a6d4018e59e 100644 --- a/tests/text/test_bertscore.py +++ b/tests/text/test_bertscore.py @@ -25,7 +25,7 @@ "The victim's brother said he cannot imagine anyone who would want to harm him,\"Finally, it went uphill again at " 'him."', ] -refs = [ +targets = [ "28-Year-Old Chef Found Dead at San Francisco Mall", "A 28-year-old chef who had recently moved to San Francisco was found dead in the stairwell of a local mall this " "week.", @@ -39,9 +39,9 @@ MODEL_NAME = "albert-base-v2" -def _assert_list(preds: Any, refs: Any, threshold: float = 1e-8): +def _assert_list(preds: Any, targets: Any, threshold: float = 1e-8): """Assert two lists are equal.""" - assert np.allclose(preds, refs, atol=threshold, equal_nan=True) + assert np.allclose(preds, targets, atol=threshold, equal_nan=True) def _parse_original_bert_score(score: torch.Tensor) -> Dict[str, List[float]]: @@ -51,21 +51,21 @@ def _parse_original_bert_score(score: torch.Tensor) -> Dict[str, List[float]]: preds_batched = [preds[0:2], preds[2:]] -refs_batched = [refs[0:2], refs[2:]] +targets_batched = [targets[0:2], targets[2:]] @pytest.mark.parametrize( - "preds,refs", - [(preds, refs)], + "preds,targets", + [(preds, targets)], ) @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") -def test_score_fn(preds, refs): +def test_score_fn(preds, targets): """Tests for functional.""" - original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3) + original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3) original_score = _parse_original_bert_score(original_score) metrics_score = metrics_bert_score( - preds, refs, model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3 + preds, targets, model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3 ) for metric in _METRICS: @@ -73,17 +73,17 @@ def test_score_fn(preds, refs): @pytest.mark.parametrize( - "preds,refs", - [(preds, refs)], + "preds,targets", + [(preds, targets)], ) @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") -def test_score_fn_with_idf(preds, refs): +def test_score_fn_with_idf(preds, targets): """Tests for functional with IDF rescaling.""" - original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, num_layers=12, idf=True, batch_size=3) + original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, num_layers=12, idf=True, batch_size=3) original_score = _parse_original_bert_score(original_score) metrics_score = metrics_bert_score( - preds, refs, model_name_or_path=MODEL_NAME, num_layers=12, idf=True, batch_size=3 + preds, targets, model_name_or_path=MODEL_NAME, num_layers=12, idf=True, batch_size=3 ) for metric in _METRICS: @@ -91,17 +91,19 @@ def test_score_fn_with_idf(preds, refs): @pytest.mark.parametrize( - "preds,refs", - [(preds, refs)], + "preds,targets", + [(preds, targets)], ) @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") -def test_score_fn_all_layers(preds, refs): +def test_score_fn_all_layers(preds, targets): """Tests for functional and all layers.""" - original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, all_layers=True, idf=False, batch_size=3) + original_score = original_bert_score( + preds, targets, model_type=MODEL_NAME, all_layers=True, idf=False, batch_size=3 + ) original_score = _parse_original_bert_score(original_score) metrics_score = metrics_bert_score( - preds, refs, model_name_or_path=MODEL_NAME, all_layers=True, idf=False, batch_size=3 + preds, targets, model_name_or_path=MODEL_NAME, all_layers=True, idf=False, batch_size=3 ) for metric in _METRICS: @@ -109,17 +111,17 @@ def test_score_fn_all_layers(preds, refs): @pytest.mark.parametrize( - "preds,refs", - [(preds, refs)], + "preds,targets", + [(preds, targets)], ) @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") -def test_score_fn_all_layers_with_idf(preds, refs): +def test_score_fn_all_layers_with_idf(preds, targets): """Tests for functional and all layers with IDF rescaling.""" - original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, all_layers=True, idf=True, batch_size=3) + original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, all_layers=True, idf=True, batch_size=3) original_score = _parse_original_bert_score(original_score) metrics_score = metrics_bert_score( - preds, refs, model_name_or_path=MODEL_NAME, all_layers=True, idf=True, batch_size=3 + preds, targets, model_name_or_path=MODEL_NAME, all_layers=True, idf=True, batch_size=3 ) for metric in _METRICS: @@ -127,15 +129,15 @@ def test_score_fn_all_layers_with_idf(preds, refs): @pytest.mark.parametrize( - "preds,refs", - [(preds, refs)], + "preds,targets", + [(preds, targets)], ) @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") -def test_score_fn_all_layers_rescale_with_baseline(preds, refs): +def test_score_fn_all_layers_rescale_with_baseline(preds, targets): """Tests for functional with baseline rescaling.""" original_score = original_bert_score( preds, - refs, + targets, model_type=MODEL_NAME, lang="en", num_layers=8, @@ -147,7 +149,7 @@ def test_score_fn_all_layers_rescale_with_baseline(preds, refs): metrics_score = metrics_bert_score( preds, - refs, + targets, model_name_or_path=MODEL_NAME, lang="en", num_layers=8, @@ -161,15 +163,15 @@ def test_score_fn_all_layers_rescale_with_baseline(preds, refs): @pytest.mark.parametrize( - "preds,refs", - [(preds, refs)], + "preds,targets", + [(preds, targets)], ) @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") -def test_score_fn_rescale_with_baseline(preds, refs): +def test_score_fn_rescale_with_baseline(preds, targets): """Tests for functional with baseline rescaling with all layers.""" original_score = original_bert_score( preds, - refs, + targets, model_type=MODEL_NAME, lang="en", all_layers=True, @@ -181,7 +183,7 @@ def test_score_fn_rescale_with_baseline(preds, refs): metrics_score = metrics_bert_score( preds, - refs, + targets, model_name_or_path=MODEL_NAME, lang="en", all_layers=True, @@ -195,17 +197,17 @@ def test_score_fn_rescale_with_baseline(preds, refs): @pytest.mark.parametrize( - "preds,refs", - [(preds, refs)], + "preds,targets", + [(preds, targets)], ) @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") -def test_score(preds, refs): +def test_score(preds, targets): """Tests for metric.""" - original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3) + original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3) original_score = _parse_original_bert_score(original_score) Scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3) - Scorer.update(preds=preds, target=refs) + Scorer.update(preds=preds, target=targets) metrics_score = Scorer.compute() for metric in _METRICS: @@ -213,17 +215,17 @@ def test_score(preds, refs): @pytest.mark.parametrize( - "preds,refs", - [(preds, refs)], + "preds,targets", + [(preds, targets)], ) @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") -def test_score_with_idf(preds, refs): +def test_score_with_idf(preds, targets): """Tests for metric with IDF rescaling.""" - original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, num_layers=8, idf=True, batch_size=3) + original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, num_layers=8, idf=True, batch_size=3) original_score = _parse_original_bert_score(original_score) Scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=True, batch_size=3) - Scorer.update(preds=preds, target=refs) + Scorer.update(preds=preds, target=targets) metrics_score = Scorer.compute() for metric in _METRICS: @@ -231,17 +233,19 @@ def test_score_with_idf(preds, refs): @pytest.mark.parametrize( - "preds,refs", - [(preds, refs)], + "preds,targets", + [(preds, targets)], ) @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") -def test_score_all_layers(preds, refs): +def test_score_all_layers(preds, targets): """Tests for metric and all layers.""" - original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, all_layers=True, idf=False, batch_size=3) + original_score = original_bert_score( + preds, targets, model_type=MODEL_NAME, all_layers=True, idf=False, batch_size=3 + ) original_score = _parse_original_bert_score(original_score) Scorer = BERTScore(model_name_or_path=MODEL_NAME, all_layers=True, idf=False, batch_size=3) - Scorer.update(preds=preds, target=refs) + Scorer.update(preds=preds, target=targets) metrics_score = Scorer.compute() for metric in _METRICS: @@ -249,17 +253,17 @@ def test_score_all_layers(preds, refs): @pytest.mark.parametrize( - "preds,refs", - [(preds, refs)], + "preds,targets", + [(preds, targets)], ) @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") -def test_score_all_layers_with_idf(preds, refs): +def test_score_all_layers_with_idf(preds, targets): """Tests for metric and all layers with IDF rescaling.""" - original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, all_layers=True, idf=True, batch_size=3) + original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, all_layers=True, idf=True, batch_size=3) original_score = _parse_original_bert_score(original_score) Scorer = BERTScore(model_name_or_path=MODEL_NAME, all_layers=True, idf=True, batch_size=3) - Scorer.update(preds=preds, target=refs) + Scorer.update(preds=preds, target=targets) metrics_score = Scorer.compute() for metric in _METRICS: @@ -267,19 +271,19 @@ def test_score_all_layers_with_idf(preds, refs): @pytest.mark.parametrize( - "preds,refs", - [(preds_batched, refs_batched)], + "preds,targets", + [(preds_batched, targets_batched)], ) @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") -def test_accumulation(preds, refs): +def test_accumulation(preds, targets): """Tests for metric works with accumulation.""" original_score = original_bert_score( - sum(preds, []), sum(refs, []), model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3 + sum(preds, []), sum(targets, []), model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3 ) original_score = _parse_original_bert_score(original_score) Scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3) - for p, r in zip(preds, refs): + for p, r in zip(preds, targets): Scorer.update(preds=p, target=r) metrics_score = Scorer.compute() @@ -287,32 +291,32 @@ def test_accumulation(preds, refs): _assert_list(metrics_score[metric], original_score[metric]) -def _bert_score_ddp(rank, world_size, preds, refs, original_score): +def _bert_score_ddp(rank, world_size, preds, targets, original_score): """Define a DDP process for BERTScore.""" os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" dist.init_process_group("gloo", rank=rank, world_size=world_size) Scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3, max_length=128) - Scorer.update(preds, refs) + Scorer.update(preds, targets) metrics_score = Scorer.compute() for metric in _METRICS: _assert_list(metrics_score[metric], original_score[metric]) dist.destroy_process_group() -def _test_score_ddp_fn(rank, world_size, preds, refs): +def _test_score_ddp_fn(rank, world_size, preds, targets): """Core functionality for the `test_score_ddp` test.""" - original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3) + original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3) original_score = _parse_original_bert_score(original_score) - _bert_score_ddp(rank, world_size, preds, refs, original_score) + _bert_score_ddp(rank, world_size, preds, targets, original_score) @pytest.mark.parametrize( - "preds,refs", - [(preds, refs)], + "preds,targets", + [(preds, targets)], ) @pytest.mark.skipif(not (_BERTSCORE_AVAILABLE and dist.is_available()), reason="test requires bert_score") -def test_score_ddp(preds, refs): +def test_score_ddp(preds, targets): """Tests for metric using DDP.""" world_size = 2 - mp.spawn(_test_score_ddp_fn, args=(world_size, preds, refs), nprocs=world_size, join=False) + mp.spawn(_test_score_ddp_fn, args=(world_size, preds, targets), nprocs=world_size, join=False) diff --git a/tests/text/test_bleu.py b/tests/text/test_bleu.py index 1de2a43b12e..029b5c34d1d 100644 --- a/tests/text/test_bleu.py +++ b/tests/text/test_bleu.py @@ -27,15 +27,11 @@ smooth_func = SmoothingFunction().method2 -def _compute_bleu_metric_nltk(hypotheses, list_of_references, weights, smoothing_function, **kwargs): - hypotheses_ = [hypothesis.split() for hypothesis in hypotheses] - list_of_references_ = [[line.split() for line in ref] for ref in list_of_references] +def _compute_bleu_metric_nltk(preds, targets, weights, smoothing_function, **kwargs): + preds_ = [pred.split() for pred in preds] + targets_ = [[line.split() for line in target] for target in targets] return corpus_bleu( - list_of_references=list_of_references_, - hypotheses=hypotheses_, - weights=weights, - smoothing_function=smoothing_function, - **kwargs + list_of_references=targets_, hypotheses=preds_, weights=weights, smoothing_function=smoothing_function, **kwargs ) @@ -100,20 +96,20 @@ def test_bleu_empty_functional(): def test_no_4_gram_functional(): - hyps = ["My full pytorch-lightning"] - refs = [["My full pytorch-lightning test", "Completely Different"]] - assert bleu_score(hyps, refs) == tensor(0.0) + preds = ["My full pytorch-lightning"] + targets = [["My full pytorch-lightning test", "Completely Different"]] + assert bleu_score(preds, targets) == tensor(0.0) def test_bleu_empty_class(): bleu = BLEUScore() - hyp = [[]] - ref = [[[]]] - assert bleu(hyp, ref) == tensor(0.0) + preds = [[]] + targets = [[[]]] + assert bleu(preds, targets) == tensor(0.0) def test_no_4_gram_class(): bleu = BLEUScore() - hyps = ["My full pytorch-lightning"] - refs = [["My full pytorch-lightning test", "Completely Different"]] - assert bleu(hyps, refs) == tensor(0.0) + preds = ["My full pytorch-lightning"] + targets = [["My full pytorch-lightning test", "Completely Different"]] + assert bleu(preds, targets) == tensor(0.0) diff --git a/tests/text/test_chrf.py b/tests/text/test_chrf.py index 54b689ee1a7..4990a7b21a3 100644 --- a/tests/text/test_chrf.py +++ b/tests/text/test_chrf.py @@ -110,28 +110,28 @@ def test_chrf_score_differentiability(self, preds, targets, char_order, word_ord def test_chrf_empty_functional(): - hyp = [] - ref = [[]] - assert chrf_score(hyp, ref) == tensor(0.0) + preds = [] + targets = [[]] + assert chrf_score(preds, targets) == tensor(0.0) def test_chrf_empty_class(): chrf = CHRFScore() - hyp = [] - ref = [[]] - assert chrf(hyp, ref) == tensor(0.0) + preds = [] + targets = [[]] + assert chrf(preds, targets) == tensor(0.0) def test_chrf_return_sentence_level_score_functional(): - hyp = _inputs_single_sentence_multiple_references.preds - ref = _inputs_single_sentence_multiple_references.targets - _, chrf_sentence_score = chrf_score(hyp, ref, return_sentence_level_score=True) + preds = _inputs_single_sentence_multiple_references.preds + targets = _inputs_single_sentence_multiple_references.targets + _, chrf_sentence_score = chrf_score(preds, targets, return_sentence_level_score=True) isinstance(chrf_sentence_score, Tensor) def test_chrf_return_sentence_level_class(): chrf = CHRFScore(return_sentence_level_score=True) - hyp = _inputs_single_sentence_multiple_references.preds - ref = _inputs_single_sentence_multiple_references.targets - _, chrf_sentence_score = chrf(hyp, ref) + preds = _inputs_single_sentence_multiple_references.preds + targets = _inputs_single_sentence_multiple_references.targets + _, chrf_sentence_score = chrf(preds, targets) isinstance(chrf_sentence_score, Tensor) diff --git a/tests/text/test_wil.py b/tests/text/test_wil.py index c901759d7e9..3dbc085bafb 100644 --- a/tests/text/test_wil.py +++ b/tests/text/test_wil.py @@ -10,8 +10,8 @@ from torchmetrics.utilities.imports import _JIWER_AVAILABLE -def _compute_wil_metric_jiwer(prediction: Union[str, List[str]], reference: Union[str, List[str]]): - return wil(reference, prediction) +def _compute_wil_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]): + return wil(target, preds) @pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer") diff --git a/tests/text/test_wip.py b/tests/text/test_wip.py index 6d67459bd4e..ae98e8d030c 100644 --- a/tests/text/test_wip.py +++ b/tests/text/test_wip.py @@ -10,8 +10,8 @@ from torchmetrics.utilities.imports import _JIWER_AVAILABLE -def _compute_wip_metric_jiwer(prediction: Union[str, List[str]], reference: Union[str, List[str]]): - return wip(reference, prediction) +def _compute_wip_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]): + return wip(target, preds) @pytest.mark.skipif(not _JIWER_AVAILABLE, reason="test requires jiwer") diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index e687f4490c5..bab0c5a3de6 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -115,8 +115,7 @@ class BERTScore(Metric): >>> preds = ["hello there", "general kenobi"] >>> target = ["hello there", "master kenobi"] >>> bertscore = BERTScore() - >>> bertscore.update(preds=preds,target=target) - >>> bertscore.compute() # doctest: +SKIP + >>> bertscore(preds=preds,target=target) # doctest: +SKIP {'precision': [0.99..., 0.99...], 'recall': [0.99..., 0.99...], 'f1': [0.99..., 0.99...]} @@ -197,9 +196,9 @@ def update(self, preds: List[str], target: List[str]) -> None: # type: ignore tokenized form to ensure the DDP mode working. Args: - predictions: + preds: An iterable of predicted sentences. - references: + target: An iterable of predicted sentences. """ preds_dict = _preprocess_text( diff --git a/torchmetrics/text/wil.py b/torchmetrics/text/wil.py index 3082ae9d6ec..f5aea9ebb55 100644 --- a/torchmetrics/text/wil.py +++ b/torchmetrics/text/wil.py @@ -84,7 +84,7 @@ def __init__( self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore - """Store references/predictions for computing Word Information Lost scores. + """Store predictions/references for computing Word Information Lost scores. Args: preds: Transcription(s) to score as a string or list of strings diff --git a/torchmetrics/text/wip.py b/torchmetrics/text/wip.py index 3fca37ca0ff..6a760234724 100644 --- a/torchmetrics/text/wip.py +++ b/torchmetrics/text/wip.py @@ -63,8 +63,8 @@ class WordInfoPreserved(Metric): is_differentiable = False higher_is_better = False errors: Tensor - target_total: Tensor preds_total: Tensor + target_total: Tensor def __init__( self, @@ -84,7 +84,7 @@ def __init__( self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore - """Store references/predictions for computing word Information Preserved scores. + """Store predictions/references for computing word Information Preserved scores. Args: preds: Transcription(s) to score as a string or list of strings From 5ff3534701dc48ec81f142f51ba0e4b790136864 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Thu, 6 Jan 2022 16:38:37 +0100 Subject: [PATCH 05/22] Fix flake8 + unwanted typo --- torchmetrics/functional/text/wil.py | 1 - torchmetrics/functional/text/wip.py | 1 - torchmetrics/text/bert.py | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/torchmetrics/functional/text/wil.py b/torchmetrics/functional/text/wil.py index a431f737e84..ff3e14e2cdd 100644 --- a/torchmetrics/functional/text/wil.py +++ b/torchmetrics/functional/text/wil.py @@ -14,7 +14,6 @@ from typing import List, Tuple, Union -import torch from torch import Tensor, tensor from torchmetrics.functional.text.helper import _edit_distance diff --git a/torchmetrics/functional/text/wip.py b/torchmetrics/functional/text/wip.py index 74f213581ad..8062548f822 100644 --- a/torchmetrics/functional/text/wip.py +++ b/torchmetrics/functional/text/wip.py @@ -14,7 +14,6 @@ from typing import List, Tuple, Union -import torch from torch import Tensor, tensor from torchmetrics.functional.text.helper import _edit_distance diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index bab0c5a3de6..69884cee3be 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -230,7 +230,7 @@ def compute(self) -> Dict[str, Union[List[float], str]]: """ return bert_score( preds=_concatenate(self.preds), - test_average_precisiont=_concatenate(self.target), + target=_concatenate(self.target), model_name_or_path=self.model_name_or_path, num_layers=self.num_layers, all_layers=self.all_layers, From f7c65ebc9ed3f2e205fdf92823554db3fa513483 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Thu, 6 Jan 2022 16:45:15 +0100 Subject: [PATCH 06/22] Fix an import in bert doc --- torchmetrics/text/bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index 69884cee3be..4cee1e76080 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -111,7 +111,7 @@ class BERTScore(Metric): Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values. Example: - >>> from torchmetrics import BERTScore + >>> from torchmetrics.text.bert import BERTScore >>> preds = ["hello there", "general kenobi"] >>> target = ["hello there", "master kenobi"] >>> bertscore = BERTScore() From 8c7b9e0ba3a8332cbb7434cd325a28e1b1ee5f40 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Fri, 7 Jan 2022 09:40:23 +0100 Subject: [PATCH 07/22] Some nits --- torchmetrics/functional/text/bleu.py | 8 ++++---- torchmetrics/functional/text/sacre_bleu.py | 4 ++-- torchmetrics/functional/text/wil.py | 4 ++-- torchmetrics/functional/text/wip.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index 627799f6a71..8edd2350227 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -84,14 +84,14 @@ def _bleu_score_update( for (pred, targets) in zip(preds_, target_): preds_len += len(pred) - target_len_list = [len(ref) for ref in targets] + target_len_list = [len(tgt) for tgt in targets] target_len_diff = [abs(len(pred) - x) for x in target_len_list] target_len += target_len_list[target_len_diff.index(min(target_len_diff))] preds_counter: Counter = _count_ngram(pred, n_gram) target_counter: Counter = Counter() - for ref in targets: - target_counter |= _count_ngram(ref, n_gram) + for tgt in targets: + target_counter |= _count_ngram(tgt, n_gram) ngram_counter_clip = preds_counter & target_counter @@ -183,7 +183,7 @@ def bleu_score( " Warning will be removed in v0.8." ) preds_ = [preds] if isinstance(preds, str) else preds - target_ = [[target_text] if isinstance(target_text, str) else target_text for target_text in target] + target_ = [[tgt] if isinstance(tgt, str) else tgt for tgt in target] if len(preds_) != len(target_): raise ValueError(f"Corpus has different size {len(preds_)} != {len(target_)}") diff --git a/torchmetrics/functional/text/sacre_bleu.py b/torchmetrics/functional/text/sacre_bleu.py index 835607579a0..1e4e11de16e 100644 --- a/torchmetrics/functional/text/sacre_bleu.py +++ b/torchmetrics/functional/text/sacre_bleu.py @@ -343,8 +343,8 @@ def sacre_bleu_score( numerator = torch.zeros(n_gram) denominator = torch.zeros(n_gram) - trans_len = tensor(0, dtype=torch.float) - ref_len = tensor(0, dtype=torch.float) + trans_len = tensor(0.0) + ref_len = tensor(0.0) tokenize_fn = partial(_SacreBLEUTokenizer.tokenize, tokenize=tokenize, lowercase=lowercase) trans_len, ref_len = _bleu_score_update( diff --git a/torchmetrics/functional/text/wil.py b/torchmetrics/functional/text/wil.py index ff3e14e2cdd..e3478b18ff8 100644 --- a/torchmetrics/functional/text/wil.py +++ b/torchmetrics/functional/text/wil.py @@ -41,9 +41,9 @@ def _wil_update( errors = tensor(0.0) target_total = tensor(0.0) preds_total = tensor(0.0) - for pred, tar in zip(preds, target): + for pred, tgt in zip(preds, target): pred_tokens = pred.split() - target_tokens = tar.split() + target_tokens = tgt.split() errors += _edit_distance(pred_tokens, target_tokens) target_total += len(target_tokens) preds_total += len(pred_tokens) diff --git a/torchmetrics/functional/text/wip.py b/torchmetrics/functional/text/wip.py index 8062548f822..faf572a8daf 100644 --- a/torchmetrics/functional/text/wip.py +++ b/torchmetrics/functional/text/wip.py @@ -41,9 +41,9 @@ def _wip_update( errors = tensor(0.0) target_total = tensor(0.0) preds_total = tensor(0.0) - for pred, tar in zip(preds, target): + for pred, tgt in zip(preds, target): pred_tokens = pred.split() - target_tokens = tar.split() + target_tokens = tgt.split() errors += _edit_distance(pred_tokens, target_tokens) target_total += len(target_tokens) preds_total += len(pred_tokens) From 8b330ea043e7d671293a918649b0b60ad24828c1 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Sun, 9 Jan 2022 14:10:40 +0100 Subject: [PATCH 08/22] Handle BC and add warnings --- torchmetrics/functional/text/bert.py | 18 +++++++ torchmetrics/functional/text/bleu.py | 21 +++++++- torchmetrics/functional/text/chrf.py | 20 +++++++- torchmetrics/functional/text/sacre_bleu.py | 57 ++++++++++++++-------- torchmetrics/functional/text/wil.py | 18 +++++++ torchmetrics/functional/text/wip.py | 18 +++++++ torchmetrics/text/bert.py | 27 ++++++++-- torchmetrics/text/bleu.py | 28 +++++++++-- torchmetrics/text/chrf.py | 24 ++++++++- torchmetrics/text/sacre_bleu.py | 30 ++++++++++-- torchmetrics/text/wil.py | 24 ++++++++- torchmetrics/text/wip.py | 24 ++++++++- 12 files changed, 271 insertions(+), 38 deletions(-) diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index 9009049a24f..f92ab1b69ab 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -17,6 +17,7 @@ import warnings from collections import Counter, defaultdict from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from warnings import warn import torch from torch import Tensor @@ -452,6 +453,8 @@ def _rescale_metrics_with_baseline( def bert_score( preds: Union[List[str], Dict[str, Tensor]], target: Union[List[str], Dict[str, Tensor]], + predictions: Union[None, List[str], Dict[str, Tensor]] = None, + references: Union[None, List[str], Dict[str, Tensor]] = None, model_name_or_path: Optional[str] = None, num_layers: Optional[int] = None, all_layers: bool = False, @@ -555,6 +558,21 @@ def bert_score( 'recall': [0.99..., 0.99...], 'f1': [0.99..., 0.99...]} """ + if predictions is not None: + warn( + "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + preds = predictions + if references is not None: + warn( + "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + target = references + if len(preds) != len(target): raise ValueError("Number of predicted and reference sententes must be the same!") diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index 8edd2350227..a692d288704 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -16,9 +16,9 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -import warnings from collections import Counter from typing import Callable, Sequence, Tuple, Union +from warnings import warn import torch from torch import Tensor, tensor @@ -146,6 +146,8 @@ def _bleu_score_compute( def bleu_score( preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]], + translate_corpus: Union[None, str, Sequence[str]] = None, + reference_corpus: Union[None, Sequence[Union[str, Sequence[str]]]] = None, n_gram: int = 4, smooth: bool = False, ) -> Tensor: @@ -178,10 +180,25 @@ def bleu_score( [2] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_ """ - warnings.warn( + warn( "Input order of targets and preds were changed to predictions firsts and targets second in v0.7." " Warning will be removed in v0.8." ) + if translate_corpus is not None: + warn( + "You are using deprecated argument `translate_corpus` in v0.7 which was renamed to `preds`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + preds = translate_corpus + if reference_corpus is not None: + warn( + "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + target = reference_corpus + preds_ = [preds] if isinstance(preds, str) else preds target_ = [[tgt] if isinstance(tgt, str) else tgt for tgt in target] diff --git a/torchmetrics/functional/text/chrf.py b/torchmetrics/functional/text/chrf.py index 555f0b8f8d8..f98750a3433 100644 --- a/torchmetrics/functional/text/chrf.py +++ b/torchmetrics/functional/text/chrf.py @@ -35,6 +35,7 @@ from collections import defaultdict from typing import Dict, List, Optional, Sequence, Tuple, Union +from warnings import warn import torch from torch import Tensor, tensor @@ -587,7 +588,9 @@ def _chrf_score_compute( def chrf_score( preds: Union[str, Sequence[str]], - target: Union[Sequence[str], Sequence[Sequence[str]]], + target: Sequence[Union[str, Sequence[str]]], + hypothesis_corpus: Union[str, Sequence[str]] = None, + reference_corpus: Union[None, Sequence[Union[str, Sequence[str]]]] = None, n_char_order: int = 6, n_word_order: int = 2, beta: float = 2.0, @@ -642,6 +645,21 @@ def chrf_score( [1] chrF: character n-gram F-score for automatic MT evaluation by Maja Popović `chrF score`_ [2] chrF++: words helping character n-grams by Maja Popović `chrF++ score`_ """ + if hypothesis_corpus is not None: + warn( + "You are using deprecated argument `hypothesis_corpus` in v0.7 which was renamed to `preds`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + preds = hypothesis_corpus + if reference_corpus is not None: + warn( + "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + target = reference_corpus + if not isinstance(n_char_order, int) or n_char_order < 1: raise ValueError("Expected argument `n_char_order` to be an integer greater than or equal to 1.") if not isinstance(n_word_order, int) or n_word_order < 0: diff --git a/torchmetrics/functional/text/sacre_bleu.py b/torchmetrics/functional/text/sacre_bleu.py index 1e4e11de16e..bb06e805efa 100644 --- a/torchmetrics/functional/text/sacre_bleu.py +++ b/torchmetrics/functional/text/sacre_bleu.py @@ -39,9 +39,9 @@ import re -import warnings from functools import partial -from typing import Sequence +from typing import Sequence, Union +from warnings import warn import torch from torch import Tensor, tensor @@ -278,8 +278,10 @@ def _lower(line: str, lowercase: bool) -> str: def sacre_bleu_score( - translate_corpus: Sequence[str], - reference_corpus: Sequence[Sequence[str]], + preds: Sequence[str], + target: Sequence[Sequence[str]], + translate_corpus: Union[None, Sequence[str]] = None, + reference_corpus: Union[None, Sequence[Sequence[str]]] = None, n_gram: int = 4, smooth: bool = False, tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a", @@ -289,9 +291,9 @@ def sacre_bleu_score( follows the behaviour of SacreBLEU [2] implementation from https://github.com/mjpost/sacrebleu. Args: - translate_corpus: + preds: An iterable of machine translated corpus - reference_corpus: + target: An iterable of iterables of reference corpus n_gram: Gram value ranged from 1 to 4 (Default 4) @@ -308,9 +310,9 @@ def sacre_bleu_score( Example: >>> from torchmetrics.functional import sacre_bleu_score - >>> translate_corpus = ['the cat is on the mat'] - >>> reference_corpus = [['there is a cat on the mat', 'a cat is on the mat']] - >>> sacre_bleu_score(translate_corpus, reference_corpus) + >>> preds = ['the cat is on the mat'] + >>> target = [['there is a cat on the mat', 'a cat is on the mat']] + >>> sacre_bleu_score(preds=preds, target=target) tensor(0.7598) References: @@ -322,10 +324,25 @@ def sacre_bleu_score( [3] Automatic Evaluation of Machine Translation Quality Using Longest Common Subsequence and Skip-Bigram Statistics by Chin-Yew Lin and Franz Josef Och `Machine Translation Evolution`_ """ - warnings.warn( + warn( "Input order of targets and preds were changed to predictions firsts and targets second in v0.7." " Warning will be removed in v0.8." ) + if translate_corpus is not None: + warn( + "You are using deprecated argument `translate_corpus` in v0.7 which was renamed to `preds`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + preds = translate_corpus + if reference_corpus is not None: + warn( + "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + target = reference_corpus + if tokenize not in AVAILABLE_TOKENIZERS: raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.") @@ -333,8 +350,8 @@ def sacre_bleu_score( raise ValueError( f"Unsupported tokenizer selected. Please, choose one of {list(_SacreBLEUTokenizer._TOKENIZE_FN.keys())}" ) - if len(translate_corpus) != len(reference_corpus): - raise ValueError(f"Corpus has different size {len(translate_corpus)} != {len(reference_corpus)}") + if len(preds) != len(target): + raise ValueError(f"Corpus has different size {len(preds)} != {len(target)}") if tokenize == "intl" and not _REGEX_AVAILABLE: raise ValueError( "`'intl'` tokenization requires `regex` installed. Use `pip install regex` or `pip install " @@ -343,19 +360,19 @@ def sacre_bleu_score( numerator = torch.zeros(n_gram) denominator = torch.zeros(n_gram) - trans_len = tensor(0.0) - ref_len = tensor(0.0) + preds_len = tensor(0.0) + target_len = tensor(0.0) tokenize_fn = partial(_SacreBLEUTokenizer.tokenize, tokenize=tokenize, lowercase=lowercase) - trans_len, ref_len = _bleu_score_update( - translate_corpus, - reference_corpus, + preds_len, target_len = _bleu_score_update( + preds, + target, numerator, denominator, - trans_len, - ref_len, + preds_len, + target_len, n_gram, tokenize_fn, ) - return _bleu_score_compute(trans_len, ref_len, numerator, denominator, n_gram, smooth) + return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, smooth) diff --git a/torchmetrics/functional/text/wil.py b/torchmetrics/functional/text/wil.py index e3478b18ff8..f5c94ea62cc 100644 --- a/torchmetrics/functional/text/wil.py +++ b/torchmetrics/functional/text/wil.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import List, Tuple, Union +from warnings import warn from torch import Tensor, tensor @@ -68,6 +69,8 @@ def _wil_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> T def word_information_lost( preds: Union[str, List[str]], target: Union[str, List[str]], + predictions: Union[None, str, List[str]] = None, + references: Union[None, str, List[str]] = None, ) -> Tensor: """Word Information Lost rate is a metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the @@ -84,5 +87,20 @@ def word_information_lost( >>> word_information_lost(preds=preds, target=target) tensor(0.6528) """ + if predictions is not None: + warn( + "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + preds = predictions + if references is not None: + warn( + "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + target = references + errors, target_total, preds_total = _wil_update(preds, target) return _wil_compute(errors, target_total, preds_total) diff --git a/torchmetrics/functional/text/wip.py b/torchmetrics/functional/text/wip.py index faf572a8daf..83e88626c22 100644 --- a/torchmetrics/functional/text/wip.py +++ b/torchmetrics/functional/text/wip.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import List, Tuple, Union +from warnings import warn from torch import Tensor, tensor @@ -68,6 +69,8 @@ def _wip_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> T def word_information_preserved( preds: Union[str, List[str]], target: Union[str, List[str]], + predictions: Union[None, str, List[str]] = None, + references: Union[None, str, List[str]] = None, ) -> Tensor: """Word Information Preserved rate is a metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the @@ -84,5 +87,20 @@ def word_information_preserved( >>> word_information_preserved(preds=preds, target=target) tensor(0.3472) """ + if predictions is not None: + warn( + "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + preds = predictions + if references is not None: + warn( + "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + target = references + errors, reference_total, prediction_total = _wip_update(preds, target) return _wip_compute(errors, reference_total, prediction_total) diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index 4cee1e76080..39fa5a16fa5 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import warnings from typing import Any, Callable, Dict, List, Optional, Union +from warnings import warn import torch @@ -183,7 +183,7 @@ def __init__( ) if not model_name_or_path: model_name_or_path = _DEFAULT_MODEL - warnings.warn( + warn( "The argument `model_name_or_path` was not specified while it is required when default " " `transformers` model are used." f"It is, therefore, used the default recommended model - {_DEFAULT_MODEL}." @@ -191,7 +191,13 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.user_tokenizer = False - def update(self, preds: List[str], target: List[str]) -> None: # type: ignore + def update( # type: ignore + self, + preds: List[str], + target: List[str], + predictions: Union[None, List[str]] = None, + references: Union[None, List[str]] = None, + ) -> None: """Store predictions/references for computing BERT scores. It is necessary to store sentences in a tokenized form to ensure the DDP mode working. @@ -201,6 +207,21 @@ def update(self, preds: List[str], target: List[str]) -> None: # type: ignore target: An iterable of predicted sentences. """ + if predictions is not None: + warn( + "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + preds = predictions + if references is not None: + warn( + "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + target = references + preds_dict = _preprocess_text( preds, self.tokenizer, diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index 090238b5873..e1fa631e616 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -16,8 +16,8 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -import warnings -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence, Union +from warnings import warn import torch from torch import Tensor, tensor @@ -83,7 +83,7 @@ def __init__( process_group=process_group, dist_sync_fn=dist_sync_fn, ) - warnings.warn( + warn( "Input order of targets and preds were changed to predictions firsts and targets second in v0.7." " Warning will be removed in v0.8." ) @@ -95,13 +95,33 @@ def __init__( self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum") self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum") - def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore + def update( # type: ignore + self, + preds: Sequence[str], + target: Sequence[Sequence[str]], + translate_corpus: Union[None, Sequence[str]] = None, + reference_corpus: Union[None, Sequence[Sequence[str]]] = None, + ) -> None: """Compute Precision Scores. Args: preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus """ + if translate_corpus is not None: + warn( + "You are using deprecated argument `translate_corpus` in v0.7 which was renamed to `preds`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + preds = translate_corpus + if reference_corpus is not None: + warn( + "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + target = reference_corpus self.preds_len, self.target_len = _bleu_score_update( preds, diff --git a/torchmetrics/text/chrf.py b/torchmetrics/text/chrf.py index a0d8e4f4933..e7185105193 100644 --- a/torchmetrics/text/chrf.py +++ b/torchmetrics/text/chrf.py @@ -19,6 +19,7 @@ import itertools from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union +from warnings import warn import torch from torch import Tensor, tensor @@ -143,7 +144,13 @@ def __init__( if self.return_sentence_level_score: self.add_state("sentence_chrf_score", [], dist_reduce_fx="cat") - def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore + def update( # type: ignore + self, + preds: Sequence[str], + target: Sequence[Sequence[str]], + hypothesis_corpus: Union[None, Sequence[str]] = None, + reference_corpus: Union[None, Sequence[Sequence[str]]] = None, + ) -> None: """Compute Precision Scores. Args: @@ -152,6 +159,21 @@ def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: target: An iterable of iterables of reference corpus. """ + if hypothesis_corpus is not None: + warn( + "You are using deprecated argument `hypothesis_corpus` in v0.7 which was renamed to `preds`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + preds = hypothesis_corpus + if reference_corpus is not None: + warn( + "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + target = reference_corpus + n_grams_dicts_tuple = _chrf_score_update( preds, target, diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index 57eb023551e..864c3e9a088 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings - # referenced from # Library Name: torchtext # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence, Union +from warnings import warn from typing_extensions import Literal @@ -102,7 +101,7 @@ def __init__( process_group=process_group, dist_sync_fn=dist_sync_fn, ) - warnings.warn( + warn( "Input order of targets and preds were changed to predictions firsts and targets \ second in v0.7. Warning will be removed in v0.8" ) @@ -116,13 +115,34 @@ def __init__( ) self.tokenizer = _SacreBLEUTokenizer(tokenize, lowercase) - def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore + def update( # type: ignore + self, + preds: Sequence[str], + target: Sequence[Sequence[str]], + translate_corpus: Union[None, Sequence[str]] = None, + reference_corpus: Union[None, Sequence[Sequence[str]]] = None, + ) -> None: """Compute Precision Scores. Args: preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus """ + if translate_corpus is not None: + warn( + "You are using deprecated argument `translate_corpus` in v0.7 which was renamed to `preds`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + preds = translate_corpus + if reference_corpus is not None: + warn( + "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + target = reference_corpus + self.preds_len, self.target_len = _bleu_score_update( preds, target, diff --git a/torchmetrics/text/wil.py b/torchmetrics/text/wil.py index f5aea9ebb55..0f64c648531 100644 --- a/torchmetrics/text/wil.py +++ b/torchmetrics/text/wil.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Any, Callable, List, Optional, Union +from warnings import warn from torch import Tensor, tensor @@ -83,13 +84,34 @@ def __init__( self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") - def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore + def update( # type: ignore + self, + preds: Union[str, List[str]], + target: Union[str, List[str]], + predictions: Union[None, str, List[str]] = None, + references: Union[None, str, List[str]] = None, + ) -> None: """Store predictions/references for computing Word Information Lost scores. Args: preds: Transcription(s) to score as a string or list of strings target: Reference(s) for each speech input as a string or list of strings """ + if predictions is not None: + warn( + "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + preds = predictions + if references is not None: + warn( + "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + target = references + errors, target_total, preds_total = _wil_update(preds, target) self.errors += errors self.target_total += target_total diff --git a/torchmetrics/text/wip.py b/torchmetrics/text/wip.py index 6a760234724..64c78d785c5 100644 --- a/torchmetrics/text/wip.py +++ b/torchmetrics/text/wip.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Any, Callable, List, Optional, Union +from warnings import warn from torch import Tensor, tensor @@ -83,13 +84,34 @@ def __init__( self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") - def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore + def update( # type: ignore + self, + preds: Union[str, List[str]], + target: Union[str, List[str]], + predictions: Union[None, str, List[str]] = None, + references: Union[None, str, List[str]] = None, + ) -> None: """Store predictions/references for computing word Information Preserved scores. Args: preds: Transcription(s) to score as a string or list of strings target: Reference(s) for each speech input as a string or list of strings """ + if predictions is not None: + warn( + "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + preds = predictions + if references is not None: + warn( + "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " + " The past argument will be removed in v0.8.", + DeprecationWarning, + ) + target = references + errors, target_total, preds_total = _wip_update(preds, target) self.errors += errors self.target_total += target_total From ab36e72a02e48e661fdb10469f1414080eee5b22 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Mon, 10 Jan 2022 11:17:17 +0100 Subject: [PATCH 09/22] Apply suggestions from code review --- torchmetrics/functional/text/bert.py | 14 +++++++++++--- torchmetrics/functional/text/bleu.py | 12 +++++++++--- torchmetrics/functional/text/chrf.py | 12 +++++++++--- torchmetrics/functional/text/sacre_bleu.py | 12 +++++++++--- torchmetrics/functional/text/wil.py | 8 +++++++- torchmetrics/functional/text/wip.py | 8 +++++++- torchmetrics/text/bert.py | 8 +++++++- torchmetrics/text/bleu.py | 6 +++++- torchmetrics/text/chrf.py | 8 +++++++- torchmetrics/text/sacre_bleu.py | 8 +++++++- torchmetrics/text/wil.py | 8 +++++++- torchmetrics/text/wip.py | 8 +++++++- 12 files changed, 92 insertions(+), 20 deletions(-) diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index f92ab1b69ab..bc2e43a56b1 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -453,8 +453,6 @@ def _rescale_metrics_with_baseline( def bert_score( preds: Union[List[str], Dict[str, Tensor]], target: Union[List[str], Dict[str, Tensor]], - predictions: Union[None, List[str], Dict[str, Tensor]] = None, - references: Union[None, List[str], Dict[str, Tensor]] = None, model_name_or_path: Optional[str] = None, num_layers: Optional[int] = None, all_layers: bool = False, @@ -472,6 +470,8 @@ def bert_score( rescale_with_baseline: bool = False, baseline_path: Optional[str] = None, baseline_url: Optional[str] = None, + predictions: Union[None, List[str], Dict[str, Tensor]] = None, + references: Union[None, List[str], Dict[str, Tensor]] = None, ) -> Dict[str, Union[List[float], str]]: """`Bert_score Evaluating Text Generation`_ leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate with @@ -533,6 +533,14 @@ def bert_score( A path to the user's own local csv/tsv file with the baseline scale. baseline_url: A url path to the user's own csv/tsv file with the baseline scale. + predictions: + Either an iterable of predicted sentences or a `Dict[str, torch.Tensor]` containing `input_ids` and + `attention_mask` `torch.Tensor`. + This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + references: + Either an iterable of target sentences or a `Dict[str, torch.Tensor]` containing `input_ids` and + `attention_mask` `torch.Tensor`. + This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. Returns: Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values. @@ -553,7 +561,7 @@ def bert_score( >>> from torchmetrics.functional.text.bert import bert_score >>> preds = ["hello there", "general kenobi"] >>> target = ["hello there", "master kenobi"] - >>> bert_score(preds=preds, target=target, lang="en") # doctest: +SKIP + >>> bert_score(preds, target, lang="en") # doctest: +SKIP {'precision': [0.99..., 0.99...], 'recall': [0.99..., 0.99...], 'f1': [0.99..., 0.99...]} diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index a692d288704..5d382ad0dd1 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -146,10 +146,10 @@ def _bleu_score_compute( def bleu_score( preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]], - translate_corpus: Union[None, str, Sequence[str]] = None, - reference_corpus: Union[None, Sequence[Union[str, Sequence[str]]]] = None, n_gram: int = 4, smooth: bool = False, + translate_corpus: Union[None, str, Sequence[str]] = None, + reference_corpus: Union[None, Sequence[Union[str, Sequence[str]]]] = None, ) -> Tensor: """Calculate `BLEU score`_ of machine translated text with one or more references. @@ -162,6 +162,12 @@ def bleu_score( Gram value ranged from 1 to 4 (Default 4) smooth: Whether or not to apply smoothing – see [2] + translate_corpus: + An iterable of machine translated corpus + This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + reference_corpus: + An iterable of iterables of reference corpus + This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. Return: Tensor with BLEU Score @@ -170,7 +176,7 @@ def bleu_score( >>> from torchmetrics.functional import bleu_score >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] - >>> bleu_score(preds=preds, target=target) + >>> bleu_score(preds, target) tensor(0.7598) References: diff --git a/torchmetrics/functional/text/chrf.py b/torchmetrics/functional/text/chrf.py index f98750a3433..9bc6e2b164f 100644 --- a/torchmetrics/functional/text/chrf.py +++ b/torchmetrics/functional/text/chrf.py @@ -589,14 +589,14 @@ def _chrf_score_compute( def chrf_score( preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]], - hypothesis_corpus: Union[str, Sequence[str]] = None, - reference_corpus: Union[None, Sequence[Union[str, Sequence[str]]]] = None, n_char_order: int = 6, n_word_order: int = 2, beta: float = 2.0, lowercase: bool = False, whitespace: bool = False, return_sentence_level_score: bool = False, + hypothesis_corpus: Union[None, str, Sequence[str]] = None, + reference_corpus: Union[None, Sequence[Union[str, Sequence[str]]]] = None, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Calculate `chrF score`_ of machine translated text with one or more references. This implementation supports both chrF score computation introduced in [1] and chrF++ score introduced in `chrF++ score`_. This @@ -621,6 +621,12 @@ def chrf_score( An indication whether to keep whitespaces during character n-gram extraction. return_sentence_level_score: An indication whether a sentence-level chrF/chrF++ score to be returned. + hypothesis_corpus: + An iterable of hypothesis corpus. + This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + reference_corpus: + An iterable of iterables of reference corpus. + This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. Return: A corpus-level chrF/chrF++ score. @@ -638,7 +644,7 @@ def chrf_score( >>> from torchmetrics.functional import chrf_score >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] - >>> chrf_score(preds=preds, target=target) + >>> chrf_score(preds, target) tensor(0.8640) References: diff --git a/torchmetrics/functional/text/sacre_bleu.py b/torchmetrics/functional/text/sacre_bleu.py index bb06e805efa..50219f52483 100644 --- a/torchmetrics/functional/text/sacre_bleu.py +++ b/torchmetrics/functional/text/sacre_bleu.py @@ -280,12 +280,12 @@ def _lower(line: str, lowercase: bool) -> str: def sacre_bleu_score( preds: Sequence[str], target: Sequence[Sequence[str]], - translate_corpus: Union[None, Sequence[str]] = None, - reference_corpus: Union[None, Sequence[Sequence[str]]] = None, n_gram: int = 4, smooth: bool = False, tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a", lowercase: bool = False, + translate_corpus: Union[None, Sequence[str]] = None, + reference_corpus: Union[None, Sequence[Sequence[str]]] = None, ) -> Tensor: """Calculate `BLEU score`_ [1] of machine translated text with one or more references. This implementation follows the behaviour of SacreBLEU [2] implementation from https://github.com/mjpost/sacrebleu. @@ -304,6 +304,12 @@ def sacre_bleu_score( Supported tokenization: ['none', '13a', 'zh', 'intl', 'char'] lowercase: If ``True``, BLEU score over lowercased text is calculated. + translate_corpus: + An iterable of machine translated corpus + This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + reference_corpus: + An iterable of iterables of reference corpus + This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. Return: Tensor with BLEU Score @@ -312,7 +318,7 @@ def sacre_bleu_score( >>> from torchmetrics.functional import sacre_bleu_score >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] - >>> sacre_bleu_score(preds=preds, target=target) + >>> sacre_bleu_score(preds, target) tensor(0.7598) References: diff --git a/torchmetrics/functional/text/wil.py b/torchmetrics/functional/text/wil.py index f5c94ea62cc..99f52f7287e 100644 --- a/torchmetrics/functional/text/wil.py +++ b/torchmetrics/functional/text/wil.py @@ -78,13 +78,19 @@ def word_information_lost( Args: preds: Transcription(s) to score as a string or list of strings target: Reference(s) for each speech input as a string or list of strings + predictions: + Transcription(s) to score as a string or list of strings + This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + references: + Reference(s) for each speech input as a string or list of strings + This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. Returns: Word Information Lost rate Examples: >>> from torchmetrics.functional import word_information_lost >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] - >>> word_information_lost(preds=preds, target=target) + >>> word_information_lost(preds, target) tensor(0.6528) """ if predictions is not None: diff --git a/torchmetrics/functional/text/wip.py b/torchmetrics/functional/text/wip.py index 83e88626c22..ebcbc09187e 100644 --- a/torchmetrics/functional/text/wip.py +++ b/torchmetrics/functional/text/wip.py @@ -78,13 +78,19 @@ def word_information_preserved( Args: preds: Transcription(s) to score as a string or list of strings total: Reference(s) for each speech input as a string or list of strings + predictions: + Transcription(s) to score as a string or list of strings + This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + references: + Reference(s) for each speech input as a string or list of strings + This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. Returns: Word Information preserved rate Examples: >>> from torchmetrics.functional import word_information_preserved >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] - >>> word_information_preserved(preds=preds, target=target) + >>> word_information_preserved(preds, target) tensor(0.3472) """ if predictions is not None: diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index 39fa5a16fa5..13dea5122ae 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -115,7 +115,7 @@ class BERTScore(Metric): >>> preds = ["hello there", "general kenobi"] >>> target = ["hello there", "master kenobi"] >>> bertscore = BERTScore() - >>> bertscore(preds=preds,target=target) # doctest: +SKIP + >>> bertscore(preds, target) # doctest: +SKIP {'precision': [0.99..., 0.99...], 'recall': [0.99..., 0.99...], 'f1': [0.99..., 0.99...]} @@ -205,7 +205,13 @@ def update( # type: ignore preds: An iterable of predicted sentences. target: + An iterable of reference sentences. + predictions: An iterable of predicted sentences. + This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + references: + An iterable of reference sentences. + This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. """ if predictions is not None: warn( diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index e1fa631e616..c0648509dc0 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -50,7 +50,7 @@ class BLEUScore(Metric): >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> metric = BLEUScore() - >>> metric(preds=preds, target=target) + >>> metric(preds, target) tensor(0.7598) References: @@ -107,6 +107,10 @@ def update( # type: ignore Args: preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus + translate_corpus: + An iterable of machine translated corpus + reference_corpus: + An iterable of iterables of reference corpus """ if translate_corpus is not None: warn( diff --git a/torchmetrics/text/chrf.py b/torchmetrics/text/chrf.py index e7185105193..c87de243524 100644 --- a/torchmetrics/text/chrf.py +++ b/torchmetrics/text/chrf.py @@ -88,7 +88,7 @@ class CHRFScore(Metric): >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> metric = CHRFScore() - >>> metric(preds=preds, target=target) + >>> metric(preds, target) tensor(0.8640) References: @@ -158,6 +158,12 @@ def update( # type: ignore An iterable of hypothesis corpus. target: An iterable of iterables of reference corpus. + hypotshesis_corpus: + An iterable of hypothesis corpus. + This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + reference_corpus: + An iterable of iterables of reference corpus. + This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. """ if hypothesis_corpus is not None: warn( diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index 864c3e9a088..330115488c3 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -69,7 +69,7 @@ class SacreBLEUScore(BLEUScore): >>> preds = ['the cat is on the mat'] >>> target = [['there is a cat on the mat', 'a cat is on the mat']] >>> metric = SacreBLEUScore() - >>> metric(preds=preds, target=target) + >>> metric(preds, target) tensor(0.7598) References: @@ -127,6 +127,12 @@ def update( # type: ignore Args: preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus + translate_corpus: + An iterable of machine translated corpus + This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + reference_corpus: + An iterable of iterables of reference corpus + This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. """ if translate_corpus is not None: warn( diff --git a/torchmetrics/text/wil.py b/torchmetrics/text/wil.py index 0f64c648531..17487120bf4 100644 --- a/torchmetrics/text/wil.py +++ b/torchmetrics/text/wil.py @@ -58,7 +58,7 @@ class WordInfoLost(Metric): >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> metric = WordInfoLost() - >>> metric(preds=preds, target=target) + >>> metric(preds, target) tensor(0.6528) """ is_differentiable = False @@ -96,6 +96,12 @@ def update( # type: ignore Args: preds: Transcription(s) to score as a string or list of strings target: Reference(s) for each speech input as a string or list of strings + predictions: + Transcription(s) to score as a string or list of strings + This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + references: + Reference(s) for each speech input as a string or list of strings + This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. """ if predictions is not None: warn( diff --git a/torchmetrics/text/wip.py b/torchmetrics/text/wip.py index 64c78d785c5..707840b43ab 100644 --- a/torchmetrics/text/wip.py +++ b/torchmetrics/text/wip.py @@ -58,7 +58,7 @@ class WordInfoPreserved(Metric): >>> preds = ["this is the prediction", "there is an other sample"] >>> target = ["this is the reference", "there is another one"] >>> metric = WordInfoPreserved() - >>> metric(preds=preds, target=target) + >>> metric(preds, target) tensor(0.3472) """ is_differentiable = False @@ -96,6 +96,12 @@ def update( # type: ignore Args: preds: Transcription(s) to score as a string or list of strings target: Reference(s) for each speech input as a string or list of strings + predictions: + Transcription(s) to score as a string or list of strings + This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + references: + Reference(s) for each speech input as a string or list of strings + This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. """ if predictions is not None: warn( From 8818569e02864b5109bde4e0072b1cb3cf995929 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Mon, 10 Jan 2022 11:21:17 +0100 Subject: [PATCH 10/22] Add one unsaved change --- torchmetrics/text/bleu.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index c0648509dc0..141c53d8df6 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -109,8 +109,10 @@ def update( # type: ignore target: An iterable of iterables of reference corpus translate_corpus: An iterable of machine translated corpus + This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. reference_corpus: An iterable of iterables of reference corpus + This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. """ if translate_corpus is not None: warn( From 5547156b29c1496a02b0b0ab8312d54305e103c9 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Mon, 10 Jan 2022 11:36:08 +0100 Subject: [PATCH 11/22] Fix doc indentation for wip/wil --- torchmetrics/functional/text/wil.py | 13 +++++++++---- torchmetrics/functional/text/wip.py | 15 ++++++++++----- torchmetrics/text/wil.py | 6 ++++-- torchmetrics/text/wip.py | 6 ++++-- 4 files changed, 27 insertions(+), 13 deletions(-) diff --git a/torchmetrics/functional/text/wil.py b/torchmetrics/functional/text/wil.py index 99f52f7287e..9d7de0a291c 100644 --- a/torchmetrics/functional/text/wil.py +++ b/torchmetrics/functional/text/wil.py @@ -73,19 +73,24 @@ def word_information_lost( references: Union[None, str, List[str]] = None, ) -> Tensor: """Word Information Lost rate is a metric of the performance of an automatic speech recognition system. This - value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the - performance of the ASR system with a Word Information Lost rate of 0 being a perfect score. + value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better + the performance of the ASR system with a Word Information Lost rate of 0 being a perfect score. + Args: - preds: Transcription(s) to score as a string or list of strings - target: Reference(s) for each speech input as a string or list of strings + preds: + Transcription(s) to score as a string or list of strings + target: + Reference(s) for each speech input as a string or list of strings predictions: Transcription(s) to score as a string or list of strings This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. references: Reference(s) for each speech input as a string or list of strings This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. + Returns: Word Information Lost rate + Examples: >>> from torchmetrics.functional import word_information_lost >>> preds = ["this is the prediction", "there is an other sample"] diff --git a/torchmetrics/functional/text/wip.py b/torchmetrics/functional/text/wip.py index ebcbc09187e..defe3842d88 100644 --- a/torchmetrics/functional/text/wip.py +++ b/torchmetrics/functional/text/wip.py @@ -72,20 +72,25 @@ def word_information_preserved( predictions: Union[None, str, List[str]] = None, references: Union[None, str, List[str]] = None, ) -> Tensor: - """Word Information Preserved rate is a metric of the performance of an automatic speech recognition system. This - value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the - performance of the ASR system with a Word Information preserved rate of 0 being a perfect score. + """Word Information Preserved rate is a metric of the performance of an automatic speech recognition system. + This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the + better the performance of the ASR system with a Word Information preserved rate of 0 being a perfect score. + Args: - preds: Transcription(s) to score as a string or list of strings - total: Reference(s) for each speech input as a string or list of strings + preds: + Transcription(s) to score as a string or list of strings + total: + Reference(s) for each speech input as a string or list of strings predictions: Transcription(s) to score as a string or list of strings This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. references: Reference(s) for each speech input as a string or list of strings This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. + Returns: Word Information preserved rate + Examples: >>> from torchmetrics.functional import word_information_preserved >>> preds = ["this is the prediction", "there is an other sample"] diff --git a/torchmetrics/text/wil.py b/torchmetrics/text/wil.py index 17487120bf4..7ee70cb9efd 100644 --- a/torchmetrics/text/wil.py +++ b/torchmetrics/text/wil.py @@ -94,8 +94,10 @@ def update( # type: ignore """Store predictions/references for computing Word Information Lost scores. Args: - preds: Transcription(s) to score as a string or list of strings - target: Reference(s) for each speech input as a string or list of strings + preds: + Transcription(s) to score as a string or list of strings + target: + Reference(s) for each speech input as a string or list of strings predictions: Transcription(s) to score as a string or list of strings This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. diff --git a/torchmetrics/text/wip.py b/torchmetrics/text/wip.py index 707840b43ab..4a645d48c16 100644 --- a/torchmetrics/text/wip.py +++ b/torchmetrics/text/wip.py @@ -94,8 +94,10 @@ def update( # type: ignore """Store predictions/references for computing word Information Preserved scores. Args: - preds: Transcription(s) to score as a string or list of strings - target: Reference(s) for each speech input as a string or list of strings + preds: + Transcription(s) to score as a string or list of strings + target: + Reference(s) for each speech input as a string or list of strings predictions: Transcription(s) to score as a string or list of strings This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. From 71e82c16ad4e3a6c7b08a3e14d8327fe24c06763 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Mon, 10 Jan 2022 16:39:46 +0100 Subject: [PATCH 12/22] Set preds, target = None --- torchmetrics/functional/text/bert.py | 15 +++++++++++---- torchmetrics/functional/text/bleu.py | 15 +++++++++++---- torchmetrics/functional/text/chrf.py | 15 +++++++++++---- torchmetrics/functional/text/sacre_bleu.py | 15 +++++++++++---- torchmetrics/functional/text/wil.py | 15 +++++++++++---- torchmetrics/functional/text/wip.py | 15 +++++++++++---- torchmetrics/text/bert.py | 15 +++++++++++---- torchmetrics/text/bleu.py | 15 +++++++++++---- torchmetrics/text/chrf.py | 15 +++++++++++---- torchmetrics/text/sacre_bleu.py | 15 +++++++++++---- torchmetrics/text/wil.py | 15 +++++++++++---- torchmetrics/text/wip.py | 15 +++++++++++---- 12 files changed, 132 insertions(+), 48 deletions(-) diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index 906e6f8d8ab..eb94ed66151 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -451,8 +451,8 @@ def _rescale_metrics_with_baseline( def bert_score( - preds: Union[List[str], Dict[str, Tensor]], - target: Union[List[str], Dict[str, Tensor]], + preds: Union[None, List[str], Dict[str, Tensor]] = None, + target: Union[None, List[str], Dict[str, Tensor]] = None, model_name_or_path: Optional[str] = None, num_layers: Optional[int] = None, all_layers: bool = False, @@ -566,20 +566,27 @@ def bert_score( 'recall': [0.99..., 0.99...], 'f1': [0.99..., 0.99...]} """ + if preds is None and predictions is None: + raise ValueError("Either `preds` or `predictions` must be provided.") + if target is None and references is None: + raise ValueError("Either `target` or `references` must be provided.") + if predictions is not None: warn( "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - preds = predictions + warn("If you specify both `preds` and `predictions`, only `preds` is considered.") + preds = preds or predictions if references is not None: warn( "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - target = references + warn("If you specify both `target` and `references`, only `target` is considered.") + target = target or references if len(preds) != len(target): raise ValueError("Number of predicted and reference sententes must be the same!") diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index 5d382ad0dd1..e3168c0a73e 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -144,8 +144,8 @@ def _bleu_score_compute( def bleu_score( - preds: Union[str, Sequence[str]], - target: Sequence[Union[str, Sequence[str]]], + preds: Union[None, str, Sequence[str]] = None, + target: Union[None, Sequence[Union[str, Sequence[str]]]] = None, n_gram: int = 4, smooth: bool = False, translate_corpus: Union[None, str, Sequence[str]] = None, @@ -190,20 +190,27 @@ def bleu_score( "Input order of targets and preds were changed to predictions firsts and targets second in v0.7." " Warning will be removed in v0.8." ) + if preds is None and translate_corpus is None: + raise ValueError("Either `preds` or `translate_corpus` must be provided.") + if target is None and reference_corpus is None: + raise ValueError("Either `target` or `reference_corpus` must be provided.") + if translate_corpus is not None: warn( "You are using deprecated argument `translate_corpus` in v0.7 which was renamed to `preds`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - preds = translate_corpus + warn("If you specify both `preds` and `translate_corpus`, only `preds` is considered.") + preds = preds or translate_corpus if reference_corpus is not None: warn( "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - target = reference_corpus + warn("If you specify both `target` and `reference_corpus`, only `target` is considered.") + target = target or reference_corpus preds_ = [preds] if isinstance(preds, str) else preds target_ = [[tgt] if isinstance(tgt, str) else tgt for tgt in target] diff --git a/torchmetrics/functional/text/chrf.py b/torchmetrics/functional/text/chrf.py index 9bc6e2b164f..e8a845199d4 100644 --- a/torchmetrics/functional/text/chrf.py +++ b/torchmetrics/functional/text/chrf.py @@ -587,8 +587,8 @@ def _chrf_score_compute( def chrf_score( - preds: Union[str, Sequence[str]], - target: Sequence[Union[str, Sequence[str]]], + preds: Union[None, str, Sequence[str]] = None, + target: Union[None, Sequence[Union[str, Sequence[str]]]] = None, n_char_order: int = 6, n_word_order: int = 2, beta: float = 2.0, @@ -651,20 +651,27 @@ def chrf_score( [1] chrF: character n-gram F-score for automatic MT evaluation by Maja Popović `chrF score`_ [2] chrF++: words helping character n-grams by Maja Popović `chrF++ score`_ """ + if preds is None and hypothesis_corpus is None: + raise ValueError("Either `preds` or `hypothesis_corpus` must be provided.") + if target is None and reference_corpus is None: + raise ValueError("Either `target` or `reference_corpus` must be provided.") + if hypothesis_corpus is not None: warn( "You are using deprecated argument `hypothesis_corpus` in v0.7 which was renamed to `preds`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - preds = hypothesis_corpus + warn("If you specify both `preds` and `hypothesis_corpus`, only `preds` is considered.") + preds = preds or hypothesis_corpus if reference_corpus is not None: warn( "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - target = reference_corpus + warn("If you specify both `target` and `reference_corpus`, only `target` is considered.") + target = target or reference_corpus if not isinstance(n_char_order, int) or n_char_order < 1: raise ValueError("Expected argument `n_char_order` to be an integer greater than or equal to 1.") diff --git a/torchmetrics/functional/text/sacre_bleu.py b/torchmetrics/functional/text/sacre_bleu.py index 5eee679e5fb..73c310fbc60 100644 --- a/torchmetrics/functional/text/sacre_bleu.py +++ b/torchmetrics/functional/text/sacre_bleu.py @@ -278,8 +278,8 @@ def _lower(line: str, lowercase: bool) -> str: def sacre_bleu_score( - preds: Sequence[str], - target: Sequence[Sequence[str]], + preds: Union[None, Sequence[str]] = None, + target: Union[None, Sequence[Sequence[str]]] = None, n_gram: int = 4, smooth: bool = False, tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a", @@ -334,20 +334,27 @@ def sacre_bleu_score( "Input order of targets and preds were changed to predictions firsts and targets second in v0.7." " Warning will be removed in v0.8." ) + if preds is None and translate_corpus is None: + raise ValueError("Either `preds` or `translate_corpus` must be provided.") + if target is None and reference_corpus is None: + raise ValueError("Either `target` or `reference_corpus` must be provided.") + if translate_corpus is not None: warn( "You are using deprecated argument `translate_corpus` in v0.7 which was renamed to `preds`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - preds = translate_corpus + warn("If you specify both `preds` and `translate_corpus`, only `preds` is considered.") + preds = preds or translate_corpus if reference_corpus is not None: warn( "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - target = reference_corpus + warn("If you specify both `target` and `reference_corpus`, only `target` is considered.") + target = target or reference_corpus if tokenize not in AVAILABLE_TOKENIZERS: raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.") diff --git a/torchmetrics/functional/text/wil.py b/torchmetrics/functional/text/wil.py index 9d7de0a291c..1c229e50075 100644 --- a/torchmetrics/functional/text/wil.py +++ b/torchmetrics/functional/text/wil.py @@ -67,8 +67,8 @@ def _wil_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> T def word_information_lost( - preds: Union[str, List[str]], - target: Union[str, List[str]], + preds: Union[None, str, List[str]] = None, + target: Union[None, str, List[str]] = None, predictions: Union[None, str, List[str]] = None, references: Union[None, str, List[str]] = None, ) -> Tensor: @@ -98,20 +98,27 @@ def word_information_lost( >>> word_information_lost(preds, target) tensor(0.6528) """ + if preds is None and predictions is None: + raise ValueError("Either `preds` or `predictions` must be provided.") + if target is None and references is None: + raise ValueError("Either `target` or `references` must be provided.") + if predictions is not None: warn( "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - preds = predictions + warn("If you specify both `preds` and `predictions`, only `preds` is considered.") + preds = preds or predictions if references is not None: warn( "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - target = references + warn("If you specify both `target` and `references`, only `target` is considered.") + target = target or references errors, target_total, preds_total = _wil_update(preds, target) return _wil_compute(errors, target_total, preds_total) diff --git a/torchmetrics/functional/text/wip.py b/torchmetrics/functional/text/wip.py index defe3842d88..90cc8fb9581 100644 --- a/torchmetrics/functional/text/wip.py +++ b/torchmetrics/functional/text/wip.py @@ -67,8 +67,8 @@ def _wip_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> T def word_information_preserved( - preds: Union[str, List[str]], - target: Union[str, List[str]], + preds: Union[None, str, List[str]] = None, + target: Union[None, str, List[str]] = None, predictions: Union[None, str, List[str]] = None, references: Union[None, str, List[str]] = None, ) -> Tensor: @@ -98,20 +98,27 @@ def word_information_preserved( >>> word_information_preserved(preds, target) tensor(0.3472) """ + if preds is None and predictions is None: + raise ValueError("Either `preds` or `predictions` must be provided.") + if target is None and references is None: + raise ValueError("Either `target` or `references` must be provided.") + if predictions is not None: warn( "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - preds = predictions + warn("If you specify both `preds` and `predictions`, only `preds` is considered.") + preds = preds or predictions if references is not None: warn( "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - target = references + warn("If you specify both `target` and `references`, only `target` is considered.") + target = target or references errors, reference_total, prediction_total = _wip_update(preds, target) return _wip_compute(errors, reference_total, prediction_total) diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index 88269e8e9bc..98a408e42ce 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -193,8 +193,8 @@ def __init__( def update( # type: ignore self, - preds: List[str], - target: List[str], + preds: Union[None, List[str]] = None, + target: Union[None, List[str]] = None, predictions: Union[None, List[str]] = None, references: Union[None, List[str]] = None, ) -> None: @@ -213,20 +213,27 @@ def update( # type: ignore An iterable of reference sentences. This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. """ + if preds is None and predictions is None: + raise ValueError("Either `preds` or `predictions` must be provided.") + if target is None and references is None: + raise ValueError("Either `target` or `references` must be provided.") + if predictions is not None: warn( "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - preds = predictions + warn("If you specify both `preds` and `predictions`, only `preds` is considered.") + preds = preds or predictions if references is not None: warn( "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - target = references + warn("If you specify both `target` and `references`, only `target` is considered.") + target = target or references preds_dict = _preprocess_text( preds, diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index 141c53d8df6..89395f253b3 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -97,8 +97,8 @@ def __init__( def update( # type: ignore self, - preds: Sequence[str], - target: Sequence[Sequence[str]], + preds: Union[None, Sequence[str]] = None, + target: Union[None, Sequence[Sequence[str]]] = None, translate_corpus: Union[None, Sequence[str]] = None, reference_corpus: Union[None, Sequence[Sequence[str]]] = None, ) -> None: @@ -114,20 +114,27 @@ def update( # type: ignore An iterable of iterables of reference corpus This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. """ + if preds is None and translate_corpus is None: + raise ValueError("Either `preds` or `translate_corpus` must be provided.") + if target is None and reference_corpus is None: + raise ValueError("Either `target` or `reference_corpus` must be provided.") + if translate_corpus is not None: warn( "You are using deprecated argument `translate_corpus` in v0.7 which was renamed to `preds`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - preds = translate_corpus + warn("If you specify both `preds` and `translate_corpus`, only `preds` is considered.") + preds = preds or translate_corpus if reference_corpus is not None: warn( "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - target = reference_corpus + warn("If you specify both `target` and `reference_corpus`, only `target` is considered.") + target = target or reference_corpus self.preds_len, self.target_len = _bleu_score_update( preds, diff --git a/torchmetrics/text/chrf.py b/torchmetrics/text/chrf.py index c87de243524..67bbfa3eac2 100644 --- a/torchmetrics/text/chrf.py +++ b/torchmetrics/text/chrf.py @@ -146,8 +146,8 @@ def __init__( def update( # type: ignore self, - preds: Sequence[str], - target: Sequence[Sequence[str]], + preds: Union[None, Sequence[str]] = None, + target: Union[None, Sequence[Sequence[str]]] = None, hypothesis_corpus: Union[None, Sequence[str]] = None, reference_corpus: Union[None, Sequence[Sequence[str]]] = None, ) -> None: @@ -165,20 +165,27 @@ def update( # type: ignore An iterable of iterables of reference corpus. This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. """ + if preds is None and hypothesis_corpus is None: + raise ValueError("Either `preds` or `hypothesis_corpus` must be provided.") + if target is None and reference_corpus is None: + raise ValueError("Either `target` or `reference_corpus` must be provided.") + if hypothesis_corpus is not None: warn( "You are using deprecated argument `hypothesis_corpus` in v0.7 which was renamed to `preds`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - preds = hypothesis_corpus + warn("If you specify both `preds` and `hypothesis_corpus`, only `preds` is considered.") + preds = preds or hypothesis_corpus if reference_corpus is not None: warn( "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - target = reference_corpus + warn("If you specify both `target` and `reference_corpus`, only `target` is considered.") + target = target or reference_corpus n_grams_dicts_tuple = _chrf_score_update( preds, diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index 77ecb227d36..42124d098c7 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -117,8 +117,8 @@ def __init__( def update( # type: ignore self, - preds: Sequence[str], - target: Sequence[Sequence[str]], + preds: Union[None, Sequence[str]] = None, + target: Union[None, Sequence[Sequence[str]]] = None, translate_corpus: Union[None, Sequence[str]] = None, reference_corpus: Union[None, Sequence[Sequence[str]]] = None, ) -> None: @@ -134,20 +134,27 @@ def update( # type: ignore An iterable of iterables of reference corpus This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. """ + if preds is None and translate_corpus is None: + raise ValueError("Either `preds` or `translate_corpus` must be provided.") + if target is None and reference_corpus is None: + raise ValueError("Either `target` or `reference_corpus` must be provided.") + if translate_corpus is not None: warn( "You are using deprecated argument `translate_corpus` in v0.7 which was renamed to `preds`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - preds = translate_corpus + warn("If you specify both `preds` and `translate_corpus`, only `preds` is considered.") + preds = preds or translate_corpus if reference_corpus is not None: warn( "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - target = reference_corpus + warn("If you specify both `target` and `reference_corpus`, only `target` is considered.") + target = target or reference_corpus self.preds_len, self.target_len = _bleu_score_update( preds, diff --git a/torchmetrics/text/wil.py b/torchmetrics/text/wil.py index 7ee70cb9efd..25e91075f95 100644 --- a/torchmetrics/text/wil.py +++ b/torchmetrics/text/wil.py @@ -86,8 +86,8 @@ def __init__( def update( # type: ignore self, - preds: Union[str, List[str]], - target: Union[str, List[str]], + preds: Union[None, str, List[str]] = None, + target: Union[None, str, List[str]] = None, predictions: Union[None, str, List[str]] = None, references: Union[None, str, List[str]] = None, ) -> None: @@ -105,20 +105,27 @@ def update( # type: ignore Reference(s) for each speech input as a string or list of strings This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. """ + if preds is None and predictions is None: + raise ValueError("Either `preds` or `predictions` must be provided.") + if target is None and references is None: + raise ValueError("Either `target` or `references` must be provided.") + if predictions is not None: warn( "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - preds = predictions + warn("If you specify both `preds` and `predictions`, only `preds` is considered.") + preds = preds or predictions if references is not None: warn( "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - target = references + warn("If you specify both `target` and `references`, only `target` is considered.") + target = target or references errors, target_total, preds_total = _wil_update(preds, target) self.errors += errors diff --git a/torchmetrics/text/wip.py b/torchmetrics/text/wip.py index 4a645d48c16..e2289a50470 100644 --- a/torchmetrics/text/wip.py +++ b/torchmetrics/text/wip.py @@ -86,8 +86,8 @@ def __init__( def update( # type: ignore self, - preds: Union[str, List[str]], - target: Union[str, List[str]], + preds: Union[None, str, List[str]] = None, + target: Union[None, str, List[str]] = None, predictions: Union[None, str, List[str]] = None, references: Union[None, str, List[str]] = None, ) -> None: @@ -105,20 +105,27 @@ def update( # type: ignore Reference(s) for each speech input as a string or list of strings This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. """ + if preds is None and predictions is None: + raise ValueError("Either `preds` or `predictions` must be provided.") + if target is None and references is None: + raise ValueError("Either `target` or `references` must be provided.") + if predictions is not None: warn( "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - preds = predictions + warn("If you specify both `preds` and `predictions`, only `preds` is considered.") + preds = preds or predictions if references is not None: warn( "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " " The past argument will be removed in v0.8.", DeprecationWarning, ) - target = references + warn("If you specify both `target` and `references`, only `target` is considered.") + target = target or references errors, target_total, preds_total = _wip_update(preds, target) self.errors += errors From 2ebc20f40f95d2f6f91e17e73e29cd1cfec45c3f Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Mon, 10 Jan 2022 23:08:51 +0100 Subject: [PATCH 13/22] Add ignore statements for mypy --- torchmetrics/functional/text/bert.py | 2 +- torchmetrics/functional/text/bleu.py | 8 ++++---- torchmetrics/functional/text/chrf.py | 4 ++-- torchmetrics/functional/text/sacre_bleu.py | 8 ++++---- torchmetrics/functional/text/wil.py | 2 +- torchmetrics/functional/text/wip.py | 2 +- torchmetrics/text/bert.py | 4 ++-- torchmetrics/text/bleu.py | 4 ++-- torchmetrics/text/chrf.py | 4 ++-- torchmetrics/text/sacre_bleu.py | 4 ++-- torchmetrics/text/wil.py | 2 +- torchmetrics/text/wip.py | 2 +- 12 files changed, 23 insertions(+), 23 deletions(-) diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index eb94ed66151..293af8509a6 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -588,7 +588,7 @@ def bert_score( warn("If you specify both `target` and `references`, only `target` is considered.") target = target or references - if len(preds) != len(target): + if len(preds) != len(target): # type: ignore raise ValueError("Number of predicted and reference sententes must be the same!") if verbose and (not _TQDM_AVAILABLE): diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index e3168c0a73e..269be28f6dd 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -213,10 +213,10 @@ def bleu_score( target = target or reference_corpus preds_ = [preds] if isinstance(preds, str) else preds - target_ = [[tgt] if isinstance(tgt, str) else tgt for tgt in target] + target_ = [[tgt] if isinstance(tgt, str) else tgt for tgt in target] # type: ignore - if len(preds_) != len(target_): - raise ValueError(f"Corpus has different size {len(preds_)} != {len(target_)}") + if len(preds_) != len(target_): # type: ignore + raise ValueError(f"Corpus has different size {len(preds_)} != {len(target_)}") # type: ignore numerator = torch.zeros(n_gram) denominator = torch.zeros(n_gram) @@ -224,7 +224,7 @@ def bleu_score( target_len = tensor(0.0) preds_len, target_len = _bleu_score_update( - preds_, target_, numerator, denominator, preds_len, target_len, n_gram, _tokenize_fn + preds_, target_, numerator, denominator, preds_len, target_len, n_gram, _tokenize_fn # type: ignore ) return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, smooth) diff --git a/torchmetrics/functional/text/chrf.py b/torchmetrics/functional/text/chrf.py index e8a845199d4..86122d6596b 100644 --- a/torchmetrics/functional/text/chrf.py +++ b/torchmetrics/functional/text/chrf.py @@ -702,8 +702,8 @@ def chrf_score( total_matching_word_n_grams, sentence_chrf_score, ) = _chrf_score_update( - preds, - target, + preds, # type: ignore + target, # type: ignore total_preds_char_n_grams, total_preds_word_n_grams, total_target_char_n_grams, diff --git a/torchmetrics/functional/text/sacre_bleu.py b/torchmetrics/functional/text/sacre_bleu.py index 73c310fbc60..8b12878035c 100644 --- a/torchmetrics/functional/text/sacre_bleu.py +++ b/torchmetrics/functional/text/sacre_bleu.py @@ -363,8 +363,8 @@ def sacre_bleu_score( raise ValueError( f"Unsupported tokenizer selected. Please, choose one of {list(_SacreBLEUTokenizer._TOKENIZE_FN.keys())}" ) - if len(preds) != len(target): - raise ValueError(f"Corpus has different size {len(preds)} != {len(target)}") + if len(preds) != len(target): # type: ignore + raise ValueError(f"Corpus has different size {len(preds)} != {len(target)}") # type: ignore if tokenize == "intl" and not _REGEX_AVAILABLE: raise ModuleNotFoundError( "`'intl'` tokenization requires that `regex` is installed." @@ -378,8 +378,8 @@ def sacre_bleu_score( tokenize_fn = partial(_SacreBLEUTokenizer.tokenize, tokenize=tokenize, lowercase=lowercase) preds_len, target_len = _bleu_score_update( - preds, - target, + preds, # type: ignore + target, # type: ignore numerator, denominator, preds_len, diff --git a/torchmetrics/functional/text/wil.py b/torchmetrics/functional/text/wil.py index 1c229e50075..62d1cd3eacd 100644 --- a/torchmetrics/functional/text/wil.py +++ b/torchmetrics/functional/text/wil.py @@ -120,5 +120,5 @@ def word_information_lost( warn("If you specify both `target` and `references`, only `target` is considered.") target = target or references - errors, target_total, preds_total = _wil_update(preds, target) + errors, target_total, preds_total = _wil_update(preds, target) # type: ignore return _wil_compute(errors, target_total, preds_total) diff --git a/torchmetrics/functional/text/wip.py b/torchmetrics/functional/text/wip.py index 90cc8fb9581..03269164e8e 100644 --- a/torchmetrics/functional/text/wip.py +++ b/torchmetrics/functional/text/wip.py @@ -120,5 +120,5 @@ def word_information_preserved( warn("If you specify both `target` and `references`, only `target` is considered.") target = target or references - errors, reference_total, prediction_total = _wip_update(preds, target) + errors, reference_total, prediction_total = _wip_update(preds, target) # type: ignore return _wip_compute(errors, reference_total, prediction_total) diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index 98a408e42ce..32a0152b43e 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -236,7 +236,7 @@ def update( # type: ignore target = target or references preds_dict = _preprocess_text( - preds, + preds, # type: ignore self.tokenizer, self.max_length, truncation=False, @@ -244,7 +244,7 @@ def update( # type: ignore own_tokenizer=self.user_tokenizer, ) target_dict = _preprocess_text( - target, + target, # type: ignore self.tokenizer, self.max_length, truncation=False, diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index 89395f253b3..73293298234 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -137,8 +137,8 @@ def update( # type: ignore target = target or reference_corpus self.preds_len, self.target_len = _bleu_score_update( - preds, - target, + preds, # type: ignore + target, # type: ignore self.numerator, self.denominator, self.preds_len, diff --git a/torchmetrics/text/chrf.py b/torchmetrics/text/chrf.py index 67bbfa3eac2..4982cacfa8e 100644 --- a/torchmetrics/text/chrf.py +++ b/torchmetrics/text/chrf.py @@ -188,8 +188,8 @@ def update( # type: ignore target = target or reference_corpus n_grams_dicts_tuple = _chrf_score_update( - preds, - target, + preds, # type: ignore + target, # type: ignore *self._convert_states_to_dicts(), self.n_char_order, self.n_word_order, diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index 42124d098c7..ed007dd930c 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -157,8 +157,8 @@ def update( # type: ignore target = target or reference_corpus self.preds_len, self.target_len = _bleu_score_update( - preds, - target, + preds, # type: ignore + target, # type: ignore self.numerator, self.denominator, self.preds_len, diff --git a/torchmetrics/text/wil.py b/torchmetrics/text/wil.py index 25e91075f95..e739891a96c 100644 --- a/torchmetrics/text/wil.py +++ b/torchmetrics/text/wil.py @@ -127,7 +127,7 @@ def update( # type: ignore warn("If you specify both `target` and `references`, only `target` is considered.") target = target or references - errors, target_total, preds_total = _wil_update(preds, target) + errors, target_total, preds_total = _wil_update(preds, target) # type: ignore self.errors += errors self.target_total += target_total self.preds_total += preds_total diff --git a/torchmetrics/text/wip.py b/torchmetrics/text/wip.py index e2289a50470..61481ba17ac 100644 --- a/torchmetrics/text/wip.py +++ b/torchmetrics/text/wip.py @@ -127,7 +127,7 @@ def update( # type: ignore warn("If you specify both `target` and `references`, only `target` is considered.") target = target or references - errors, target_total, preds_total = _wip_update(preds, target) + errors, target_total, preds_total = _wip_update(preds, target) # type: ignore self.errors += errors self.target_total += target_total self.preds_total += preds_total From 26f7fd02c6c7059efd8b09f4c9dbfd4e4aacf02a Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Tue, 11 Jan 2022 23:10:54 +0100 Subject: [PATCH 14/22] Use deprecate package --- torchmetrics/functional/text/bert.py | 44 +++++++--------------- torchmetrics/functional/text/bleu.py | 40 +++++++------------- torchmetrics/functional/text/chrf.py | 43 +++++++-------------- torchmetrics/functional/text/sacre_bleu.py | 41 +++++++------------- torchmetrics/functional/text/wil.py | 42 +++++++-------------- torchmetrics/functional/text/wip.py | 43 +++++++-------------- torchmetrics/text/bert.py | 41 +++++++------------- torchmetrics/text/bleu.py | 41 +++++++------------- torchmetrics/text/chrf.py | 44 +++++++--------------- torchmetrics/text/sacre_bleu.py | 41 +++++++------------- torchmetrics/text/wil.py | 40 +++++++------------- torchmetrics/text/wip.py | 42 +++++++-------------- 12 files changed, 157 insertions(+), 345 deletions(-) diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index 293af8509a6..86a9410ede2 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import csv +import logging import math import urllib import warnings from collections import Counter, defaultdict from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union -from warnings import warn import torch +from deprecate import deprecated from torch import Tensor from torch.utils.data import DataLoader, Dataset @@ -450,6 +451,13 @@ def _rescale_metrics_with_baseline( return all_metrics[..., 0], all_metrics[..., 1], all_metrics[..., 2] +@deprecated( + args_mapping={"predictions": "preds", "references": "target"}, + target=True, + stream=logging.warning, + deprecated_in="0.7", + remove_in="0.8", +) def bert_score( preds: Union[None, List[str], Dict[str, Tensor]] = None, target: Union[None, List[str], Dict[str, Tensor]] = None, @@ -470,8 +478,6 @@ def bert_score( rescale_with_baseline: bool = False, baseline_path: Optional[str] = None, baseline_url: Optional[str] = None, - predictions: Union[None, List[str], Dict[str, Tensor]] = None, - references: Union[None, List[str], Dict[str, Tensor]] = None, ) -> Dict[str, Union[List[float], str]]: """`Bert_score Evaluating Text Generation`_ leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate with @@ -534,13 +540,11 @@ def bert_score( baseline_url: A url path to the user's own csv/tsv file with the baseline scale. predictions: - Either an iterable of predicted sentences or a `Dict[str, torch.Tensor]` containing `input_ids` and - `attention_mask` `torch.Tensor`. - This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. references: - Either an iterable of target sentences or a `Dict[str, torch.Tensor]` containing `input_ids` and - `attention_mask` `torch.Tensor`. - This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. Returns: Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values. @@ -566,28 +570,6 @@ def bert_score( 'recall': [0.99..., 0.99...], 'f1': [0.99..., 0.99...]} """ - if preds is None and predictions is None: - raise ValueError("Either `preds` or `predictions` must be provided.") - if target is None and references is None: - raise ValueError("Either `target` or `references` must be provided.") - - if predictions is not None: - warn( - "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `preds` and `predictions`, only `preds` is considered.") - preds = preds or predictions - if references is not None: - warn( - "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `target` and `references`, only `target` is considered.") - target = target or references - if len(preds) != len(target): # type: ignore raise ValueError("Number of predicted and reference sententes must be the same!") diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index 269be28f6dd..b143808bd2e 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -16,11 +16,13 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score +import logging from collections import Counter from typing import Callable, Sequence, Tuple, Union from warnings import warn import torch +from deprecate import deprecated from torch import Tensor, tensor @@ -143,13 +145,18 @@ def _bleu_score_compute( return bleu +@deprecated( + args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, + target=True, + stream=logging.warning, + deprecated_in="0.7", + remove_in="0.8", +) def bleu_score( preds: Union[None, str, Sequence[str]] = None, target: Union[None, Sequence[Union[str, Sequence[str]]]] = None, n_gram: int = 4, smooth: bool = False, - translate_corpus: Union[None, str, Sequence[str]] = None, - reference_corpus: Union[None, Sequence[Union[str, Sequence[str]]]] = None, ) -> Tensor: """Calculate `BLEU score`_ of machine translated text with one or more references. @@ -163,11 +170,11 @@ def bleu_score( smooth: Whether or not to apply smoothing – see [2] translate_corpus: - An iterable of machine translated corpus - This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. reference_corpus: - An iterable of iterables of reference corpus - This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. Return: Tensor with BLEU Score @@ -190,27 +197,6 @@ def bleu_score( "Input order of targets and preds were changed to predictions firsts and targets second in v0.7." " Warning will be removed in v0.8." ) - if preds is None and translate_corpus is None: - raise ValueError("Either `preds` or `translate_corpus` must be provided.") - if target is None and reference_corpus is None: - raise ValueError("Either `target` or `reference_corpus` must be provided.") - - if translate_corpus is not None: - warn( - "You are using deprecated argument `translate_corpus` in v0.7 which was renamed to `preds`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `preds` and `translate_corpus`, only `preds` is considered.") - preds = preds or translate_corpus - if reference_corpus is not None: - warn( - "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `target` and `reference_corpus`, only `target` is considered.") - target = target or reference_corpus preds_ = [preds] if isinstance(preds, str) else preds target_ = [[tgt] if isinstance(tgt, str) else tgt for tgt in target] # type: ignore diff --git a/torchmetrics/functional/text/chrf.py b/torchmetrics/functional/text/chrf.py index 86122d6596b..4dbe78d7ed8 100644 --- a/torchmetrics/functional/text/chrf.py +++ b/torchmetrics/functional/text/chrf.py @@ -32,12 +32,12 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . - +import logging from collections import defaultdict from typing import Dict, List, Optional, Sequence, Tuple, Union -from warnings import warn import torch +from deprecate import deprecated from torch import Tensor, tensor from torchmetrics.functional.text.helper import _validate_inputs @@ -586,6 +586,13 @@ def _chrf_score_compute( return chrf_f_score +@deprecated( + args_mapping={"hypothesis_corpus": "preds", "reference_corpus": "target"}, + target=True, + stream=logging.warning, + deprecated_in="0.7", + remove_in="0.8", +) def chrf_score( preds: Union[None, str, Sequence[str]] = None, target: Union[None, Sequence[Union[str, Sequence[str]]]] = None, @@ -595,8 +602,6 @@ def chrf_score( lowercase: bool = False, whitespace: bool = False, return_sentence_level_score: bool = False, - hypothesis_corpus: Union[None, str, Sequence[str]] = None, - reference_corpus: Union[None, Sequence[Union[str, Sequence[str]]]] = None, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Calculate `chrF score`_ of machine translated text with one or more references. This implementation supports both chrF score computation introduced in [1] and chrF++ score introduced in `chrF++ score`_. This @@ -622,11 +627,11 @@ def chrf_score( return_sentence_level_score: An indication whether a sentence-level chrF/chrF++ score to be returned. hypothesis_corpus: - An iterable of hypothesis corpus. - This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. reference_corpus: - An iterable of iterables of reference corpus. - This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. Return: A corpus-level chrF/chrF++ score. @@ -651,28 +656,6 @@ def chrf_score( [1] chrF: character n-gram F-score for automatic MT evaluation by Maja Popović `chrF score`_ [2] chrF++: words helping character n-grams by Maja Popović `chrF++ score`_ """ - if preds is None and hypothesis_corpus is None: - raise ValueError("Either `preds` or `hypothesis_corpus` must be provided.") - if target is None and reference_corpus is None: - raise ValueError("Either `target` or `reference_corpus` must be provided.") - - if hypothesis_corpus is not None: - warn( - "You are using deprecated argument `hypothesis_corpus` in v0.7 which was renamed to `preds`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `preds` and `hypothesis_corpus`, only `preds` is considered.") - preds = preds or hypothesis_corpus - if reference_corpus is not None: - warn( - "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `target` and `reference_corpus`, only `target` is considered.") - target = target or reference_corpus - if not isinstance(n_char_order, int) or n_char_order < 1: raise ValueError("Expected argument `n_char_order` to be an integer greater than or equal to 1.") if not isinstance(n_word_order, int) or n_word_order < 0: diff --git a/torchmetrics/functional/text/sacre_bleu.py b/torchmetrics/functional/text/sacre_bleu.py index 8b12878035c..963ed7dfeea 100644 --- a/torchmetrics/functional/text/sacre_bleu.py +++ b/torchmetrics/functional/text/sacre_bleu.py @@ -37,13 +37,14 @@ # MIT License # Copyright (c) 2017 - Shujian Huang - +import logging import re from functools import partial from typing import Sequence, Union from warnings import warn import torch +from deprecate import deprecated from torch import Tensor, tensor from typing_extensions import Literal @@ -277,6 +278,13 @@ def _lower(line: str, lowercase: bool) -> str: return line +@deprecated( + args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, + target=True, + stream=logging.warning, + deprecated_in="0.7", + remove_in="0.8", +) def sacre_bleu_score( preds: Union[None, Sequence[str]] = None, target: Union[None, Sequence[Sequence[str]]] = None, @@ -284,8 +292,6 @@ def sacre_bleu_score( smooth: bool = False, tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a", lowercase: bool = False, - translate_corpus: Union[None, Sequence[str]] = None, - reference_corpus: Union[None, Sequence[Sequence[str]]] = None, ) -> Tensor: """Calculate `BLEU score`_ [1] of machine translated text with one or more references. This implementation follows the behaviour of SacreBLEU [2] implementation from https://github.com/mjpost/sacrebleu. @@ -305,11 +311,11 @@ def sacre_bleu_score( lowercase: If ``True``, BLEU score over lowercased text is calculated. translate_corpus: - An iterable of machine translated corpus - This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. reference_corpus: - An iterable of iterables of reference corpus - This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. Return: Tensor with BLEU Score @@ -334,27 +340,6 @@ def sacre_bleu_score( "Input order of targets and preds were changed to predictions firsts and targets second in v0.7." " Warning will be removed in v0.8." ) - if preds is None and translate_corpus is None: - raise ValueError("Either `preds` or `translate_corpus` must be provided.") - if target is None and reference_corpus is None: - raise ValueError("Either `target` or `reference_corpus` must be provided.") - - if translate_corpus is not None: - warn( - "You are using deprecated argument `translate_corpus` in v0.7 which was renamed to `preds`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `preds` and `translate_corpus`, only `preds` is considered.") - preds = preds or translate_corpus - if reference_corpus is not None: - warn( - "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `target` and `reference_corpus`, only `target` is considered.") - target = target or reference_corpus if tokenize not in AVAILABLE_TOKENIZERS: raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.") diff --git a/torchmetrics/functional/text/wil.py b/torchmetrics/functional/text/wil.py index 62d1cd3eacd..a7f0ac7e7f9 100644 --- a/torchmetrics/functional/text/wil.py +++ b/torchmetrics/functional/text/wil.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import List, Tuple, Union -from warnings import warn +from deprecate import deprecated from torch import Tensor, tensor from torchmetrics.functional.text.helper import _edit_distance @@ -66,11 +67,16 @@ def _wil_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> T return 1 - ((errors / target_total) * (errors / preds_total)) +@deprecated( + args_mapping={"predictions": "preds", "references": "target"}, + target=True, + stream=logging.warning, + deprecated_in="0.7", + remove_in="0.8", +) def word_information_lost( preds: Union[None, str, List[str]] = None, target: Union[None, str, List[str]] = None, - predictions: Union[None, str, List[str]] = None, - references: Union[None, str, List[str]] = None, ) -> Tensor: """Word Information Lost rate is a metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better @@ -82,11 +88,11 @@ def word_information_lost( target: Reference(s) for each speech input as a string or list of strings predictions: - Transcription(s) to score as a string or list of strings - This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. references: - Reference(s) for each speech input as a string or list of strings - This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. Returns: Word Information Lost rate @@ -98,27 +104,5 @@ def word_information_lost( >>> word_information_lost(preds, target) tensor(0.6528) """ - if preds is None and predictions is None: - raise ValueError("Either `preds` or `predictions` must be provided.") - if target is None and references is None: - raise ValueError("Either `target` or `references` must be provided.") - - if predictions is not None: - warn( - "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `preds` and `predictions`, only `preds` is considered.") - preds = preds or predictions - if references is not None: - warn( - "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `target` and `references`, only `target` is considered.") - target = target or references - errors, target_total, preds_total = _wil_update(preds, target) # type: ignore return _wil_compute(errors, target_total, preds_total) diff --git a/torchmetrics/functional/text/wip.py b/torchmetrics/functional/text/wip.py index 03269164e8e..bd39eaae02b 100644 --- a/torchmetrics/functional/text/wip.py +++ b/torchmetrics/functional/text/wip.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import logging from typing import List, Tuple, Union -from warnings import warn +from deprecate import deprecated from torch import Tensor, tensor from torchmetrics.functional.text.helper import _edit_distance @@ -66,11 +66,16 @@ def _wip_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> T return (errors / target_total) * (errors / preds_total) +@deprecated( + args_mapping={"predictions": "preds", "references": "target"}, + target=True, + stream=logging.warning, + deprecated_in="0.7", + remove_in="0.8", +) def word_information_preserved( preds: Union[None, str, List[str]] = None, target: Union[None, str, List[str]] = None, - predictions: Union[None, str, List[str]] = None, - references: Union[None, str, List[str]] = None, ) -> Tensor: """Word Information Preserved rate is a metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the @@ -82,11 +87,11 @@ def word_information_preserved( total: Reference(s) for each speech input as a string or list of strings predictions: - Transcription(s) to score as a string or list of strings - This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. references: - Reference(s) for each speech input as a string or list of strings - This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. Returns: Word Information preserved rate @@ -98,27 +103,5 @@ def word_information_preserved( >>> word_information_preserved(preds, target) tensor(0.3472) """ - if preds is None and predictions is None: - raise ValueError("Either `preds` or `predictions` must be provided.") - if target is None and references is None: - raise ValueError("Either `target` or `references` must be provided.") - - if predictions is not None: - warn( - "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `preds` and `predictions`, only `preds` is considered.") - preds = preds or predictions - if references is not None: - warn( - "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `target` and `references`, only `target` is considered.") - target = target or references - errors, reference_total, prediction_total = _wip_update(preds, target) # type: ignore return _wip_compute(errors, reference_total, prediction_total) diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index 32a0152b43e..2741921ae46 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import Any, Callable, Dict, List, Optional, Union from warnings import warn import torch +from deprecate import deprecated from torchmetrics.functional import bert_score from torchmetrics.functional.text.bert import _preprocess_text @@ -191,12 +193,17 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.user_tokenizer = False + @deprecated( + args_mapping={"predictions": "preds", "references": "target"}, + target=True, + stream=logging.warning, + deprecated_in="0.7", + remove_in="0.8", + ) def update( # type: ignore self, preds: Union[None, List[str]] = None, target: Union[None, List[str]] = None, - predictions: Union[None, List[str]] = None, - references: Union[None, List[str]] = None, ) -> None: """Store predictions/references for computing BERT scores. It is necessary to store sentences in a tokenized form to ensure the DDP mode working. @@ -207,34 +214,12 @@ def update( # type: ignore target: An iterable of reference sentences. predictions: - An iterable of predicted sentences. - This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. references: - An iterable of reference sentences. - This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. """ - if preds is None and predictions is None: - raise ValueError("Either `preds` or `predictions` must be provided.") - if target is None and references is None: - raise ValueError("Either `target` or `references` must be provided.") - - if predictions is not None: - warn( - "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `preds` and `predictions`, only `preds` is considered.") - preds = preds or predictions - if references is not None: - warn( - "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `target` and `references`, only `target` is considered.") - target = target or references - preds_dict = _preprocess_text( preds, # type: ignore self.tokenizer, diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index 73293298234..381ebd44df6 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -16,10 +16,12 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score +import logging from typing import Any, Callable, Optional, Sequence, Union from warnings import warn import torch +from deprecate import deprecated from torch import Tensor, tensor from torchmetrics import Metric @@ -95,12 +97,17 @@ def __init__( self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum") self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum") + @deprecated( + args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, + target=True, + stream=logging.warning, + deprecated_in="0.7", + remove_in="0.8", + ) def update( # type: ignore self, preds: Union[None, Sequence[str]] = None, target: Union[None, Sequence[Sequence[str]]] = None, - translate_corpus: Union[None, Sequence[str]] = None, - reference_corpus: Union[None, Sequence[Sequence[str]]] = None, ) -> None: """Compute Precision Scores. @@ -108,34 +115,12 @@ def update( # type: ignore preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus translate_corpus: - An iterable of machine translated corpus - This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. reference_corpus: - An iterable of iterables of reference corpus - This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. """ - if preds is None and translate_corpus is None: - raise ValueError("Either `preds` or `translate_corpus` must be provided.") - if target is None and reference_corpus is None: - raise ValueError("Either `target` or `reference_corpus` must be provided.") - - if translate_corpus is not None: - warn( - "You are using deprecated argument `translate_corpus` in v0.7 which was renamed to `preds`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `preds` and `translate_corpus`, only `preds` is considered.") - preds = preds or translate_corpus - if reference_corpus is not None: - warn( - "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `target` and `reference_corpus`, only `target` is considered.") - target = target or reference_corpus - self.preds_len, self.target_len = _bleu_score_update( preds, # type: ignore target, # type: ignore diff --git a/torchmetrics/text/chrf.py b/torchmetrics/text/chrf.py index 4982cacfa8e..4d44dbbba3e 100644 --- a/torchmetrics/text/chrf.py +++ b/torchmetrics/text/chrf.py @@ -18,10 +18,11 @@ # Link: import itertools +import logging from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union -from warnings import warn import torch +from deprecate import deprecated from torch import Tensor, tensor from torchmetrics import Metric @@ -144,12 +145,17 @@ def __init__( if self.return_sentence_level_score: self.add_state("sentence_chrf_score", [], dist_reduce_fx="cat") + @deprecated( + args_mapping={"hypothesis_corpus": "preds", "reference_corpus": "target"}, + target=True, + stream=logging.warning, + deprecated_in="0.7", + remove_in="0.8", + ) def update( # type: ignore self, preds: Union[None, Sequence[str]] = None, target: Union[None, Sequence[Sequence[str]]] = None, - hypothesis_corpus: Union[None, Sequence[str]] = None, - reference_corpus: Union[None, Sequence[Sequence[str]]] = None, ) -> None: """Compute Precision Scores. @@ -158,35 +164,13 @@ def update( # type: ignore An iterable of hypothesis corpus. target: An iterable of iterables of reference corpus. - hypotshesis_corpus: - An iterable of hypothesis corpus. - This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + hypothesis_corpus: + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. reference_corpus: - An iterable of iterables of reference corpus. - This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. """ - if preds is None and hypothesis_corpus is None: - raise ValueError("Either `preds` or `hypothesis_corpus` must be provided.") - if target is None and reference_corpus is None: - raise ValueError("Either `target` or `reference_corpus` must be provided.") - - if hypothesis_corpus is not None: - warn( - "You are using deprecated argument `hypothesis_corpus` in v0.7 which was renamed to `preds`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `preds` and `hypothesis_corpus`, only `preds` is considered.") - preds = preds or hypothesis_corpus - if reference_corpus is not None: - warn( - "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `target` and `reference_corpus`, only `target` is considered.") - target = target or reference_corpus - n_grams_dicts_tuple = _chrf_score_update( preds, # type: ignore target, # type: ignore diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index ed007dd930c..171d31e61d8 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -17,9 +17,11 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score +import logging from typing import Any, Callable, Optional, Sequence, Union from warnings import warn +from deprecate import deprecated from typing_extensions import Literal from torchmetrics.functional.text.bleu import _bleu_score_update @@ -115,12 +117,17 @@ def __init__( ) self.tokenizer = _SacreBLEUTokenizer(tokenize, lowercase) + @deprecated( + args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, + target=True, + stream=logging.warning, + deprecated_in="0.7", + remove_in="0.8", + ) def update( # type: ignore self, preds: Union[None, Sequence[str]] = None, target: Union[None, Sequence[Sequence[str]]] = None, - translate_corpus: Union[None, Sequence[str]] = None, - reference_corpus: Union[None, Sequence[Sequence[str]]] = None, ) -> None: """Compute Precision Scores. @@ -128,34 +135,12 @@ def update( # type: ignore preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus translate_corpus: - An iterable of machine translated corpus - This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. reference_corpus: - An iterable of iterables of reference corpus - This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. """ - if preds is None and translate_corpus is None: - raise ValueError("Either `preds` or `translate_corpus` must be provided.") - if target is None and reference_corpus is None: - raise ValueError("Either `target` or `reference_corpus` must be provided.") - - if translate_corpus is not None: - warn( - "You are using deprecated argument `translate_corpus` in v0.7 which was renamed to `preds`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `preds` and `translate_corpus`, only `preds` is considered.") - preds = preds or translate_corpus - if reference_corpus is not None: - warn( - "You are using deprecated argument `reference_corpus` in v0.7 which was renamed to `target`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `target` and `reference_corpus`, only `target` is considered.") - target = target or reference_corpus - self.preds_len, self.target_len = _bleu_score_update( preds, # type: ignore target, # type: ignore diff --git a/torchmetrics/text/wil.py b/torchmetrics/text/wil.py index e739891a96c..801b5b6d3d9 100644 --- a/torchmetrics/text/wil.py +++ b/torchmetrics/text/wil.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import Any, Callable, List, Optional, Union -from warnings import warn +from deprecate import deprecated from torch import Tensor, tensor from torchmetrics.functional.text.wil import _wil_compute, _wil_update @@ -84,6 +85,13 @@ def __init__( self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") + @deprecated( + args_mapping={"predictions": "preds", "references": "target"}, + target=True, + stream=logging.warning, + deprecated_in="0.7", + remove_in="0.8", + ) def update( # type: ignore self, preds: Union[None, str, List[str]] = None, @@ -99,34 +107,12 @@ def update( # type: ignore target: Reference(s) for each speech input as a string or list of strings predictions: - Transcription(s) to score as a string or list of strings - This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. references: - Reference(s) for each speech input as a string or list of strings - This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. """ - if preds is None and predictions is None: - raise ValueError("Either `preds` or `predictions` must be provided.") - if target is None and references is None: - raise ValueError("Either `target` or `references` must be provided.") - - if predictions is not None: - warn( - "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `preds` and `predictions`, only `preds` is considered.") - preds = preds or predictions - if references is not None: - warn( - "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `target` and `references`, only `target` is considered.") - target = target or references - errors, target_total, preds_total = _wil_update(preds, target) # type: ignore self.errors += errors self.target_total += target_total diff --git a/torchmetrics/text/wip.py b/torchmetrics/text/wip.py index 61481ba17ac..a46411028d3 100644 --- a/torchmetrics/text/wip.py +++ b/torchmetrics/text/wip.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import Any, Callable, List, Optional, Union -from warnings import warn +from deprecate import deprecated from torch import Tensor, tensor from torchmetrics.functional.text.wip import _wip_compute, _wip_update @@ -84,12 +85,17 @@ def __init__( self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") + @deprecated( + args_mapping={"predictions": "preds", "references": "target"}, + target=True, + stream=logging.warning, + deprecated_in="0.7", + remove_in="0.8", + ) def update( # type: ignore self, preds: Union[None, str, List[str]] = None, target: Union[None, str, List[str]] = None, - predictions: Union[None, str, List[str]] = None, - references: Union[None, str, List[str]] = None, ) -> None: """Store predictions/references for computing word Information Preserved scores. @@ -99,34 +105,12 @@ def update( # type: ignore target: Reference(s) for each speech input as a string or list of strings predictions: - Transcription(s) to score as a string or list of strings - This argument is deprecated in v0.7 and will be removed in v0.8. Use `preds` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. references: - Reference(s) for each speech input as a string or list of strings - This argument is deprecated in v0.7 and will be removed in v0.8. Use `target` instead. + .. deprecated:: v0.7 + This argument is deprecated in favor of `preds` and will be removed in v0.8. """ - if preds is None and predictions is None: - raise ValueError("Either `preds` or `predictions` must be provided.") - if target is None and references is None: - raise ValueError("Either `target` or `references` must be provided.") - - if predictions is not None: - warn( - "You are using deprecated argument `predictions` in v0.7 which was renamed to `preds`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `preds` and `predictions`, only `preds` is considered.") - preds = preds or predictions - if references is not None: - warn( - "You are using deprecated argument `references` in v0.7 which was renamed to `target`. " - " The past argument will be removed in v0.8.", - DeprecationWarning, - ) - warn("If you specify both `target` and `references`, only `target` is considered.") - target = target or references - errors, target_total, preds_total = _wip_update(preds, target) # type: ignore self.errors += errors self.target_total += target_total From 93cceb350d8603230fc1ee650739997199aa7f3b Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Tue, 11 Jan 2022 23:19:04 +0100 Subject: [PATCH 15/22] Add pyDeprecate==0.3.* to the requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index b47541f6e94..5d7a0bc464d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ numpy>=1.17.2 torch>=1.3.1 +pyDeprecate==0.3.* packaging From c8a34637bc68c9062eca7760975220e05423d4d7 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 12 Jan 2022 15:10:06 +0100 Subject: [PATCH 16/22] Apply suggestions from code review --- torchmetrics/functional/text/bert.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index 86a9410ede2..3264c96eaa9 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -540,11 +540,15 @@ def bert_score( baseline_url: A url path to the user's own csv/tsv file with the baseline scale. predictions: + .. deprecated:: v0.7 This argument is deprecated in favor of `preds` and will be removed in v0.8. + references: + .. deprecated:: v0.7 This argument is deprecated in favor of `preds` and will be removed in v0.8. + Returns: Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values. @@ -552,7 +556,7 @@ def bert_score( Raises: ValueError: If `len(preds) != len(target)`. - ValueError: + ModuleNotFoundError: If `tqdm` package is required and not installed. ModuleNotFoundError: If ``transformers`` package is required and not installed. From 3ecb965f98362fc85d9a4fd2b58b8c496550d1f4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jan 2022 14:11:43 +0000 Subject: [PATCH 17/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/text/bert.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index 3264c96eaa9..f63f5039330 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -540,15 +540,15 @@ def bert_score( baseline_url: A url path to the user's own csv/tsv file with the baseline scale. predictions: - + .. deprecated:: v0.7 This argument is deprecated in favor of `preds` and will be removed in v0.8. - + references: - + .. deprecated:: v0.7 This argument is deprecated in favor of `preds` and will be removed in v0.8. - + Returns: Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values. From e632c1d21bdadd57ba5a38c63ff80653a332a203 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Wed, 12 Jan 2022 17:16:26 +0100 Subject: [PATCH 18/22] Apply some suggestions from code review --- torchmetrics/functional/text/bert.py | 16 +++------------- torchmetrics/functional/text/bleu.py | 18 ++++++------------ torchmetrics/functional/text/chrf.py | 14 ++++---------- torchmetrics/functional/text/sacre_bleu.py | 20 +++++++------------- torchmetrics/functional/text/wil.py | 12 +++--------- torchmetrics/functional/text/wip.py | 12 +++--------- torchmetrics/text/bert.py | 14 ++++---------- torchmetrics/text/bleu.py | 16 +++++----------- torchmetrics/text/chrf.py | 14 ++++---------- torchmetrics/text/sacre_bleu.py | 16 +++++----------- torchmetrics/text/wil.py | 14 +++----------- torchmetrics/text/wip.py | 12 +++--------- 12 files changed, 50 insertions(+), 128 deletions(-) diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index f63f5039330..bcec4e26ee4 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -459,8 +459,8 @@ def _rescale_metrics_with_baseline( remove_in="0.8", ) def bert_score( - preds: Union[None, List[str], Dict[str, Tensor]] = None, - target: Union[None, List[str], Dict[str, Tensor]] = None, + preds: Union[List[str], Dict[str, Tensor]], + target: Union[List[str], Dict[str, Tensor]], model_name_or_path: Optional[str] = None, num_layers: Optional[int] = None, all_layers: bool = False, @@ -539,16 +539,6 @@ def bert_score( A path to the user's own local csv/tsv file with the baseline scale. baseline_url: A url path to the user's own csv/tsv file with the baseline scale. - predictions: - - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. - - references: - - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. - Returns: Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values. @@ -574,7 +564,7 @@ def bert_score( 'recall': [0.99..., 0.99...], 'f1': [0.99..., 0.99...]} """ - if len(preds) != len(target): # type: ignore + if len(preds) != len(target): raise ValueError("Number of predicted and reference sententes must be the same!") if verbose and (not _TQDM_AVAILABLE): diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index b143808bd2e..12ca015828b 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -153,8 +153,8 @@ def _bleu_score_compute( remove_in="0.8", ) def bleu_score( - preds: Union[None, str, Sequence[str]] = None, - target: Union[None, Sequence[Union[str, Sequence[str]]]] = None, + preds: Union[str, Sequence[str]], + target: Sequence[Union[str, Sequence[str]]], n_gram: int = 4, smooth: bool = False, ) -> Tensor: @@ -169,12 +169,6 @@ def bleu_score( Gram value ranged from 1 to 4 (Default 4) smooth: Whether or not to apply smoothing – see [2] - translate_corpus: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. - reference_corpus: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. Return: Tensor with BLEU Score @@ -199,10 +193,10 @@ def bleu_score( ) preds_ = [preds] if isinstance(preds, str) else preds - target_ = [[tgt] if isinstance(tgt, str) else tgt for tgt in target] # type: ignore + target_ = [[tgt] if isinstance(tgt, str) else tgt for tgt in target] - if len(preds_) != len(target_): # type: ignore - raise ValueError(f"Corpus has different size {len(preds_)} != {len(target_)}") # type: ignore + if len(preds_) != len(target_): + raise ValueError(f"Corpus has different size {len(preds_)} != {len(target_)}") numerator = torch.zeros(n_gram) denominator = torch.zeros(n_gram) @@ -210,7 +204,7 @@ def bleu_score( target_len = tensor(0.0) preds_len, target_len = _bleu_score_update( - preds_, target_, numerator, denominator, preds_len, target_len, n_gram, _tokenize_fn # type: ignore + preds_, target_, numerator, denominator, preds_len, target_len, n_gram, _tokenize_fn ) return _bleu_score_compute(preds_len, target_len, numerator, denominator, n_gram, smooth) diff --git a/torchmetrics/functional/text/chrf.py b/torchmetrics/functional/text/chrf.py index 4dbe78d7ed8..173b09be2a8 100644 --- a/torchmetrics/functional/text/chrf.py +++ b/torchmetrics/functional/text/chrf.py @@ -594,8 +594,8 @@ def _chrf_score_compute( remove_in="0.8", ) def chrf_score( - preds: Union[None, str, Sequence[str]] = None, - target: Union[None, Sequence[Union[str, Sequence[str]]]] = None, + preds: Union[str, Sequence[str]], + target: Sequence[Union[str, Sequence[str]]], n_char_order: int = 6, n_word_order: int = 2, beta: float = 2.0, @@ -626,12 +626,6 @@ def chrf_score( An indication whether to keep whitespaces during character n-gram extraction. return_sentence_level_score: An indication whether a sentence-level chrF/chrF++ score to be returned. - hypothesis_corpus: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. - reference_corpus: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. Return: A corpus-level chrF/chrF++ score. @@ -685,8 +679,8 @@ def chrf_score( total_matching_word_n_grams, sentence_chrf_score, ) = _chrf_score_update( - preds, # type: ignore - target, # type: ignore + preds, + target, total_preds_char_n_grams, total_preds_word_n_grams, total_target_char_n_grams, diff --git a/torchmetrics/functional/text/sacre_bleu.py b/torchmetrics/functional/text/sacre_bleu.py index 963ed7dfeea..bcf57cd702e 100644 --- a/torchmetrics/functional/text/sacre_bleu.py +++ b/torchmetrics/functional/text/sacre_bleu.py @@ -40,7 +40,7 @@ import logging import re from functools import partial -from typing import Sequence, Union +from typing import Sequence from warnings import warn import torch @@ -286,8 +286,8 @@ def _lower(line: str, lowercase: bool) -> str: remove_in="0.8", ) def sacre_bleu_score( - preds: Union[None, Sequence[str]] = None, - target: Union[None, Sequence[Sequence[str]]] = None, + preds: Sequence[str], + target: Sequence[Sequence[str]], n_gram: int = 4, smooth: bool = False, tokenize: Literal["none", "13a", "zh", "intl", "char"] = "13a", @@ -310,12 +310,6 @@ def sacre_bleu_score( Supported tokenization: ['none', '13a', 'zh', 'intl', 'char'] lowercase: If ``True``, BLEU score over lowercased text is calculated. - translate_corpus: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. - reference_corpus: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. Return: Tensor with BLEU Score @@ -348,8 +342,8 @@ def sacre_bleu_score( raise ValueError( f"Unsupported tokenizer selected. Please, choose one of {list(_SacreBLEUTokenizer._TOKENIZE_FN.keys())}" ) - if len(preds) != len(target): # type: ignore - raise ValueError(f"Corpus has different size {len(preds)} != {len(target)}") # type: ignore + if len(preds) != len(target): + raise ValueError(f"Corpus has different size {len(preds)} != {len(target)}") if tokenize == "intl" and not _REGEX_AVAILABLE: raise ModuleNotFoundError( "`'intl'` tokenization requires that `regex` is installed." @@ -363,8 +357,8 @@ def sacre_bleu_score( tokenize_fn = partial(_SacreBLEUTokenizer.tokenize, tokenize=tokenize, lowercase=lowercase) preds_len, target_len = _bleu_score_update( - preds, # type: ignore - target, # type: ignore + preds, + target, numerator, denominator, preds_len, diff --git a/torchmetrics/functional/text/wil.py b/torchmetrics/functional/text/wil.py index a7f0ac7e7f9..a88a2f64482 100644 --- a/torchmetrics/functional/text/wil.py +++ b/torchmetrics/functional/text/wil.py @@ -75,8 +75,8 @@ def _wil_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> T remove_in="0.8", ) def word_information_lost( - preds: Union[None, str, List[str]] = None, - target: Union[None, str, List[str]] = None, + preds: Union[str, List[str]], + target: Union[str, List[str]], ) -> Tensor: """Word Information Lost rate is a metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better @@ -87,12 +87,6 @@ def word_information_lost( Transcription(s) to score as a string or list of strings target: Reference(s) for each speech input as a string or list of strings - predictions: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. Returns: Word Information Lost rate @@ -104,5 +98,5 @@ def word_information_lost( >>> word_information_lost(preds, target) tensor(0.6528) """ - errors, target_total, preds_total = _wil_update(preds, target) # type: ignore + errors, target_total, preds_total = _wil_update(preds, target) return _wil_compute(errors, target_total, preds_total) diff --git a/torchmetrics/functional/text/wip.py b/torchmetrics/functional/text/wip.py index bd39eaae02b..eced3869950 100644 --- a/torchmetrics/functional/text/wip.py +++ b/torchmetrics/functional/text/wip.py @@ -74,8 +74,8 @@ def _wip_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> T remove_in="0.8", ) def word_information_preserved( - preds: Union[None, str, List[str]] = None, - target: Union[None, str, List[str]] = None, + preds: Union[str, List[str]], + target: Union[str, List[str]], ) -> Tensor: """Word Information Preserved rate is a metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the @@ -86,12 +86,6 @@ def word_information_preserved( Transcription(s) to score as a string or list of strings total: Reference(s) for each speech input as a string or list of strings - predictions: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. Returns: Word Information preserved rate @@ -103,5 +97,5 @@ def word_information_preserved( >>> word_information_preserved(preds, target) tensor(0.3472) """ - errors, reference_total, prediction_total = _wip_update(preds, target) # type: ignore + errors, reference_total, prediction_total = _wip_update(preds, target) return _wip_compute(errors, reference_total, prediction_total) diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index 2741921ae46..8add5d2aee9 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -202,8 +202,8 @@ def __init__( ) def update( # type: ignore self, - preds: Union[None, List[str]] = None, - target: Union[None, List[str]] = None, + preds: List[str], + target: List[str], ) -> None: """Store predictions/references for computing BERT scores. It is necessary to store sentences in a tokenized form to ensure the DDP mode working. @@ -213,15 +213,9 @@ def update( # type: ignore An iterable of predicted sentences. target: An iterable of reference sentences. - predictions: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. """ preds_dict = _preprocess_text( - preds, # type: ignore + preds, self.tokenizer, self.max_length, truncation=False, @@ -229,7 +223,7 @@ def update( # type: ignore own_tokenizer=self.user_tokenizer, ) target_dict = _preprocess_text( - target, # type: ignore + target, self.tokenizer, self.max_length, truncation=False, diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index 381ebd44df6..0fedbfc704a 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -17,7 +17,7 @@ # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score import logging -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence from warnings import warn import torch @@ -106,24 +106,18 @@ def __init__( ) def update( # type: ignore self, - preds: Union[None, Sequence[str]] = None, - target: Union[None, Sequence[Sequence[str]]] = None, + preds: Sequence[str], + target: Sequence[Sequence[str]], ) -> None: """Compute Precision Scores. Args: preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus - translate_corpus: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. - reference_corpus: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. """ self.preds_len, self.target_len = _bleu_score_update( - preds, # type: ignore - target, # type: ignore + preds, + target, self.numerator, self.denominator, self.preds_len, diff --git a/torchmetrics/text/chrf.py b/torchmetrics/text/chrf.py index 4d44dbbba3e..f978daf4c8e 100644 --- a/torchmetrics/text/chrf.py +++ b/torchmetrics/text/chrf.py @@ -154,8 +154,8 @@ def __init__( ) def update( # type: ignore self, - preds: Union[None, Sequence[str]] = None, - target: Union[None, Sequence[Sequence[str]]] = None, + preds: Sequence[str], + target: Sequence[Sequence[str]], ) -> None: """Compute Precision Scores. @@ -164,16 +164,10 @@ def update( # type: ignore An iterable of hypothesis corpus. target: An iterable of iterables of reference corpus. - hypothesis_corpus: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. - reference_corpus: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. """ n_grams_dicts_tuple = _chrf_score_update( - preds, # type: ignore - target, # type: ignore + preds, + target, *self._convert_states_to_dicts(), self.n_char_order, self.n_word_order, diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index 171d31e61d8..ceaa6d74b0e 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -18,7 +18,7 @@ # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score import logging -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence from warnings import warn from deprecate import deprecated @@ -126,24 +126,18 @@ def __init__( ) def update( # type: ignore self, - preds: Union[None, Sequence[str]] = None, - target: Union[None, Sequence[Sequence[str]]] = None, + preds: Sequence[str], + target: Sequence[Sequence[str]], ) -> None: """Compute Precision Scores. Args: preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus - translate_corpus: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. - reference_corpus: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. """ self.preds_len, self.target_len = _bleu_score_update( - preds, # type: ignore - target, # type: ignore + preds, + target, self.numerator, self.denominator, self.preds_len, diff --git a/torchmetrics/text/wil.py b/torchmetrics/text/wil.py index 801b5b6d3d9..2752d2863e9 100644 --- a/torchmetrics/text/wil.py +++ b/torchmetrics/text/wil.py @@ -94,10 +94,8 @@ def __init__( ) def update( # type: ignore self, - preds: Union[None, str, List[str]] = None, - target: Union[None, str, List[str]] = None, - predictions: Union[None, str, List[str]] = None, - references: Union[None, str, List[str]] = None, + preds: Union[str, List[str]], + target: Union[str, List[str]], ) -> None: """Store predictions/references for computing Word Information Lost scores. @@ -106,14 +104,8 @@ def update( # type: ignore Transcription(s) to score as a string or list of strings target: Reference(s) for each speech input as a string or list of strings - predictions: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. """ - errors, target_total, preds_total = _wil_update(preds, target) # type: ignore + errors, target_total, preds_total = _wil_update(preds, target) self.errors += errors self.target_total += target_total self.preds_total += preds_total diff --git a/torchmetrics/text/wip.py b/torchmetrics/text/wip.py index a46411028d3..85ec29e748a 100644 --- a/torchmetrics/text/wip.py +++ b/torchmetrics/text/wip.py @@ -94,8 +94,8 @@ def __init__( ) def update( # type: ignore self, - preds: Union[None, str, List[str]] = None, - target: Union[None, str, List[str]] = None, + preds: Union[str, List[str]], + target: Union[str, List[str]], ) -> None: """Store predictions/references for computing word Information Preserved scores. @@ -104,14 +104,8 @@ def update( # type: ignore Transcription(s) to score as a string or list of strings target: Reference(s) for each speech input as a string or list of strings - predictions: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. - references: - .. deprecated:: v0.7 - This argument is deprecated in favor of `preds` and will be removed in v0.8. """ - errors, target_total, preds_total = _wip_update(preds, target) # type: ignore + errors, target_total, preds_total = _wip_update(preds, target) self.errors += errors self.target_total += target_total self.preds_total += preds_total From 0e4d98dc999672d830367b8d47b65c88a4c53dce Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Wed, 12 Jan 2022 17:21:12 +0100 Subject: [PATCH 19/22] Drop indentation --- torchmetrics/text/bert.py | 6 +----- torchmetrics/text/bleu.py | 6 +----- torchmetrics/text/chrf.py | 6 +----- torchmetrics/text/sacre_bleu.py | 6 +----- torchmetrics/text/wil.py | 6 +----- torchmetrics/text/wip.py | 6 +----- 6 files changed, 6 insertions(+), 30 deletions(-) diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index 8add5d2aee9..c5c1db96c1d 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -200,11 +200,7 @@ def __init__( deprecated_in="0.7", remove_in="0.8", ) - def update( # type: ignore - self, - preds: List[str], - target: List[str], - ) -> None: + def update(self, preds: List[str], target: List[str]) -> None: # type: ignore """Store predictions/references for computing BERT scores. It is necessary to store sentences in a tokenized form to ensure the DDP mode working. diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index 0fedbfc704a..ef547f116cf 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -104,11 +104,7 @@ def __init__( deprecated_in="0.7", remove_in="0.8", ) - def update( # type: ignore - self, - preds: Sequence[str], - target: Sequence[Sequence[str]], - ) -> None: + def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore """Compute Precision Scores. Args: diff --git a/torchmetrics/text/chrf.py b/torchmetrics/text/chrf.py index f978daf4c8e..ff4b8dc13b9 100644 --- a/torchmetrics/text/chrf.py +++ b/torchmetrics/text/chrf.py @@ -152,11 +152,7 @@ def __init__( deprecated_in="0.7", remove_in="0.8", ) - def update( # type: ignore - self, - preds: Sequence[str], - target: Sequence[Sequence[str]], - ) -> None: + def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore """Compute Precision Scores. Args: diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index ceaa6d74b0e..4e8a00b7e87 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -124,11 +124,7 @@ def __init__( deprecated_in="0.7", remove_in="0.8", ) - def update( # type: ignore - self, - preds: Sequence[str], - target: Sequence[Sequence[str]], - ) -> None: + def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore """Compute Precision Scores. Args: diff --git a/torchmetrics/text/wil.py b/torchmetrics/text/wil.py index 2752d2863e9..8570ff25d7a 100644 --- a/torchmetrics/text/wil.py +++ b/torchmetrics/text/wil.py @@ -92,11 +92,7 @@ def __init__( deprecated_in="0.7", remove_in="0.8", ) - def update( # type: ignore - self, - preds: Union[str, List[str]], - target: Union[str, List[str]], - ) -> None: + def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store predictions/references for computing Word Information Lost scores. Args: diff --git a/torchmetrics/text/wip.py b/torchmetrics/text/wip.py index 85ec29e748a..e1ef1dfb44d 100644 --- a/torchmetrics/text/wip.py +++ b/torchmetrics/text/wip.py @@ -92,11 +92,7 @@ def __init__( deprecated_in="0.7", remove_in="0.8", ) - def update( # type: ignore - self, - preds: Union[str, List[str]], - target: Union[str, List[str]], - ) -> None: + def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store predictions/references for computing word Information Preserved scores. Args: From 47fd3b9641f4e0d242d4a384b3bb5f6c86e088b9 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Wed, 12 Jan 2022 17:36:14 +0100 Subject: [PATCH 20/22] Drop deprecated warning where not needed + add deprecated info to doc --- torchmetrics/functional/text/bert.py | 7 +++++++ torchmetrics/functional/text/bleu.py | 7 +++++++ torchmetrics/functional/text/chrf.py | 9 --------- torchmetrics/functional/text/sacre_bleu.py | 7 +++++++ torchmetrics/functional/text/wil.py | 9 --------- torchmetrics/functional/text/wip.py | 9 --------- torchmetrics/text/bert.py | 7 +++++++ torchmetrics/text/bleu.py | 7 +++++++ torchmetrics/text/chrf.py | 9 --------- torchmetrics/text/sacre_bleu.py | 7 +++++++ torchmetrics/text/wil.py | 9 --------- torchmetrics/text/wip.py | 9 --------- 12 files changed, 42 insertions(+), 54 deletions(-) diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index bcec4e26ee4..e3a23c37796 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -543,6 +543,13 @@ def bert_score( Returns: Python dictionary containing the keys `precision`, `recall` and `f1` with corresponding values. + .. deprecated:: v0.7 + Args: + predictions: + This argument is deprecated in favor of `preds` and will be removed in v0.8. + references: + This argument is deprecated in favor of `target` and will be removed in v0.8. + Raises: ValueError: If `len(preds) != len(target)`. diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index 12ca015828b..fa4eb20cf25 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -173,6 +173,13 @@ def bleu_score( Return: Tensor with BLEU Score + .. deprecated:: v0.7 + Args: + translate_corpus: + This argument is deprecated in favor of `preds` and will be removed in v0.8. + reference_corpus: + This argument is deprecated in favor of `target` and will be removed in v0.8. + Example: >>> from torchmetrics.functional import bleu_score >>> preds = ['the cat is on the mat'] diff --git a/torchmetrics/functional/text/chrf.py b/torchmetrics/functional/text/chrf.py index 173b09be2a8..efeb87c2f4d 100644 --- a/torchmetrics/functional/text/chrf.py +++ b/torchmetrics/functional/text/chrf.py @@ -32,12 +32,10 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import logging from collections import defaultdict from typing import Dict, List, Optional, Sequence, Tuple, Union import torch -from deprecate import deprecated from torch import Tensor, tensor from torchmetrics.functional.text.helper import _validate_inputs @@ -586,13 +584,6 @@ def _chrf_score_compute( return chrf_f_score -@deprecated( - args_mapping={"hypothesis_corpus": "preds", "reference_corpus": "target"}, - target=True, - stream=logging.warning, - deprecated_in="0.7", - remove_in="0.8", -) def chrf_score( preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]], diff --git a/torchmetrics/functional/text/sacre_bleu.py b/torchmetrics/functional/text/sacre_bleu.py index bcf57cd702e..a0fd7446281 100644 --- a/torchmetrics/functional/text/sacre_bleu.py +++ b/torchmetrics/functional/text/sacre_bleu.py @@ -314,6 +314,13 @@ def sacre_bleu_score( Return: Tensor with BLEU Score + .. deprecated:: v0.7 + Args: + translate_corpus: + This argument is deprecated in favor of `preds` and will be removed in v0.8. + reference_corpus: + This argument is deprecated in favor of `target` and will be removed in v0.8. + Example: >>> from torchmetrics.functional import sacre_bleu_score >>> preds = ['the cat is on the mat'] diff --git a/torchmetrics/functional/text/wil.py b/torchmetrics/functional/text/wil.py index a88a2f64482..a2c4ee1342c 100644 --- a/torchmetrics/functional/text/wil.py +++ b/torchmetrics/functional/text/wil.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from typing import List, Tuple, Union -from deprecate import deprecated from torch import Tensor, tensor from torchmetrics.functional.text.helper import _edit_distance @@ -67,13 +65,6 @@ def _wil_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> T return 1 - ((errors / target_total) * (errors / preds_total)) -@deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - stream=logging.warning, - deprecated_in="0.7", - remove_in="0.8", -) def word_information_lost( preds: Union[str, List[str]], target: Union[str, List[str]], diff --git a/torchmetrics/functional/text/wip.py b/torchmetrics/functional/text/wip.py index eced3869950..a7fdc9878f0 100644 --- a/torchmetrics/functional/text/wip.py +++ b/torchmetrics/functional/text/wip.py @@ -11,10 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging from typing import List, Tuple, Union -from deprecate import deprecated from torch import Tensor, tensor from torchmetrics.functional.text.helper import _edit_distance @@ -66,13 +64,6 @@ def _wip_compute(errors: Tensor, target_total: Tensor, preds_total: Tensor) -> T return (errors / target_total) * (errors / preds_total) -@deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - stream=logging.warning, - deprecated_in="0.7", - remove_in="0.8", -) def word_information_preserved( preds: Union[str, List[str]], target: Union[str, List[str]], diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index c5c1db96c1d..b0ff2c8e610 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -209,6 +209,13 @@ def update(self, preds: List[str], target: List[str]) -> None: # type: ignore An iterable of predicted sentences. target: An iterable of reference sentences. + + .. deprecated:: v0.7 + Args: + predictions: + This argument is deprecated in favor of `preds` and will be removed in v0.8. + references: + This argument is deprecated in favor of `target` and will be removed in v0.8. """ preds_dict = _preprocess_text( preds, diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index ef547f116cf..0784a0ba945 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -110,6 +110,13 @@ def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: Args: preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus + + .. deprecated:: v0.7 + Args: + translate_corpus: + This argument is deprecated in favor of `preds` and will be removed in v0.8. + reference_corpus: + This argument is deprecated in favor of `target` and will be removed in v0.8. """ self.preds_len, self.target_len = _bleu_score_update( preds, diff --git a/torchmetrics/text/chrf.py b/torchmetrics/text/chrf.py index ff4b8dc13b9..487c26b627f 100644 --- a/torchmetrics/text/chrf.py +++ b/torchmetrics/text/chrf.py @@ -18,11 +18,9 @@ # Link: import itertools -import logging from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union import torch -from deprecate import deprecated from torch import Tensor, tensor from torchmetrics import Metric @@ -145,13 +143,6 @@ def __init__( if self.return_sentence_level_score: self.add_state("sentence_chrf_score", [], dist_reduce_fx="cat") - @deprecated( - args_mapping={"hypothesis_corpus": "preds", "reference_corpus": "target"}, - target=True, - stream=logging.warning, - deprecated_in="0.7", - remove_in="0.8", - ) def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: # type: ignore """Compute Precision Scores. diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index 4e8a00b7e87..0254df75a39 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -130,6 +130,13 @@ def update(self, preds: Sequence[str], target: Sequence[Sequence[str]]) -> None: Args: preds: An iterable of machine translated corpus target: An iterable of iterables of reference corpus + + .. deprecated:: v0.7 + Args: + translate_corpus: + This argument is deprecated in favor of `preds` and will be removed in v0.8. + reference_corpus: + This argument is deprecated in favor of `target` and will be removed in v0.8. """ self.preds_len, self.target_len = _bleu_score_update( preds, diff --git a/torchmetrics/text/wil.py b/torchmetrics/text/wil.py index 8570ff25d7a..d3940d5db8e 100644 --- a/torchmetrics/text/wil.py +++ b/torchmetrics/text/wil.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from typing import Any, Callable, List, Optional, Union -from deprecate import deprecated from torch import Tensor, tensor from torchmetrics.functional.text.wil import _wil_compute, _wil_update @@ -85,13 +83,6 @@ def __init__( self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") - @deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - stream=logging.warning, - deprecated_in="0.7", - remove_in="0.8", - ) def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store predictions/references for computing Word Information Lost scores. diff --git a/torchmetrics/text/wip.py b/torchmetrics/text/wip.py index e1ef1dfb44d..1a0351999ac 100644 --- a/torchmetrics/text/wip.py +++ b/torchmetrics/text/wip.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from typing import Any, Callable, List, Optional, Union -from deprecate import deprecated from torch import Tensor, tensor from torchmetrics.functional.text.wip import _wip_compute, _wip_update @@ -85,13 +83,6 @@ def __init__( self.add_state("target_total", tensor(0.0), dist_reduce_fx="sum") self.add_state("preds_total", tensor(0.0), dist_reduce_fx="sum") - @deprecated( - args_mapping={"predictions": "preds", "references": "target"}, - target=True, - stream=logging.warning, - deprecated_in="0.7", - remove_in="0.8", - ) def update(self, preds: Union[str, List[str]], target: Union[str, List[str]]) -> None: # type: ignore """Store predictions/references for computing word Information Preserved scores. From 8f43eb31e0820a09fa837f1a4051f2bb6b225b20 Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Wed, 12 Jan 2022 22:19:27 +0100 Subject: [PATCH 21/22] Change stream --- torchmetrics/functional/text/bert.py | 12 ++++++------ torchmetrics/functional/text/bleu.py | 4 ++-- torchmetrics/functional/text/sacre_bleu.py | 3 +-- torchmetrics/text/bert.py | 4 ++-- torchmetrics/text/bleu.py | 4 ++-- torchmetrics/text/sacre_bleu.py | 4 ++-- 6 files changed, 15 insertions(+), 16 deletions(-) diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index e3a23c37796..9da552722f9 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import csv -import logging import math import urllib -import warnings from collections import Counter, defaultdict +from functools import partial from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from warnings import warn import torch from deprecate import deprecated @@ -428,7 +428,7 @@ def _load_baseline( baseline = _read_csv_from_url(baseline_url) else: baseline = None - warnings.warn("Baseline was not successfully loaded. No baseline is going to be used.") + warn("Baseline was not successfully loaded. No baseline is going to be used.") return baseline @@ -454,7 +454,7 @@ def _rescale_metrics_with_baseline( @deprecated( args_mapping={"predictions": "preds", "references": "target"}, target=True, - stream=logging.warning, + stream=partial(warn, category=FutureWarning), deprecated_in="0.7", remove_in="0.8", ) @@ -599,7 +599,7 @@ def bert_score( f"Please use num_layers <= {model.config.num_hidden_layers}" # type: ignore ) except AttributeError: - warnings.warn("It was not possible to retrieve the parameter `num_layers` from the model specification.") + warn("It was not possible to retrieve the parameter `num_layers` from the model specification.") _are_empty_lists = all(isinstance(text, list) and len(text) == 0 for text in (preds, target)) _are_valid_lists = all( @@ -609,7 +609,7 @@ def bert_score( isinstance(text, dict) and isinstance(text["input_ids"], Tensor) for text in (preds, target) ) if _are_empty_lists: - warnings.warn("Predictions and references are empty.") + warn("Predictions and references are empty.") output_dict: Dict[str, Union[List[float], str]] = { "precision": [0.0], "recall": [0.0], diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index fa4eb20cf25..576f61dcda5 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -16,8 +16,8 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -import logging from collections import Counter +from functools import partial from typing import Callable, Sequence, Tuple, Union from warnings import warn @@ -148,7 +148,7 @@ def _bleu_score_compute( @deprecated( args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, target=True, - stream=logging.warning, + stream=partial(warn, category=FutureWarning), deprecated_in="0.7", remove_in="0.8", ) diff --git a/torchmetrics/functional/text/sacre_bleu.py b/torchmetrics/functional/text/sacre_bleu.py index a0fd7446281..7576988e48b 100644 --- a/torchmetrics/functional/text/sacre_bleu.py +++ b/torchmetrics/functional/text/sacre_bleu.py @@ -37,7 +37,6 @@ # MIT License # Copyright (c) 2017 - Shujian Huang -import logging import re from functools import partial from typing import Sequence @@ -281,7 +280,7 @@ def _lower(line: str, lowercase: bool) -> str: @deprecated( args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, target=True, - stream=logging.warning, + stream=partial(warn, category=FutureWarning), deprecated_in="0.7", remove_in="0.8", ) diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index b0ff2c8e610..04cedde323c 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging +from functools import partial from typing import Any, Callable, Dict, List, Optional, Union from warnings import warn @@ -196,7 +196,7 @@ def __init__( @deprecated( args_mapping={"predictions": "preds", "references": "target"}, target=True, - stream=logging.warning, + stream=partial(warn, category=FutureWarning), deprecated_in="0.7", remove_in="0.8", ) diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index 0784a0ba945..b1e08fca27f 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -16,7 +16,7 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -import logging +from functools import partial from typing import Any, Callable, Optional, Sequence from warnings import warn @@ -100,7 +100,7 @@ def __init__( @deprecated( args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, target=True, - stream=logging.warning, + stream=partial(warn, category=FutureWarning), deprecated_in="0.7", remove_in="0.8", ) diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index 0254df75a39..102abff1972 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -17,7 +17,7 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -import logging +from functools import partial from typing import Any, Callable, Optional, Sequence from warnings import warn @@ -120,7 +120,7 @@ def __init__( @deprecated( args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, target=True, - stream=logging.warning, + stream=partial(warn, category=FutureWarning), deprecated_in="0.7", remove_in="0.8", ) From 0a0d80a4d2b8d2ded5ccdd4ce5dc0d33fb4f7dfd Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Wed, 12 Jan 2022 23:54:12 +0100 Subject: [PATCH 22/22] Switch to default stream method --- torchmetrics/functional/text/bert.py | 2 -- torchmetrics/functional/text/bleu.py | 2 -- torchmetrics/functional/text/sacre_bleu.py | 1 - torchmetrics/text/bert.py | 2 -- torchmetrics/text/bleu.py | 2 -- torchmetrics/text/sacre_bleu.py | 2 -- 6 files changed, 11 deletions(-) diff --git a/torchmetrics/functional/text/bert.py b/torchmetrics/functional/text/bert.py index 9da552722f9..98e387c6f98 100644 --- a/torchmetrics/functional/text/bert.py +++ b/torchmetrics/functional/text/bert.py @@ -15,7 +15,6 @@ import math import urllib from collections import Counter, defaultdict -from functools import partial from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from warnings import warn @@ -454,7 +453,6 @@ def _rescale_metrics_with_baseline( @deprecated( args_mapping={"predictions": "preds", "references": "target"}, target=True, - stream=partial(warn, category=FutureWarning), deprecated_in="0.7", remove_in="0.8", ) diff --git a/torchmetrics/functional/text/bleu.py b/torchmetrics/functional/text/bleu.py index 576f61dcda5..c301d32cbba 100644 --- a/torchmetrics/functional/text/bleu.py +++ b/torchmetrics/functional/text/bleu.py @@ -17,7 +17,6 @@ # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score from collections import Counter -from functools import partial from typing import Callable, Sequence, Tuple, Union from warnings import warn @@ -148,7 +147,6 @@ def _bleu_score_compute( @deprecated( args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, target=True, - stream=partial(warn, category=FutureWarning), deprecated_in="0.7", remove_in="0.8", ) diff --git a/torchmetrics/functional/text/sacre_bleu.py b/torchmetrics/functional/text/sacre_bleu.py index 7576988e48b..41754c060e2 100644 --- a/torchmetrics/functional/text/sacre_bleu.py +++ b/torchmetrics/functional/text/sacre_bleu.py @@ -280,7 +280,6 @@ def _lower(line: str, lowercase: bool) -> str: @deprecated( args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, target=True, - stream=partial(warn, category=FutureWarning), deprecated_in="0.7", remove_in="0.8", ) diff --git a/torchmetrics/text/bert.py b/torchmetrics/text/bert.py index 04cedde323c..31cd24c48fb 100644 --- a/torchmetrics/text/bert.py +++ b/torchmetrics/text/bert.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Any, Callable, Dict, List, Optional, Union from warnings import warn @@ -196,7 +195,6 @@ def __init__( @deprecated( args_mapping={"predictions": "preds", "references": "target"}, target=True, - stream=partial(warn, category=FutureWarning), deprecated_in="0.7", remove_in="0.8", ) diff --git a/torchmetrics/text/bleu.py b/torchmetrics/text/bleu.py index b1e08fca27f..5edd0c9b5ee 100644 --- a/torchmetrics/text/bleu.py +++ b/torchmetrics/text/bleu.py @@ -16,7 +16,6 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from functools import partial from typing import Any, Callable, Optional, Sequence from warnings import warn @@ -100,7 +99,6 @@ def __init__( @deprecated( args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, target=True, - stream=partial(warn, category=FutureWarning), deprecated_in="0.7", remove_in="0.8", ) diff --git a/torchmetrics/text/sacre_bleu.py b/torchmetrics/text/sacre_bleu.py index 102abff1972..a66019fd97b 100644 --- a/torchmetrics/text/sacre_bleu.py +++ b/torchmetrics/text/sacre_bleu.py @@ -17,7 +17,6 @@ # Authors: torchtext authors and @sluks # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score -from functools import partial from typing import Any, Callable, Optional, Sequence from warnings import warn @@ -120,7 +119,6 @@ def __init__( @deprecated( args_mapping={"translate_corpus": "preds", "reference_corpus": "target"}, target=True, - stream=partial(warn, category=FutureWarning), deprecated_in="0.7", remove_in="0.8", )