bugfix: use torch cached default generators#2295
Conversation
Summary of ChangesHello @cyx-6, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a bug related to the initialization of the default PyTorch CUDA generator. Previously, the generator might not have been properly seeded, leading to predictable or non-random outputs in scenarios where true randomness was required. The change ensures that whenever a default Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughget_seed_and_offset now accepts an optional Changes
Sequence Diagram(s)(omitted — changes are focused on RNG selection/propagation, not multi-component control flow) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
flashinfer/sampling.py (1)
39-39: Fix typo in comment.Minor typo: "multi-trheading" should be "multi-threading".
🔎 Proposed fix
- # add mutex if multi-trheading needed + # add mutex if multi-threading needed
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/sampling.py
🧰 Additional context used
📓 Path-based instructions (1)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/sampling.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/sampling.py (1)
35-38: LGTM! Non-deterministic seeding restored as intended.The explicit call to
generator.seed()correctly restores non-deterministic behavior when no generator is provided, addressing issue #2284 as described in the PR objectives.
There was a problem hiding this comment.
Code Review
This pull request correctly identifies and fixes an issue where sampling could be deterministic if a torch.Generator was not provided. By adding generator.seed(), a non-deterministic seed is now used for newly created generators. However, my review found a critical pre-existing issue in the get_seed_and_offset function that will cause it to fail on CUDA devices. I've left a detailed comment on how to address this.
|
/bot run |
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/sampling.py (1)
33-47: Document or validate the device parameter requirement.The function signature allows
device: Optional[torch.device] = None, but if bothgeneratoranddeviceare None, line 39 will callget_default_generators(None), which will fail.While all current call sites correctly provide the device parameter, the type hint is misleading. Consider either:
- Making device required when generator is None
- Adding validation that raises a clear error message
- Updating the docstring to clarify the requirement
🔧 Option 1: Add validation with clear error message
def get_seed_and_offset( increment: int, generator: Optional[torch.Generator] = None, device: Optional[torch.device] = None, ) -> Tuple[int, int]: if generator is None: + if device is None: + raise ValueError("device must be provided when generator is None") generator = get_default_generators(device) # add mutex if multi-trheading needed state = generator.get_state() seed, offset = state.view(torch.int64) offset += (increment + 3) // 4 * 4 generator.set_state( torch.tensor([seed, offset], dtype=torch.int64).view(torch.uint8) ) return int(seed), int(offset)
🤖 Fix all issues with AI agents
In @flashinfer/utils.py:
- Around line 1187-1190: Validate the incoming torch.device in
get_default_generators: ensure device.type == "cuda", that device.index is not
None, call torch.cuda.init() and then confirm device.index is within range (0 <=
device.index < torch.cuda.device_count()); if any check fails raise a ValueError
with a clear message describing the problem (e.g. non-CUDA device, unspecified
index, or index out of range), otherwise return
torch.cuda.default_generators[device.index]; keep the @functools.cache decorator
and reference get_default_generators and torch.cuda.default_generators in your
changes.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/sampling.pyflashinfer/utils.py
🧰 Additional context used
📓 Path-based instructions (1)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/sampling.pyflashinfer/utils.py
🧠 Learnings (1)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `functools.cache` decorator on Python API functions to implement module-level caching and avoid recompilation
Applied to files:
flashinfer/utils.py
🧬 Code graph analysis (2)
flashinfer/sampling.py (1)
flashinfer/utils.py (1)
get_default_generators(1188-1190)
flashinfer/utils.py (2)
include/flashinfer/trtllm/common.h (1)
device(83-90)flashinfer/comm/nvshmem.py (1)
init(52-55)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/sampling.py (1)
105-107: LGTM: Device-aware RNG initialization.All call sites correctly extract the device from the input tensor and propagate it to
get_seed_and_offset, ensuring that device-specific default generators are used. This pattern is consistently applied across all sampling functions.Also applies to: 146-146, 192-192, 239-239, 288-288, 326-326, 486-488
| @functools.cache | ||
| def get_default_generators(device: torch.device): | ||
| torch.cuda.init() | ||
| return torch.cuda.default_generators[device.index] |
There was a problem hiding this comment.
Add device validation to prevent cryptic runtime errors.
The function lacks validation for the device parameter, which can lead to confusing errors:
- If
device.type != "cuda", accessingtorch.cuda.default_generatorswill fail - If
device.index is None(e.g.,torch.device("cuda")), the indexing operation will fail - If
device.indexexceeds the number of CUDA devices, it will raise an IndexError
🛡️ Proposed fix with device validation
@functools.cache
def get_default_generators(device: torch.device):
+ if device.type != "cuda":
+ raise ValueError(f"Device must be a CUDA device, got {device.type}")
+ if device.index is None:
+ device = torch.device("cuda", torch.cuda.current_device())
torch.cuda.init()
return torch.cuda.default_generators[device.index]Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In @flashinfer/utils.py around lines 1187 - 1190, Validate the incoming
torch.device in get_default_generators: ensure device.type == "cuda", that
device.index is not None, call torch.cuda.init() and then confirm device.index
is within range (0 <= device.index < torch.cuda.device_count()); if any check
fails raise a ValueError with a clear message describing the problem (e.g.
non-CUDA device, unspecified index, or index out of range), otherwise return
torch.cuda.default_generators[device.index]; keep the @functools.cache decorator
and reference get_default_generators and torch.cuda.default_generators in your
changes.
|
The CI failure looks like network connection issues, bypass and merge. |
📌 Description
Fixed the #2284. In the past, before #1641, the flashinfer used torch default generator
at::cuda::detail::getDefaultCUDAGenerator()while #1641 will create one new generator instance at a time. This PR recovers the default generator from torch.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Chores
✏️ Tip: You can customize this high-level summary in your review settings.