From 3b717c8e60fdfdb33bb47f9285a36b6f99d79cf4 Mon Sep 17 00:00:00 2001 From: Nadav Timor Date: Mon, 25 Nov 2024 23:29:57 -0500 Subject: [PATCH] fix test_generated_length_assisted_generation --- tests/generation/test_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 0605ea793971..867df0b004a9 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3365,7 +3365,14 @@ def test_generated_length_assisted_generation(self): assistant_model=assistant, min_new_tokens=10, ) - self.assertTrue((input_length + 10) <= out.shape[-1] <= 20) + self.assertTrue((input_length + 10) <= out.shape[-1]) + + out = model.generate( + input_ids, + assistant_model=assistant, + max_new_tokens=7, + ) + self.assertTrue(out.shape[-1] <= (input_length + 7)) def test_model_kwarg_assisted_decoding_decoder_only(self): # PT-only test: TF doesn't support assisted decoding yet.