From cfc15177be9862afd809b2f86217d1dde95e174e Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 19 Sep 2024 18:25:12 +0000 Subject: [PATCH] almost zero is not zero --- tests/generation/test_utils.py | 40 ++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 413e920609a8..26ece9c25d06 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1647,26 +1647,42 @@ def test_generate_from_inputs_embeds_decoder_only(self): continue # Traditional way of generating text - outputs_from_ids = model.generate(input_ids, max_new_tokens=5) - self.assertEqual(outputs_from_ids.shape, (input_ids.shape[0], input_ids.shape[1] + 5)) + outputs_from_ids = model.generate( + input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True + ) + self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5)) # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output) inputs_embeds = model.get_input_embeddings()(input_ids) - outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds, max_new_tokens=5) - self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist()) + outputs_from_embeds = model.generate( + input_ids, + inputs_embeds=inputs_embeds, + max_new_tokens=5, + return_dict_in_generate=True, + output_scores=True, + ) + self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist()) - # But if we pass different inputs_embeds, we should get different outputs - torch.manual_seed(0) + # But if we pass different inputs_embeds, we should get different outputs (the output text may be the + # same, but the logits will almost surely be different) random_embeds = torch.rand_like(inputs_embeds) - outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds, max_new_tokens=5) - with self.assertRaises(AssertionError): - self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist()) + outputs_from_rand_embeds = model.generate( + input_ids, + inputs_embeds=random_embeds, + max_new_tokens=5, + return_dict_in_generate=True, + output_scores=True, + ) + for i in range(len(outputs_from_rand_embeds.scores)): + self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i])) # input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same - outputs_from_embeds_wo_ids = model.generate(inputs_embeds=inputs_embeds, max_new_tokens=5) + outputs_from_embeds_wo_ids = model.generate( + inputs_embeds=inputs_embeds, max_new_tokens=5, return_dict_in_generate=True, output_scores=True + ) self.assertListEqual( - outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(), - outputs_from_embeds_wo_ids.tolist(), + outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(), + outputs_from_embeds_wo_ids.sequences.tolist(), ) @pytest.mark.generate