diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index cbe851e97e9a..112f4e9d5270 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -14,6 +14,7 @@ # limitations under the License. +import collections import copy import gc import inspect @@ -2448,6 +2449,58 @@ def test_speculative_sampling(self): self.assertTrue(n_matches.item() == 2) self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8]) + def test_speculative_sampling_target_distribution(self): + """ + Asserts that the target distribution is preserved. + Should help with catching issues like #32867. + """ + # assume vocab size 10, input length 5 + 3 generated candidates + candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens + candidate_logits = torch.tensor( + [ + [ + [-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 1 + [-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 4 + [-10.0, -10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0], # generated 5 + ] + ] + ) + candidate_length = 3 + inf = float("inf") + new_logits = torch.tensor( + [ + [ + # accepts 1: + [-inf, 10.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], + # accepts 4: + [-inf, -inf, -inf, -inf, 10.0, -inf, -inf, -inf, -inf, -inf], + # most likely to be 1 or 8, less likely to be 3, then 7, and should never be any other value: + [-inf, 2.0, -inf, 1.0, -inf, -inf, -inf, -0.01, 2.0, -inf], + # N/A: + [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], + ] + ] + ) + last_assistant_token_is_eos = False + last_validated_token = [] + for _ in range(10_000): + validated_tokens, n_matches = _speculative_sampling( + candidate_input_ids, + candidate_logits, + candidate_length, + new_logits, + last_assistant_token_is_eos, + ) + self.assertTrue(n_matches.item() == 2) + self.assertTrue(validated_tokens.tolist()[0][0] == 1) + self.assertTrue(validated_tokens.tolist()[0][1] == 4) + self.assertTrue(validated_tokens.tolist()[0][2] in [1, 3, 7, 8]) + last_validated_token.append(validated_tokens.tolist()[0][2]) + # check that the most likely tokens are selected more often than the less likely ones + last_token_counts = collections.Counter(last_validated_token) + self.assertTrue(last_token_counts[1] > last_token_counts[3] > last_token_counts[7] > 0) + self.assertTrue(last_token_counts[8] > last_token_counts[3]) + @pytest.mark.generate @require_torch