[MISC] Add strict contiguity check for FlashInfer attention tensors#32008
Conversation
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a is_strictly_contiguous utility function to perform a stricter check for tensor contiguity, addressing potential memory access issues in FlashInfer CUDA kernels caused by degenerate strides. This new check is correctly applied to the FlashInfer TRTLLM prefill attention path. The implementation of the new utility is robust. However, a similar vulnerability seems to exist in the TRTLLM decode path which has not been addressed in this PR. I've added a critical comment to highlight this omission.
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
| else: | ||
| assert self.o_sf_scale is None | ||
| out = output[num_decode_tokens:] | ||
|
|
There was a problem hiding this comment.
Should we also check if out is contiguous?
| else: | ||
| assert isinstance(attn_metadata.prefill, TRTLLMPrefill) | ||
| # prefill_query may be non-contiguous | ||
| prefill_query = prefill_query.contiguous() |
There was a problem hiding this comment.
I did some tests, and if torch tensor's is_contiguous() returns True where is_strictly_contiguous returns False, tensor.contiguous() actually doesn't make it a contiguous tensor. Eg:
t_base = torch.randn(16, 8, 32)
t2 = t_base.as_strided(size=(16, 1, 8, 32), stride=(256, 1, 32, 1))
t5 = t2.contiguous()
here, the result is -
t2 ->
Shape: torch.Size([16, 1, 8, 32]), Stride: (256, 1, 32, 1)
torch.is_contiguous(): True
is_strictly_contiguous: False
Expected canonical stride: (256, 256, 32, 1)
t5 ->
Shape: torch.Size([16, 1, 8, 32]), Stride: (256, 1, 32, 1)
torch.is_contiguous(): True
is_strictly_contiguous: False
t2.contiguous() didn't actually convert to a contiguous tensor. While the assertion works in detecting a noncontiguous tensor, the earlier prefill_query.contiguous() may be insufficient. Wondering if we need to do anything additionally?
There was a problem hiding this comment.
we may need to do -
prefill_query_contiguous = torch.empty(t2.shape, dtype=t2.dtype, device=t2.device)
prefill_query_contiguous.copy_(prefill_query)
And similar for the rest where we may do squeeze / unsqueeze / slice.
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
|
Hi @vadiklyutiy, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…llm-project#32008) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
…llm-project#32008) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
…llm-project#32008) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Early check of potential error as in #30842. See also #31617, flashinfer-ai/flashinfer#2232
Updates FlashInfer attention path to use stricter contiguous check, preventing potential CUDA kernel memory access issues.
Introduces
is_strictly_contiguous()utility to detect tensors with degenerate strides that PyTorch'sis_contiguous()reports as contiguous.Note
Strengthens memory layout validation for FlashInfer TRTLLM attention.
is_strictly_contiguous(t)invllm/utils/torch_utils.pyto verify canonical contiguous strides and catch degenerate-stride tensors.is_contiguous()assertions withis_strictly_contiguous()inflashinfer.pyTRTLLM (HND) prefill and decode paths forquery,kv_cache_permute, workspace buffer,block_tables, andseq_lensWritten by Cursor Bugbot for commit b1e334e. This will update automatically on new commits. Configure here.