Skip to content

Commit ea813d3

Browse files
yuxqiufacebook-github-bot
authored andcommitted
fix: correct reference length calculation (#195)
Summary: This PR fixes the way brevity penalty (specifically the effective reference corpus length) is calculated in BLEU. Previously, `len_reference` was calculated as `min([len(ref) for ref in references_tokenized])`. However, this is incorrect, because according to the paper, we need to find the "best match length", not the minimum reference length. For more information, see [wikipedia - brevity penalty](https://en.wikipedia.org/wiki/BLEU#Brevity_penalty) and [nltk implementation](https://www.nltk.org/_modules/nltk/translate/bleu_score.html#closest_ref_length). Pull Request resolved: #195 Test Plan: I added another unit test to `test_bleu.py` and compared the results of the calculations to the results of the `nltk.translate.bleu_score.corpus_bleu` function to make sure the implementation is correct. Reviewed By: galrotem Differential Revision: D56846091 Pulled By: JKSenthil fbshipit-source-id: 2bf1cd0ba169535a118222e60f4264259248f1fd
1 parent cb6bc39 commit ea813d3

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

tests/metrics/text/test_bleu.py

+29
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,32 @@ def test_bleu_multiple_examples_per_update(self) -> None:
107107
num_total_updates=2,
108108
num_processes=2,
109109
)
110+
111+
def test_bleu_brevity(self) -> None:
112+
candidates = [["the squirrel is eating the nut"], ["the cat is on mat"]]
113+
references = [
114+
[
115+
[
116+
"a squirrel is eating a nut",
117+
"the squirrel is eating a tasty nut",
118+
"hi",
119+
]
120+
],
121+
[["there is a cat on the mat", "a cat is on the mat"]],
122+
]
123+
self.run_class_implementation_tests(
124+
metric=BLEUScore(n_gram=4),
125+
state_names={
126+
"input_len",
127+
"target_len",
128+
"matches_by_order",
129+
"possible_matches_by_order",
130+
},
131+
update_kwargs={
132+
"input": candidates,
133+
"target": references,
134+
},
135+
compute_result=torch.tensor(0.41650065, dtype=torch.float64),
136+
num_total_updates=2,
137+
num_processes=2,
138+
)

torcheval/metrics/functional/text/bleu.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,10 @@ def _bleu_score_update(
8888
references_tokenized = [ref.split() for ref in references]
8989

9090
len_candidate = len(candidate_tokenized)
91-
len_reference = min([len(ref) for ref in references_tokenized])
91+
len_reference = min(
92+
[len(ref) for ref in references_tokenized],
93+
key=lambda ref_len: (abs(ref_len - len_candidate), ref_len),
94+
)
9295
input_len += len_candidate
9396
target_len += len_reference
9497

0 commit comments

Comments
 (0)