Skip to content
Merged
53 changes: 53 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.


import collections
import copy
import gc
import inspect
Expand Down Expand Up @@ -2448,6 +2449,58 @@ def test_speculative_sampling(self):
self.assertTrue(n_matches.item() == 2)
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])

def test_speculative_sampling_target_distribution(self):
"""
Asserts that the target distribution is preserved.
Should help with catching issues like #32867.
"""
# assume vocab size 10, input length 5 + 3 generated candidates
candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens
candidate_logits = torch.tensor(
[
[
[-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 1
[-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 4
[-10.0, -10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0], # generated 5
]
]
)
candidate_length = 3
inf = float("inf")
new_logits = torch.tensor(
[
[
# accepts 1:
[-inf, 10.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
# accepts 4:
[-inf, -inf, -inf, -inf, 10.0, -inf, -inf, -inf, -inf, -inf],
# most likely to be 1 or 8, less likely to be 3, then 7, and should never be any other value:
[-inf, 2.0, -inf, 1.0, -inf, -inf, -inf, -0.01, 2.0, -inf],
# N/A:
[-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
]
]
)
last_assistant_token_is_eos = False
last_validated_token = []
for _ in range(10_000):
validated_tokens, n_matches = _speculative_sampling(
candidate_input_ids,
candidate_logits,
candidate_length,
new_logits,
last_assistant_token_is_eos,
)
self.assertTrue(n_matches.item() == 2)
self.assertTrue(validated_tokens.tolist()[0][0] == 1)
self.assertTrue(validated_tokens.tolist()[0][1] == 4)
self.assertTrue(validated_tokens.tolist()[0][2] in [1, 3, 7, 8])
last_validated_token.append(validated_tokens.tolist()[0][2])
# check that the most likely tokens are selected more often than the less likely ones
last_token_counts = collections.Counter(last_validated_token)
self.assertTrue(last_token_counts[1] > last_token_counts[3] > last_token_counts[7] > 0)
self.assertTrue(last_token_counts[8] > last_token_counts[3])


@pytest.mark.generate
@require_torch
Expand Down