-
Notifications
You must be signed in to change notification settings - Fork 413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix Bug in BertScore calculation: pred target misalignment #2347
Conversation
Here is the test code: from torchmetrics.text.bert import BERTScore
score_model = BERTScore(model_name_or_path='roberta-large', batch_size=2)
text1 = [ "Claim A from machine", "Claim A from machine"]
text2 = ["Claim A from machine", "Claim B"]
similarities = score_model(text1, text2)
print(similarities) |
@stancld could you help here, pls? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I can reproduce the bug using your test code snippet. Although the actual metric values seem to be correct, the ordering is not always valid.
It might be nice to somehow integrate this case with the current test suite. E.g. an assertation that reversing the targets/preds also reverses the scores.
@gxy-gxy can you pls add it as a test? |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #2347 +/- ##
========================================
- Coverage 69% 39% -30%
========================================
Files 316 316
Lines 17878 17874 -4
========================================
- Hits 12329 7030 -5299
- Misses 5549 10844 +5295 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding test - #2347 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the addition. I can quickly add the test
@baskrahmer, would mind also adding an entry to the changelog? |
* fix pred target misalignment * add test --------- Co-authored-by: Xinyan Guan <[email protected]> Co-authored-by: Bas Krahmer <[email protected]> (cherry picked from commit 75c33ea)
* fix pred target misalignment * add test --------- Co-authored-by: Xinyan Guan <[email protected]> Co-authored-by: Bas Krahmer <[email protected]> (cherry picked from commit 75c33ea)
fix Bug in BertScore calculation: pred target misalignment
Fixes bug in BertScore cal.
This pull request addresses a bug identified in the BertScore calculation within the TextDataset class in
src/torchmetrics/functional/text/helper_embedding_metric.py
.The class is designed with a preprocess function automatically sorts input text by length to optimize batch encoding efficiency. However, this behavior introduces an issue during the BertScore calculation process, as predictions (preds) and targets (targets) are initialized in separate datasets. This results in a mismatched ordering of text pairs, which is problematic given the pairwise nature of BertScore's calculation. To ensure accurate scoring, it is critical to re-align the datasets to their original order before computing the scores. The proposed fix involves ensuring that the datasets for predictions and targets are processed in a way that maintains their original pairing throughout the calculation process.
Here is the fixed code:
This change is essential for preserving the integrity of the BertScore evaluation, ensuring that each prediction is accurately compared against its corresponding target.
📚 Documentation preview 📚: https://torchmetrics--2347.org.readthedocs.build/en/2347/