Skip to content

bugfix: use torch cached default generators#2295

Merged
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
cyx-6:torch-generator-seed
Jan 7, 2026
Merged

bugfix: use torch cached default generators#2295
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
cyx-6:torch-generator-seed

Conversation

@cyx-6
Copy link
Copy Markdown
Collaborator

@cyx-6 cyx-6 commented Jan 6, 2026

📌 Description

Fixed the #2284. In the past, before #1641, the flashinfer used torch default generator at::cuda::detail::getDefaultCUDAGenerator() while #1641 will create one new generator instance at a time. This PR recovers the default generator from torch.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Sampling now uses a device-aware default random generator, ensuring consistent and correct sampling behavior across CPU and GPU when no generator is provided.
  • Chores

    • Small public API update to accept a device context so sampling routines derive RNG state from the correct device.

✏️ Tip: You can customize this high-level summary in your review settings.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @cyx-6, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a bug related to the initialization of the default PyTorch CUDA generator. Previously, the generator might not have been properly seeded, leading to predictable or non-random outputs in scenarios where true randomness was required. The change ensures that whenever a default torch.Generator is instantiated, its seed is explicitly set to a non-deterministic value, thereby guaranteeing proper random number generation for sampling operations.

Highlights

  • Bug Fix: Resolved an issue where the default torch.Generator was not consistently seeded, potentially leading to deterministic random number generation when non-deterministic output was expected.
  • Randomness Initialization: Explicitly added a call to generator.seed() when a new default torch.Generator is created, ensuring it is initialized with a non-deterministic seed.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 6, 2026

📝 Walkthrough

Walkthrough

get_seed_and_offset now accepts an optional device and uses get_default_generators(device) when no generator is provided; the device argument is propagated to sampling call sites. A new cached get_default_generators helper was added to resolve per-device default RNGs.

Changes

Cohort / File(s) Summary
Sampling call sites
flashinfer/sampling.py
Propagated new device: Optional[torch.device] argument to calls that obtain RNG state; get_seed_and_offset signature updated to accept device. Multiple sampling functions updated to pass device when requesting seed/offset.
RNG helpers
flashinfer/utils.py
Added get_default_generators(device: torch.device) (cached) to return per-CUDA-device default generators; imported and used by get_seed_and_offset.

Sequence Diagram(s)

(omitted — changes are focused on RNG selection/propagation, not multi-component control flow)

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • bkryu
  • nvmbreughe

Poem

🐰 I hopped through seeds and tiny threads,
I nudged the generators in their beds,
Per-device whispers, tidy and neat,
Now randomness dances on nimble feet. 🥕✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 8.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'bugfix: use torch cached default generators' accurately describes the main change—introducing cached default generators from torch to fix a seeding issue.
Description check ✅ Passed The PR description provides context by referencing issue #2284 and explaining the recovery of torch default generators, though it lacks specifics on testing status and implementation details.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
flashinfer/sampling.py (1)

39-39: Fix typo in comment.

Minor typo: "multi-trheading" should be "multi-threading".

🔎 Proposed fix
-    # add mutex if multi-trheading needed
+    # add mutex if multi-threading needed
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5a8bcf6 and 9bb2e04.

📒 Files selected for processing (1)
  • flashinfer/sampling.py
🧰 Additional context used
📓 Path-based instructions (1)
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/sampling.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/sampling.py (1)

35-38: LGTM! Non-deterministic seeding restored as intended.

The explicit call to generator.seed() correctly restores non-deterministic behavior when no generator is provided, addressing issue #2284 as described in the PR objectives.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly identifies and fixes an issue where sampling could be deterministic if a torch.Generator was not provided. By adding generator.seed(), a non-deterministic seed is now used for newly created generators. However, my review found a critical pre-existing issue in the get_seed_and_offset function that will cause it to fail on CUDA devices. I've left a detailed comment on how to address this.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 6, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !225 has been created, and the CI pipeline #41234020 is currently running. I'll report back once the pipeline job completes.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 7, 2026

/bot run

@cyx-6 cyx-6 changed the title bugfix: set seed for default torch generator bugfix: use torch cached default generators Jan 7, 2026
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !225 has been updated with latest changes, and the CI pipeline #41270449 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/sampling.py (1)

33-47: Document or validate the device parameter requirement.

The function signature allows device: Optional[torch.device] = None, but if both generator and device are None, line 39 will call get_default_generators(None), which will fail.

While all current call sites correctly provide the device parameter, the type hint is misleading. Consider either:

  1. Making device required when generator is None
  2. Adding validation that raises a clear error message
  3. Updating the docstring to clarify the requirement
🔧 Option 1: Add validation with clear error message
 def get_seed_and_offset(
     increment: int,
     generator: Optional[torch.Generator] = None,
     device: Optional[torch.device] = None,
 ) -> Tuple[int, int]:
     if generator is None:
+        if device is None:
+            raise ValueError("device must be provided when generator is None")
         generator = get_default_generators(device)
     # add mutex if multi-trheading needed
     state = generator.get_state()
     seed, offset = state.view(torch.int64)
     offset += (increment + 3) // 4 * 4
     generator.set_state(
         torch.tensor([seed, offset], dtype=torch.int64).view(torch.uint8)
     )
     return int(seed), int(offset)
🤖 Fix all issues with AI agents
In @flashinfer/utils.py:
- Around line 1187-1190: Validate the incoming torch.device in
get_default_generators: ensure device.type == "cuda", that device.index is not
None, call torch.cuda.init() and then confirm device.index is within range (0 <=
device.index < torch.cuda.device_count()); if any check fails raise a ValueError
with a clear message describing the problem (e.g. non-CUDA device, unspecified
index, or index out of range), otherwise return
torch.cuda.default_generators[device.index]; keep the @functools.cache decorator
and reference get_default_generators and torch.cuda.default_generators in your
changes.
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9bb2e04 and b8c9e71.

📒 Files selected for processing (2)
  • flashinfer/sampling.py
  • flashinfer/utils.py
🧰 Additional context used
📓 Path-based instructions (1)
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/sampling.py
  • flashinfer/utils.py
🧠 Learnings (1)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `functools.cache` decorator on Python API functions to implement module-level caching and avoid recompilation

Applied to files:

  • flashinfer/utils.py
🧬 Code graph analysis (2)
flashinfer/sampling.py (1)
flashinfer/utils.py (1)
  • get_default_generators (1188-1190)
flashinfer/utils.py (2)
include/flashinfer/trtllm/common.h (1)
  • device (83-90)
flashinfer/comm/nvshmem.py (1)
  • init (52-55)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/sampling.py (1)

105-107: LGTM: Device-aware RNG initialization.

All call sites correctly extract the device from the input tensor and propagate it to get_seed_and_offset, ensuring that device-specific default generators are used. This pattern is consistently applied across all sampling functions.

Also applies to: 146-146, 192-192, 239-239, 288-288, 326-326, 486-488

Comment on lines +1187 to +1190
@functools.cache
def get_default_generators(device: torch.device):
torch.cuda.init()
return torch.cuda.default_generators[device.index]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add device validation to prevent cryptic runtime errors.

The function lacks validation for the device parameter, which can lead to confusing errors:

  • If device.type != "cuda", accessing torch.cuda.default_generators will fail
  • If device.index is None (e.g., torch.device("cuda")), the indexing operation will fail
  • If device.index exceeds the number of CUDA devices, it will raise an IndexError
🛡️ Proposed fix with device validation
 @functools.cache
 def get_default_generators(device: torch.device):
+    if device.type != "cuda":
+        raise ValueError(f"Device must be a CUDA device, got {device.type}")
+    if device.index is None:
+        device = torch.device("cuda", torch.cuda.current_device())
     torch.cuda.init()
     return torch.cuda.default_generators[device.index]

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In @flashinfer/utils.py around lines 1187 - 1190, Validate the incoming
torch.device in get_default_generators: ensure device.type == "cuda", that
device.index is not None, call torch.cuda.init() and then confirm device.index
is within range (0 <= device.index < torch.cuda.device_count()); if any check
fails raise a ValueError with a clear message describing the problem (e.g.
non-CUDA device, unspecified index, or index out of range), otherwise return
torch.cuda.default_generators[device.index]; keep the @functools.cache decorator
and reference get_default_generators and torch.cuda.default_generators in your
changes.

@yzh119 yzh119 mentioned this pull request Jan 7, 2026
5 tasks
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 7, 2026

The CI failure looks like network connection issues, bypass and merge.

@yzh119 yzh119 merged commit ede764f into flashinfer-ai:main Jan 7, 2026
3 of 4 checks passed
@cyx-6 cyx-6 deleted the torch-generator-seed branch February 3, 2026 08:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants