trtllm non causal support#3020
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughPropagates a causal flag through the paged-attention stack: CUDA launcher/context signatures gain an Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test/Client
participant Py as flashinfer.prefill
participant Kernel as trtllm_paged_attention_launcher (CUDA)
Test->>Py: trtllm_batch_context_with_kv_cache(..., causal)
Py->>Py: validate causal, window_left, logits_soft_cap
Py->>Kernel: trtllm_paged_attention_context(..., is_causal = causal)
Kernel-->>Py: attention outputs / updated KV cache
Py-->>Test: return logits / updated KV cache
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
55071d9 to
bfa018c
Compare
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
benchmarks/routines/attention.py (1)
1269-1283:⚠️ Potential issue | 🟡 MinorNon-causal support is still blocked for
trtllm-nativein this benchmark flow.
causal=causalis forwarded correctly here, but the earlier backend filter still removestrtllm-nativewhencausal=False, so non-causal benchmarking for this path remains unreachable.Suggested fix
- if "trtllm-native" in backends: - remove_trtllm_native = False - if not causal: - print("[INFO] trtllm-native backend currently requires causal = True") - remove_trtllm_native = True - if remove_trtllm_native: - backends.remove("trtllm-native") + # Keep trtllm-native enabled for both causal and non-causal prefill.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 1269 - 1283, The benchmark currently forwards causal to flashinfer.prefill.trtllm_batch_context_with_kv_cache but an earlier backend filter incorrectly excludes the "trtllm-native" backend when causal is False, preventing non-causal runs; update the backend-selection logic (the filter that builds the backends/selected_backends list) to allow "trtllm-native" for non-causal runs (or add an explicit exception for "trtllm-native") so that when causal=False the trtllm-native path still runs and reaches trtllm_batch_context_with_kv_cache with causal=False.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/prefill.py`:
- Around line 1996-2001: Replace the runtime assertion on logits_soft_cap with
an explicit parameter validation that raises a ValueError with a clear message;
locate the check where logits_soft_cap == 0.0 (and the surrounding conditional
using causal and window_left) and change the assert to: if logits_soft_cap !=
0.0: raise ValueError("logits_soft_cap must be 0.0 for trtllm-gen paged KV
cache") so the public API validates reliably even when Python is run with -O.
- Line 3747: The new boolean parameter causal was inserted between window_left
and out which can break callers using positional args; move the causal parameter
to the end of the function signature (after uses_shared_paged_kv_idx) in the
affected function(s) in flashinfer/prefill.py so it becomes a trailing
keyword-only parameter, update any internal calls accordingly (use named
argument where needed), and add a brief note in the function's docstring or
changelog if you prefer to keep the current ordering; reference the function
signature containing window_left, out, causal, and uses_shared_paged_kv_idx to
locate the change.
---
Outside diff comments:
In `@benchmarks/routines/attention.py`:
- Around line 1269-1283: The benchmark currently forwards causal to
flashinfer.prefill.trtllm_batch_context_with_kv_cache but an earlier backend
filter incorrectly excludes the "trtllm-native" backend when causal is False,
preventing non-causal runs; update the backend-selection logic (the filter that
builds the backends/selected_backends list) to allow "trtllm-native" for
non-causal runs (or add an explicit exception for "trtllm-native") so that when
causal=False the trtllm-native path still runs and reaches
trtllm_batch_context_with_kv_cache with causal=False.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: bc7eddd5-e9f0-4637-926c-6b266b97ef0e
📥 Commits
Reviewing files that changed from the base of the PR and between c2b4db2 and bfa018cfa5015ddde2812a0f17b8bf980cc696d6.
📒 Files selected for processing (4)
benchmarks/routines/attention.pycsrc/trtllm_fmha_kernel_launcher.cuflashinfer/prefill.pytests/attention/test_trtllm_gen_attention.py
There was a problem hiding this comment.
Code Review
This pull request adds support for non-causal (dense/bidirectional) attention to the trtllm-gen backend for paged KV cache, restricted to cases where window_left is -1. Key changes include modifying the CUDA kernel launcher to toggle between causal and dense mask types, updating the Python prefill interface to accept a causal flag, and adding a new test suite for non-causal prefill. Feedback indicates that the placement of the new causal parameters in flashinfer/prefill.py breaks backward compatibility for positional arguments and should be moved to the end of the function signatures.
bfa018c to
413b821
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Around line 378-380: The public FFI entrypoint was changed by inserting
is_causal before workspace_size, shifting positional parameters; revert to a
backward-compatible signature by moving is_causal to the end of the parameter
list (or add an overload/shim that preserves the original ordering) and give it
a default value so existing positional callers are unaffected; update the
function that contains the parameters enable_pdl, workspace_size,
attention_sinks, key_block_scales, value_block_scales,
skip_softmax_threshold_scale_factor and is_causal (or add a wrapper with the old
ordering that calls the new implementation) to maintain positional compatibility
for the exported context API.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2e785937-cb2a-4a24-bbf3-656a14f62200
📥 Commits
Reviewing files that changed from the base of the PR and between bfa018cfa5015ddde2812a0f17b8bf980cc696d6 and 413b821.
📒 Files selected for processing (4)
benchmarks/routines/attention.pycsrc/trtllm_fmha_kernel_launcher.cuflashinfer/prefill.pytests/attention/test_trtllm_gen_attention.py
✅ Files skipped from review due to trivial changes (1)
- flashinfer/prefill.py
🚧 Files skipped from review as they are similar to previous changes (1)
- benchmarks/routines/attention.py
bkryu
left a comment
There was a problem hiding this comment.
PR looks concise, but breaks API compat. Otherwise looks goot to me. Can you check the comments?
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Around line 491-492: Before calling the launcher that forwards is_causal (the
call passing uses_shared_paged_kv_idx_value, sm_count, enable_pdl, is_causal,
workspace_size, ...), add a fast-fail guard that detects the unsupported
combination: if is_causal is false AND window_left is finite (i.e., not the
sentinel meaning “infinite”/unbounded), immediately return an error (e.g.,
cudaErrorInvalidValue / appropriate error code or set a failing status) and log
a clear message; do not proceed to invoke the launcher. Refer to the local
variables is_causal and window_left and place this check just before the
launcher invocation.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 340674e6-aedc-4904-9f74-8899a37ed397
📒 Files selected for processing (2)
csrc/trtllm_fmha_kernel_launcher.cuflashinfer/prefill.py
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/prefill.py
|
@saltyminty, can you check the internal CI results? |
|
/bot run |
78328a2 to
9362cce
Compare
712496e to
637fe72
Compare
The prefill helper now accepts an explicit causal flag, but the head-dim 512 wrapper still used the old positional layout and the non-causal coverage was duplicated in a separate test. Folding causal=True/False into the main prefill matrix keeps the broad coverage in one wrapper and fixes the missing-argument CI failure. Constraint: PR #3020 needs both causal and non-causal TRTLLM-gen paged prefill coverage without duplicating the same parameter matrix. Rejected: Keep a separate non-causal wrapper | duplicates the main prefill matrix and let one helper call drift out of sync. Confidence: high Scope-risk: narrow Tested: python3 -m py_compile tests/attention/test_trtllm_gen_attention.py Tested: AST check that all _test_trtllm_batch_prefill calls pass the expected positional argument count Not-tested: Full GPU pytest matrix locally
637fe72 to
5221174
Compare
|
/bot run |
|
/bot run |
2a616b9 to
38ac474
Compare
|
/bot run |
38ac474 to
681685f
Compare
|
/bot run |
The prefill helper now accepts an explicit causal flag, but the head-dim 512 wrapper still used the old positional layout and the non-causal coverage was duplicated in a separate test. Folding causal=True/False into the main prefill matrix keeps the broad coverage in one wrapper and fixes the missing-argument CI failure. Constraint: PR #3020 needs both causal and non-causal TRTLLM-gen paged prefill coverage without duplicating the same parameter matrix. Rejected: Keep a separate non-causal wrapper | duplicates the main prefill matrix and let one helper call drift out of sync. Confidence: high Scope-risk: narrow Tested: python3 -m py_compile tests/attention/test_trtllm_gen_attention.py Tested: AST check that all _test_trtllm_batch_prefill calls pass the expected positional argument count Not-tested: Full GPU pytest matrix locally
The paged TRTLLM launcher now carries causal state for context kernels. Keeping the new boolean adjacent to the stream argument avoids shifting the existing workspace and stride argument group in the middle of a long internal helper signature. Constraint: The exported TVM FFI and Python APIs already carry causal as a trailing/defaulted argument. Rejected: Leave is_causal before workspace_size | makes future audits of same-typed launcher arguments more error-prone. Confidence: high Scope-risk: narrow Tested: git diff --check Tested: python3 -m py_compile flashinfer/prefill.py tests/attention/test_trtllm_gen_attention.py benchmarks/routines/attention.py Tested: repository search found only the launcher definition and two call sites Not-tested: Full CUDA build locally
681685f to
646dd0f
Compare
|
/bot run |
📌 Description
Non-causal (dense-mask) support to trtllm_batch_context_with_kv_cache
NOTE TO REVIEWER: the new "casual" input being inserted in the middle of the public API could cause API regressions to users using positional arguments. I think the current ordering next to window_left makes more sense, but for reviewer to double check if we should instead move it to the end.
🔍 Related Issues
#2826
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes / Validation
Tests