fix: fix illegal memory access for NaN input in sampling kernels#2456
fix: fix illegal memory access for NaN input in sampling kernels#2456yzh119 merged 5 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello @zack041, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical memory access bug in the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a public per-request validity mask and per-iteration last_valid_id tracking to sampling kernels and host bindings, propagating a bool* valid through CUDA kernels and Python APIs; kernels now early-exit and write a defined output when no valid index exists. Adds a test exercising NaN rows and the new valid outputs. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request correctly addresses a critical illegal memory access issue in TopKSamplingFromProbKernel and TopPSamplingFromProbKernel when handling NaN inputs. The fix, which involves initializing last_valid_id and checking for it before use, is sound and well-tested by the new unit test.
While this PR fixes the immediate issue, it's worth noting that other sampling kernels like SamplingFromProbKernel, MinPSamplingFromProbKernel, and TopKTopPSamplingFromProbKernel might have similar (though not identical) issues with all-NaN inputs, potentially leading to incorrect output. It would be beneficial to address these in a follow-up to ensure the robustness of the entire sampling module.
I've added one suggestion to enhance the new test case to be more comprehensive.
tests/utils/test_sampling.py
Outdated
| result_top_k = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50) | ||
| assert result_top_k[0].item() == 0 | ||
|
|
||
| result_top_p = flashinfer.sampling.top_p_sampling_from_probs(probs, top_p=0.9) | ||
| assert result_top_p[0].item() == 0 |
There was a problem hiding this comment.
The test correctly asserts the behavior for the row with NaN inputs. To make it more robust, consider adding assertions to verify that the sampling for other valid rows in the batch is unaffected. This ensures that the fix for NaN handling doesn't introduce side effects for normal inputs.
| result_top_k = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50) | |
| assert result_top_k[0].item() == 0 | |
| result_top_p = flashinfer.sampling.top_p_sampling_from_probs(probs, top_p=0.9) | |
| assert result_top_p[0].item() == 0 | |
| result_top_k = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50) | |
| assert result_top_k[0].item() == 0 | |
| if batch_size > 1: | |
| assert torch.all(result_top_k[1:] >= 0) | |
| assert torch.all(result_top_k[1:] < vocab_size) | |
| result_top_p = flashinfer.sampling.top_p_sampling_from_probs(probs, top_p=0.9) | |
| assert result_top_p[0].item() == 0 | |
| if batch_size > 1: | |
| assert torch.all(result_top_p[1:] >= 0) | |
| assert torch.all(result_top_p[1:] < vocab_size) |
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/sampling.cuh (1)
784-792:⚠️ Potential issue | 🔴 CriticalApply same defensive initialization to
SamplingFromProbKernel,TopKTopPSamplingFromProbKernel, andChainSpeculativeSampling.This vulnerability is systemic across multiple sampling kernels.
SamplingFromProbKernelat line 789 usestemp_storage.last_valid_idwithout initializing it, and identical patterns exist in:
TopKTopPSamplingFromProbKernelat line 1168ChainSpeculativeSamplingat line 1882In contrast,
TopKSamplingFromProbKernel(line 823) andTopPSamplingFromProbKernel(line 945) correctly initializelast_valid_id = -1and add defensive checks before use. Apply the same pattern to all affected kernels:Fix pattern for all three kernels
temp_storage.sampled_id = d; + temp_storage.last_valid_id = -1; __syncthreads();Then before using
last_valid_id, add:if (sampled_id == d) { + if (temp_storage.last_valid_id == -1) { + if (tx == 0) { + output[bx] = 0; // or appropriate fallback + } + return; + } sampled_id = temp_storage.last_valid_id; }
🧹 Nitpick comments (1)
tests/utils/test_sampling.py (1)
1113-1127: Test validates NaN handling correctly, but consider expanding coverage.The test correctly validates that top-k and top-p sampling return 0 for all-NaN rows. However, other sampling functions (
sampling_from_probs,min_p_sampling_from_probs,top_k_top_p_sampling_from_probs) use similar patterns withlast_valid_idin their kernels and may benefit from similar testing.Consider whether NaN resilience should be validated across all sampling methods for consistency, or if this scope is intentionally limited to the methods mentioned in issue
#2402.
|
Hi @zack041 can you add unittest for NaN inputs, and we might also fix the behavior of |
Hi @yzh119 , just looking for some clarification. I've added the unittest in the PR and the function is test_sampling_nan_input(), are you looking for some additional checks? I'll add the same fix to SamplingFromProbKernel and include a test for it. For the other sampling kernels, should I also fix them in this PR or leave them in a follow-up? |
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/sampling.cuh (1)
134-138:⚠️ Potential issue | 🟠 MajorInitialize/guard
last_valid_idin all kernels that rely on it.
last_valid_idis now a shared field, but only some kernels reset/guard it.MinPSamplingFromProbKernel,TopKTopPSamplingFromProbKernel, andChainSpeculativeSamplingstill readlast_valid_idwhensampled_id == dwithout initialization or a-1guard, so an all-NaN/empty row can leave stale shared memory and reintroduce invalid reads (e.g., TopKTopP’spivot_0access) or invalid IDs. Please mirror the reset + guard pattern there as well.🛠️ Example fix pattern to mirror in other kernels
@@ MinPSamplingFromProbKernel(...) - temp_storage.sampled_id = d; + temp_storage.sampled_id = d; + temp_storage.last_valid_id = -1; __syncthreads(); @@ if (sampled_id == d) { // NOTE(Zihao): this would happen when u is very close to 1 // and the sum of probabilities is smaller than u // In this case, we use the last valid index as the sampled id - sampled_id = temp_storage.last_valid_id; + if (temp_storage.last_valid_id == -1) { + if (tx == 0) output[bx] = 0; + return; + } + sampled_id = temp_storage.last_valid_id; }@@ TopKTopPSamplingFromProbKernel(...) do { temp_storage.sampled_id = d; + temp_storage.last_valid_id = -1; __syncthreads(); @@ if (sampled_id == d) { // NOTE(Zihao): this would happen when u is very close to 1 // and the sum of probabilities is smaller than u // In this case, we use the last valid index as the sampled id - sampled_id = temp_storage.last_valid_id; + if (temp_storage.last_valid_id == -1) { + if (tx == 0) output[bx] = 0; + return; + } + sampled_id = temp_storage.last_valid_id; }@@ ChainSpeculativeSampling(...) - temp_storage.sampled_id = d; + temp_storage.sampled_id = d; + temp_storage.last_valid_id = -1; __syncthreads(); @@ if (sampled_id == d) { // NOTE(Zihao): this would happen when u is very close to 1 // and the sum of probabilities is smaller than u // In this case, we use the last valid index as the sampled id - sampled_id = temp_storage.last_valid_id; + sampled_id = (temp_storage.last_valid_id == -1) ? 0 : temp_storage.last_valid_id; }
🧹 Nitpick comments (1)
include/flashinfer/sampling.cuh (1)
763-795: Document the NaN fallback choice (token 0).The new early return is an important behavioral choice; please add a brief note explaining the NaN/empty-row fallback and why token 0 is used.
✍️ Suggested comment tweak
- // In this case, we use the last valid index as the sampled id + // In this case, we use the last valid index as the sampled id. + // If no valid token exists (e.g., all-NaN row), fall back to token 0.As per coding guidelines, for performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@include/flashinfer/sampling.cuh`:
- Around line 624-626: The kernel DeviceSamplingFromProb only sets
temp_storage->last_valid_id when a valid index exists, which can leave it
uninitialized for all-NaN/all-nonpositive rows; initialize
temp_storage->last_valid_id = -1 at the start of DeviceSamplingFromProb (before
the tx==0 conditional) so it has a defined fallback, and add the same
initialization or a fallback guard in kernels that read last_valid_id such as
MinPSamplingFromProbKernel, TopKTopPSamplingFromProbKernel, and
ChainSpeculativeSampling to prevent OOB when sampled_id == d.
🧹 Nitpick comments (1)
include/flashinfer/sampling.cuh (1)
764-795: Document the all-NaN fallback choice (output = 0).This fallback is a special algorithmic choice in a hot path; a short rationale comment will prevent future confusion and clarifies why index 0 is preferred over other sentinels.
Suggested inline comment
- if (temp_storage.last_valid_id == -1) { + // All-NaN / no-valid row: fall back to token 0 to avoid OOB access. + if (temp_storage.last_valid_id == -1) {As per coding guidelines: For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers.
|
Hi @yzh119 , I fixed |
|
Thanks for the fix @zack041, the direction looks right. A few things to address before we merge:
Once these are addressed, let's get this merged. |
9a9f0df to
241b89f
Compare
There was a problem hiding this comment.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@include/flashinfer/sampling.cuh`:
- Around line 1186-1213: The TopKTopPSamplingFromProbKernel is missing the
NaN-safe initialization and guard: initialize temp_storage.last_valid_id = -1
before the sampling loop (same place other kernels set last_valid_id) and ensure
the kernel checks for last_valid_id == -1 after the search/aggregation and
before reading probs[row_idx * d + sampled_id]; if last_valid_id is still -1,
bail out or set sampled_id to a safe value to avoid the OOB access. Ensure the
per-element logic that sets last_valid_id when a valid (non-NaN) index is seen
remains in DeviceSamplingFromProb usage so the fallback works.
- Around line 1894-1937: Initialize temp_storage.last_valid_id to -1 before the
sampling loop and ensure the post-sampling guard handles the uninitialized case:
after reading sampled_id = temp_storage.sampled_id, if sampled_id == d then
check if temp_storage.last_valid_id == -1 and in that case set sampled_id to a
safe fallback (e.g., d-1 or clamp to the last valid index), otherwise use
temp_storage.last_valid_id; make sure the code path that updates last_valid_id
(inside DeviceSamplingFromProb or the loop where relu_q_minus_p_vec is
processed) remains unchanged so last_valid_id is valid when any positive mass
exists.
- Around line 1126-1150: The kernel fails to initialize
temp_storage.last_valid_id so an all-NaN row can leave it uninitialized; to fix,
in MinPSamplingFromProbKernel initialize temp_storage.last_valid_id = -1 before
the sampling loop, and after the reduction where sampled_id is set (the block
using temp_storage.sampled_id), add an explicit guard: if sampled_id == d then
check temp_storage.last_valid_id; if last_valid_id != -1 use it, otherwise set
sampled_id to a safe sentinel (e.g., keep d or set to -1 per calling convention)
before writing to output[bx]; this ensures no uninitialized read and matches the
TopKTopPSamplingFromProbKernel pattern.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/sampling.py (1)
855-864:⚠️ Potential issue | 🟡 MinorPublic API type hints/docs still claim
torch.Tensoronly.
Withreturn_valid, these functions can return(samples, valid); please align annotations and docstrings to avoid API confusion.✍️ Example signature update (apply similarly to other APIs)
-def sampling_from_probs(..., return_valid: bool = False) -> torch.Tensor: +def sampling_from_probs(..., return_valid: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:Also applies to: 947-958, 1059-1070, 1171-1182, 1425-1438
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/sampling.py` around lines 855 - 864, The public API type hints and docstring for sampling_from_probs are inaccurate because when return_valid=True the function returns (samples, valid); update the function signature and docstring to reflect a return type of Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] (or torch.Tensor | tuple[torch.Tensor, torch.Tensor] for Python 3.10+), clearly documenting the conditional tuple return and the meaning of the second tensor, and adjust any Sphinx/type comments accordingly; apply the same change to the other sampling helper functions in this module (the additional sampling-related functions flagged in the review) so all public APIs consistently annotate and document the optional (samples, valid) tuple return.csrc/sampling.cu (1)
104-137:⚠️ Potential issue | 🟠 MajorAdd validation for the new
validoutput tensor.
validis not checked for shape/device/dtype; a mismatch can cause illegal writes in the kernels. Please add the same input checks you apply tooutput/probsin each sampling_*_from_probs entry point.🔒 Suggested checks (apply to each sampling_*_from_probs)
void sampling_from_probs(TensorView probs, TensorView output, TensorView valid, Optional<TensorView> maybe_indices, bool deterministic, Optional<TensorView> maybe_seed_arr, uint64_t seed_val, Optional<TensorView> maybe_offset_arr, uint64_t offset_val) { CHECK_INPUT(probs); + CHECK_INPUT(valid); + CHECK_DEVICE(valid, probs); + CHECK_DIM(1, valid); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)Also applies to: 139-177, 179-220, 222-263, 265-310
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/sampling.cu` around lines 104 - 137, Add validation for the `valid` output tensor in sampling_from_probs (and the other sampling_*_from_probs entry points) similar to the existing checks for `probs`/`output`: assert `valid` is a tensor input (e.g. CHECK_INPUT(valid)), has the expected rank/shape (batch dimension equals output.size(0)), lives on the same device as `probs` (same device guard), and has a boolean-compatible dtype (the kernel reads it as bool). Update sampling_from_probs to run these checks before using valid (and replicate the same checks in the other entry points at 139-177, 179-220, 222-263, 265-310).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/sampling.py`:
- Around line 374-391: The fake operator _fake_min_p_sampling_from_probs has
unused parameters (maybe_min_p_arr, min_p_val, deterministic, generator) causing
ARG001; rename those parameters with a leading underscore (e.g.,
_maybe_min_p_arr, _min_p_val, _deterministic, _generator) in the function
signature of _fake_min_p_sampling_from_probs so the linter recognizes them as
intentionally unused, or alternatively add a per-parameter noqa comment, and
keep the rest of the implementation (batch_size, out_dtype, return_valid
behavior) unchanged.
---
Outside diff comments:
In `@csrc/sampling.cu`:
- Around line 104-137: Add validation for the `valid` output tensor in
sampling_from_probs (and the other sampling_*_from_probs entry points) similar
to the existing checks for `probs`/`output`: assert `valid` is a tensor input
(e.g. CHECK_INPUT(valid)), has the expected rank/shape (batch dimension equals
output.size(0)), lives on the same device as `probs` (same device guard), and
has a boolean-compatible dtype (the kernel reads it as bool). Update
sampling_from_probs to run these checks before using valid (and replicate the
same checks in the other entry points at 139-177, 179-220, 222-263, 265-310).
In `@flashinfer/sampling.py`:
- Around line 855-864: The public API type hints and docstring for
sampling_from_probs are inaccurate because when return_valid=True the function
returns (samples, valid); update the function signature and docstring to reflect
a return type of Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] (or
torch.Tensor | tuple[torch.Tensor, torch.Tensor] for Python 3.10+), clearly
documenting the conditional tuple return and the meaning of the second tensor,
and adjust any Sphinx/type comments accordingly; apply the same change to the
other sampling helper functions in this module (the additional sampling-related
functions flagged in the review) so all public APIs consistently annotate and
document the optional (samples, valid) tuple return.
| @register_fake_op("flashinfer::min_p_sampling_from_probs") | ||
| def _fake_min_p_sampling_from_probs( | ||
| probs: torch.Tensor, | ||
| indices: Optional[torch.Tensor], | ||
| maybe_min_p_arr: Optional[torch.Tensor], | ||
| min_p_val: float, | ||
| deterministic: bool, | ||
| generator: Optional[torch.Generator], | ||
| return_valid: bool = False, | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
| batch_size = indices.size(0) if indices is not None else probs.size(0) | ||
| out_dtype = indices.dtype if indices is not None else torch.int32 | ||
| if return_valid: | ||
| return ( | ||
| torch.empty(batch_size, dtype=out_dtype, device=probs.device), | ||
| torch.empty(batch_size, dtype=torch.bool, device=probs.device), | ||
| ) | ||
| return torch.empty(batch_size, dtype=out_dtype, device=probs.device) |
There was a problem hiding this comment.
Ruff ARG001: unused parameters in _fake_min_p_sampling_from_probs.
Rename unused args with a leading underscore (or add # noqa: ARG001) to keep lint clean.
✅ Minimal fix
-def _fake_min_p_sampling_from_probs(
- probs: torch.Tensor,
- indices: Optional[torch.Tensor],
- maybe_min_p_arr: Optional[torch.Tensor],
- min_p_val: float,
- deterministic: bool,
- generator: Optional[torch.Generator],
- return_valid: bool = False,
-) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+def _fake_min_p_sampling_from_probs(
+ probs: torch.Tensor,
+ indices: Optional[torch.Tensor],
+ _maybe_min_p_arr: Optional[torch.Tensor],
+ _min_p_val: float,
+ _deterministic: bool,
+ _generator: Optional[torch.Generator],
+ return_valid: bool = False,
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:🧰 Tools
🪛 Ruff (0.15.1)
[warning] 378-378: Unused function argument: maybe_min_p_arr
(ARG001)
[warning] 379-379: Unused function argument: min_p_val
(ARG001)
[warning] 380-380: Unused function argument: deterministic
(ARG001)
[warning] 381-381: Unused function argument: generator
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/sampling.py` around lines 374 - 391, The fake operator
_fake_min_p_sampling_from_probs has unused parameters (maybe_min_p_arr,
min_p_val, deterministic, generator) causing ARG001; rename those parameters
with a leading underscore (e.g., _maybe_min_p_arr, _min_p_val, _deterministic,
_generator) in the function signature of _fake_min_p_sampling_from_probs so the
linter recognizes them as intentionally unused, or alternatively add a
per-parameter noqa comment, and keep the rest of the implementation (batch_size,
out_dtype, return_valid behavior) unchanged.
Hi @yzh119 , the fixes are added, detailed description has been updated in the PR description. |
|
@flashinfer-bot run |
|
/bot run |
|
[FAILED] Pipeline #44428153: 9/20 passed |
|
@flashinfer-bot run |
|
Hi @kahyunnam would you taking a look at this PR as well? Relevant issue #2455 . |
|
Could you update, are you continue to work on this PR? |
…shinfer-ai#2456) <!-- .github/pull_request_template.md --> ## 📌 Description ### Summary Fix illegal memory access when input probabilities contain NaN values. Added `valid` output tensor so callers can distinguish failed sampling from legitimately sampling token 0. Also added missing `@register_fake_op` for `min_p_sampling_from_probs` to support `torch.compile`. ### Changes #### API - New `return_valid: bool = False` parameter for all sampling functions - When `True`, returns `(samples, valid)` tuple #### Kernel - Added `bool* valid` output to sampling kernels - Initialize `last_valid_id = -1` before sampling - `valid[bx] = false` when no valid token found (NaN input) - `valid[bx] = true` when valid #### PyTorch Integration - Add missing `@register_fake_op` for `min_p_sampling_from_probs` #### Affected Functions - `sampling_from_probs` - `top_k_sampling_from_probs` - `top_p_sampling_from_probs` - `min_p_sampling_from_probs` - `top_k_top_p_sampling_from_probs` ### Usage ```python # Default samples = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50) # With validity check samples, valid = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50, return_valid=True) # valid[i] == False means NaN or invalid input for row i ``` Fixes flashinfer-ai#2402 ## 🔍 Related Issues flashinfer-ai#2402 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes I'm currently using 0 as the placeholder token, let me know if a different value is preferred. I will add the optional NaN counting feature in a follow up PR. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Sampling APIs can optionally return a per-sample boolean "valid" mask alongside outputs (use a new return flag). * **Bug Fixes** * Per-iteration state reset and improved handling of invalid/NaN or out-of-range probability rows: such cases now produce a neutral zero sample and mark valid=false. * **Tests** * Added tests verifying behavior with NaN/invalid probability inputs and the new valid-mask return. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
…shinfer-ai#2456) <!-- .github/pull_request_template.md --> ## 📌 Description ### Summary Fix illegal memory access when input probabilities contain NaN values. Added `valid` output tensor so callers can distinguish failed sampling from legitimately sampling token 0. Also added missing `@register_fake_op` for `min_p_sampling_from_probs` to support `torch.compile`. ### Changes #### API - New `return_valid: bool = False` parameter for all sampling functions - When `True`, returns `(samples, valid)` tuple #### Kernel - Added `bool* valid` output to sampling kernels - Initialize `last_valid_id = -1` before sampling - `valid[bx] = false` when no valid token found (NaN input) - `valid[bx] = true` when valid #### PyTorch Integration - Add missing `@register_fake_op` for `min_p_sampling_from_probs` #### Affected Functions - `sampling_from_probs` - `top_k_sampling_from_probs` - `top_p_sampling_from_probs` - `min_p_sampling_from_probs` - `top_k_top_p_sampling_from_probs` ### Usage ```python # Default samples = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50) # With validity check samples, valid = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50, return_valid=True) # valid[i] == False means NaN or invalid input for row i ``` Fixes flashinfer-ai#2402 ## 🔍 Related Issues flashinfer-ai#2402 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes I'm currently using 0 as the placeholder token, let me know if a different value is preferred. I will add the optional NaN counting feature in a follow up PR. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Sampling APIs can optionally return a per-sample boolean "valid" mask alongside outputs (use a new return flag). * **Bug Fixes** * Per-iteration state reset and improved handling of invalid/NaN or out-of-range probability rows: such cases now produce a neutral zero sample and mark valid=false. * **Tests** * Added tests verifying behavior with NaN/invalid probability inputs and the new valid-mask return. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📌 Description
Summary
Fix illegal memory access when input probabilities contain NaN values. Added
validoutput tensor so callers can distinguish failed sampling from legitimately sampling token 0. Also added missing@register_fake_opformin_p_sampling_from_probsto supporttorch.compile.Changes
API
return_valid: bool = Falseparameter for all sampling functionsTrue, returns(samples, valid)tupleKernel
bool* validoutput to sampling kernelslast_valid_id = -1before samplingvalid[bx] = falsewhen no valid token found (NaN input)valid[bx] = truewhen validPyTorch Integration
@register_fake_opformin_p_sampling_from_probsAffected Functions
sampling_from_probstop_k_sampling_from_probstop_p_sampling_from_probsmin_p_sampling_from_probstop_k_top_p_sampling_from_probsUsage
Fixes #2402
🔍 Related Issues
#2402
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
I'm currently using 0 as the placeholder token, let me know if a different value is preferred. I will add the optional NaN counting feature in a follow up PR.
Summary by CodeRabbit
New Features
Bug Fixes
Tests