feat: add seed offset args to sampler to allow cuda graph support#2132
feat: add seed offset args to sampler to allow cuda graph support#2132yzh119 merged 3 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughSeven public sampling functions in Changes
Sequence DiagramsequenceDiagram
participant Caller
participant API as Sampling API
participant RNG as get_seed_and_offset
participant Kernel as GPU Kernel Wrapper
Caller->>API: call sampling_* (..., seed=?, offset=?)
alt seed/offset provided
API->>RNG: (skip) use supplied seed/offset
else seed/offset not provided
API->>RNG: compute seed/offset (with increment)
RNG-->>API: return seed/offset
end
API->>Kernel: invoke kernel wrapper with seed/offset
Kernel->>Kernel: initialize RNG and sample
Kernel-->>API: return sampled indices
API-->>Caller: return results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @ksukrit, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the FlashInfer sampling library by adding explicit seed and offset parameters to all sampling functions. This change is crucial for integrating these sampling operations into CUDAGraphs, as it prevents dynamic calls to random number generation state within the graph, thereby improving performance and enabling more efficient GPU execution for generative models. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces seed and offset parameters to various sampling functions to enable CUDA graph support by avoiding calls to get_seed_and_offset. The changes are consistent and well-implemented across the file. My review includes two main suggestions for improvement:
- Enhancing the robustness of the API by explicitly checking that
seedandoffsetare either both provided or bothNone, raising aValueErrorotherwise. This prevents unexpected behavior where a partially provided pair is silently ignored. - Improving the clarity of the docstrings for
seedandoffsetto explicitly state this requirement.
These changes will make the API more user-friendly and prevent potential misuse. Overall, this is a good addition to the library.
| if seed is None or offset is None: | ||
| seed, offset = get_seed_and_offset(batch_size * logits.size(1), generator) |
There was a problem hiding this comment.
The current logic to conditionally get the seed and offset can be improved for robustness. If a user provides only one of seed or offset, it will be silently ignored and both will be regenerated. It's better to enforce that either both or neither are provided by raising a ValueError if only one is given. This makes the API less error-prone.
This comment applies to all similar changes in this file (e.g., lines 137-138, 181-182, 225-226, etc.).
if (seed is None) != (offset is None):
raise ValueError("Both seed and offset must be provided, or neither.")
if seed is None:
seed, offset = get_seed_and_offset(batch_size * logits.size(1), generator)| seed: Optional[int] | ||
| seed value to use for the rng during the sampling operation. | ||
| offset: Optional[int] | ||
| offset value to use for the rng during the sampling operation. |
There was a problem hiding this comment.
The docstrings for seed and offset could be more explicit about the requirement to provide both or neither. This helps prevent misuse of the API and clarifies the behavior when seed and offset are partially provided.
This comment applies to all similar docstring additions in this file.
seed: Optional[int]
The seed for the random number generator. If provided, `offset` must also be provided.
offset: Optional[int]
The offset for the random number generator. If provided, `seed` must also be provided.There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/sampling.py (1)
112-121: Update fake-op signatures to acceptseed/offsetparameters to match real op signaturesAll six
@register_fake_opkernels (_fake_sampling_from_logits,_fake_sampling_from_probs,_fake_top_p_sampling_from_probs,_fake_top_k_sampling_from_probs,_fake_top_k_top_p_sampling_from_probs,_fake_chain_speculative_sampling) are missingseedandoffsetparameters that their corresponding@register_custom_opimplementations have. This signature mismatch will cause dispatcher mismatches in meta/torch.compile execution paths.Update the six fake kernels to accept the additional parameters (which can be safely ignored since they just return empty tensors):
_fake_sampling_from_logits(line 113): Addseed: Optional[int] = None, offset: Optional[int] = None_fake_sampling_from_probs(line 151): Addseed: Optional[int] = None, offset: Optional[int] = None_fake_top_p_sampling_from_probs(line 195): Addseed: Optional[int] = None, offset: Optional[int] = None_fake_top_k_sampling_from_probs(line 239): Addseed: Optional[int] = None, offset: Optional[int] = None_fake_top_k_top_p_sampling_from_probs(line 325): Addseed: Optional[int] = None, offset: Optional[int] = None_fake_chain_speculative_sampling(line 468): Addseed: Optional[int] = None, offset: Optional[int] = None
🧹 Nitpick comments (4)
flashinfer/sampling.py (4)
84-110: RNG plumbing viaseed/offsetis consistent, but partial specification is silently ignoredThe new
seed/offsetparameters are threaded consistently through all custom-op entry points and only triggerget_seed_and_offset(...)when at least one of them isNone, which is exactly what you want for CUDA graph capture (no generator interaction if both are provided).However, if a caller passes only one of
seedoroffset, the currentif seed is None or offset is None:branch discards the specified value and recomputes both from the generator. That’s a bit surprising and easy to misuse.I’d recommend either:
- Enforcing “all-or-nothing” with an explicit check, e.g. raising on partial specification, or
- Documenting clearly that both must be provided together and treating partial specification as an error.
A simple pattern you can reuse across these kernels:
- if seed is None or offset is None: - seed, offset = get_seed_and_offset(batch_size * logits.size(1), generator) + if (seed is None) ^ (offset is None): + raise ValueError("seed and offset must be both None or both specified.") + if seed is None and offset is None: + seed, offset = get_seed_and_offset(batch_size * logits.size(1), generator)(and similarly for the other custom ops, with their respective
incrementexpressions).Also applies to: 124-147, 163-193, 209-237, 254-284, 288-323, 425-466
589-652: High-levelsampling_from_logits/sampling_from_probswrappers correctly expose seed/offsetThe public APIs for
sampling_from_logitsandsampling_from_probsnow acceptseedandoffset, document them, and pass them through directly to the underlying custom ops. The default behaviour (noseed/offsetprovided) remains unchanged and backward compatibility is preserved since the new parameters are appended at the end of the signature.Assuming you add the optional “all-or-nothing” validation mentioned earlier, these wrappers look solid.
Also applies to: 654-723
725-819: Top-p / top-k / min-p wrappers propagate RNG state correctlyFor
top_p_sampling_from_probs,top_k_sampling_from_probs, andmin_p_sampling_from_probs, the newseed/offsetparameters are:
- Added to the public signatures and docstrings.
- Passed through the
_to_tensor_scalar_tuple(...)dance into the custom ops with correct ordering.- Forwarded consistently from higher-level code paths (e.g., from
top_k_top_p_sampling_from_logitsandtop_k_top_p_sampling_from_probs).The change preserves previous behaviour when
seed/offsetare not specified and enables explicit RNG control when they are. No issues beyond the partial-seed corner case already mentioned.Also applies to: 821-915, 917-1007
432-454:chain_speculative_samplingseed/offset wiring looks correct end-to-endFor speculative sampling:
- The custom op
chain_speculative_samplingacceptsgenerator, seed, offset, lazily computes them viaget_seed_and_offset(...)only when needed, and forwardsseed/offsetto the JIT module.- The public Python wrapper now exposes
seed/offsetin its signature and docstring and forwards them toget_sampling_module().chain_speculative_sampling(...).This achieves the desired “inject precomputed RNG state” behaviour for speculative decoding as well. Only remaining suggestion is the same partial-seed check if you decide to enforce the all-or-nothing contract consistently.
Also applies to: 1456-1574
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/sampling.py(37 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/sampling.py (1)
flashinfer/logits_processor/operators.py (1)
_to_tensor_scalar_tuple(28-34)
🔇 Additional comments (1)
flashinfer/sampling.py (1)
1009-1139:top_k_top_p_*variants correctly thread seed/offset through both branchesIn both
top_k_top_p_sampling_from_logitsandtop_k_top_p_sampling_from_probs:
- The
"top_k_first"branch delegates totop_p_sampling_from_probs(...)and passesseedandoffsetby keyword, so the lower-level RNG logic is reused as intended.- The
"joint"branch calls the fusedtop_k_top_p_sampling_from_probscustom op and passesgenerator, seed, offsetin the correct order.This ensures consistent RNG handling regardless of
filter_apply_orderand keeps the new API surface coherent.Also applies to: 1142-1265
yzh119
left a comment
There was a problem hiding this comment.
Hi @ksukrit it's a good feature to have.
Would you mind adding unittest for this?
Also, it would be to support an array of seed/offset, where each of them is a int64 array, to describe per-batch seed/offset. (It requires some modification to the kernels).
Sure thing, will add the unittests for this. But I had a quick question, right now the kernels just take a single int offset/seed value right ? The main purpose of this was to avoid the Is it okay if I take up the seed/offset array changes for the batch in a separate PR @yzh119 |
Sure, we can do that in a separate PR. |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/utils/test_sampling.py (1)
787-792: Remove redundant computation.
samples_offset1is computed with the same parameters assamples_seed1(seed=12345, offset=0), making it redundant. You can reusesamples_seed1instead.Apply this diff:
- samples_offset1 = flashinfer.sampling.sampling_from_probs( - normalized_prob, seed=12345, offset=0 - ) samples_offset2 = flashinfer.sampling.sampling_from_probs( normalized_prob, seed=12345, offset=1000 ) seed_match_rate = (samples_seed1 == samples_seed2).float().mean().item() - offset_match_rate = (samples_offset1 == samples_offset2).float().mean().item() + offset_match_rate = (samples_seed1 == samples_offset2).float().mean().item()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/utils/test_sampling.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/utils/test_sampling.py (1)
flashinfer/sampling.py (4)
sampling_from_probs(125-147)sampling_from_probs(654-722)sampling_from_logits(86-110)sampling_from_logits(589-651)
🔇 Additional comments (3)
tests/utils/test_sampling.py (3)
729-749: LGTM! Reproducibility test is well-structured.The test correctly verifies that supplying the same seed and offset produces identical samples, which is essential for CUDA graph replay scenarios.
751-770: LGTM! Logits reproducibility test mirrors the probs test appropriately.The test correctly validates reproducibility for the logits-based sampling function.
729-805: Consider testing seed/offset with other sampling functions.The PR adds seed/offset parameters to all sampling functions (top_k, top_p, min_p, etc.), but the tests only cover
sampling_from_probsandsampling_from_logits. Consider adding similar reproducibility tests for at least one of the filtered sampling functions to ensure the seed/offset parameters propagate correctly through the entire sampling stack.You can add a test similar to this:
@pytest.mark.parametrize("batch_size", [1, 99]) @pytest.mark.parametrize("vocab_size", [111, 32000]) @pytest.mark.parametrize("k", [10, 100]) def test_top_k_sampling_seed_offset_reproducibility(batch_size, vocab_size, k): """Test that explicit seed/offset produces reproducible results for top_k sampling.""" if k > vocab_size: pytest.skip("k should be less than vocab_size") torch.manual_seed(42) pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0") normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) seed, offset = 12345, 0 samples1 = flashinfer.sampling.top_k_sampling_from_probs( normalized_prob, k, seed=seed, offset=offset ) samples2 = flashinfer.sampling.top_k_sampling_from_probs( normalized_prob, k, seed=seed, offset=offset ) assert torch.all(samples1 == samples2), ( "Same seed/offset should produce identical samples in top_k sampling" )
| assert seed_match_rate < 1, ( | ||
| f"Different seeds should produce mostly different samples, " | ||
| f"got {seed_match_rate:.2%} match rate" | ||
| ) | ||
| assert offset_match_rate < 1, ( | ||
| f"Different offsets should produce mostly different samples, " | ||
| f"got {offset_match_rate:.2%} match rate" | ||
| ) |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Strengthen the assertion to verify substantial randomness.
The current assertions only check that match_rate < 1.0, meaning at least one sample differs out of 1000. This is too weak—even broken randomness could pass. With batch_size=1000 and different seeds/offsets, you should expect a much lower match rate (e.g., < 0.1 or < 0.2 depending on vocab_size).
Apply this diff to add a more meaningful threshold:
+ # With proper randomness and large batch size, we expect low coincidental match rate
+ # The exact threshold depends on vocab_size, but for large vocabs, it should be very low
+ max_expected_match_rate = 0.1 # Allow up to 10% coincidental matches
+
- assert seed_match_rate < 1, (
+ assert seed_match_rate < max_expected_match_rate, (
f"Different seeds should produce mostly different samples, "
f"got {seed_match_rate:.2%} match rate"
)
- assert offset_match_rate < 1, (
+ assert offset_match_rate < max_expected_match_rate, (
f"Different offsets should produce mostly different samples, "
f"got {offset_match_rate:.2%} match rate"
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| assert seed_match_rate < 1, ( | |
| f"Different seeds should produce mostly different samples, " | |
| f"got {seed_match_rate:.2%} match rate" | |
| ) | |
| assert offset_match_rate < 1, ( | |
| f"Different offsets should produce mostly different samples, " | |
| f"got {offset_match_rate:.2%} match rate" | |
| ) | |
| # With proper randomness and large batch size, we expect low coincidental match rate | |
| # The exact threshold depends on vocab_size, but for large vocabs, it should be very low | |
| max_expected_match_rate = 0.1 # Allow up to 10% coincidental matches | |
| assert seed_match_rate < max_expected_match_rate, ( | |
| f"Different seeds should produce mostly different samples, " | |
| f"got {seed_match_rate:.2%} match rate" | |
| ) | |
| assert offset_match_rate < max_expected_match_rate, ( | |
| f"Different offsets should produce mostly different samples, " | |
| f"got {offset_match_rate:.2%} match rate" | |
| ) |
🤖 Prompt for AI Agents
In tests/utils/test_sampling.py around lines 797 to 804, the assertions only
check match_rate < 1 (at least one differing sample) which is too weak; change
the assertions to require substantially lower match rates (for example assert
seed_match_rate < 0.2 and assert offset_match_rate < 0.2) and update the failure
messages to reflect the new threshold (e.g., "Different seeds/offsets should
produce substantially different samples, got {seed_match_rate:.2%} match rate").
Ensure the threshold value is easy to adjust (use a named constant if preferred)
and keep the existing formatting of the f-string messages.
|
/bot run |
|
[FAILED] Pipeline #39161762: 16/20 passed |
…ashinfer-ai#2132) <!-- .github/pull_request_template.md --> ## 📌 Description This PR adds optional seed/offset args to all the sampler functions to prevent calling the `get_seed_and_offset` function. If that function is not called, we can potentially make the sampler forward call as part of CUDAGraph and use that to replay it. We can directly compute the Seed/offset values, before launching the graph in a similar way to as it is being done in the current method and pass them when making the flashinfer call ## 🔍 Related Issues flashinfer-ai#978 : top_k_top_p_sampling_from_logits incompatible with torch.compile + CUDAGraph ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Optional seed and offset parameters added to sampling APIs to enable deterministic RNG control while remaining optional. * **Tests** * New tests verify reproducible sampling when using the same seed/offset and variability when different values are used. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
…ashinfer-ai#2132) <!-- .github/pull_request_template.md --> ## 📌 Description This PR adds optional seed/offset args to all the sampler functions to prevent calling the `get_seed_and_offset` function. If that function is not called, we can potentially make the sampler forward call as part of CUDAGraph and use that to replay it. We can directly compute the Seed/offset values, before launching the graph in a similar way to as it is being done in the current method and pass them when making the flashinfer call ## 🔍 Related Issues flashinfer-ai#978 : top_k_top_p_sampling_from_logits incompatible with torch.compile + CUDAGraph ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Optional seed and offset parameters added to sampling APIs to enable deterministic RNG control while remaining optional. * **Tests** * New tests verify reproducible sampling when using the same seed/offset and variability when different values are used. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
This PR adds optional seed/offset args to all the sampler functions to prevent calling the
get_seed_and_offsetfunction. If that function is not called, we can potentially make the sampler forward call as part of CUDAGraph and use that to replay it.We can directly compute the Seed/offset values, before launching the graph in a similar way to as it is being done in the current method and pass them when making the flashinfer call
🔍 Related Issues
#978 : top_k_top_p_sampling_from_logits incompatible with torch.compile + CUDAGraph
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.