Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
add rouge metrics (#1719)
Browse files Browse the repository at this point in the history
* add rouge metrics by using py-rouge
* Will warn if nltk tokenizer is not there
* We add this tokenizer into our circleci so we can test rouge in CI.
  • Loading branch information
dexterju27 committed Jun 11, 2019
1 parent e0cb245 commit 471db18
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 6 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ installdeps: &installdeps
name: Installs basic dependencies
command: |
python setup.py develop
python -c "import nltk; nltk.download('punkt')"
installtorchgpu: &installtorchgpu
run:
Expand Down
53 changes: 52 additions & 1 deletion parlai/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from parlai.core.thread_utils import SharedTable
from parlai.core.utils import round_sigfigs, no_lock
from collections import Counter
from parlai.core.utils import warn_once

import re

Expand All @@ -23,6 +24,14 @@
# We'll just turn off things, but we might want to warn the user
nltkbleu = None

try:
import rouge as rouge
except ImportError:
# User doesn't have rouge installed, so we can't use it for rouge
# We'll just turn off things, but we might want to warn the user
warn_once('Rouge metrics require py-rouge. Please run `pip install py-rouge`.')
rouge = None

re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')

Expand Down Expand Up @@ -103,6 +112,29 @@ def _bleu(guess, answers):
)


def _rouge(guess, answers):
global rouge
"""Compute ROUGE score between guess and *any* answers. Return the best."""
if rouge is None:
return None, None, None
evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], max_n=2)
try:
scores = [evaluator.get_scores(normalize_answer(guess), normalize_answer(a))
for a in answers]
except LookupError:
warn_once(
'ROUGE requires nltk punkt tokenizer. Please run '
'`python -c "import nltk; nltk.download(\'punkt\')`'
)
rouge = None
return None, None, None

scores_rouge1 = [score['rouge-1']['r'] for score in scores]
scores_rouge2 = [score['rouge-2']['r'] for score in scores]
scores_rougel = [score['rouge-l']['r'] for score in scores]
return max(scores_rouge1), max(scores_rouge2), max(scores_rougel)


def aggregate_metrics(reporters):
"""Aggregate metrics from multiple reports."""
# reporters is a list of teachers or worlds
Expand All @@ -111,6 +143,10 @@ def aggregate_metrics(reporters):
sums = {'accuracy': 0, 'f1': 0, 'loss': 0, 'ppl': 0}
if nltkbleu is not None:
sums['bleu'] = 0
if rouge is not None:
sums['rouge-1'] = 0.0
sums['rouge-2'] = 0.0
sums['rouge-L'] = 0.0
num_tasks = 0
total = 0
for i in range(len(reporters)):
Expand Down Expand Up @@ -146,6 +182,11 @@ def __init__(self, opt):
if nltkbleu is not None:
# only compute bleu if we can
self.metrics_list.append('bleu')
if rouge is not None:
# only compute rouge if we can
self.metrics_list.append('rouge-1')
self.metrics_list.append('rouge-2')
self.metrics_list.append('rouge-L')
for k in self.metrics_list:
self.metrics[k] = 0.0
self.metrics[k + '_cnt'] = 0
Expand Down Expand Up @@ -219,20 +260,30 @@ def update(self, observation, labels):
# F1 and BLEU metrics.
f1 = _f1_score(prediction, labels)
bleu = _bleu(prediction, labels)
rouge1, rouge2, rougel = _rouge(prediction, labels)

with self._lock():
self.metrics['f1'] += f1
self.metrics['f1_cnt'] += 1
if bleu is not None:
self.metrics['bleu'] += bleu
self.metrics['bleu_cnt'] += 1
if rouge1 is not None:
self.metrics['rouge-1'] += rouge1
self.metrics['rouge-2'] += rouge2
self.metrics['rouge-L'] += rougel
self.metrics['rouge-1_cnt'] += 1
self.metrics['rouge-2_cnt'] += 1
self.metrics['rouge-L_cnt'] += 1

# Ranking metrics.
self._update_ranking_metrics(observation, labels)

# User-reported metrics
if 'metrics' in observation:
for k, v in observation['metrics'].items():
if k not in ['correct', 'f1', 'hits@k', 'bleu']:
if k not in ['correct', 'f1', 'hits@k', 'bleu', 'rouge-1',
'rouge-2', 'rouge-L']:
if k in self.metrics_list:
with self._lock():
self.metrics[k] += v
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ sphinx_rtd_theme
tqdm
websocket-client
websocket-server
py-rouge
13 changes: 8 additions & 5 deletions tests/test_eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_output(self):
"""Test output of running eval_model"""
parser = setup_args()
parser.set_defaults(
task='tasks.repeat:RepeatTeacher:10',
task='integration_tests',
model='repeat_label',
datatype='valid',
num_examples=5,
Expand All @@ -30,14 +30,17 @@ def test_output(self):

# decode the output
scores = str_output.split("\n---\n")

for i in range(1, len(scores)):
score = ast.literal_eval(scores[i])
# check totals
self.assertTrue(score['exs'] == i,
"Total is incorrect")
self.assertEqual(score['exs'], i, "Total is incorrect")
# accuracy should be one
self.assertTrue(score['accuracy'] == 1,
"accuracy != 1")
self.assertEqual(score['accuracy'], 1, "accuracy != 1")
if 'rouge-1' in score:
self.assertEqual(score['rouge-1'], 1, 'rouge1 != 1')
self.assertEqual(score['rouge-2'], 1, 'rouge-2 != 1')
self.assertEqual(score['rouge-L'], 1, 'rouge-L != 1')


if __name__ == '__main__':
Expand Down

0 comments on commit 471db18

Please sign in to comment.