Skip to content

fix: fix illegal memory access for NaN input in sampling kernels#2456

Merged
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
zack041:fix-nan-sampling
Mar 13, 2026
Merged

fix: fix illegal memory access for NaN input in sampling kernels#2456
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
zack041:fix-nan-sampling

Conversation

@zack041
Copy link
Contributor

@zack041 zack041 commented Jan 31, 2026

📌 Description

Summary

Fix illegal memory access when input probabilities contain NaN values. Added valid output tensor so callers can distinguish failed sampling from legitimately sampling token 0. Also added missing @register_fake_op for min_p_sampling_from_probs to support torch.compile.

Changes

API

  • New return_valid: bool = False parameter for all sampling functions
  • When True, returns (samples, valid) tuple

Kernel

  • Added bool* valid output to sampling kernels
  • Initialize last_valid_id = -1 before sampling
  • valid[bx] = false when no valid token found (NaN input)
  • valid[bx] = true when valid

PyTorch Integration

  • Add missing @register_fake_op for min_p_sampling_from_probs

Affected Functions

  • sampling_from_probs
  • top_k_sampling_from_probs
  • top_p_sampling_from_probs
  • min_p_sampling_from_probs
  • top_k_top_p_sampling_from_probs

Usage

# Default
samples = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50)

# With validity check
samples, valid = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50, return_valid=True)
# valid[i] == False means NaN or invalid input for row i

Fixes #2402

🔍 Related Issues

#2402

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

I'm currently using 0 as the placeholder token, let me know if a different value is preferred. I will add the optional NaN counting feature in a follow up PR.

Summary by CodeRabbit

  • New Features

    • Sampling APIs can optionally return a per-sample boolean "valid" mask alongside outputs (use a new return flag).
  • Bug Fixes

    • Per-iteration state reset and improved handling of invalid/NaN or out-of-range probability rows: such cases now produce a neutral zero sample and mark valid=false.
  • Tests

    • Added tests verifying behavior with NaN/invalid probability inputs and the new valid-mask return.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @zack041, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical memory access bug in the TopK and TopP sampling kernels that occurred when input probability distributions contained NaN values. By introducing robust checks and a default output for rows with entirely NaN probabilities, the change significantly enhances the stability and reliability of the sampling functions, preventing potential crashes and ensuring predictable behavior in edge cases.

Highlights

  • Memory Access Fix: Initialized temp_storage.last_valid_id to -1 in TopKSamplingFromProbKernel and TopPSamplingFromProbKernel to ensure a known state before sampling.
  • NaN Input Handling: Implemented a safeguard in both sampling kernels to detect cases where no valid token is found (i.e., last_valid_id remains -1) due to NaN inputs, preventing illegal memory access.
  • Default Output for NaN Cases: When no valid token is found, the output for the affected row is now explicitly set to 0, providing a graceful fallback instead of a crash.
  • New Test Case: Added a new test case, test_sampling_nan_input, to tests/utils/test_sampling.py to specifically validate the handling of NaN probabilities in both TopK and TopP sampling.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 31, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a public per-request validity mask and per-iteration last_valid_id tracking to sampling kernels and host bindings, propagating a bool* valid through CUDA kernels and Python APIs; kernels now early-exit and write a defined output when no valid index exists. Adds a test exercising NaN rows and the new valid outputs.

Changes

Cohort / File(s) Summary
Sampling kernel header
include/flashinfer/sampling.cuh
Added last_valid_id handling in SamplingTempStorage and bounds check when assigning last_valid_id; adjusted device sampling helper to respect bounds.
CUDA kernels
include/flashinfer/sampling.cuh (kernels: SamplingFromProbKernel, TopKSamplingFromProbKernel, TopPSamplingFromProbKernel, MinPSamplingFromProbKernel, TopKTopPSamplingFromProbKernel)
Kernel signatures extended with bool* valid and (where applicable) IdType* indices; initialize last_valid_id = -1 per iteration; add early-exit fallback when no valid id found (write 0 and set valid=false at tx==0); ensure valid[bx]=true on success and only write output when tx==0 and valid.
CUDA host bindings
csrc/flashinfer_sampling_binding.cu, csrc/sampling.cu
Public host API signatures updated to accept a valid TensorView argument; pass static_cast<bool*>(valid.data_ptr()) into kernels; call-sites adjusted to new parameter order.
Python API
flashinfer/sampling.py
Added return_valid: bool = False to public sampling functions; allocate and pass a boolean valid tensor to the backend when requested; return (samples, valid) when return_valid is True.
Tests
tests/utils/test_sampling.py
Added test_sampling_nan_input(batch_size, vocab_size) to inject NaNs into probability rows and assert that NaN rows produce sample 0 and valid=False, while other rows produce valid indices and valid=True; exercises all sampling variants.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related issues

Possibly related PRs

Suggested labels

op: misc

Suggested reviewers

  • yzh119
  • cyx-6
  • bkryu
  • jiahanc
  • nvmbreughe
  • kahyunnam
  • djmmoss

Poem

🐰 I sniffed the probs where NaNs like to hide,
I tracked the last hop and kept logic wide.
If nothing's valid, I pick zero, not crash—
A careful little hop, tidy and brash.
🥕 Safe sampling done, now onward I dash.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 24.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and concisely summarizes the main fix: addressing illegal memory access in sampling kernels when NaN input is provided.
Description check ✅ Passed The PR description comprehensively addresses all required sections from the template with detailed information about changes, related issues, and completed checklist items.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@zack041 zack041 mentioned this pull request Jan 31, 2026
5 tasks
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a critical illegal memory access issue in TopKSamplingFromProbKernel and TopPSamplingFromProbKernel when handling NaN inputs. The fix, which involves initializing last_valid_id and checking for it before use, is sound and well-tested by the new unit test.

While this PR fixes the immediate issue, it's worth noting that other sampling kernels like SamplingFromProbKernel, MinPSamplingFromProbKernel, and TopKTopPSamplingFromProbKernel might have similar (though not identical) issues with all-NaN inputs, potentially leading to incorrect output. It would be beneficial to address these in a follow-up to ensure the robustness of the entire sampling module.

I've added one suggestion to enhance the new test case to be more comprehensive.

Comment on lines +1122 to +1126
result_top_k = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50)
assert result_top_k[0].item() == 0

result_top_p = flashinfer.sampling.top_p_sampling_from_probs(probs, top_p=0.9)
assert result_top_p[0].item() == 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test correctly asserts the behavior for the row with NaN inputs. To make it more robust, consider adding assertions to verify that the sampling for other valid rows in the batch is unaffected. This ensures that the fix for NaN handling doesn't introduce side effects for normal inputs.

Suggested change
result_top_k = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50)
assert result_top_k[0].item() == 0
result_top_p = flashinfer.sampling.top_p_sampling_from_probs(probs, top_p=0.9)
assert result_top_p[0].item() == 0
result_top_k = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50)
assert result_top_k[0].item() == 0
if batch_size > 1:
assert torch.all(result_top_k[1:] >= 0)
assert torch.all(result_top_k[1:] < vocab_size)
result_top_p = flashinfer.sampling.top_p_sampling_from_probs(probs, top_p=0.9)
assert result_top_p[0].item() == 0
if batch_size > 1:
assert torch.all(result_top_p[1:] >= 0)
assert torch.all(result_top_p[1:] < vocab_size)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/sampling.cuh (1)

784-792: ⚠️ Potential issue | 🔴 Critical

Apply same defensive initialization to SamplingFromProbKernel, TopKTopPSamplingFromProbKernel, and ChainSpeculativeSampling.

This vulnerability is systemic across multiple sampling kernels. SamplingFromProbKernel at line 789 uses temp_storage.last_valid_id without initializing it, and identical patterns exist in:

  • TopKTopPSamplingFromProbKernel at line 1168
  • ChainSpeculativeSampling at line 1882

In contrast, TopKSamplingFromProbKernel (line 823) and TopPSamplingFromProbKernel (line 945) correctly initialize last_valid_id = -1 and add defensive checks before use. Apply the same pattern to all affected kernels:

Fix pattern for all three kernels
  temp_storage.sampled_id = d;
+ temp_storage.last_valid_id = -1;
  __syncthreads();

Then before using last_valid_id, add:

  if (sampled_id == d) {
+   if (temp_storage.last_valid_id == -1) {
+     if (tx == 0) {
+       output[bx] = 0;  // or appropriate fallback
+     }
+     return;
+   }
    sampled_id = temp_storage.last_valid_id;
  }
🧹 Nitpick comments (1)
tests/utils/test_sampling.py (1)

1113-1127: Test validates NaN handling correctly, but consider expanding coverage.

The test correctly validates that top-k and top-p sampling return 0 for all-NaN rows. However, other sampling functions (sampling_from_probs, min_p_sampling_from_probs, top_k_top_p_sampling_from_probs) use similar patterns with last_valid_id in their kernels and may benefit from similar testing.

Consider whether NaN resilience should be validated across all sampling methods for consistency, or if this scope is intentionally limited to the methods mentioned in issue #2402.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 1, 2026

Hi @zack041 can you add unittest for NaN inputs, and we might also fix the behavior of SamplingFromProbKernel.

@zack041
Copy link
Contributor Author

zack041 commented Feb 1, 2026

Hi @zack041 can you add unittest for NaN inputs, and we might also fix the behavior of SamplingFromProbKernel.

Hi @yzh119 , just looking for some clarification. I've added the unittest in the PR and the function is test_sampling_nan_input(), are you looking for some additional checks?

I'll add the same fix to SamplingFromProbKernel and include a test for it. For the other sampling kernels, should I also fix them in this PR or leave them in a follow-up?

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/sampling.cuh (1)

134-138: ⚠️ Potential issue | 🟠 Major

Initialize/guard last_valid_id in all kernels that rely on it.

last_valid_id is now a shared field, but only some kernels reset/guard it. MinPSamplingFromProbKernel, TopKTopPSamplingFromProbKernel, and ChainSpeculativeSampling still read last_valid_id when sampled_id == d without initialization or a -1 guard, so an all-NaN/empty row can leave stale shared memory and reintroduce invalid reads (e.g., TopKTopP’s pivot_0 access) or invalid IDs. Please mirror the reset + guard pattern there as well.

🛠️ Example fix pattern to mirror in other kernels
@@ MinPSamplingFromProbKernel(...)
-  temp_storage.sampled_id = d;
+  temp_storage.sampled_id = d;
+  temp_storage.last_valid_id = -1;
   __syncthreads();
@@
   if (sampled_id == d) {
     // NOTE(Zihao): this would happen when u is very close to 1
     // and the sum of probabilities is smaller than u
     // In this case, we use the last valid index as the sampled id
-    sampled_id = temp_storage.last_valid_id;
+    if (temp_storage.last_valid_id == -1) {
+      if (tx == 0) output[bx] = 0;
+      return;
+    }
+    sampled_id = temp_storage.last_valid_id;
   }
@@ TopKTopPSamplingFromProbKernel(...)
   do {
     temp_storage.sampled_id = d;
+    temp_storage.last_valid_id = -1;
     __syncthreads();
@@
     if (sampled_id == d) {
       // NOTE(Zihao): this would happen when u is very close to 1
       // and the sum of probabilities is smaller than u
       // In this case, we use the last valid index as the sampled id
-      sampled_id = temp_storage.last_valid_id;
+      if (temp_storage.last_valid_id == -1) {
+        if (tx == 0) output[bx] = 0;
+        return;
+      }
+      sampled_id = temp_storage.last_valid_id;
     }
@@ ChainSpeculativeSampling(...)
-  temp_storage.sampled_id = d;
+  temp_storage.sampled_id = d;
+  temp_storage.last_valid_id = -1;
   __syncthreads();
@@
   if (sampled_id == d) {
     // NOTE(Zihao): this would happen when u is very close to 1
     // and the sum of probabilities is smaller than u
     // In this case, we use the last valid index as the sampled id
-    sampled_id = temp_storage.last_valid_id;
+    sampled_id = (temp_storage.last_valid_id == -1) ? 0 : temp_storage.last_valid_id;
   }
🧹 Nitpick comments (1)
include/flashinfer/sampling.cuh (1)

763-795: Document the NaN fallback choice (token 0).

The new early return is an important behavioral choice; please add a brief note explaining the NaN/empty-row fallback and why token 0 is used.

✍️ Suggested comment tweak
-    // In this case, we use the last valid index as the sampled id
+    // In this case, we use the last valid index as the sampled id.
+    // If no valid token exists (e.g., all-NaN row), fall back to token 0.

As per coding guidelines, for performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@include/flashinfer/sampling.cuh`:
- Around line 624-626: The kernel DeviceSamplingFromProb only sets
temp_storage->last_valid_id when a valid index exists, which can leave it
uninitialized for all-NaN/all-nonpositive rows; initialize
temp_storage->last_valid_id = -1 at the start of DeviceSamplingFromProb (before
the tx==0 conditional) so it has a defined fallback, and add the same
initialization or a fallback guard in kernels that read last_valid_id such as
MinPSamplingFromProbKernel, TopKTopPSamplingFromProbKernel, and
ChainSpeculativeSampling to prevent OOB when sampled_id == d.
🧹 Nitpick comments (1)
include/flashinfer/sampling.cuh (1)

764-795: Document the all-NaN fallback choice (output = 0).

This fallback is a special algorithmic choice in a hot path; a short rationale comment will prevent future confusion and clarifies why index 0 is preferred over other sentinels.

Suggested inline comment
-    if (temp_storage.last_valid_id == -1) {
+    // All-NaN / no-valid row: fall back to token 0 to avoid OOB access.
+    if (temp_storage.last_valid_id == -1) {

As per coding guidelines: For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers.

@zack041
Copy link
Contributor Author

zack041 commented Feb 1, 2026

Hi @yzh119 , I fixed SamplingFromProbKernel and I'm back with some updates. Several sampling methods (including SamplingFromProbKernel) did not cause illegal memory access but instead assigns garbage value to .last_valid_id. I had problems capturing this case using the previous approach, therefore I added a restraint max_valid_index < (int)d to protect the assigning of .last_valid_id in DeviceSamplingFromProb to solve. I'm not sure if the scope of this problem is going beyond the scope of this issue, what is our next steps?

@yzh119
Copy link
Collaborator

yzh119 commented Feb 19, 2026

Thanks for the fix @zack041, the direction looks right. A few things to address before we merge:

  1. TopKTopPSamplingFromProbKernel is not fixed. The combined top-k + top-p kernel has the same pattern but wasn't touched — it will have the same crash on all-NaN input. Please add the same fix there.

  2. Test coverage for non-NaN rows. The test only sets row 0 to NaN. Please also assert that the other rows still produce valid token indices (i.e., NaN in one request doesn't affect other requests in the batch).

  3. Set success output to false for NaN rows. Currently NaN rows silently output index 0, which is indistinguishable from legitimately sampling token 0. The success output tensor should be set to false for these rows so callers can detect the failure.

Once these are addressed, let's get this merged.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@include/flashinfer/sampling.cuh`:
- Around line 1186-1213: The TopKTopPSamplingFromProbKernel is missing the
NaN-safe initialization and guard: initialize temp_storage.last_valid_id = -1
before the sampling loop (same place other kernels set last_valid_id) and ensure
the kernel checks for last_valid_id == -1 after the search/aggregation and
before reading probs[row_idx * d + sampled_id]; if last_valid_id is still -1,
bail out or set sampled_id to a safe value to avoid the OOB access. Ensure the
per-element logic that sets last_valid_id when a valid (non-NaN) index is seen
remains in DeviceSamplingFromProb usage so the fallback works.
- Around line 1894-1937: Initialize temp_storage.last_valid_id to -1 before the
sampling loop and ensure the post-sampling guard handles the uninitialized case:
after reading sampled_id = temp_storage.sampled_id, if sampled_id == d then
check if temp_storage.last_valid_id == -1 and in that case set sampled_id to a
safe fallback (e.g., d-1 or clamp to the last valid index), otherwise use
temp_storage.last_valid_id; make sure the code path that updates last_valid_id
(inside DeviceSamplingFromProb or the loop where relu_q_minus_p_vec is
processed) remains unchanged so last_valid_id is valid when any positive mass
exists.
- Around line 1126-1150: The kernel fails to initialize
temp_storage.last_valid_id so an all-NaN row can leave it uninitialized; to fix,
in MinPSamplingFromProbKernel initialize temp_storage.last_valid_id = -1 before
the sampling loop, and after the reduction where sampled_id is set (the block
using temp_storage.sampled_id), add an explicit guard: if sampled_id == d then
check temp_storage.last_valid_id; if last_valid_id != -1 use it, otherwise set
sampled_id to a safe sentinel (e.g., keep d or set to -1 per calling convention)
before writing to output[bx]; this ensures no uninitialized read and matches the
TopKTopPSamplingFromProbKernel pattern.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/sampling.py (1)

855-864: ⚠️ Potential issue | 🟡 Minor

Public API type hints/docs still claim torch.Tensor only.
With return_valid, these functions can return (samples, valid); please align annotations and docstrings to avoid API confusion.

✍️ Example signature update (apply similarly to other APIs)
-def sampling_from_probs(..., return_valid: bool = False) -> torch.Tensor:
+def sampling_from_probs(..., return_valid: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:

Also applies to: 947-958, 1059-1070, 1171-1182, 1425-1438

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/sampling.py` around lines 855 - 864, The public API type hints and
docstring for sampling_from_probs are inaccurate because when return_valid=True
the function returns (samples, valid); update the function signature and
docstring to reflect a return type of Union[torch.Tensor, Tuple[torch.Tensor,
torch.Tensor]] (or torch.Tensor | tuple[torch.Tensor, torch.Tensor] for Python
3.10+), clearly documenting the conditional tuple return and the meaning of the
second tensor, and adjust any Sphinx/type comments accordingly; apply the same
change to the other sampling helper functions in this module (the additional
sampling-related functions flagged in the review) so all public APIs
consistently annotate and document the optional (samples, valid) tuple return.
csrc/sampling.cu (1)

104-137: ⚠️ Potential issue | 🟠 Major

Add validation for the new valid output tensor.
valid is not checked for shape/device/dtype; a mismatch can cause illegal writes in the kernels. Please add the same input checks you apply to output/probs in each sampling_*_from_probs entry point.

🔒 Suggested checks (apply to each sampling_*_from_probs)
 void sampling_from_probs(TensorView probs, TensorView output, TensorView valid,
                          Optional<TensorView> maybe_indices, bool deterministic,
                          Optional<TensorView> maybe_seed_arr, uint64_t seed_val,
                          Optional<TensorView> maybe_offset_arr, uint64_t offset_val) {
   CHECK_INPUT(probs);
+  CHECK_INPUT(valid);
+  CHECK_DEVICE(valid, probs);
+  CHECK_DIM(1, valid);
   CHECK_DIM(2, probs);  // probs: (batch_size, vocab_size)

Also applies to: 139-177, 179-220, 222-263, 265-310

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/sampling.cu` around lines 104 - 137, Add validation for the `valid`
output tensor in sampling_from_probs (and the other sampling_*_from_probs entry
points) similar to the existing checks for `probs`/`output`: assert `valid` is a
tensor input (e.g. CHECK_INPUT(valid)), has the expected rank/shape (batch
dimension equals output.size(0)), lives on the same device as `probs` (same
device guard), and has a boolean-compatible dtype (the kernel reads it as bool).
Update sampling_from_probs to run these checks before using valid (and replicate
the same checks in the other entry points at 139-177, 179-220, 222-263,
265-310).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/sampling.py`:
- Around line 374-391: The fake operator _fake_min_p_sampling_from_probs has
unused parameters (maybe_min_p_arr, min_p_val, deterministic, generator) causing
ARG001; rename those parameters with a leading underscore (e.g.,
_maybe_min_p_arr, _min_p_val, _deterministic, _generator) in the function
signature of _fake_min_p_sampling_from_probs so the linter recognizes them as
intentionally unused, or alternatively add a per-parameter noqa comment, and
keep the rest of the implementation (batch_size, out_dtype, return_valid
behavior) unchanged.

---

Outside diff comments:
In `@csrc/sampling.cu`:
- Around line 104-137: Add validation for the `valid` output tensor in
sampling_from_probs (and the other sampling_*_from_probs entry points) similar
to the existing checks for `probs`/`output`: assert `valid` is a tensor input
(e.g. CHECK_INPUT(valid)), has the expected rank/shape (batch dimension equals
output.size(0)), lives on the same device as `probs` (same device guard), and
has a boolean-compatible dtype (the kernel reads it as bool). Update
sampling_from_probs to run these checks before using valid (and replicate the
same checks in the other entry points at 139-177, 179-220, 222-263, 265-310).

In `@flashinfer/sampling.py`:
- Around line 855-864: The public API type hints and docstring for
sampling_from_probs are inaccurate because when return_valid=True the function
returns (samples, valid); update the function signature and docstring to reflect
a return type of Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] (or
torch.Tensor | tuple[torch.Tensor, torch.Tensor] for Python 3.10+), clearly
documenting the conditional tuple return and the meaning of the second tensor,
and adjust any Sphinx/type comments accordingly; apply the same change to the
other sampling helper functions in this module (the additional sampling-related
functions flagged in the review) so all public APIs consistently annotate and
document the optional (samples, valid) tuple return.

Comment on lines +374 to +391
@register_fake_op("flashinfer::min_p_sampling_from_probs")
def _fake_min_p_sampling_from_probs(
probs: torch.Tensor,
indices: Optional[torch.Tensor],
maybe_min_p_arr: Optional[torch.Tensor],
min_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
return_valid: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
batch_size = indices.size(0) if indices is not None else probs.size(0)
out_dtype = indices.dtype if indices is not None else torch.int32
if return_valid:
return (
torch.empty(batch_size, dtype=out_dtype, device=probs.device),
torch.empty(batch_size, dtype=torch.bool, device=probs.device),
)
return torch.empty(batch_size, dtype=out_dtype, device=probs.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Ruff ARG001: unused parameters in _fake_min_p_sampling_from_probs.
Rename unused args with a leading underscore (or add # noqa: ARG001) to keep lint clean.

✅ Minimal fix
-def _fake_min_p_sampling_from_probs(
-    probs: torch.Tensor,
-    indices: Optional[torch.Tensor],
-    maybe_min_p_arr: Optional[torch.Tensor],
-    min_p_val: float,
-    deterministic: bool,
-    generator: Optional[torch.Generator],
-    return_valid: bool = False,
-) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+def _fake_min_p_sampling_from_probs(
+    probs: torch.Tensor,
+    indices: Optional[torch.Tensor],
+    _maybe_min_p_arr: Optional[torch.Tensor],
+    _min_p_val: float,
+    _deterministic: bool,
+    _generator: Optional[torch.Generator],
+    return_valid: bool = False,
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
🧰 Tools
🪛 Ruff (0.15.1)

[warning] 378-378: Unused function argument: maybe_min_p_arr

(ARG001)


[warning] 379-379: Unused function argument: min_p_val

(ARG001)


[warning] 380-380: Unused function argument: deterministic

(ARG001)


[warning] 381-381: Unused function argument: generator

(ARG001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/sampling.py` around lines 374 - 391, The fake operator
_fake_min_p_sampling_from_probs has unused parameters (maybe_min_p_arr,
min_p_val, deterministic, generator) causing ARG001; rename those parameters
with a leading underscore (e.g., _maybe_min_p_arr, _min_p_val, _deterministic,
_generator) in the function signature of _fake_min_p_sampling_from_probs so the
linter recognizes them as intentionally unused, or alternatively add a
per-parameter noqa comment, and keep the rest of the implementation (batch_size,
out_dtype, return_valid behavior) unchanged.

@zack041
Copy link
Contributor Author

zack041 commented Feb 19, 2026

Thanks for the fix @zack041, the direction looks right. A few things to address before we merge:

  1. TopKTopPSamplingFromProbKernel is not fixed. The combined top-k + top-p kernel has the same pattern but wasn't touched — it will have the same crash on all-NaN input. Please add the same fix there.
  2. Test coverage for non-NaN rows. The test only sets row 0 to NaN. Please also assert that the other rows still produce valid token indices (i.e., NaN in one request doesn't affect other requests in the batch).
  3. Set success output to false for NaN rows. Currently NaN rows silently output index 0, which is indistinguishable from legitimately sampling token 0. The success output tensor should be set to false for these rows so callers can detect the failure.

Once these are addressed, let's get this merged.

Hi @yzh119 , the fixes are added, detailed description has been updated in the PR description.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 20, 2026

@flashinfer-bot run

@yzh119
Copy link
Collaborator

yzh119 commented Feb 20, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !331 has been created, and the CI pipeline #44428153 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #44428153: 9/20 passed

@yzh119
Copy link
Collaborator

yzh119 commented Feb 22, 2026

@flashinfer-bot run

@yzh119 yzh119 added the run-ci label Feb 22, 2026
@yzh119
Copy link
Collaborator

yzh119 commented Feb 23, 2026

Hi @kahyunnam would you taking a look at this PR as well? Relevant issue #2455 .

@vadiklyutiy
Copy link

Could you update, are you continue to work on this PR?

@yzh119 yzh119 merged commit f6ec0d8 into flashinfer-ai:main Mar 13, 2026
47 of 50 checks passed
frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
…shinfer-ai#2456)

<!-- .github/pull_request_template.md -->

## 📌 Description

### Summary

Fix illegal memory access when input probabilities contain NaN values.
Added `valid` output tensor so callers can distinguish failed sampling
from legitimately sampling token 0. Also added missing
`@register_fake_op` for `min_p_sampling_from_probs` to support
`torch.compile`.

### Changes

#### API
- New `return_valid: bool = False` parameter for all sampling functions
- When `True`, returns `(samples, valid)` tuple

#### Kernel
- Added `bool* valid` output to sampling kernels
- Initialize `last_valid_id = -1` before sampling
- `valid[bx] = false` when no valid token found (NaN input)
- `valid[bx] = true` when valid

#### PyTorch Integration
- Add missing `@register_fake_op` for `min_p_sampling_from_probs`

#### Affected Functions
- `sampling_from_probs`
- `top_k_sampling_from_probs`
- `top_p_sampling_from_probs`
- `min_p_sampling_from_probs`
- `top_k_top_p_sampling_from_probs`

### Usage
```python
# Default
samples = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50)

# With validity check
samples, valid = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50, return_valid=True)
# valid[i] == False means NaN or invalid input for row i
```

Fixes flashinfer-ai#2402

## 🔍 Related Issues

flashinfer-ai#2402

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

I'm currently using 0 as the placeholder token, let me know if a
different value is preferred. I will add the optional NaN counting
feature in a follow up PR.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Sampling APIs can optionally return a per-sample boolean "valid" mask
alongside outputs (use a new return flag).

* **Bug Fixes**
* Per-iteration state reset and improved handling of invalid/NaN or
out-of-range probability rows: such cases now produce a neutral zero
sample and mark valid=false.

* **Tests**
* Added tests verifying behavior with NaN/invalid probability inputs and
the new valid-mask return.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…shinfer-ai#2456)

<!-- .github/pull_request_template.md -->

## 📌 Description

### Summary

Fix illegal memory access when input probabilities contain NaN values.
Added `valid` output tensor so callers can distinguish failed sampling
from legitimately sampling token 0. Also added missing
`@register_fake_op` for `min_p_sampling_from_probs` to support
`torch.compile`.

### Changes

#### API
- New `return_valid: bool = False` parameter for all sampling functions
- When `True`, returns `(samples, valid)` tuple

#### Kernel
- Added `bool* valid` output to sampling kernels
- Initialize `last_valid_id = -1` before sampling
- `valid[bx] = false` when no valid token found (NaN input)
- `valid[bx] = true` when valid

#### PyTorch Integration
- Add missing `@register_fake_op` for `min_p_sampling_from_probs`

#### Affected Functions
- `sampling_from_probs`
- `top_k_sampling_from_probs`
- `top_p_sampling_from_probs`
- `min_p_sampling_from_probs`
- `top_k_top_p_sampling_from_probs`

### Usage
```python
# Default
samples = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50)

# With validity check
samples, valid = flashinfer.sampling.top_k_sampling_from_probs(probs, top_k=50, return_valid=True)
# valid[i] == False means NaN or invalid input for row i
```

Fixes flashinfer-ai#2402

## 🔍 Related Issues

flashinfer-ai#2402

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

I'm currently using 0 as the placeholder token, let me know if a
different value is preferred. I will add the optional NaN counting
feature in a follow up PR.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Sampling APIs can optionally return a per-sample boolean "valid" mask
alongside outputs (use a new return flag).

* **Bug Fixes**
* Per-iteration state reset and improved handling of invalid/NaN or
out-of-range probability rows: such cases now produce a neutral zero
sample and mark valid=false.

* **Tests**
* Added tests verifying behavior with NaN/invalid probability inputs and
the new valid-mask return.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

bug: improve sampling kernels robustness when inputs include NaN

4 participants