Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify preds, target input arguments for text metrics [1of2] bert, bleu, chrf, sacre_bleu, wip, wil #723

Merged
merged 30 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e6f1736
[WIP] Unify some text metrics
stancld Jan 6, 2022
6e690be
Removed input_order from text unit tests (#717)
jscottcronin Jan 5, 2022
ad72e68
Fix some unchanged names
stancld Jan 6, 2022
b3fd97b
Unify some other stuff
stancld Jan 6, 2022
5ff3534
Fix flake8 + unwanted typo
stancld Jan 6, 2022
f7c65eb
Fix an import in bert doc
stancld Jan 6, 2022
8c7b9e0
Some nits
stancld Jan 7, 2022
8b330ea
Handle BC and add warnings
stancld Jan 9, 2022
ab36e72
Apply suggestions from code review
stancld Jan 10, 2022
8818569
Add one unsaved change
stancld Jan 10, 2022
5547156
Fix doc indentation for wip/wil
stancld Jan 10, 2022
f8b8b01
Merge branch 'master' into text-preds-target
stancld Jan 10, 2022
71e82c1
Set preds, target = None
stancld Jan 10, 2022
0466c51
Merge branch 'master' into text-preds-target
stancld Jan 10, 2022
2ebc20f
Add ignore statements for mypy
stancld Jan 10, 2022
9b01cc7
Merge branch 'master' into text-preds-target
stancld Jan 10, 2022
904916d
Merge branch 'master' into text-preds-target
stancld Jan 11, 2022
26f7fd0
Use deprecate package
stancld Jan 11, 2022
93cceb3
Add pyDeprecate==0.3.* to the requirements
stancld Jan 11, 2022
c8a3463
Apply suggestions from code review
Borda Jan 12, 2022
3ecb965
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2022
e632c1d
Apply some suggestions from code review
stancld Jan 12, 2022
a93705e
Merge branch 'master' into text-preds-target
stancld Jan 12, 2022
0e4d98d
Drop indentation
stancld Jan 12, 2022
47fd3b9
Drop deprecated warning where not needed + add deprecated info to doc
stancld Jan 12, 2022
21fc7fd
Merge branch 'master' into text-preds-target
SkafteNicki Jan 12, 2022
339e746
Merge branch 'master' into text-preds-target
stancld Jan 12, 2022
8f43eb3
Change stream
stancld Jan 12, 2022
c2d45bc
Merge branch 'master' into text-preds-target
Borda Jan 12, 2022
0a0d80a
Switch to default stream method
stancld Jan 12, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy>=1.17.2
torch>=1.3.1
pyDeprecate==0.3.*
packaging
142 changes: 73 additions & 69 deletions tests/text/test_bertscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"The victim's brother said he cannot imagine anyone who would want to harm him,\"Finally, it went uphill again at "
'him."',
]
refs = [
targets = [
"28-Year-Old Chef Found Dead at San Francisco Mall",
"A 28-year-old chef who had recently moved to San Francisco was found dead in the stairwell of a local mall this "
"week.",
Expand All @@ -39,9 +39,9 @@
MODEL_NAME = "albert-base-v2"


def _assert_list(preds: Any, refs: Any, threshold: float = 1e-8):
def _assert_list(preds: Any, targets: Any, threshold: float = 1e-8):
"""Assert two lists are equal."""
assert np.allclose(preds, refs, atol=threshold, equal_nan=True)
assert np.allclose(preds, targets, atol=threshold, equal_nan=True)


def _parse_original_bert_score(score: torch.Tensor) -> Dict[str, List[float]]:
Expand All @@ -51,91 +51,93 @@ def _parse_original_bert_score(score: torch.Tensor) -> Dict[str, List[float]]:


preds_batched = [preds[0:2], preds[2:]]
refs_batched = [refs[0:2], refs[2:]]
targets_batched = [targets[0:2], targets[2:]]


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
"preds,targets",
[(preds, targets)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_fn(preds, refs):
def test_score_fn(preds, targets):
"""Tests for functional."""
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3)
original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3)
original_score = _parse_original_bert_score(original_score)

metrics_score = metrics_bert_score(
preds, refs, model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3
preds, targets, model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3
)

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
"preds,targets",
[(preds, targets)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_fn_with_idf(preds, refs):
def test_score_fn_with_idf(preds, targets):
"""Tests for functional with IDF rescaling."""
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, num_layers=12, idf=True, batch_size=3)
original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, num_layers=12, idf=True, batch_size=3)
original_score = _parse_original_bert_score(original_score)

metrics_score = metrics_bert_score(
preds, refs, model_name_or_path=MODEL_NAME, num_layers=12, idf=True, batch_size=3
preds, targets, model_name_or_path=MODEL_NAME, num_layers=12, idf=True, batch_size=3
)

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
"preds,targets",
[(preds, targets)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_fn_all_layers(preds, refs):
def test_score_fn_all_layers(preds, targets):
"""Tests for functional and all layers."""
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, all_layers=True, idf=False, batch_size=3)
original_score = original_bert_score(
preds, targets, model_type=MODEL_NAME, all_layers=True, idf=False, batch_size=3
)
original_score = _parse_original_bert_score(original_score)

metrics_score = metrics_bert_score(
preds, refs, model_name_or_path=MODEL_NAME, all_layers=True, idf=False, batch_size=3
preds, targets, model_name_or_path=MODEL_NAME, all_layers=True, idf=False, batch_size=3
)

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
"preds,targets",
[(preds, targets)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_fn_all_layers_with_idf(preds, refs):
def test_score_fn_all_layers_with_idf(preds, targets):
"""Tests for functional and all layers with IDF rescaling."""
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, all_layers=True, idf=True, batch_size=3)
original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, all_layers=True, idf=True, batch_size=3)
original_score = _parse_original_bert_score(original_score)

metrics_score = metrics_bert_score(
preds, refs, model_name_or_path=MODEL_NAME, all_layers=True, idf=True, batch_size=3
preds, targets, model_name_or_path=MODEL_NAME, all_layers=True, idf=True, batch_size=3
)

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
"preds,targets",
[(preds, targets)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_fn_all_layers_rescale_with_baseline(preds, refs):
def test_score_fn_all_layers_rescale_with_baseline(preds, targets):
"""Tests for functional with baseline rescaling."""
original_score = original_bert_score(
preds,
refs,
targets,
model_type=MODEL_NAME,
lang="en",
num_layers=8,
Expand All @@ -147,7 +149,7 @@ def test_score_fn_all_layers_rescale_with_baseline(preds, refs):

metrics_score = metrics_bert_score(
preds,
refs,
targets,
model_name_or_path=MODEL_NAME,
lang="en",
num_layers=8,
Expand All @@ -161,15 +163,15 @@ def test_score_fn_all_layers_rescale_with_baseline(preds, refs):


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
"preds,targets",
[(preds, targets)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_fn_rescale_with_baseline(preds, refs):
def test_score_fn_rescale_with_baseline(preds, targets):
"""Tests for functional with baseline rescaling with all layers."""
original_score = original_bert_score(
preds,
refs,
targets,
model_type=MODEL_NAME,
lang="en",
all_layers=True,
Expand All @@ -181,7 +183,7 @@ def test_score_fn_rescale_with_baseline(preds, refs):

metrics_score = metrics_bert_score(
preds,
refs,
targets,
model_name_or_path=MODEL_NAME,
lang="en",
all_layers=True,
Expand All @@ -195,124 +197,126 @@ def test_score_fn_rescale_with_baseline(preds, refs):


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
"preds,targets",
[(preds, targets)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score(preds, refs):
def test_score(preds, targets):
"""Tests for metric."""
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3)
original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3)
original_score = _parse_original_bert_score(original_score)

Scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3)
Scorer.update(predictions=preds, references=refs)
Scorer.update(preds=preds, target=targets)
metrics_score = Scorer.compute()

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
"preds,targets",
[(preds, targets)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_with_idf(preds, refs):
def test_score_with_idf(preds, targets):
"""Tests for metric with IDF rescaling."""
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, num_layers=8, idf=True, batch_size=3)
original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, num_layers=8, idf=True, batch_size=3)
original_score = _parse_original_bert_score(original_score)

Scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=True, batch_size=3)
Scorer.update(predictions=preds, references=refs)
Scorer.update(preds=preds, target=targets)
metrics_score = Scorer.compute()

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
"preds,targets",
[(preds, targets)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_all_layers(preds, refs):
def test_score_all_layers(preds, targets):
"""Tests for metric and all layers."""
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, all_layers=True, idf=False, batch_size=3)
original_score = original_bert_score(
preds, targets, model_type=MODEL_NAME, all_layers=True, idf=False, batch_size=3
)
original_score = _parse_original_bert_score(original_score)

Scorer = BERTScore(model_name_or_path=MODEL_NAME, all_layers=True, idf=False, batch_size=3)
Scorer.update(predictions=preds, references=refs)
Scorer.update(preds=preds, target=targets)
metrics_score = Scorer.compute()

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
"preds,targets",
[(preds, targets)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_score_all_layers_with_idf(preds, refs):
def test_score_all_layers_with_idf(preds, targets):
"""Tests for metric and all layers with IDF rescaling."""
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, all_layers=True, idf=True, batch_size=3)
original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, all_layers=True, idf=True, batch_size=3)
original_score = _parse_original_bert_score(original_score)

Scorer = BERTScore(model_name_or_path=MODEL_NAME, all_layers=True, idf=True, batch_size=3)
Scorer.update(predictions=preds, references=refs)
Scorer.update(preds=preds, target=targets)
metrics_score = Scorer.compute()

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


@pytest.mark.parametrize(
"preds,refs",
[(preds_batched, refs_batched)],
"preds,targets",
[(preds_batched, targets_batched)],
)
@pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score")
def test_accumulation(preds, refs):
def test_accumulation(preds, targets):
"""Tests for metric works with accumulation."""
original_score = original_bert_score(
sum(preds, []), sum(refs, []), model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3
sum(preds, []), sum(targets, []), model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3
)
original_score = _parse_original_bert_score(original_score)

Scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3)
for p, r in zip(preds, refs):
Scorer.update(predictions=p, references=r)
for p, r in zip(preds, targets):
Scorer.update(preds=p, target=r)
metrics_score = Scorer.compute()

for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])


def _bert_score_ddp(rank, world_size, preds, refs, original_score):
def _bert_score_ddp(rank, world_size, preds, targets, original_score):
"""Define a DDP process for BERTScore."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("gloo", rank=rank, world_size=world_size)
Scorer = BERTScore(model_name_or_path=MODEL_NAME, num_layers=8, idf=False, batch_size=3, max_length=128)
Scorer.update(preds, refs)
Scorer.update(preds, targets)
metrics_score = Scorer.compute()
for metric in _METRICS:
_assert_list(metrics_score[metric], original_score[metric])
dist.destroy_process_group()


def _test_score_ddp_fn(rank, world_size, preds, refs):
def _test_score_ddp_fn(rank, world_size, preds, targets):
"""Core functionality for the `test_score_ddp` test."""
original_score = original_bert_score(preds, refs, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3)
original_score = original_bert_score(preds, targets, model_type=MODEL_NAME, num_layers=8, idf=False, batch_size=3)
original_score = _parse_original_bert_score(original_score)
_bert_score_ddp(rank, world_size, preds, refs, original_score)
_bert_score_ddp(rank, world_size, preds, targets, original_score)


@pytest.mark.parametrize(
"preds,refs",
[(preds, refs)],
"preds,targets",
[(preds, targets)],
)
@pytest.mark.skipif(not (_BERTSCORE_AVAILABLE and dist.is_available()), reason="test requires bert_score")
def test_score_ddp(preds, refs):
def test_score_ddp(preds, targets):
"""Tests for metric using DDP."""
world_size = 2
mp.spawn(_test_score_ddp_fn, args=(world_size, preds, refs), nprocs=world_size, join=False)
mp.spawn(_test_score_ddp_fn, args=(world_size, preds, targets), nprocs=world_size, join=False)
30 changes: 13 additions & 17 deletions tests/text/test_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@
smooth_func = SmoothingFunction().method2


def _compute_bleu_metric_nltk(hypotheses, list_of_references, weights, smoothing_function, **kwargs):
hypotheses_ = [hypothesis.split() for hypothesis in hypotheses]
list_of_references_ = [[line.split() for line in ref] for ref in list_of_references]
def _compute_bleu_metric_nltk(preds, targets, weights, smoothing_function, **kwargs):
preds_ = [pred.split() for pred in preds]
targets_ = [[line.split() for line in target] for target in targets]
return corpus_bleu(
list_of_references=list_of_references_,
hypotheses=hypotheses_,
weights=weights,
smoothing_function=smoothing_function,
**kwargs
list_of_references=targets_, hypotheses=preds_, weights=weights, smoothing_function=smoothing_function, **kwargs
)


Expand Down Expand Up @@ -100,20 +96,20 @@ def test_bleu_empty_functional():


def test_no_4_gram_functional():
hyps = ["My full pytorch-lightning"]
refs = [["My full pytorch-lightning test", "Completely Different"]]
assert bleu_score(hyps, refs) == tensor(0.0)
preds = ["My full pytorch-lightning"]
targets = [["My full pytorch-lightning test", "Completely Different"]]
assert bleu_score(preds, targets) == tensor(0.0)


def test_bleu_empty_class():
bleu = BLEUScore()
hyp = [[]]
ref = [[[]]]
assert bleu(hyp, ref) == tensor(0.0)
preds = [[]]
targets = [[[]]]
assert bleu(preds, targets) == tensor(0.0)


def test_no_4_gram_class():
bleu = BLEUScore()
hyps = ["My full pytorch-lightning"]
refs = [["My full pytorch-lightning test", "Completely Different"]]
assert bleu(hyps, refs) == tensor(0.0)
preds = ["My full pytorch-lightning"]
targets = [["My full pytorch-lightning test", "Completely Different"]]
assert bleu(preds, targets) == tensor(0.0)
Loading