Skip to content

feat: Add Option For Fixed cta_tile_q#2830

Open
frankwang28 wants to merge 8 commits into
flashinfer-ai:mainfrom
frankwang28:add-fixed-cta-tile-q
Open

feat: Add Option For Fixed cta_tile_q#2830
frankwang28 wants to merge 8 commits into
flashinfer-ai:mainfrom
frankwang28:add-fixed-cta-tile-q

Conversation

@frankwang28
Copy link
Copy Markdown

@frankwang28 frankwang28 commented Mar 19, 2026

📌 Description

This PR adds functionality for a caller to set a fixed cta_tile_q size.

This is mainly a use case for batch invariance as dynamically chosen cta_tile_q values can lead to variant outputs.

🔍 Related Issues

Fixes #2424

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Also tested integration with vLLM using a slightly modified test_logprobs_bitwise_batch_invariance_bs1_vs_bsN which uses Qwen/Qwen3-1.7B (gqa_group_size=2).

vLLM with 0.6.6 FlashInfer:

CUDA_VISIBLE_DEVICES=4 pytest tests/v1/determinism/test_batch_invariance.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[FLASHINFER] -s

...

FAILED tests/v1/determinism/test_batch_invariance.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[FLASHINFER] - RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {}
================================================== 1 failed, 19 warnings in 14.64s ===================================================

Logging the args of the failing request to FlashInfer's BatchPrefillWithPagedKVCacheWrapper plan:

Single request:

{
    "qo_indptr": tensor([0, 372], dtype=torch.int32),
    "paged_kv_indptr": tensor([0, 24], dtype=torch.int32),
    "paged_kv_indices": tensor(
        [
            57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
        ],
        device="cuda:0",
        dtype=torch.int32,
    ),
    "paged_kv_last_page_len": tensor([4], dtype=torch.int32),
    "num_qo_heads": 16,
    "num_kv_heads": 8,
    "head_dim_qk": 128,
    "page_size": 16,
    "causal": True,
    "sm_scale": 0.08838834764831845,
    "window_left": -1,
    "logits_soft_cap": None,
    "q_data_type": torch.bfloat16,
    "kv_data_type": torch.bfloat16,
    "o_data_type": torch.bfloat16,
    "fixed_split_size": 4096,
    "disable_split_kv": True,
}

avg_packed_qo_len = 372 * 2 = 744 and so FA2DetermineCtaTileQ -> 128

Batched with other requests:

{
    "qo_indptr": tensor(
        [ 
            0, 8, 19, 27, 37, 50, 57, 73, 83, 96, 108, 120, 133, 149, 165, 182, 192, 202, 212, 223, 233, 242, 258, 266, 278, 291, 302, 312, 684,
        ],
        dtype=torch.int32,
    ),
    "paged_kv_indptr": tensor(
        [ 
            1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 53,
        ],
        dtype=torch.int32,
    ),
    "paged_kv_indices": tensor(
        [ 
            82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134,
        ],
        device="cuda:0",
        dtype=torch.int32,
    ),
    "paged_kv_last_page_len": tensor(
        [ 
            8, 11, 8, 10, 13, 7, 16, 10, 13, 12, 12, 13, 16, 16, 1, 10, 10, 10, 11, 10, 9, 16, 8, 12, 13, 11, 10, 4,
        ],
        dtype=torch.int32,
    ),
    "num_qo_heads": 16,
    "num_kv_heads": 8,
    "head_dim_qk": 128,
    "page_size": 16,
    "causal": True,
    "sm_scale": 0.08838834764831845,
    "window_left": -1,
    "logits_soft_cap": None,
    "q_data_type": torch.bfloat16,
    "kv_data_type": torch.bfloat16,
    "o_data_type": torch.bfloat16,
    "fixed_split_size": 4096,
    "disable_split_kv": True,
}

avg_packed_qo_len = 684 * 2 / 28 = 48.8571428571 and so FA2DetermineCtaTileQ -> 64

vLLM with FlashInfer built off of this branch:

CUDA_VISIBLE_DEVICES=4 pytest tests/v1/determinism/test_batch_invariance.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[FLASHINFER] -s

...

================================================== 1 passed, 19 warnings in 28.28s ===================================================

Logging the args of the previously failing request sent to FlashInfer's BatchPrefillWithPagedKVCacheWrapper plan:

{
    "qo_indptr": tensor([0, 372], dtype=torch.int32),
    "paged_kv_indptr": tensor([0, 24], dtype=torch.int32),
    "paged_kv_indices": tensor(
        [ 
            57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
        ],
        device="cuda:0",
        dtype=torch.int32,
    ),
    "paged_kv_last_page_len": tensor([4], dtype=torch.int32),
    "num_qo_heads": 16,
    "num_kv_heads": 8,
    "head_dim_qk": 128,
    "page_size": 16,
    "causal": True,
    "sm_scale": 0.08838834764831845,
    "window_left": -1,
    "logits_soft_cap": None,
    "q_data_type": torch.bfloat16,
    "kv_data_type": torch.bfloat16,
    "o_data_type": torch.bfloat16,
    "fixed_split_size": 4096,
    "disable_split_kv": True,
    "fixed_cta_tile_q": 128,
}

avg_packed_qo_len = 372 * 2 = 744, which would typically cause FA2DetermineCtaTileQ -> 128 but doesn't matter as fixed_cta_tile_q overrides to 128 anyways.

Batched with other requests:

{
    "qo_indptr": tensor(
        [ 
            0, 8, 18, 31, 38, 54, 64, 77, 89, 101, 114, 130, 146, 163, 173, 183, 193, 204, 214, 223, 239, 247, 259, 272, 283, 293, 665,
        ],
        dtype=torch.int32,
    ),
    "paged_kv_indptr": tensor(
        [ 
            3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 53,
        ],
        dtype=torch.int32,
    ),
    "paged_kv_indices": tensor(
        [ 
            82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134,
        ],
        device="cuda:0",
        dtype=torch.int32,
    ),
    "paged_kv_last_page_len": tensor(
        [ 
            8, 10, 13, 7, 16, 10, 13, 12, 12, 13, 16, 16, 1, 10, 10, 10, 11, 10, 9, 16, 8, 12, 13, 11, 10, 4,
        ],
        dtype=torch.int32,
    ),
    "num_qo_heads": 16,
    "num_kv_heads": 8,
    "head_dim_qk": 128,
    "page_size": 16,
    "causal": True,
    "sm_scale": 0.08838834764831845,
    "window_left": -1,
    "logits_soft_cap": None,
    "q_data_type": torch.bfloat16,
    "kv_data_type": torch.bfloat16,
    "o_data_type": torch.bfloat16,
    "fixed_split_size": 4096,
    "disable_split_kv": True,
    "fixed_cta_tile_q": 128,
}

avg_packed_qo_len = 665 * 2 / 26 = 51.1538461538, which would typically cause FA2DetermineCtaTileQ -> 64 but fixed_cta_tile_q overrides to 128.

Summary by CodeRabbit

  • New Features

    • Added a configurable fixed_cta_tile_q planning option (16, 64, 128; default auto) with validation, clear user-facing errors for unsupported values, backend restriction (only allowed for the fa2 tensor-core backend), and a restriction against 128 with large head dimensions. Integrated into prefill, decode, POD and sparse planning flows.
  • Tests

    • Added tests covering valid values, invalid-value errors, backend restrictions, and large-head-dim incompatibility.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 19, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 84eb5c95-d0e4-47cb-b596-387a036c004e

📥 Commits

Reviewing files that changed from the base of the PR and between 1bfb3b3 and b36e76a.

📒 Files selected for processing (1)
  • flashinfer/decode.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/decode.py

📝 Walkthrough

Walkthrough

Adds a new planning parameter fixed_cta_tile_q across Python APIs, JIT bindings, and the CUDA scheduler; validates allowed values and head_dim compatibility; forwards the parameter into prefill planning and kernel-selection logic.

Changes

Batch Prefill / Decode End-to-End

Layer / File(s) Summary
Public API Signatures
flashinfer/decode.py, flashinfer/prefill.py
Added fixed_cta_tile_q: Optional[int] = None to plan(...) and fast_decode_plan(...) signatures; docstrings updated.
Validation Helper
flashinfer/utils.py
Added _validate_fixed_cta_tile_q(fixed_cta_tile_q, head_dim) to normalize None -> -1, enforce allowed set {16,64,128}, and reject 128 when head_dim >= 256.
Python Plan Wiring
flashinfer/decode.py, flashinfer/prefill.py, flashinfer/pod.py, flashinfer/sparse.py
Validate/normalize fixed_cta_tile_q in plan paths; require tensor-core decode and fa2 backend when normalized value != -1; append fixed_cta_tile_q into FA2 plan argument lists (positioned after disable_split_kv).
JIT Binding
csrc/batch_prefill_jit_binding.cu
Updated BatchPrefillWithKVCachePlan declaration to include fixed_cta_tile_q parameter (shifts subsequent arg positions).
CUDA Scheduler / Core
csrc/batch_prefill.cu, include/flashinfer/attention/scheduler.cuh
Extended PrefillPlan / PrefillSplitQOKVIndptr signatures to accept fixed_cta_tile_q; selection logic now uses fixed value when >0 (validate in {16,64,128}), reject 128 for head_dim >= 256; reorganized total tile computations and added explicit error messages.
Call sites / Forwarding
csrc/batch_prefill.cu, ...scheduler.cuh
Forwarded the new fixed_cta_tile_q argument into lower-level PrefillPlan / PrefillSplitQOKVIndptr calls and updated call sites to match new parameter ordering.
Tests
tests/attention/test_batch_invariant_fa2.py, tests/attention/test_batch_prefill.py, tests/attention/test_tensor_cores_decode.py
Parametrized existing tests to include fixed_cta_tile_q values (16,64,128); added unit tests asserting ValueError for unsupported tile values, head_dim incompatibility, and non-fa2 backend; skip invalid combinations in parametrized tests.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Python as Python Plan Caller
  participant Wrapper as Wrapper (.plan)
  participant JIT as JIT Binding
  participant GPU as CUDA Scheduler
  Python->>Wrapper: .plan(..., fixed_cta_tile_q)
  Wrapper->>Python: validate via _validate_fixed_cta_tile_q
  Wrapper->>JIT: cached_module.plan(..., fixed_cta_tile_q, ...)
  JIT->>GPU: BatchPrefillWithKVCachePlan(..., fixed_cta_tile_q, ...)
  GPU->>GPU: PrefillPlan(..., fixed_cta_tile_q) / PrefillSplitQOKVIndptr(...)
  GPU-->>JIT: plan array / tile-size decisions
  JIT-->>Wrapper: return plan/results
  Wrapper-->>Python: propagate plan outcome
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

run-ci, op: attention

Suggested reviewers

  • yzh119
  • aleozlx
  • sricketts
  • cyx-6
  • bkryu
  • kahyunnam
  • jimmyzho
  • nvmbreughe

Poem

🐰 I hopped through code to set a tile,
Fixed CTA sizes keep batching in style,
Sixteen, sixty-four, or one-twenty-eight,
I validated, rejected, and updated the slate,
Hooray for steady tiles and tests that celebrate! 🥕

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 52.94% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The PR title 'feat: Add Option For Fixed cta_tile_q' clearly summarizes the main change: adding a new optional parameter for fixed cta_tile_q configuration.
Description check ✅ Passed The PR description covers all template requirements: clear description of changes, related issue link, pre-commit checks marked complete, tests added/passing, and detailed reviewer notes with reproduction logs.
Linked Issues check ✅ Passed The PR fully addresses issue #2424 by implementing a fixed cta_tile_q option that provides qo_len-invariant CTA determination, enabling batch-invariant behavior as requested.
Out of Scope Changes check ✅ Passed All changes are directly scoped to adding the fixed_cta_tile_q parameter across the codebase: C++ bindings, Python wrappers, validation logic, and tests—nothing extraneous.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances FlashInfer's attention planning by providing a mechanism to enforce a fixed CTA tile size. This change is crucial for achieving deterministic behavior and batch invariance, particularly when dealing with varying batch sizes or CUDA graph optimizations, by preventing the dynamic selection of tile sizes that could lead to inconsistent outputs.

Highlights

  • Fixed CTA Tile Q Option: Introduced a new fixed_cta_tile_q parameter across various FlashInfer attention planning functions, allowing callers to explicitly set a fixed CTA tile size for improved batch invariance.
  • Batch Invariance Improvement: This feature addresses issues where dynamically chosen cta_tile_q values could lead to variant outputs, especially in scenarios involving CUDA graphs, ensuring more consistent results.
  • Input Validation: Added validation logic for fixed_cta_tile_q to ensure it's one of the supported values (16, 64, 128) and to prevent incompatible configurations, such as fixed_cta_tile_q=128 with head_dim >= 256.
  • Expanded Test Coverage: New test cases were added to verify the correct behavior and error handling of the fixed_cta_tile_q parameter, including tests for invalid values and incompatible head dimensions.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an option to set a fixed cta_tile_q size, which is a valuable feature for ensuring batch invariance and deterministic outputs. The changes are well-implemented across the C++ backend and Python wrappers, and the inclusion of new tests for validation and invariance is commendable.

My main feedback is regarding code duplication in the Python validation logic for fixed_cta_tile_q. I've left specific comments suggesting refactoring this into a shared helper function to improve maintainability. Other than that, the changes look solid.

Comment thread flashinfer/decode.py Outdated
Comment thread flashinfer/prefill.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
flashinfer/decode.py (1)

2675-2684: Consider extracting duplicated validation to a helper function.

The validation logic for fixed_cta_tile_q is duplicated between plan() (lines 988-997) and fast_decode_plan() (lines 2675-2684). Consider extracting this to a small helper function to improve maintainability.

♻️ Suggested helper extraction
def _validate_fixed_cta_tile_q(fixed_cta_tile_q: Optional[int], head_dim: int) -> int:
    """Validate and normalize fixed_cta_tile_q parameter.
    
    Returns -1 for auto heuristic, or the validated value.
    """
    if fixed_cta_tile_q is None:
        return -1
    if fixed_cta_tile_q not in (16, 64, 128):
        raise ValueError(
            f"fixed_cta_tile_q should be one of {{16, 64, 128}}, got {fixed_cta_tile_q}"
        )
    if fixed_cta_tile_q == 128 and head_dim >= 256:
        raise ValueError(
            f"fixed_cta_tile_q=128 is not supported with head_dim={head_dim} (requires head_dim < 256)"
        )
    return fixed_cta_tile_q
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 2675 - 2684, Extract the duplicated
fixed_cta_tile_q validation into a helper function (e.g.,
_validate_fixed_cta_tile_q) that accepts fixed_cta_tile_q and head_dim and
returns -1 when fixed_cta_tile_q is None or the validated value otherwise; move
the three checks (None -> -1, membership in (16,64,128), and the 128 vs head_dim
>= 256 error) into that helper, and replace the inline logic in both plan() and
fast_decode_plan() with a call to this helper to normalize/validate the value
and preserve the same ValueError messages and semantics.
flashinfer/prefill.py (1)

1818-1827: Scope fixed_cta_tile_q validation to FA2-resolved plans.

fixed_cta_tile_q is documented as FA2-specific, but current validation executes unconditionally before backend resolution. This can reject non-FA2 plans for a parameter that is otherwise ignored there. Consider validating compatibility only when the resolved backend is fa2.

Also applies to: 2809-2818

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 1818 - 1827, The validation for
fixed_cta_tile_q currently runs unconditionally; restrict this check so it only
executes for FA2-resolved plans by wrapping the existing checks for
fixed_cta_tile_q and the head_dim_vo compatibility check in a conditional that
first verifies the plan/backend is FA2 (e.g., resolved_backend == "fa2" or
plan.is_fa2) at the point after backend resolution; apply the same change to the
other duplicate validation block (the one around the second occurrence) so
non-FA2 backends won’t be rejected for this FA2-specific parameter.
tests/attention/test_batch_invariant_fa2.py (1)

56-67: Consider reducing the new Cartesian test expansion.

Adding fixed_cta_tile_q as a full-axis multiplier triples already-large matrices and may make GPU CI much slower/flakier. Prefer a smaller targeted matrix (or dedicated focused cases) for fixed_cta_tile_q coverage.

Also applies to: 198-210

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_batch_invariant_fa2.py` around lines 56 - 67, The
Cartesian expansion from adding full-axis parametrize for fixed_cta_tile_q is
inflating test matrix size; restrict it by replacing
pytest.mark.parametrize("fixed_cta_tile_q", [16, 64, 128]) with a much smaller
set (e.g., a single representative value like [64] or two targeted values like
[16, 64]) or move fixed_cta_tile_q into a separate, focused test marked slow so
it doesn't multiply all other params; update the parametrize for
fixed_cta_tile_q and/or add a dedicated test function (or pytest.mark.slow) to
cover the remaining tiles without expanding the entire grid, ensuring references
to the same parameter name fixed_cta_tile_q are adjusted accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@flashinfer/decode.py`:
- Around line 2675-2684: Extract the duplicated fixed_cta_tile_q validation into
a helper function (e.g., _validate_fixed_cta_tile_q) that accepts
fixed_cta_tile_q and head_dim and returns -1 when fixed_cta_tile_q is None or
the validated value otherwise; move the three checks (None -> -1, membership in
(16,64,128), and the 128 vs head_dim >= 256 error) into that helper, and replace
the inline logic in both plan() and fast_decode_plan() with a call to this
helper to normalize/validate the value and preserve the same ValueError messages
and semantics.

In `@flashinfer/prefill.py`:
- Around line 1818-1827: The validation for fixed_cta_tile_q currently runs
unconditionally; restrict this check so it only executes for FA2-resolved plans
by wrapping the existing checks for fixed_cta_tile_q and the head_dim_vo
compatibility check in a conditional that first verifies the plan/backend is FA2
(e.g., resolved_backend == "fa2" or plan.is_fa2) at the point after backend
resolution; apply the same change to the other duplicate validation block (the
one around the second occurrence) so non-FA2 backends won’t be rejected for this
FA2-specific parameter.

In `@tests/attention/test_batch_invariant_fa2.py`:
- Around line 56-67: The Cartesian expansion from adding full-axis parametrize
for fixed_cta_tile_q is inflating test matrix size; restrict it by replacing
pytest.mark.parametrize("fixed_cta_tile_q", [16, 64, 128]) with a much smaller
set (e.g., a single representative value like [64] or two targeted values like
[16, 64]) or move fixed_cta_tile_q into a separate, focused test marked slow so
it doesn't multiply all other params; update the parametrize for
fixed_cta_tile_q and/or add a dedicated test function (or pytest.mark.slow) to
cover the remaining tiles without expanding the entire grid, ensuring references
to the same parameter name fixed_cta_tile_q are adjusted accordingly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: fc264a2a-c5dc-487a-85e0-2de88b01bbbd

📥 Commits

Reviewing files that changed from the base of the PR and between 6f0928c and 056d60a.

📒 Files selected for processing (9)
  • csrc/batch_prefill.cu
  • csrc/batch_prefill_jit_binding.cu
  • flashinfer/decode.py
  • flashinfer/pod.py
  • flashinfer/prefill.py
  • flashinfer/sparse.py
  • include/flashinfer/attention/scheduler.cuh
  • tests/attention/test_batch_invariant_fa2.py
  • tests/attention/test_batch_prefill.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/decode.py`:
- Line 849: The parameter fixed_cta_tile_q is currently ignored in
non-tensor-core paths; mirror the fixed_split_size guard by rejecting/raising
when fixed_cta_tile_q is non-None and use_tensor_cores is False in the public
plan() (validate-and-drop) and similarly drop/raise in fast_decode_plan() where
parameters are pre-resolution. Additionally, after backend resolution (the code
path that inspects the resolved tensor-core backend and chooses FA2), add a
check that if the chosen tensor-core path is not "fa2" and fixed_cta_tile_q is
non-None, reject it (raise or error) so fixed_cta_tile_q is only accepted when
the final backend is fa2; update the same checks where fixed_split_size is
handled to cover fixed_cta_tile_q as well.

In `@flashinfer/prefill.py`:
- Line 1679: After resolving the effective backend (i.e., after handling
backend="auto") add a fail-fast check: if fixed_cta_tile_q is not None and the
resolved backend is not "fa2", raise a clear ValueError indicating
fixed_cta_tile_q is only supported for the fa2 backend. Locate the check and
insertion point inside plan() where self._backend == "fa2" is handled and
move/duplicate the validation so it runs after backend resolution (not before),
and mirror the same change in the other affected call sites referenced (around
the fixed_cta_tile_q occurrences and plan-like flows at the other locations) so
non-None values are rejected unless backend == "fa2".

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 979bb9a2-f818-4ca8-8b09-204e5e45ea05

📥 Commits

Reviewing files that changed from the base of the PR and between 056d60a and 44af458.

📒 Files selected for processing (3)
  • flashinfer/decode.py
  • flashinfer/prefill.py
  • flashinfer/utils.py
✅ Files skipped from review due to trivial changes (1)
  • flashinfer/utils.py

Comment thread flashinfer/decode.py
Comment thread flashinfer/prefill.py
@frankwang28 frankwang28 changed the title Add Option For Fixed cta_tile_q feat: Add Option For Fixed cta_tile_q Mar 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] qo_len Invariant CTA Sizes

2 participants