diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 02979574af..8571016a15 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -42,7 +42,9 @@ def get_seed_and_offset( seed, offset = state.view(torch.int64) offset += (increment + 3) // 4 * 4 generator.set_state( - torch.tensor([seed, offset], dtype=torch.int64).view(torch.uint8) + torch.tensor( + [seed, offset], dtype=torch.int64, device=torch.device("cpu") + ).view(torch.uint8) ) return int(seed), int(offset) diff --git a/tests/utils/test_sampling.py b/tests/utils/test_sampling.py index 99ff6a3e2b..fdd09c79c4 100644 --- a/tests/utils/test_sampling.py +++ b/tests/utils/test_sampling.py @@ -930,6 +930,47 @@ def test_int64_indices_sampling(batch_size, vocab_size, sampling_type, indices_d assert torch.all(samples < vocab_size) and torch.all(samples >= 0) +@pytest.mark.parametrize("batch_size", [1, 19, 99]) +@pytest.mark.parametrize("vocab_size", [111, 32000]) +def test_sampling_with_default_device_cuda(batch_size, vocab_size): + """Test that sampling works correctly when torch.set_default_device("cuda") is set. + + This is a regression test for issue #2333 where generator.set_state() would fail + with "RNG state must be a torch.ByteTensor" error when the default device is CUDA. + """ + torch.manual_seed(42) + original_device = torch.get_default_device() + try: + # Set default device to CUDA + torch.set_default_device("cuda") + + # Create logits and test top_k_top_p_sampling_from_logits + logits = torch.randn(batch_size, vocab_size, device="cuda:0") + + # This should not raise "RNG state must be a torch.ByteTensor" error + samples = flashinfer.sampling.top_k_top_p_sampling_from_logits( + logits, top_k=100, top_p=0.9 + ) + + assert samples.shape == (batch_size,) + assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + + # Also test other sampling functions + probs = torch.softmax(logits, dim=-1) + + samples = flashinfer.sampling.sampling_from_probs(probs) + assert samples.shape == (batch_size,) + assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + + samples = flashinfer.sampling.top_p_sampling_from_probs(probs, 0.9) + assert samples.shape == (batch_size,) + assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + + finally: + # Restore original default device + torch.set_default_device(original_device) + + if __name__ == "__main__": # test_sampling_freq(128256, gumbel_distribution(0.1), 0.5) test_sampling_from_logits_freq(128256, gumbel_distribution(0.1))