Skip to content

trtllm non causal support#3020

Merged
saltyminty merged 6 commits into
mainfrom
fix/mingyangw/2826-trtllm-gen-non-causal-support
May 1, 2026
Merged

trtllm non causal support#3020
saltyminty merged 6 commits into
mainfrom
fix/mingyangw/2826-trtllm-gen-non-causal-support

Conversation

@saltyminty
Copy link
Copy Markdown
Collaborator

@saltyminty saltyminty commented Apr 9, 2026

📌 Description

Non-causal (dense-mask) support to trtllm_batch_context_with_kv_cache

NOTE TO REVIEWER: the new "casual" input being inserted in the middle of the public API could cause API regressions to users using positional arguments. I think the current ordering next to window_left makes more sense, but for reviewer to double check if we should instead move it to the end.

🔍 Related Issues

#2826

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Public "causal" option for batched prefill so users can select causal vs non-causal attention; runtime execution respects this choice.
  • Bug Fixes / Validation

    • Added runtime validation: non-causal mode disallowed with sliding-window left offset and requires logits soft-cap == 0.0 (errors raised when violated).
  • Tests

    • Added and parameterized non-causal attention tests; updated batch-prefill tests and added runtime skips for unsupported combinations.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 9, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Propagates a causal flag through the paged-attention stack: CUDA launcher/context signatures gain an is_causal parameter, Python prefill APIs accept/validate and forward a public causal argument, and tests/benchmarks updated to exercise causal and non‑causal flows.

Changes

Cohort / File(s) Summary
CUDA Kernel Launcher
csrc/trtllm_fmha_kernel_launcher.cu
Added bool is_causal to trtllm_paged_attention_launcher and trtllm_paged_attention_context; context now selects Causal vs Dense based on is_causal; decode still calls with true; added runtime check disallowing non-causal + sliding-window.
Python API & Wrapper
flashinfer/prefill.py
Added public causal: bool = True to trtllm_batch_context_with_kv_cache; _paged_run and paged path accept/forward is_causal; TRT-LLM paged_run computes is_causal from mask_mode; tightened validations (logits_soft_cap constraint and non-causal + window_left check).
Tests
tests/attention/test_trtllm_gen_attention.py
Threaded causal through _test_trtllm_batch_prefill, added skip for unsupported non-causal+sliding-window, added test_trtllm_batch_prefill_non_causal, and extracted shared shapes/dtypes constants.
Benchmarks
benchmarks/routines/attention.py
TRT-LLM native backend call updated to pass causal=causal into the prefill wrapper.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test/Client
    participant Py as flashinfer.prefill
    participant Kernel as trtllm_paged_attention_launcher (CUDA)

    Test->>Py: trtllm_batch_context_with_kv_cache(..., causal)
    Py->>Py: validate causal, window_left, logits_soft_cap
    Py->>Kernel: trtllm_paged_attention_context(..., is_causal = causal)
    Kernel-->>Py: attention outputs / updated KV cache
    Py-->>Test: return logits / updated KV cache
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • sricketts
  • aleozlx
  • yzh119
  • cyx-6
  • jimmyzho
  • kahyunnam
  • nv-yunzheq
  • samuellees

Poem

🐰 I nudged a causal flag from Python down to CUDA,

KV caches hummed as code found its way,
Tests hopped in to prove both paths true,
Benchmarks waved a carrot for the play,
A little rabbit patch — hooray!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 23.53% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: adding non-causal support to trtllm functionality.
Description check ✅ Passed The description addresses the main objective and includes context about API concerns, but pre-commit and test checklist items are left unchecked.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fix/mingyangw/2826-trtllm-gen-non-causal-support

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@saltyminty saltyminty force-pushed the fix/mingyangw/2826-trtllm-gen-non-causal-support branch from 55071d9 to bfa018c Compare April 9, 2026 00:09
@saltyminty
Copy link
Copy Markdown
Collaborator Author

/bot run

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
benchmarks/routines/attention.py (1)

1269-1283: ⚠️ Potential issue | 🟡 Minor

Non-causal support is still blocked for trtllm-native in this benchmark flow.

causal=causal is forwarded correctly here, but the earlier backend filter still removes trtllm-native when causal=False, so non-causal benchmarking for this path remains unreachable.

Suggested fix
-    if "trtllm-native" in backends:
-        remove_trtllm_native = False
-        if not causal:
-            print("[INFO] trtllm-native backend currently requires causal = True")
-            remove_trtllm_native = True
-        if remove_trtllm_native:
-            backends.remove("trtllm-native")
+    # Keep trtllm-native enabled for both causal and non-causal prefill.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 1269 - 1283, The benchmark
currently forwards causal to
flashinfer.prefill.trtllm_batch_context_with_kv_cache but an earlier backend
filter incorrectly excludes the "trtllm-native" backend when causal is False,
preventing non-causal runs; update the backend-selection logic (the filter that
builds the backends/selected_backends list) to allow "trtllm-native" for
non-causal runs (or add an explicit exception for "trtllm-native") so that when
causal=False the trtllm-native path still runs and reaches
trtllm_batch_context_with_kv_cache with causal=False.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/prefill.py`:
- Around line 1996-2001: Replace the runtime assertion on logits_soft_cap with
an explicit parameter validation that raises a ValueError with a clear message;
locate the check where logits_soft_cap == 0.0 (and the surrounding conditional
using causal and window_left) and change the assert to: if logits_soft_cap !=
0.0: raise ValueError("logits_soft_cap must be 0.0 for trtllm-gen paged KV
cache") so the public API validates reliably even when Python is run with -O.
- Line 3747: The new boolean parameter causal was inserted between window_left
and out which can break callers using positional args; move the causal parameter
to the end of the function signature (after uses_shared_paged_kv_idx) in the
affected function(s) in flashinfer/prefill.py so it becomes a trailing
keyword-only parameter, update any internal calls accordingly (use named
argument where needed), and add a brief note in the function's docstring or
changelog if you prefer to keep the current ordering; reference the function
signature containing window_left, out, causal, and uses_shared_paged_kv_idx to
locate the change.

---

Outside diff comments:
In `@benchmarks/routines/attention.py`:
- Around line 1269-1283: The benchmark currently forwards causal to
flashinfer.prefill.trtllm_batch_context_with_kv_cache but an earlier backend
filter incorrectly excludes the "trtllm-native" backend when causal is False,
preventing non-causal runs; update the backend-selection logic (the filter that
builds the backends/selected_backends list) to allow "trtllm-native" for
non-causal runs (or add an explicit exception for "trtllm-native") so that when
causal=False the trtllm-native path still runs and reaches
trtllm_batch_context_with_kv_cache with causal=False.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: bc7eddd5-e9f0-4637-926c-6b266b97ef0e

📥 Commits

Reviewing files that changed from the base of the PR and between c2b4db2 and bfa018cfa5015ddde2812a0f17b8bf980cc696d6.

📒 Files selected for processing (4)
  • benchmarks/routines/attention.py
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/prefill.py
  • tests/attention/test_trtllm_gen_attention.py

Comment thread flashinfer/prefill.py Outdated
Comment thread flashinfer/prefill.py Outdated
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !527 has been created, and the CI pipeline #48068274 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for non-causal (dense/bidirectional) attention to the trtllm-gen backend for paged KV cache, restricted to cases where window_left is -1. Key changes include modifying the CUDA kernel launcher to toggle between causal and dense mask types, updating the Python prefill interface to accept a causal flag, and adding a new test suite for non-causal prefill. Feedback indicates that the placement of the new causal parameters in flashinfer/prefill.py breaks backward compatibility for positional arguments and should be moved to the end of the function signatures.

Comment thread flashinfer/prefill.py Outdated
Comment thread flashinfer/prefill.py Outdated
Comment thread tests/attention/test_trtllm_gen_attention.py Outdated
Comment thread tests/attention/test_trtllm_gen_attention.py Outdated
@saltyminty saltyminty force-pushed the fix/mingyangw/2826-trtllm-gen-non-causal-support branch from bfa018c to 413b821 Compare April 9, 2026 17:56
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Around line 378-380: The public FFI entrypoint was changed by inserting
is_causal before workspace_size, shifting positional parameters; revert to a
backward-compatible signature by moving is_causal to the end of the parameter
list (or add an overload/shim that preserves the original ordering) and give it
a default value so existing positional callers are unaffected; update the
function that contains the parameters enable_pdl, workspace_size,
attention_sinks, key_block_scales, value_block_scales,
skip_softmax_threshold_scale_factor and is_causal (or add a wrapper with the old
ordering that calls the new implementation) to maintain positional compatibility
for the exported context API.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2e785937-cb2a-4a24-bbf3-656a14f62200

📥 Commits

Reviewing files that changed from the base of the PR and between bfa018cfa5015ddde2812a0f17b8bf980cc696d6 and 413b821.

📒 Files selected for processing (4)
  • benchmarks/routines/attention.py
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/prefill.py
  • tests/attention/test_trtllm_gen_attention.py
✅ Files skipped from review due to trivial changes (1)
  • flashinfer/prefill.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • benchmarks/routines/attention.py

Comment thread csrc/trtllm_fmha_kernel_launcher.cu Outdated
@saltyminty saltyminty requested a review from yzh119 April 9, 2026 18:02
Copy link
Copy Markdown
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR looks concise, but breaks API compat. Otherwise looks goot to me. Can you check the comments?

Comment thread flashinfer/prefill.py Outdated
Comment thread flashinfer/prefill.py Outdated
@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Apr 10, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !527 has been updated with latest changes, and the CI pipeline #48169618 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Around line 491-492: Before calling the launcher that forwards is_causal (the
call passing uses_shared_paged_kv_idx_value, sm_count, enable_pdl, is_causal,
workspace_size, ...), add a fast-fail guard that detects the unsupported
combination: if is_causal is false AND window_left is finite (i.e., not the
sentinel meaning “infinite”/unbounded), immediately return an error (e.g.,
cudaErrorInvalidValue / appropriate error code or set a failing status) and log
a clear message; do not proceed to invoke the launcher. Refer to the local
variables is_causal and window_left and place this check just before the
launcher invocation.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 340674e6-aedc-4904-9f74-8899a37ed397

📥 Commits

Reviewing files that changed from the base of the PR and between 413b821 and 6de575d.

📒 Files selected for processing (2)
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/prefill.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/prefill.py

Comment thread csrc/trtllm_fmha_kernel_launcher.cu Outdated
@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Apr 10, 2026

@saltyminty, can you check the internal CI results?

@saltyminty
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !527 has been updated with latest changes, and the CI pipeline #48227207 is currently running. I'll report back once the pipeline job completes.

@saltyminty saltyminty force-pushed the fix/mingyangw/2826-trtllm-gen-non-causal-support branch from 78328a2 to 9362cce Compare April 10, 2026 18:05
@saltyminty saltyminty force-pushed the fix/mingyangw/2826-trtllm-gen-non-causal-support branch from 712496e to 637fe72 Compare April 29, 2026 16:40
saltyminty added a commit that referenced this pull request Apr 29, 2026
The prefill helper now accepts an explicit causal flag, but the head-dim 512 wrapper still used the old positional layout and the non-causal coverage was duplicated in a separate test. Folding causal=True/False into the main prefill matrix keeps the broad coverage in one wrapper and fixes the missing-argument CI failure.

Constraint: PR #3020 needs both causal and non-causal TRTLLM-gen paged prefill coverage without duplicating the same parameter matrix.

Rejected: Keep a separate non-causal wrapper | duplicates the main prefill matrix and let one helper call drift out of sync.

Confidence: high

Scope-risk: narrow

Tested: python3 -m py_compile tests/attention/test_trtllm_gen_attention.py

Tested: AST check that all _test_trtllm_batch_prefill calls pass the expected positional argument count

Not-tested: Full GPU pytest matrix locally
@saltyminty saltyminty force-pushed the fix/mingyangw/2826-trtllm-gen-non-causal-support branch from 637fe72 to 5221174 Compare April 29, 2026 16:47
@saltyminty
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !527 has been updated with latest changes, and the CI pipeline #49825288 is currently running. I'll report back once the pipeline job completes.

@saltyminty
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !527 has been updated with latest changes, and the CI pipeline #49835735 is currently running. I'll report back once the pipeline job completes.

@saltyminty saltyminty force-pushed the fix/mingyangw/2826-trtllm-gen-non-causal-support branch from 2a616b9 to 38ac474 Compare April 30, 2026 22:59
@saltyminty
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !527 has been updated with latest changes, and the CI pipeline #49960840 is currently running. I'll report back once the pipeline job completes.

@saltyminty saltyminty force-pushed the fix/mingyangw/2826-trtllm-gen-non-causal-support branch from 38ac474 to 681685f Compare May 1, 2026 05:58
@saltyminty
Copy link
Copy Markdown
Collaborator Author

/bot run

The prefill helper now accepts an explicit causal flag, but the head-dim 512 wrapper still used the old positional layout and the non-causal coverage was duplicated in a separate test. Folding causal=True/False into the main prefill matrix keeps the broad coverage in one wrapper and fixes the missing-argument CI failure.

Constraint: PR #3020 needs both causal and non-causal TRTLLM-gen paged prefill coverage without duplicating the same parameter matrix.

Rejected: Keep a separate non-causal wrapper | duplicates the main prefill matrix and let one helper call drift out of sync.

Confidence: high

Scope-risk: narrow

Tested: python3 -m py_compile tests/attention/test_trtllm_gen_attention.py

Tested: AST check that all _test_trtllm_batch_prefill calls pass the expected positional argument count

Not-tested: Full GPU pytest matrix locally
The paged TRTLLM launcher now carries causal state for context kernels. Keeping the new boolean adjacent to the stream argument avoids shifting the existing workspace and stride argument group in the middle of a long internal helper signature.

Constraint: The exported TVM FFI and Python APIs already carry causal as a trailing/defaulted argument.

Rejected: Leave is_causal before workspace_size | makes future audits of same-typed launcher arguments more error-prone.

Confidence: high

Scope-risk: narrow

Tested: git diff --check

Tested: python3 -m py_compile flashinfer/prefill.py tests/attention/test_trtllm_gen_attention.py benchmarks/routines/attention.py

Tested: repository search found only the launcher definition and two call sites

Not-tested: Full CUDA build locally
@saltyminty saltyminty force-pushed the fix/mingyangw/2826-trtllm-gen-non-causal-support branch from 681685f to 646dd0f Compare May 1, 2026 05:58
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !527 has been updated with latest changes, and the CI pipeline #49984029 is currently running. I'll report back once the pipeline job completes.

@saltyminty
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !527 has been created, and the CI pipeline #49984128 is currently running. I'll report back once the pipeline job completes.

@saltyminty saltyminty merged commit e1e6714 into main May 1, 2026
30 checks passed
@saltyminty saltyminty deleted the fix/mingyangw/2826-trtllm-gen-non-causal-support branch May 1, 2026 18:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants