diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 7617c15efabf..4aad6647aa8b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4842,8 +4842,8 @@ def test_forward_with_num_logits_to_keep(self): self.assertEqual(tuple(all_logits.shape), (batch_size, sequence_length, vocab_size)) self.assertEqual(tuple(last_token_logits.shape), (batch_size, 1, vocab_size)) - # Assert the last tokens are actually the same - self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits)) + # Assert the last tokens are actually the same (except for the natural fluctuation due to order of FP ops) + self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits, atol=1e-5)) global_rng = random.Random()