Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion flashinfer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def get_seed_and_offset(
seed, offset = state.view(torch.int64)
offset += (increment + 3) // 4 * 4
generator.set_state(
torch.tensor([seed, offset], dtype=torch.int64).view(torch.uint8)
torch.tensor(
[seed, offset], dtype=torch.int64, device=torch.device("cpu")
).view(torch.uint8)
)
return int(seed), int(offset)

Expand Down
41 changes: 41 additions & 0 deletions tests/utils/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,47 @@ def test_int64_indices_sampling(batch_size, vocab_size, sampling_type, indices_d
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)


@pytest.mark.parametrize("batch_size", [1, 19, 99])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new test test_sampling_with_default_device_cuda will fail if run on a machine without a CUDA device, as torch.set_default_device("cuda") will raise an error. It's a good practice to add a pytest.mark.skipif decorator to skip tests that require specific hardware when it's not available. This will make your test suite more robust across different environments.

Suggested change
@pytest.mark.parametrize("batch_size", [1, 19, 99])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA device not available")
@pytest.mark.parametrize("batch_size", [1, 19, 99])

@pytest.mark.parametrize("vocab_size", [111, 32000])
def test_sampling_with_default_device_cuda(batch_size, vocab_size):
"""Test that sampling works correctly when torch.set_default_device("cuda") is set.

This is a regression test for issue #2333 where generator.set_state() would fail
with "RNG state must be a torch.ByteTensor" error when the default device is CUDA.
"""
torch.manual_seed(42)
original_device = torch.get_default_device()
try:
# Set default device to CUDA
torch.set_default_device("cuda")

# Create logits and test top_k_top_p_sampling_from_logits
logits = torch.randn(batch_size, vocab_size, device="cuda:0")

# This should not raise "RNG state must be a torch.ByteTensor" error
samples = flashinfer.sampling.top_k_top_p_sampling_from_logits(
logits, top_k=100, top_p=0.9
)

assert samples.shape == (batch_size,)
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)

# Also test other sampling functions
probs = torch.softmax(logits, dim=-1)

samples = flashinfer.sampling.sampling_from_probs(probs)
assert samples.shape == (batch_size,)
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)

samples = flashinfer.sampling.top_p_sampling_from_probs(probs, 0.9)
assert samples.shape == (batch_size,)
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)

finally:
# Restore original default device
torch.set_default_device(original_device)

Comment on lines +933 to +972
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

torch.set_default_device PyTorch version introduced when

💡 Result:

torch.set_default_device was introduced as a (beta) feature in PyTorch 2.0. [1][2]

Sources:

  • PyTorch 2.0 release notes (mentions torch.set_default_device). [1]
  • torch.set_default_device documentation in PyTorch docs. [2]

🏁 Script executed:

# Check repo's minimum PyTorch requirement
find . -name "setup.py" -o -name "setup.cfg" -o -name "pyproject.toml" -o -name "requirements*.txt" | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 199


🏁 Script executed:

# Look at conftest.py to understand existing test guards
cat tests/conftest.py 2>/dev/null | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 3644


🏁 Script executed:

# Check other tests in test_sampling.py to see existing patterns for CUDA/arch guards
rg "pytest.skip|torch.cuda.is_available|get_compute_capability" tests/utils/test_sampling.py -A 2 -B 2

Repository: flashinfer-ai/flashinfer

Length of output: 1356


🏁 Script executed:

# Search for uses of flashinfer.utils in tests for arch-based skips
rg "flashinfer.utils|is_sm90a_supported|is_sm100a_supported|get_compute_capability" tests/ -l

Repository: flashinfer-ai/flashinfer

Length of output: 2121


🏁 Script executed:

# Check minimum PyTorch version requirement
cat pyproject.toml | grep -A 20 -i "torch" | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check beginning of test_sampling.py for module-level setup and imports
head -50 tests/utils/test_sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 1453


🏁 Script executed:

# Check if there's a module-level pytestmark or skip decorator
rg "pytestmark|skip_if" tests/utils/test_sampling.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look for flashinfer.utils usage in tests to understand the pattern
rg "from flashinfer.utils import|flashinfer.utils\." tests/ -A 2 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 2209


🏁 Script executed:

# Check the last part of conftest to see if there's automatic CUDA/arch skipping
tail -100 tests/conftest.py

Repository: flashinfer-ai/flashinfer

Length of output: 4140


🏁 Script executed:

# Try to find PyTorch version requirement - check pyproject.toml more carefully
cat pyproject.toml | head -150

Repository: flashinfer-ai/flashinfer

Length of output: 3136


🏁 Script executed:

# Also check if there's any torch version check decorator or pattern used elsewhere
rg "torch.*version|TorchVersion" tests/ -B 2 -A 2 | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 1304


🏁 Script executed:

# Check how other tests handle CUDA availability - look for cuda checks
rg "cuda.is_available|device.*cuda" tests/utils/test_sampling.py -B 2 -A 2 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1869


🏁 Script executed:

# Look for decorators or patterns in tests for architecture-specific skips
rg "@pytest.mark|skipif|skip_if" tests/utils/ -A 1 | grep -A 1 "cuda\|arch\|sm90\|compute" | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1468


🏁 Script executed:

# Check requirements.txt for minimum torch version
cat requirements.txt

Repository: flashinfer-ai/flashinfer

Length of output: 233


🏁 Script executed:

# Look at how test_logging.py uses the CUDA skip pattern since it was referenced
rg -B 5 -A 10 "skipif.*cuda.is_available" tests/utils/test_logging.py

Repository: flashinfer-ai/flashinfer

Length of output: 690


🏁 Script executed:

# Check if test_sampling.py has any module-level markers or decorators
head -100 tests/utils/test_sampling.py | rg "pytestmark|import pytest|@pytest"

Repository: flashinfer-ai/flashinfer

Length of output: 537


🏁 Script executed:

# Verify the exact PyTorch version when set_default_device was introduced with hasattr pattern
rg "hasattr.*set_default_device|getattr.*set_default_device" . --type py

Repository: flashinfer-ai/flashinfer

Length of output: 50


Add CUDA and API version guards for the default-device CUDA regression test

This test manipulates global PyTorch state and requires APIs introduced in PyTorch 2.0. Since requirements.txt specifies no minimum PyTorch version, add explicit guards to skip cleanly on unsupported environments:

Recommended changes
 def test_sampling_with_default_device_cuda(batch_size, vocab_size):
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA is required for this test")
+    if not (hasattr(torch, "set_default_device") and hasattr(torch, "get_default_device")):
+        pytest.skip("torch.set_default_device / torch.get_default_device not available (requires PyTorch 2.0+)")
+
     torch.manual_seed(42)
     original_device = torch.get_default_device()

The pattern mirrors existing tests in the repo (e.g., tests/utils/test_logging.py). The API guard is important because these functions were only introduced as beta features in PyTorch 2.0.


if __name__ == "__main__":
# test_sampling_freq(128256, gumbel_distribution(0.1), 0.5)
test_sampling_from_logits_freq(128256, gumbel_distribution(0.1))
Expand Down