diff --git a/tests/test_pipelines_fill_mask.py b/tests/test_pipelines_fill_mask.py index 9499b646943f..bb54c24ba98b 100644 --- a/tests/test_pipelines_fill_mask.py +++ b/tests/test_pipelines_fill_mask.py @@ -90,6 +90,20 @@ def test_torch_fill_mask_with_targets(self): for targets in invalid_targets: self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets) + @require_torch + @slow + def test_torch_fill_mask_targets_equivalence(self): + model_name = self.large_models[0] + unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt") + unmasked = unmasker(self.valid_inputs[0]) + tokens = [top_mask["token_str"] for top_mask in unmasked] + scores = [top_mask["score"] for top_mask in unmasked] + + unmasked_targets = unmasker(self.valid_inputs[0], targets=tokens) + target_scores = [top_mask["score"] for top_mask in unmasked_targets] + + self.assertEqual(scores, target_scores) + @require_torch def test_torch_fill_mask_with_targets_and_topk(self): model_name = self.small_models[0] @@ -287,3 +301,17 @@ def test_tf_fill_mask_results(self): self.assertIn(key, result) self.assertRaises(Exception, unmasker, [None]) + + @require_tf + @slow + def test_tf_fill_mask_targets_equivalence(self): + model_name = self.large_models[0] + unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf") + unmasked = unmasker(self.valid_inputs[0]) + tokens = [top_mask["token_str"] for top_mask in unmasked] + scores = [top_mask["score"] for top_mask in unmasked] + + unmasked_targets = unmasker(self.valid_inputs[0], targets=tokens) + target_scores = [top_mask["score"] for top_mask in unmasked_targets] + + self.assertEqual(scores, target_scores)