diff --git a/tests/test_generate.py b/tests/test_generate.py index 9214281cd..4f5bb4c91 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -402,6 +402,25 @@ def test_batch_generate_with_logits_processors(self): self.assertEqual(responses[uid1].logprobs[1].item(), 0.0) self.assertEqual(responses[uid2].logprobs[2].item(), 0.0) + def test_batch_generate_processor_tokens_match_prompt_on_first_step(self): + prompt = self.tokenizer.encode("hello") + seen = [] + + def processor(tokens, logits): + seen.append(tokens) + return logits + + batch_gen = BatchGenerator( + self.model, + max_tokens=1, + logits_processors=[processor], + ) + batch_gen.insert([prompt]) + batch_gen.next_generated() + + self.assertTrue(hasattr(seen[0], "shape")) + self.assertEqual(seen[0].tolist(), prompt) + def test_batch_generate_function_with_logits_processors(self): """Test that batch_generate function with logits_processors produces correct results.""" logit_bias = {0: 2000.0, 1: -2000.0}