Skip to content

[MISC] Add strict contiguity check for FlashInfer attention tensors#32008

Merged
vllm-bot merged 5 commits intovllm-project:mainfrom
CentML:vadim/trtgen-attn-stronger-asserts
Jan 10, 2026
Merged

[MISC] Add strict contiguity check for FlashInfer attention tensors#32008
vllm-bot merged 5 commits intovllm-project:mainfrom
CentML:vadim/trtgen-attn-stronger-asserts

Conversation

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

@vadiklyutiy vadiklyutiy commented Jan 9, 2026

Early check of potential error as in #30842. See also #31617, flashinfer-ai/flashinfer#2232

Updates FlashInfer attention path to use stricter contiguous check, preventing potential CUDA kernel memory access issues.
Introduces is_strictly_contiguous() utility to detect tensors with degenerate strides that PyTorch's is_contiguous() reports as contiguous.


Note

Strengthens memory layout validation for FlashInfer TRTLLM attention.

  • Adds is_strictly_contiguous(t) in vllm/utils/torch_utils.py to verify canonical contiguous strides and catch degenerate-stride tensors
  • Replaces .is_contiguous() assertions with is_strictly_contiguous() in flashinfer.py TRTLLM (HND) prefill and decode paths for query, kv_cache_permute, workspace buffer, block_tables, and seq_lens
  • Aims to fail fast before launching CUDA kernels; no algorithmic changes

Written by Cursor Bugbot for commit b1e334e. This will update automatically on new commits. Configure here.

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a is_strictly_contiguous utility function to perform a stricter check for tensor contiguity, addressing potential memory access issues in FlashInfer CUDA kernels caused by degenerate strides. This new check is correctly applied to the FlashInfer TRTLLM prefill attention path. The implementation of the new utility is robust. However, a similar vulnerability seems to exist in the TRTLLM decode path which has not been addressed in this PR. I've added a critical comment to highlight this omission.

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Copy link
Copy Markdown
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Jan 9, 2026
@vadiklyutiy vadiklyutiy added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 9, 2026
else:
assert self.o_sf_scale is None
out = output[num_decode_tokens:]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also check if out is contiguous?

else:
assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
# prefill_query may be non-contiguous
prefill_query = prefill_query.contiguous()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some tests, and if torch tensor's is_contiguous() returns True where is_strictly_contiguous returns False, tensor.contiguous() actually doesn't make it a contiguous tensor. Eg:

t_base = torch.randn(16, 8, 32)
t2 = t_base.as_strided(size=(16, 1, 8, 32), stride=(256, 1, 32, 1))
t5 = t2.contiguous()

here, the result is -

t2  ->
  Shape: torch.Size([16, 1, 8, 32]), Stride: (256, 1, 32, 1)
  torch.is_contiguous(): True
  is_strictly_contiguous: False
  Expected canonical stride: (256, 256, 32, 1)
t5 ->
  Shape: torch.Size([16, 1, 8, 32]), Stride: (256, 1, 32, 1)
  torch.is_contiguous(): True
  is_strictly_contiguous: False

t2.contiguous() didn't actually convert to a contiguous tensor. While the assertion works in detecting a noncontiguous tensor, the earlier prefill_query.contiguous() may be insufficient. Wondering if we need to do anything additionally?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may need to do -

prefill_query_contiguous = torch.empty(t2.shape, dtype=t2.dtype, device=t2.device)
prefill_query_contiguous.copy_(prefill_query)

And similar for the rest where we may do squeeze / unsqueeze / slice.

Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Jan 10, 2026

Hi @vadiklyutiy, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
@vllm-bot vllm-bot merged commit e15a5ff into vllm-project:main Jan 10, 2026
54 of 56 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Jan 10, 2026
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
…llm-project#32008)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…llm-project#32008)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…llm-project#32008)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
@vadiklyutiy vadiklyutiy deleted the vadim/trtgen-attn-stronger-asserts branch March 11, 2026 08:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants