fix: explicitly set device to CPU for RNG state tensor#2344
Conversation
When torch.set_default_device('cuda') is called, torch.tensor() creates
tensors on CUDA by default. However, PyTorch's generator.set_state()
requires the state to be a CPU ByteTensor, causing a TypeError:
'RNG state must be a torch.ByteTensor'.
This fix explicitly sets device=torch.device('cpu') when creating the
RNG state tensor in get_seed_and_offset(), ensuring compatibility
regardless of the default device setting.
Also adds a regression test to verify the fix works correctly when
torch.set_default_device('cuda') is set.
Fixes #2333
Co-authored-by: Zihao Ye <yzh119@users.noreply.github.com>
📝 WalkthroughWalkthroughFixed an RNG state device issue where Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @cyx-6, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical compatibility issue where PyTorch's random number generator state management conflicted with a global default device setting of CUDA. By explicitly directing the RNG state tensor to the CPU, the fix ensures that sampling operations function reliably across different device configurations, preventing runtime errors and improving robustness. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly addresses the TypeError that occurs when using sampling functions with a default CUDA device. The fix, which explicitly sets the RNG state tensor to the CPU, is appropriate. The inclusion of a regression test is a great addition to prevent this issue from recurring. I have one suggestion to make the new test more robust by ensuring it is skipped in environments where a CUDA device is not available.
| assert torch.all(samples < vocab_size) and torch.all(samples >= 0) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("batch_size", [1, 19, 99]) |
There was a problem hiding this comment.
The new test test_sampling_with_default_device_cuda will fail if run on a machine without a CUDA device, as torch.set_default_device("cuda") will raise an error. It's a good practice to add a pytest.mark.skipif decorator to skip tests that require specific hardware when it's not available. This will make your test suite more robust across different environments.
| @pytest.mark.parametrize("batch_size", [1, 19, 99]) | |
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA device not available") | |
| @pytest.mark.parametrize("batch_size", [1, 19, 99]) |
Code ReviewSummaryThis PR fixes a bug where sampling functions would fail with TypeError: RNG state must be a torch.ByteTensor when torch.set_default_device("cuda") is set. The fix is correct and minimal. Strengths
Code Quality
Potential ConsiderationsMinor suggestion (optional): Consider testing top_k_sampling_from_probs and sampling_from_logits as well for complete coverage, though the current test coverage is sufficient since all sampling functions go through the same get_seed_and_offset code path. Verification: I confirmed that:
Security & Best Practices
Final VerdictLGTM - Ready to merge This is an exemplary bug fix:
Great work on the fix and the regression test! |
Code ReviewSummaryThis PR fixes a critical bug (issue #2333) where sampling functions would crash with Strengths
Code Quality✅ Correctness: The fix is correct. PyTorch generators always operate on CPU regardless of the generator's device, and ✅ Performance: No performance impact - tensor creation is only a 2-element int64 tensor on CPU. ✅ Backwards compatibility: Fully backwards compatible. When default device is CPU (the common case), this is a no-op. ✅ Testing: Comprehensive regression test prevents future breakage. ✅ Style: Follows FlashInfer conventions (proper formatting, matches existing code style). Observations
Potential ConsiderationsMinor suggestion (non-blocking): Consider adding a comment in the code explaining why # PyTorch requires generator state to be a CPU ByteTensor, regardless of generator device
# See: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/core/Generator.h#L178
generator.set_state(
torch.tensor(
[seed, offset], dtype=torch.int64, device=torch.device("cpu")
).view(torch.uint8)
)However, this is optional - the code is clear enough without it, and the test documentation explains the issue. RecommendationLGTM - Approve and merge ✅ This is a well-executed bug fix that:
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
flashinfer/sampling.py (1)
33-49: Good fix; consider avoidingtorch.tensor([seed, offset], ...)with tensor elementsFor robustness/perf, build the state from Python ints (or a stacked tensor) rather than a list of 0-d tensors.
Proposed tweak (same behavior, avoids tensor-from-tensor list construction)
def get_seed_and_offset( increment: int, generator: Optional[torch.Generator] = None, device: Optional[torch.device] = None, ) -> Tuple[int, int]: @@ seed, offset = state.view(torch.int64) offset += (increment + 3) // 4 * 4 + seed_i, offset_i = int(seed), int(offset) generator.set_state( torch.tensor( - [seed, offset], dtype=torch.int64, device=torch.device("cpu") + [seed_i, offset_i], dtype=torch.int64, device=torch.device("cpu") ).view(torch.uint8) ) - return int(seed), int(offset) + return seed_i, offset_i
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/sampling.pytests/utils/test_sampling.py
🧰 Additional context used
📓 Path-based instructions (2)
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/utils/test_sampling.py
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/sampling.py
🧠 Learnings (1)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures
Applied to files:
tests/utils/test_sampling.py
🧬 Code graph analysis (1)
tests/utils/test_sampling.py (1)
flashinfer/sampling.py (7)
top_k_top_p_sampling_from_logits(1027-1156)softmax(57-77)softmax(542-596)sampling_from_probs(134-157)sampling_from_probs(667-736)top_p_sampling_from_probs(175-205)top_p_sampling_from_probs(740-833)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Deploy Docs
- GitHub Check: claude-review
| @pytest.mark.parametrize("batch_size", [1, 19, 99]) | ||
| @pytest.mark.parametrize("vocab_size", [111, 32000]) | ||
| def test_sampling_with_default_device_cuda(batch_size, vocab_size): | ||
| """Test that sampling works correctly when torch.set_default_device("cuda") is set. | ||
|
|
||
| This is a regression test for issue #2333 where generator.set_state() would fail | ||
| with "RNG state must be a torch.ByteTensor" error when the default device is CUDA. | ||
| """ | ||
| torch.manual_seed(42) | ||
| original_device = torch.get_default_device() | ||
| try: | ||
| # Set default device to CUDA | ||
| torch.set_default_device("cuda") | ||
|
|
||
| # Create logits and test top_k_top_p_sampling_from_logits | ||
| logits = torch.randn(batch_size, vocab_size, device="cuda:0") | ||
|
|
||
| # This should not raise "RNG state must be a torch.ByteTensor" error | ||
| samples = flashinfer.sampling.top_k_top_p_sampling_from_logits( | ||
| logits, top_k=100, top_p=0.9 | ||
| ) | ||
|
|
||
| assert samples.shape == (batch_size,) | ||
| assert torch.all(samples < vocab_size) and torch.all(samples >= 0) | ||
|
|
||
| # Also test other sampling functions | ||
| probs = torch.softmax(logits, dim=-1) | ||
|
|
||
| samples = flashinfer.sampling.sampling_from_probs(probs) | ||
| assert samples.shape == (batch_size,) | ||
| assert torch.all(samples < vocab_size) and torch.all(samples >= 0) | ||
|
|
||
| samples = flashinfer.sampling.top_p_sampling_from_probs(probs, 0.9) | ||
| assert samples.shape == (batch_size,) | ||
| assert torch.all(samples < vocab_size) and torch.all(samples >= 0) | ||
|
|
||
| finally: | ||
| # Restore original default device | ||
| torch.set_default_device(original_device) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
torch.set_default_device PyTorch version introduced when
💡 Result:
torch.set_default_device was introduced as a (beta) feature in PyTorch 2.0. [1][2]
Sources:
- PyTorch 2.0 release notes (mentions torch.set_default_device). [1]
- torch.set_default_device documentation in PyTorch docs. [2]
🏁 Script executed:
# Check repo's minimum PyTorch requirement
find . -name "setup.py" -o -name "setup.cfg" -o -name "pyproject.toml" -o -name "requirements*.txt" | head -10Repository: flashinfer-ai/flashinfer
Length of output: 199
🏁 Script executed:
# Look at conftest.py to understand existing test guards
cat tests/conftest.py 2>/dev/null | head -100Repository: flashinfer-ai/flashinfer
Length of output: 3644
🏁 Script executed:
# Check other tests in test_sampling.py to see existing patterns for CUDA/arch guards
rg "pytest.skip|torch.cuda.is_available|get_compute_capability" tests/utils/test_sampling.py -A 2 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 1356
🏁 Script executed:
# Search for uses of flashinfer.utils in tests for arch-based skips
rg "flashinfer.utils|is_sm90a_supported|is_sm100a_supported|get_compute_capability" tests/ -lRepository: flashinfer-ai/flashinfer
Length of output: 2121
🏁 Script executed:
# Check minimum PyTorch version requirement
cat pyproject.toml | grep -A 20 -i "torch" | head -40Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check beginning of test_sampling.py for module-level setup and imports
head -50 tests/utils/test_sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 1453
🏁 Script executed:
# Check if there's a module-level pytestmark or skip decorator
rg "pytestmark|skip_if" tests/utils/test_sampling.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for flashinfer.utils usage in tests to understand the pattern
rg "from flashinfer.utils import|flashinfer.utils\." tests/ -A 2 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 2209
🏁 Script executed:
# Check the last part of conftest to see if there's automatic CUDA/arch skipping
tail -100 tests/conftest.pyRepository: flashinfer-ai/flashinfer
Length of output: 4140
🏁 Script executed:
# Try to find PyTorch version requirement - check pyproject.toml more carefully
cat pyproject.toml | head -150Repository: flashinfer-ai/flashinfer
Length of output: 3136
🏁 Script executed:
# Also check if there's any torch version check decorator or pattern used elsewhere
rg "torch.*version|TorchVersion" tests/ -B 2 -A 2 | head -60Repository: flashinfer-ai/flashinfer
Length of output: 1304
🏁 Script executed:
# Check how other tests handle CUDA availability - look for cuda checks
rg "cuda.is_available|device.*cuda" tests/utils/test_sampling.py -B 2 -A 2 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1869
🏁 Script executed:
# Look for decorators or patterns in tests for architecture-specific skips
rg "@pytest.mark|skipif|skip_if" tests/utils/ -A 1 | grep -A 1 "cuda\|arch\|sm90\|compute" | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1468
🏁 Script executed:
# Check requirements.txt for minimum torch version
cat requirements.txtRepository: flashinfer-ai/flashinfer
Length of output: 233
🏁 Script executed:
# Look at how test_logging.py uses the CUDA skip pattern since it was referenced
rg -B 5 -A 10 "skipif.*cuda.is_available" tests/utils/test_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 690
🏁 Script executed:
# Check if test_sampling.py has any module-level markers or decorators
head -100 tests/utils/test_sampling.py | rg "pytestmark|import pytest|@pytest"Repository: flashinfer-ai/flashinfer
Length of output: 537
🏁 Script executed:
# Verify the exact PyTorch version when set_default_device was introduced with hasattr pattern
rg "hasattr.*set_default_device|getattr.*set_default_device" . --type pyRepository: flashinfer-ai/flashinfer
Length of output: 50
Add CUDA and API version guards for the default-device CUDA regression test
This test manipulates global PyTorch state and requires APIs introduced in PyTorch 2.0. Since requirements.txt specifies no minimum PyTorch version, add explicit guards to skip cleanly on unsupported environments:
Recommended changes
def test_sampling_with_default_device_cuda(batch_size, vocab_size):
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA is required for this test")
+ if not (hasattr(torch, "set_default_device") and hasattr(torch, "get_default_device")):
+ pytest.skip("torch.set_default_device / torch.get_default_device not available (requires PyTorch 2.0+)")
+
torch.manual_seed(42)
original_device = torch.get_default_device()The pattern mirrors existing tests in the repo (e.g., tests/utils/test_logging.py). The API guard is important because these functions were only introduced as beta features in PyTorch 2.0.
Fixes #2333
Summary
Fixed the
TypeError: RNG state must be a torch.ByteTensorerror that occurred when using sampling functions withtorch.set_default_device("cuda")enabled.Root Cause
PyTorch's
generator.set_state()requires the state tensor to be a CPU ByteTensor. Whentorch.set_default_device("cuda")is set,torch.tensor()creates tensors on CUDA by default, causing the error.Changes
flashinfer/sampling.py:45to explicitly setdevice=torch.device("cpu")when creating the RNG state tensortest_sampling_with_default_device_cuda()intests/utils/test_sampling.pyTesting
The new test verifies that all sampling functions work correctly when
torch.set_default_device("cuda")is active.🤖 Generated with Claude Code
Summary by CodeRabbit
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.