Skip to content

fix: explicitly set device to CPU for RNG state tensor#2344

Merged
yzh119 merged 2 commits intomainfrom
claude/issue-2333-20260112-0956
Jan 13, 2026
Merged

fix: explicitly set device to CPU for RNG state tensor#2344
yzh119 merged 2 commits intomainfrom
claude/issue-2333-20260112-0956

Conversation

@cyx-6
Copy link
Copy Markdown
Collaborator

@cyx-6 cyx-6 commented Jan 13, 2026

Fixes #2333

Summary

Fixed the TypeError: RNG state must be a torch.ByteTensor error that occurred when using sampling functions with torch.set_default_device("cuda") enabled.

Root Cause

PyTorch's generator.set_state() requires the state tensor to be a CPU ByteTensor. When torch.set_default_device("cuda") is set, torch.tensor() creates tensors on CUDA by default, causing the error.

Changes

  • Modified flashinfer/sampling.py:45 to explicitly set device=torch.device("cpu") when creating the RNG state tensor
  • Added regression test test_sampling_with_default_device_cuda() in tests/utils/test_sampling.py

Testing

The new test verifies that all sampling functions work correctly when torch.set_default_device("cuda") is active.

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Bug Fixes

    • Fixed device state handling in random number generation for sampling operations to ensure proper CPU/CUDA compatibility.
  • Tests

    • Added comprehensive test coverage for sampling functions with CUDA as the default device.

✏️ Tip: You can customize this high-level summary in your review settings.

When torch.set_default_device('cuda') is called, torch.tensor() creates
tensors on CUDA by default. However, PyTorch's generator.set_state()
requires the state to be a CPU ByteTensor, causing a TypeError:
'RNG state must be a torch.ByteTensor'.

This fix explicitly sets device=torch.device('cpu') when creating the
RNG state tensor in get_seed_and_offset(), ensuring compatibility
regardless of the default device setting.

Also adds a regression test to verify the fix works correctly when
torch.set_default_device('cuda') is set.

Fixes #2333

Co-authored-by: Zihao Ye <yzh119@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 13, 2026

📝 Walkthrough

Walkthrough

Fixed an RNG state device issue where generator.set_state() fails when default PyTorch device is CUDA. The fix explicitly creates the state tensor on CPU before uint8 conversion. Added a test verifying sampling functions operate correctly with CUDA as the default device.

Changes

Cohort / File(s) Summary
Core RNG State Fix
flashinfer/sampling.py
Modified get_seed_and_offset() to explicitly create state tensor on CPU before viewing as uint8, ensuring PyTorch's RNG state validation requirements are met regardless of default device setting
Device-Aware Testing
tests/utils/test_sampling.py
Added test_sampling_with_default_device_cuda() to verify sampling functions (top_k_top_p_sampling_from_logits, sampling_from_probs, top_p_sampling_from_probs) operate without RNG-device errors when CUDA is the default device

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

Suggested reviewers

  • nvmbreughe
  • bkryu
  • jimmyzho

Poem

🐰✨ A quantum leap for CUDA dreams,
No more RNG state device schemes!
CPU tensors keep their place,
While GPUs compute at full pace. 🚀

🚥 Pre-merge checks | ✅ 4 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: explicitly setting the device to CPU for the RNG state tensor to fix a compatibility issue with torch.set_default_device('cuda').
Description check ✅ Passed The description clearly explains the issue, root cause, and solution. It maps to the template with a summary section and mentions the fixes/testing, though formal checklist items are not explicitly marked.
Linked Issues check ✅ Passed The pull request fully addresses issue #2333 by explicitly placing the RNG state tensor on CPU in flashinfer/sampling.py and adding a comprehensive regression test that verifies sampling functions work with torch.set_default_device('cuda').
Out of Scope Changes check ✅ Passed All changes are directly related to fixing the RNG state device issue: the core fix in sampling.py and a focused regression test in test_sampling.py. No unrelated modifications are present.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @cyx-6, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical compatibility issue where PyTorch's random number generator state management conflicted with a global default device setting of CUDA. By explicitly directing the RNG state tensor to the CPU, the fix ensures that sampling operations function reliably across different device configurations, preventing runtime errors and improving robustness.

Highlights

  • Bug Fix: Resolved a TypeError: RNG state must be a torch.ByteTensor that occurred when using sampling functions with torch.set_default_device("cuda") enabled, by explicitly setting the device for the RNG state tensor to CPU.
  • Code Change: Modified flashinfer/sampling.py to ensure the RNG state tensor is created on the CPU, as required by PyTorch's generator.set_state() method.
  • Testing: Added a new regression test, test_sampling_with_default_device_cuda(), in tests/utils/test_sampling.py to verify that sampling functions work correctly when the default device is set to CUDA.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses the TypeError that occurs when using sampling functions with a default CUDA device. The fix, which explicitly sets the RNG state tensor to the CPU, is appropriate. The inclusion of a regression test is a great addition to prevent this issue from recurring. I have one suggestion to make the new test more robust by ensuring it is skipped in environments where a CUDA device is not available.

assert torch.all(samples < vocab_size) and torch.all(samples >= 0)


@pytest.mark.parametrize("batch_size", [1, 19, 99])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new test test_sampling_with_default_device_cuda will fail if run on a machine without a CUDA device, as torch.set_default_device("cuda") will raise an error. It's a good practice to add a pytest.mark.skipif decorator to skip tests that require specific hardware when it's not available. This will make your test suite more robust across different environments.

Suggested change
@pytest.mark.parametrize("batch_size", [1, 19, 99])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA device not available")
@pytest.mark.parametrize("batch_size", [1, 19, 99])

@claude
Copy link
Copy Markdown

claude bot commented Jan 13, 2026

Code Review

Summary

This PR fixes a bug where sampling functions would fail with TypeError: RNG state must be a torch.ByteTensor when torch.set_default_device("cuda") is set. The fix is correct and minimal.


Strengths

  1. Root cause correctly identified: PyTorch's Generator.set_state() requires a CPU ByteTensor, but torch.tensor() respects the default device setting, causing GPU tensors to be created when torch.set_default_device("cuda") is active.

  2. Minimal, surgical fix: The change is a one-line addition that explicitly sets device=torch.device("cpu") in the only location where this issue occurs (flashinfer/sampling.py:45).

  3. Comprehensive test coverage: The new test test_sampling_with_default_device_cuda verifies:

    • Multiple batch sizes (1, 19, 99)
    • Multiple vocab sizes (111, 32000)
    • Three different sampling functions (top_k_top_p_sampling_from_logits, sampling_from_probs, top_p_sampling_from_probs)
    • Proper cleanup with try/finally to restore original device
  4. Good test hygiene: The test properly restores the default device in a finally block, preventing side effects on other tests.

  5. Clear documentation: Excellent docstring explaining this is a regression test for issue generator.set_state raise an error in top_k_top_p_sampling_from_logits #2333, making the purpose crystal clear for future maintainers.


Code Quality

  • Follows FlashInfer conventions: Matches the style of other tests in the file
  • No unnecessary changes: Doesn't refactor unrelated code
  • Zero performance impact: The fix is on a non-hot path (RNG state updates)
  • Framework-agnostic: Maintains FlashInfer's principle of being compatible with PyTorch's various device configurations

Potential Considerations

Minor suggestion (optional): Consider testing top_k_sampling_from_probs and sampling_from_logits as well for complete coverage, though the current test coverage is sufficient since all sampling functions go through the same get_seed_and_offset code path.

Verification: I confirmed that:

  • This is the only location in the codebase where .view(torch.uint8) is called on a tensor that needs to be a CPU ByteTensor
  • All sampling functions eventually call get_seed_and_offset, so fixing it here fixes the issue everywhere
  • The test will run on CUDA-enabled CI systems (no special markers needed)

Security & Best Practices

  • No security concerns: This is a device placement fix
  • No backwards compatibility issues: Explicitly setting device to CPU is always correct here
  • No edge cases: PyTorch's requirement is strict - RNG state must always be CPU

Final Verdict

LGTM - Ready to merge

This is an exemplary bug fix:

  • Correct root cause analysis
  • Minimal change
  • Comprehensive test
  • Clear documentation
  • No side effects

Great work on the fix and the regression test!

@claude
Copy link
Copy Markdown

claude bot commented Jan 13, 2026

Code Review

Summary

This PR fixes a critical bug (issue #2333) where sampling functions would crash with TypeError: RNG state must be a torch.ByteTensor when torch.set_default_device("cuda") is set. The fix is correct and well-tested. ✅


Strengths

  1. Correct root cause analysis: The issue correctly identifies that PyTorch's generator.set_state() requires a CPU ByteTensor, as enforced by PyTorch's C++ code.

  2. Minimal, surgical fix: The change is exactly one line - adding device=torch.device("cpu") to the tensor creation in flashinfer/sampling.py:45. This is the right approach.

  3. Excellent test coverage: The new regression test test_sampling_with_default_device_cuda() is comprehensive:

    • Tests multiple batch sizes and vocab sizes
    • Tests multiple sampling functions (top_k_top_p_sampling_from_logits, sampling_from_probs, top_p_sampling_from_probs)
    • Properly saves/restores the default device in a try-finally block
    • Includes helpful docstring explaining the issue
  4. Good documentation: The commit message and PR description clearly explain the problem, root cause, and solution.


Code Quality

Correctness: The fix is correct. PyTorch generators always operate on CPU regardless of the generator's device, and set_state() explicitly requires a CPU tensor.

Performance: No performance impact - tensor creation is only a 2-element int64 tensor on CPU.

Backwards compatibility: Fully backwards compatible. When default device is CPU (the common case), this is a no-op.

Testing: Comprehensive regression test prevents future breakage.

Style: Follows FlashInfer conventions (proper formatting, matches existing code style).


Observations

  1. Existing behavior preserved: The fix doesn't change behavior for users who don't use torch.set_default_device() - it just makes the implicit assumption explicit.

  2. Generator state handling: The generator itself (created via get_default_generators()) is correctly a CUDA generator, but its state must be set via a CPU tensor. This is a PyTorch API constraint, not a FlashInfer issue.

  3. Test quality: The test properly isolates the default device context and restores it, preventing side effects on other tests.


Potential Considerations

Minor suggestion (non-blocking): Consider adding a comment in the code explaining why device="cpu" is required:

# PyTorch requires generator state to be a CPU ByteTensor, regardless of generator device
# See: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/core/Generator.h#L178
generator.set_state(
    torch.tensor(
        [seed, offset], dtype=torch.int64, device=torch.device("cpu")
    ).view(torch.uint8)
)

However, this is optional - the code is clear enough without it, and the test documentation explains the issue.


Recommendation

LGTM - Approve and merge

This is a well-executed bug fix that:

  • Solves a real user-reported issue
  • Uses the minimal correct fix
  • Includes comprehensive regression tests
  • Maintains backwards compatibility
  • Follows project conventions

Great work by @cyx-6 and @yzh119! 🎉

@yzh119 yzh119 enabled auto-merge (squash) January 13, 2026 07:13
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
flashinfer/sampling.py (1)

33-49: Good fix; consider avoiding torch.tensor([seed, offset], ...) with tensor elements

For robustness/perf, build the state from Python ints (or a stacked tensor) rather than a list of 0-d tensors.

Proposed tweak (same behavior, avoids tensor-from-tensor list construction)
 def get_seed_and_offset(
     increment: int,
     generator: Optional[torch.Generator] = None,
     device: Optional[torch.device] = None,
 ) -> Tuple[int, int]:
@@
     seed, offset = state.view(torch.int64)
     offset += (increment + 3) // 4 * 4
+    seed_i, offset_i = int(seed), int(offset)
     generator.set_state(
         torch.tensor(
-            [seed, offset], dtype=torch.int64, device=torch.device("cpu")
+            [seed_i, offset_i], dtype=torch.int64, device=torch.device("cpu")
         ).view(torch.uint8)
     )
-    return int(seed), int(offset)
+    return seed_i, offset_i
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c4a3172 and fb67d06.

📒 Files selected for processing (2)
  • flashinfer/sampling.py
  • tests/utils/test_sampling.py
🧰 Additional context used
📓 Path-based instructions (2)
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/utils/test_sampling.py
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/sampling.py
🧠 Learnings (1)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures

Applied to files:

  • tests/utils/test_sampling.py
🧬 Code graph analysis (1)
tests/utils/test_sampling.py (1)
flashinfer/sampling.py (7)
  • top_k_top_p_sampling_from_logits (1027-1156)
  • softmax (57-77)
  • softmax (542-596)
  • sampling_from_probs (134-157)
  • sampling_from_probs (667-736)
  • top_p_sampling_from_probs (175-205)
  • top_p_sampling_from_probs (740-833)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Deploy Docs
  • GitHub Check: claude-review

Comment on lines +933 to +972
@pytest.mark.parametrize("batch_size", [1, 19, 99])
@pytest.mark.parametrize("vocab_size", [111, 32000])
def test_sampling_with_default_device_cuda(batch_size, vocab_size):
"""Test that sampling works correctly when torch.set_default_device("cuda") is set.

This is a regression test for issue #2333 where generator.set_state() would fail
with "RNG state must be a torch.ByteTensor" error when the default device is CUDA.
"""
torch.manual_seed(42)
original_device = torch.get_default_device()
try:
# Set default device to CUDA
torch.set_default_device("cuda")

# Create logits and test top_k_top_p_sampling_from_logits
logits = torch.randn(batch_size, vocab_size, device="cuda:0")

# This should not raise "RNG state must be a torch.ByteTensor" error
samples = flashinfer.sampling.top_k_top_p_sampling_from_logits(
logits, top_k=100, top_p=0.9
)

assert samples.shape == (batch_size,)
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)

# Also test other sampling functions
probs = torch.softmax(logits, dim=-1)

samples = flashinfer.sampling.sampling_from_probs(probs)
assert samples.shape == (batch_size,)
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)

samples = flashinfer.sampling.top_p_sampling_from_probs(probs, 0.9)
assert samples.shape == (batch_size,)
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)

finally:
# Restore original default device
torch.set_default_device(original_device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

torch.set_default_device PyTorch version introduced when

💡 Result:

torch.set_default_device was introduced as a (beta) feature in PyTorch 2.0. [1][2]

Sources:

  • PyTorch 2.0 release notes (mentions torch.set_default_device). [1]
  • torch.set_default_device documentation in PyTorch docs. [2]

🏁 Script executed:

# Check repo's minimum PyTorch requirement
find . -name "setup.py" -o -name "setup.cfg" -o -name "pyproject.toml" -o -name "requirements*.txt" | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 199


🏁 Script executed:

# Look at conftest.py to understand existing test guards
cat tests/conftest.py 2>/dev/null | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 3644


🏁 Script executed:

# Check other tests in test_sampling.py to see existing patterns for CUDA/arch guards
rg "pytest.skip|torch.cuda.is_available|get_compute_capability" tests/utils/test_sampling.py -A 2 -B 2

Repository: flashinfer-ai/flashinfer

Length of output: 1356


🏁 Script executed:

# Search for uses of flashinfer.utils in tests for arch-based skips
rg "flashinfer.utils|is_sm90a_supported|is_sm100a_supported|get_compute_capability" tests/ -l

Repository: flashinfer-ai/flashinfer

Length of output: 2121


🏁 Script executed:

# Check minimum PyTorch version requirement
cat pyproject.toml | grep -A 20 -i "torch" | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check beginning of test_sampling.py for module-level setup and imports
head -50 tests/utils/test_sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 1453


🏁 Script executed:

# Check if there's a module-level pytestmark or skip decorator
rg "pytestmark|skip_if" tests/utils/test_sampling.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look for flashinfer.utils usage in tests to understand the pattern
rg "from flashinfer.utils import|flashinfer.utils\." tests/ -A 2 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 2209


🏁 Script executed:

# Check the last part of conftest to see if there's automatic CUDA/arch skipping
tail -100 tests/conftest.py

Repository: flashinfer-ai/flashinfer

Length of output: 4140


🏁 Script executed:

# Try to find PyTorch version requirement - check pyproject.toml more carefully
cat pyproject.toml | head -150

Repository: flashinfer-ai/flashinfer

Length of output: 3136


🏁 Script executed:

# Also check if there's any torch version check decorator or pattern used elsewhere
rg "torch.*version|TorchVersion" tests/ -B 2 -A 2 | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 1304


🏁 Script executed:

# Check how other tests handle CUDA availability - look for cuda checks
rg "cuda.is_available|device.*cuda" tests/utils/test_sampling.py -B 2 -A 2 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1869


🏁 Script executed:

# Look for decorators or patterns in tests for architecture-specific skips
rg "@pytest.mark|skipif|skip_if" tests/utils/ -A 1 | grep -A 1 "cuda\|arch\|sm90\|compute" | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1468


🏁 Script executed:

# Check requirements.txt for minimum torch version
cat requirements.txt

Repository: flashinfer-ai/flashinfer

Length of output: 233


🏁 Script executed:

# Look at how test_logging.py uses the CUDA skip pattern since it was referenced
rg -B 5 -A 10 "skipif.*cuda.is_available" tests/utils/test_logging.py

Repository: flashinfer-ai/flashinfer

Length of output: 690


🏁 Script executed:

# Check if test_sampling.py has any module-level markers or decorators
head -100 tests/utils/test_sampling.py | rg "pytestmark|import pytest|@pytest"

Repository: flashinfer-ai/flashinfer

Length of output: 537


🏁 Script executed:

# Verify the exact PyTorch version when set_default_device was introduced with hasattr pattern
rg "hasattr.*set_default_device|getattr.*set_default_device" . --type py

Repository: flashinfer-ai/flashinfer

Length of output: 50


Add CUDA and API version guards for the default-device CUDA regression test

This test manipulates global PyTorch state and requires APIs introduced in PyTorch 2.0. Since requirements.txt specifies no minimum PyTorch version, add explicit guards to skip cleanly on unsupported environments:

Recommended changes
 def test_sampling_with_default_device_cuda(batch_size, vocab_size):
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA is required for this test")
+    if not (hasattr(torch, "set_default_device") and hasattr(torch, "get_default_device")):
+        pytest.skip("torch.set_default_device / torch.get_default_device not available (requires PyTorch 2.0+)")
+
     torch.manual_seed(42)
     original_device = torch.get_default_device()

The pattern mirrors existing tests in the repo (e.g., tests/utils/test_logging.py). The API guard is important because these functions were only introduced as beta features in PyTorch 2.0.

@yzh119 yzh119 merged commit c546c9a into main Jan 13, 2026
8 checks passed
@yzh119 yzh119 deleted the claude/issue-2333-20260112-0956 branch January 13, 2026 09:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

generator.set_state raise an error in top_k_top_p_sampling_from_logits

2 participants