diff --git a/.circleci/config.yml b/.circleci/config.yml index 639db1a3e9e..302b110b8ca 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -42,6 +42,7 @@ installdeps: &installdeps name: Installs basic dependencies command: | python setup.py develop + python -c "import nltk; nltk.download('punkt')" installtorchgpu: &installtorchgpu run: diff --git a/parlai/core/metrics.py b/parlai/core/metrics.py index 4771af6a612..6f0b10311b8 100644 --- a/parlai/core/metrics.py +++ b/parlai/core/metrics.py @@ -13,6 +13,7 @@ from parlai.core.thread_utils import SharedTable from parlai.core.utils import round_sigfigs, no_lock from collections import Counter +from parlai.core.utils import warn_once import re @@ -23,6 +24,14 @@ # We'll just turn off things, but we might want to warn the user nltkbleu = None +try: + import rouge as rouge +except ImportError: + # User doesn't have rouge installed, so we can't use it for rouge + # We'll just turn off things, but we might want to warn the user + warn_once('Rouge metrics require py-rouge. Please run `pip install py-rouge`.') + rouge = None + re_art = re.compile(r'\b(a|an|the)\b') re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']') @@ -103,6 +112,29 @@ def _bleu(guess, answers): ) +def _rouge(guess, answers): + global rouge + """Compute ROUGE score between guess and *any* answers. Return the best.""" + if rouge is None: + return None, None, None + evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], max_n=2) + try: + scores = [evaluator.get_scores(normalize_answer(guess), normalize_answer(a)) + for a in answers] + except LookupError: + warn_once( + 'ROUGE requires nltk punkt tokenizer. Please run ' + '`python -c "import nltk; nltk.download(\'punkt\')`' + ) + rouge = None + return None, None, None + + scores_rouge1 = [score['rouge-1']['r'] for score in scores] + scores_rouge2 = [score['rouge-2']['r'] for score in scores] + scores_rougel = [score['rouge-l']['r'] for score in scores] + return max(scores_rouge1), max(scores_rouge2), max(scores_rougel) + + def aggregate_metrics(reporters): """Aggregate metrics from multiple reports.""" # reporters is a list of teachers or worlds @@ -111,6 +143,10 @@ def aggregate_metrics(reporters): sums = {'accuracy': 0, 'f1': 0, 'loss': 0, 'ppl': 0} if nltkbleu is not None: sums['bleu'] = 0 + if rouge is not None: + sums['rouge-1'] = 0.0 + sums['rouge-2'] = 0.0 + sums['rouge-L'] = 0.0 num_tasks = 0 total = 0 for i in range(len(reporters)): @@ -146,6 +182,11 @@ def __init__(self, opt): if nltkbleu is not None: # only compute bleu if we can self.metrics_list.append('bleu') + if rouge is not None: + # only compute rouge if we can + self.metrics_list.append('rouge-1') + self.metrics_list.append('rouge-2') + self.metrics_list.append('rouge-L') for k in self.metrics_list: self.metrics[k] = 0.0 self.metrics[k + '_cnt'] = 0 @@ -219,12 +260,21 @@ def update(self, observation, labels): # F1 and BLEU metrics. f1 = _f1_score(prediction, labels) bleu = _bleu(prediction, labels) + rouge1, rouge2, rougel = _rouge(prediction, labels) + with self._lock(): self.metrics['f1'] += f1 self.metrics['f1_cnt'] += 1 if bleu is not None: self.metrics['bleu'] += bleu self.metrics['bleu_cnt'] += 1 + if rouge1 is not None: + self.metrics['rouge-1'] += rouge1 + self.metrics['rouge-2'] += rouge2 + self.metrics['rouge-L'] += rougel + self.metrics['rouge-1_cnt'] += 1 + self.metrics['rouge-2_cnt'] += 1 + self.metrics['rouge-L_cnt'] += 1 # Ranking metrics. self._update_ranking_metrics(observation, labels) @@ -232,7 +282,8 @@ def update(self, observation, labels): # User-reported metrics if 'metrics' in observation: for k, v in observation['metrics'].items(): - if k not in ['correct', 'f1', 'hits@k', 'bleu']: + if k not in ['correct', 'f1', 'hits@k', 'bleu', 'rouge-1', + 'rouge-2', 'rouge-L']: if k in self.metrics_list: with self._lock(): self.metrics[k] += v diff --git a/requirements.txt b/requirements.txt index f96b96c2558..78b660bba43 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,4 @@ sphinx_rtd_theme tqdm websocket-client websocket-server +py-rouge diff --git a/tests/test_eval_model.py b/tests/test_eval_model.py index 0b95b2635d5..ecf349d05c7 100644 --- a/tests/test_eval_model.py +++ b/tests/test_eval_model.py @@ -17,7 +17,7 @@ def test_output(self): """Test output of running eval_model""" parser = setup_args() parser.set_defaults( - task='tasks.repeat:RepeatTeacher:10', + task='integration_tests', model='repeat_label', datatype='valid', num_examples=5, @@ -30,14 +30,17 @@ def test_output(self): # decode the output scores = str_output.split("\n---\n") + for i in range(1, len(scores)): score = ast.literal_eval(scores[i]) # check totals - self.assertTrue(score['exs'] == i, - "Total is incorrect") + self.assertEqual(score['exs'], i, "Total is incorrect") # accuracy should be one - self.assertTrue(score['accuracy'] == 1, - "accuracy != 1") + self.assertEqual(score['accuracy'], 1, "accuracy != 1") + if 'rouge-1' in score: + self.assertEqual(score['rouge-1'], 1, 'rouge1 != 1') + self.assertEqual(score['rouge-2'], 1, 'rouge-2 != 1') + self.assertEqual(score['rouge-L'], 1, 'rouge-L != 1') if __name__ == '__main__':