-
Notifications
You must be signed in to change notification settings - Fork 836
fix: explicitly set device to CPU for RNG state tensor #2344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -930,6 +930,47 @@ def test_int64_indices_sampling(batch_size, vocab_size, sampling_type, indices_d | |
| assert torch.all(samples < vocab_size) and torch.all(samples >= 0) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("batch_size", [1, 19, 99]) | ||
| @pytest.mark.parametrize("vocab_size", [111, 32000]) | ||
| def test_sampling_with_default_device_cuda(batch_size, vocab_size): | ||
| """Test that sampling works correctly when torch.set_default_device("cuda") is set. | ||
|
|
||
| This is a regression test for issue #2333 where generator.set_state() would fail | ||
| with "RNG state must be a torch.ByteTensor" error when the default device is CUDA. | ||
| """ | ||
| torch.manual_seed(42) | ||
| original_device = torch.get_default_device() | ||
| try: | ||
| # Set default device to CUDA | ||
| torch.set_default_device("cuda") | ||
|
|
||
| # Create logits and test top_k_top_p_sampling_from_logits | ||
| logits = torch.randn(batch_size, vocab_size, device="cuda:0") | ||
|
|
||
| # This should not raise "RNG state must be a torch.ByteTensor" error | ||
| samples = flashinfer.sampling.top_k_top_p_sampling_from_logits( | ||
| logits, top_k=100, top_p=0.9 | ||
| ) | ||
|
|
||
| assert samples.shape == (batch_size,) | ||
| assert torch.all(samples < vocab_size) and torch.all(samples >= 0) | ||
|
|
||
| # Also test other sampling functions | ||
| probs = torch.softmax(logits, dim=-1) | ||
|
|
||
| samples = flashinfer.sampling.sampling_from_probs(probs) | ||
| assert samples.shape == (batch_size,) | ||
| assert torch.all(samples < vocab_size) and torch.all(samples >= 0) | ||
|
|
||
| samples = flashinfer.sampling.top_p_sampling_from_probs(probs, 0.9) | ||
| assert samples.shape == (batch_size,) | ||
| assert torch.all(samples < vocab_size) and torch.all(samples >= 0) | ||
|
|
||
| finally: | ||
| # Restore original default device | ||
| torch.set_default_device(original_device) | ||
|
|
||
|
Comment on lines
+933
to
+972
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🌐 Web query:
💡 Result: torch.set_default_device was introduced as a (beta) feature in PyTorch 2.0. [1][2] Sources:
🏁 Script executed: # Check repo's minimum PyTorch requirement
find . -name "setup.py" -o -name "setup.cfg" -o -name "pyproject.toml" -o -name "requirements*.txt" | head -10Repository: flashinfer-ai/flashinfer Length of output: 199 🏁 Script executed: # Look at conftest.py to understand existing test guards
cat tests/conftest.py 2>/dev/null | head -100Repository: flashinfer-ai/flashinfer Length of output: 3644 🏁 Script executed: # Check other tests in test_sampling.py to see existing patterns for CUDA/arch guards
rg "pytest.skip|torch.cuda.is_available|get_compute_capability" tests/utils/test_sampling.py -A 2 -B 2Repository: flashinfer-ai/flashinfer Length of output: 1356 🏁 Script executed: # Search for uses of flashinfer.utils in tests for arch-based skips
rg "flashinfer.utils|is_sm90a_supported|is_sm100a_supported|get_compute_capability" tests/ -lRepository: flashinfer-ai/flashinfer Length of output: 2121 🏁 Script executed: # Check minimum PyTorch version requirement
cat pyproject.toml | grep -A 20 -i "torch" | head -40Repository: flashinfer-ai/flashinfer Length of output: 50 🏁 Script executed: # Check beginning of test_sampling.py for module-level setup and imports
head -50 tests/utils/test_sampling.pyRepository: flashinfer-ai/flashinfer Length of output: 1453 🏁 Script executed: # Check if there's a module-level pytestmark or skip decorator
rg "pytestmark|skip_if" tests/utils/test_sampling.py | head -20Repository: flashinfer-ai/flashinfer Length of output: 50 🏁 Script executed: # Look for flashinfer.utils usage in tests to understand the pattern
rg "from flashinfer.utils import|flashinfer.utils\." tests/ -A 2 | head -50Repository: flashinfer-ai/flashinfer Length of output: 2209 🏁 Script executed: # Check the last part of conftest to see if there's automatic CUDA/arch skipping
tail -100 tests/conftest.pyRepository: flashinfer-ai/flashinfer Length of output: 4140 🏁 Script executed: # Try to find PyTorch version requirement - check pyproject.toml more carefully
cat pyproject.toml | head -150Repository: flashinfer-ai/flashinfer Length of output: 3136 🏁 Script executed: # Also check if there's any torch version check decorator or pattern used elsewhere
rg "torch.*version|TorchVersion" tests/ -B 2 -A 2 | head -60Repository: flashinfer-ai/flashinfer Length of output: 1304 🏁 Script executed: # Check how other tests handle CUDA availability - look for cuda checks
rg "cuda.is_available|device.*cuda" tests/utils/test_sampling.py -B 2 -A 2 | head -40Repository: flashinfer-ai/flashinfer Length of output: 1869 🏁 Script executed: # Look for decorators or patterns in tests for architecture-specific skips
rg "@pytest.mark|skipif|skip_if" tests/utils/ -A 1 | grep -A 1 "cuda\|arch\|sm90\|compute" | head -30Repository: flashinfer-ai/flashinfer Length of output: 1468 🏁 Script executed: # Check requirements.txt for minimum torch version
cat requirements.txtRepository: flashinfer-ai/flashinfer Length of output: 233 🏁 Script executed: # Look at how test_logging.py uses the CUDA skip pattern since it was referenced
rg -B 5 -A 10 "skipif.*cuda.is_available" tests/utils/test_logging.pyRepository: flashinfer-ai/flashinfer Length of output: 690 🏁 Script executed: # Check if test_sampling.py has any module-level markers or decorators
head -100 tests/utils/test_sampling.py | rg "pytestmark|import pytest|@pytest"Repository: flashinfer-ai/flashinfer Length of output: 537 🏁 Script executed: # Verify the exact PyTorch version when set_default_device was introduced with hasattr pattern
rg "hasattr.*set_default_device|getattr.*set_default_device" . --type pyRepository: flashinfer-ai/flashinfer Length of output: 50 Add CUDA and API version guards for the default-device CUDA regression test This test manipulates global PyTorch state and requires APIs introduced in PyTorch 2.0. Since Recommended changes def test_sampling_with_default_device_cuda(batch_size, vocab_size):
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA is required for this test")
+ if not (hasattr(torch, "set_default_device") and hasattr(torch, "get_default_device")):
+ pytest.skip("torch.set_default_device / torch.get_default_device not available (requires PyTorch 2.0+)")
+
torch.manual_seed(42)
original_device = torch.get_default_device()The pattern mirrors existing tests in the repo (e.g., |
||
|
|
||
| if __name__ == "__main__": | ||
| # test_sampling_freq(128256, gumbel_distribution(0.1), 0.5) | ||
| test_sampling_from_logits_freq(128256, gumbel_distribution(0.1)) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new test
test_sampling_with_default_device_cudawill fail if run on a machine without a CUDA device, astorch.set_default_device("cuda")will raise an error. It's a good practice to add apytest.mark.skipifdecorator to skip tests that require specific hardware when it's not available. This will make your test suite more robust across different environments.