Skip to content
Merged
60 changes: 60 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.


import collections
import copy
import gc
import inspect
Expand Down Expand Up @@ -2448,6 +2449,65 @@ def test_speculative_sampling(self):
self.assertTrue(n_matches.item() == 2)
self.assertTrue(validated_tokens.tolist()[0] == [1, 4, 8])

def test_speculative_sampling_target_distribution(self):
"""
Asserts that the target distribution is preserved.
Should help with catching issues like #32867.
"""
# assume vocab size 10, input length 5 + 3 generated candidates
candidate_input_ids = torch.tensor([[8, 0, 3, 9, 8, 1, 4, 5]]) # input tokens
candidate_logits = torch.tensor(
[
[
[-10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 1
[-10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0, -10.0], # generated 4
[-10.0, -10.0, -10.0, -10.0, -10.0, 10.0, -10.0, -10.0, -10.0, -10.0], # generated 5
]
]
)
candidate_length = 3
inf = float("inf")
new_logits = torch.tensor(
[
[
[-inf, 10.0, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], # accepts 1
[-inf, -inf, -inf, -inf, 10.0, -inf, -inf, -inf, -inf, -inf], # accepts 4
[
-inf,
2.0,
-inf,
1.0,
-inf,
-inf,
-inf,
-0.01,
2.0,
-inf,
], # most likely to be 1 or 8, less likely to be 3, then 7, and should never be any other value
Copy link
Contributor

@gante gante Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: let's make it in one line, so we can quickly compare indexes with other tensors.

(you'll have to remove the comma after the last -inf, otherwise the make fixup command will make it revert back to this format)

Copy link
Contributor Author

@keyboardAnt keyboardAnt Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the formatting as requested, but ruff's formatting check then failed the CI. (make fixup still reformats it into a column, even after removing the last comma you mentioned)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use # fmt: off and # fmt: on

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ArthurZucker. I changed all these inline comments to block comments, and it solved the issue while keeping the ruff checks on. 👍

[-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], # N/A
]
]
)
last_assistant_token_is_eos = False
last_validated_token = []
for _ in range(10_000):
validated_tokens, n_matches = _speculative_sampling(
candidate_input_ids,
candidate_logits,
candidate_length,
new_logits,
last_assistant_token_is_eos,
)
self.assertTrue(n_matches.item() == 2)
self.assertTrue(validated_tokens.tolist()[0][0] == 1)
self.assertTrue(validated_tokens.tolist()[0][1] == 4)
self.assertTrue(validated_tokens.tolist()[0][2] in [1, 3, 7, 8])
last_validated_token.append(validated_tokens.tolist()[0][2])
# check that the most likely tokens are selected more often than the less likely ones
last_token_counts = collections.Counter(last_validated_token)
self.assertTrue(last_token_counts[1] > last_token_counts[3] > last_token_counts[7] > 0)
self.assertTrue(last_token_counts[8] > last_token_counts[3])


@pytest.mark.generate
@require_torch
Expand Down