Skip to content

Conversation

@AKKamath
Copy link
Contributor

@AKKamath AKKamath commented Nov 12, 2025

Co-authored-by: @Edenzzzz

📌 Description

Fixes #1022. Unlike #1231, this splits the inputs into separate prefill and decode inputs. It probably should be possible to automatically handle this splitting in Python so you can simply just provide a single batch of requests?

To run the benchmark for this run: python benchmarks/bench_mixed_attention.py

Performance:
===== Benchmark 1: (kv_len, qo_len) set =====
Prefill = 2 requests, 2048 Q len, 2048 KV len
Decode = 128 requests, 2048 KV len
Elapsed time (Batched Prefill): 0.65 ms
Elapsed time (Batched POD Attention): 0.46 ms
Elapsed time (Persistent BatchAttention): 0.56 ms
Batch POD speedup over Persistent BatchAttention: 1.22x

===== Benchmark 2: (kv_len, qo_len) set =====
Prefill = 1 request, 2048 Q len, 2048 KV len
Decode = 128 requests, 2048 KV len
Elapsed time (Batched Prefill): 0.55 ms
Elapsed time (Batched POD Attention): 0.41 ms
Elapsed time (POD Attention): 0.41 ms
Elapsed time (Sequential two kernels): 0.51 ms
Elapsed time (Persistent BatchAttention): 0.45 ms
Batch POD speedup over Persistent BatchAttention: 1.11x

===== Benchmark 3: (kv_len, qo_len) set =====
Prefill = 1 request, 4096 Q len, 4096 KV len
Decode = 128 requests, 4096 KV len
Elapsed time (Batched Prefill): 1.27 ms
Elapsed time (Batched POD Attention): 0.86 ms
Elapsed time (POD Attention): 0.82 ms
Elapsed time (Sequential two kernels): 1.15 ms
Elapsed time (Persistent BatchAttention): 1.08 ms
Batch POD speedup over Persistent BatchAttention: 1.26x

===== Benchmark 4: (kv_len, qo_len) set =====
Prefill = 1 request, 4096 Q len, 4096 KV len
Decode = 128 requests, 8192 KV len
Elapsed time (Batched Prefill): 2.15 ms
Elapsed time (Batched POD Attention): 1.52 ms
Elapsed time (POD Attention): 1.54 ms
Elapsed time (Sequential two kernels): 1.82 ms
Elapsed time (Persistent BatchAttention): 1.76 ms
Batch POD speedup over Persistent BatchAttention: 1.16x

===== Benchmark 5: (kv_len, qo_len) set =====
Prefill = 1 request, 6000 Q len, 7000 KV len
Decode = 128 requests, 8192 KV len
Elapsed time (Batched Prefill): 2.86 ms
Elapsed time (Batched POD Attention): 2.03 ms
Elapsed time (POD Attention): 1.95 ms
Elapsed time (Sequential two kernels): 2.52 ms
Elapsed time (Persistent BatchAttention): 2.45 ms
Batch POD speedup over Persistent BatchAttention: 1.20x

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added a batched prefill+decode attention path with a public batch-oriented POD wrapper and JIT module export.
  • Performance

    • Benchmarks extended to include batched-path timings, memory bandwidth, elapsed-time and comparative speedup metrics across expanded prefill/decode scenarios.
  • API

    • Runtime binding for batched KV‑cache execution added; planning APIs now accept an optional colocated-CTA parameter that influences scheduling.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 12, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a two-phase batched POD attention implementation: new host/JIT bindings, CUDA kernel and dispatch, JIT codegen and Python wrapper for batched prefill+decode with paged KV-cache, introduces num_colocated_ctas planner parameter, and updates benchmarks to exercise and time the batched path.

Changes

Cohort / File(s) Summary
Batch POD host & dispatcher
csrc/batch_pod.cu
New host entry batch_pod_with_kv_cache_tensor: validates tensors/strides, assembles Prefill/Decode params and plan_info, prepares buffers/indptrs/masks, selects dispatch variant and launches dispatched BatchPOD flow (SM‑aware, CUDA‑graph aware).
JIT binding (CUDA ↔ runtime)
csrc/batch_pod_jit_binding.cu, csrc/batch_prefill_jit_binding.cu
Adds exported FFI binding batch_pod_with_kv_cache_tensor(...) and updates BatchPrefillWithKVCachePlan binding to accept num_colocated_ctas. Exports for TVM/FFI.
Kernel & device dispatch
include/flashinfer/attention/batch_pod.cuh
New global kernel BatchPODWithKVCacheTensorKernel and host template BatchPODWithKVCacheTensorDispatched: two‑phase PREFILL/DECODE scheduling, SM‑aware atomic block assignment, occupancy/shared‑memory calculations, split‑KV handling and post‑kernel merge/aggregation.
JIT config & instantiation templates
csrc/batch_pod_customize_config.jinja, csrc/batch_pod_kernel_inst.jinja
New Jinja config with DType/Id aliases, head/feature constexprs, Prefill/Decode typedefs and DISPATCH_context macro; kernel-instantiation template emits explicit instantiations across CTA_TILE_Q and mask modes.
Planner / scheduler changes
include/flashinfer/attention/scheduler.cuh, csrc/batch_prefill.cu
PrefillPlan / BatchPrefillWithKVCachePlan gain int64_t num_colocated_ctas; scheduler uses it to reduce available CTAs/grid size and affects split‑KV decisions.
Python JIT generation & modules
flashinfer/jit/attention/modules.py, flashinfer/jit/__init__.py, flashinfer/jit/attention/__init__.py
Adds gen_batch_pod_module and batch customization flow: generates batch_pod_config, mask‑mode kernel variants, and exposes gen_batch_pod_module in the jit package.
Python wrapper & API surface
flashinfer/pod.py, flashinfer/__init__.py
Adds get_batch_pod_module, new BatchPODWithPagedKVCacheWrapper (plan/run lifecycle) and exposes it at package root; wires batched run to JIT symbol batch_pod_with_kv_cache_tensor.
Plan call updates (Python clients)
flashinfer/decode.py, flashinfer/prefill.py, flashinfer/sparse.py
FA2/fast_decode plan invocations now pass an extra trailing 0 argument representing num_colocated_ctas for specific backend plan calls.
Benchmarks
benchmarks/bench_mixed_attention.py
Adds Batched POD Attention path: constructs batched indptrs/plan info, runs batch wrapper, validates outputs vs baseline, records timings and bandwidth, and prints comparisons against persistent and other paths.
Minor doc/comment
csrc/pod_jit_binding.cu
Updated descriptive comment to reflect "Single prefill, Batch-request decode attention with KV-Cache operator" (no API change).

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Bench as Benchmark
    participant Wrapper as BatchPODWithPagedKVCacheWrapper
    participant JIT as JIT Module (batch_pod_with_kv_cache_tensor)
    participant Host as batch_pod_with_kv_cache_tensor (host)
    participant Kernel as BatchPOD Kernel
    participant Post as Split-KV Postproc

    Bench->>Wrapper: plan(prefill_plans, decode_plans, ...)
    Bench->>Wrapper: run(q_p, paged_kv_p, ..., q_d, paged_kv_d, ...)
    Wrapper->>JIT: call run_tensor -> batch_pod_with_kv_cache_tensor(...)
    JIT->>Host: validate/prepare PrefillParams & DecodeParams, bind buffers
    Host->>Kernel: launch BatchPODWithKVCacheTensorKernel (SM-aware dispatch)
    Kernel->>Kernel: assign blocks PREFILL / DECODE, compute attention, write tmp buffers
    Kernel-->>Host: kernel complete
    alt split-KV present
        Host->>Post: merge/aggregate VariableLength results
        Post-->>Host: merged outputs
    end
    Host-->>Wrapper: return outputs + timings
    Wrapper-->>Bench: validated result + timings
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Files warranting extra attention:
    • include/flashinfer/attention/batch_pod.cuh — SM‑aware scheduling, occupancy/shared‑memory sizing, split‑KV postprocessing.
    • csrc/batch_pod.cu and csrc/batch_pod_jit_binding.cu — parameter packing, stride/shape validation, DISPATCH correctness, CUDA‑graph flags.
    • ABI/FFI surface and lifetime of temporary buffers (tmp_v*/tmp_s*).
    • Propagation and semantics of num_colocated_ctas through planner and user call sites.

Possibly related PRs

Suggested reviewers

  • bkryu
  • cyx-6
  • djmmoss
  • yzh119
  • nvmbreughe
  • joker-eph
  • wenscarl

Poem

🐇
I hop through kernels, swift and neat,
Two‑phase prefill then decode — what a feat.
Pages of KV tumble, blocks take their cue,
Batched attention hums, the timings ring true.
The rabbit claps, "Benchmark — hop to you!"

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 57.14% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title '[Feature] Support batch prefill for POD Attention' accurately describes the main feature addition: batched prefill support for POD Attention, which is the primary change across the changeset.
Description check ✅ Passed The PR description provides comprehensive context: references issue #1022, explains the implementation approach with separate prefill/decode inputs, includes benchmark results showing performance improvements, and confirms completed pre-commit and testing checklists.
Linked Issues check ✅ Passed Changes directly address issue #1022: implement batch prefill for POD Attention to improve performance versus BatchPrefillWithPagedKVCache, with benchmarks showing 1.11x-1.26x speedups over Persistent BatchAttention and fixes for mixed prefill/decode workloads.
Out of Scope Changes check ✅ Passed All changes are scoped to batch POD Attention implementation: new batch_pod modules, templates, wrappers, kernel implementations, and related plan parameter updates. No unrelated modifications or extraneous changes detected.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @AKKamath, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the FlashInfer library by integrating batch prefill capabilities into the POD Attention mechanism. This allows for more efficient processing of batched requests by separating and optimizing the prefill and decode stages. The changes include new CUDA kernels and a Python API, which collectively deliver measurable performance gains over previous attention implementations, as evidenced by the provided benchmarks.

Highlights

  • Batch POD Attention Support: Introduced comprehensive support for batch prefill operations within the POD Attention mechanism, allowing for distinct handling of prefill and decode inputs.
  • Performance Improvements: Benchmarks demonstrate significant speedups ranging from 1.11x to 1.26x over the existing Persistent BatchAttention implementation across various configurations.
  • New CUDA Kernels and Python API: Added new CUDA C++ kernels (batch_pod.cu, batch_pod.cuh) and a Python wrapper (BatchPODWithPagedKVCacheWrapper) to facilitate the new batch POD attention functionality.
  • Dynamic SM-Aware Scheduling: Implemented an SM-aware CTA scheduler within the CUDA kernel to dynamically assign thread blocks to either prefill or decode operations, optimizing resource utilization.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for batched prefill in POD Attention, which is a significant feature enhancement. The implementation includes new CUDA kernels, JIT compilation logic, and a Python wrapper. The benchmark results look promising. I've found a few critical issues related to correctness in the benchmark and Python wrapper code, as well as a potential memory leak and thread-safety issue in the CUDA kernel implementation. Additionally, there are several opportunities for code cleanup and improving maintainability by removing dead code, TODOs, and refactoring duplicated logic. Addressing these points will improve the robustness and quality of this new feature.

Comment on lines +161 to +315
PrefillParams prefill_params;
DTypeO* tmp_v_p = nullptr;
float* tmp_s_p = nullptr;
{
PrefillParams& params = prefill_params;
params.q = static_cast<DTypeQ*>(q_p.data_ptr());
paged_kv_t<DTypeKV, IdType> paged_kv(
num_kv_heads_p, page_size_p, HEAD_DIM_VO, batch_size_p, kv_layout_p,
static_cast<DTypeKV*>(paged_k_cache_p.data_ptr()),
static_cast<DTypeKV*>(paged_v_cache_p.data_ptr()), kv_cache_strides_p,
static_cast<IdType*>(paged_kv_indices_p.data_ptr()),
static_cast<IdType*>(paged_kv_indptr_p.data_ptr()),
static_cast<IdType*>(paged_kv_last_page_len_p.data_ptr()));
params.paged_kv = paged_kv;
params.q_indptr = static_cast<IdType*>(qo_indptr_p.data_ptr());
params.o = static_cast<DTypeO*>(o_p.data_ptr());

params.lse = maybe_lse_p.has_value() ? static_cast<float*>(maybe_lse_p.value().data_ptr())
: nullptr;
params.num_qo_heads = num_qo_heads;
params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads);
params.q_stride_n = q_stride_n_p;
params.q_stride_h = q_stride_h_p;
params.window_left = window_left_p;

params.request_indices = nullptr;
params.qo_tile_indices = nullptr;
params.kv_tile_indices = nullptr;
params.merge_indptr = nullptr;
params.o_indptr = nullptr;
params.kv_chunk_size_ptr = nullptr;
params.block_valid_mask = nullptr;
params.total_num_rows = nullptr;
params.max_total_num_rows = 0;
params.padded_batch_size = 0;
params.partition_kv = false;

params.maybe_mask_indptr =
maybe_mask_indptr_p.has_value()
? static_cast<int32_t*>(maybe_mask_indptr_p.value().data_ptr())
: nullptr;
params.maybe_alibi_slopes =
maybe_alibi_slopes_p.has_value()
? static_cast<float*>(maybe_alibi_slopes_p.value().data_ptr())
: nullptr;
params.logits_soft_cap = logits_soft_cap_p;
params.sm_scale = sm_scale_p;
params.rope_rcp_scale = rope_rcp_scale_p;
params.rope_rcp_theta = rope_rcp_theta_p;

params.request_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_p, plan_info_p.request_indices_offset);
params.qo_tile_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_p, plan_info_p.qo_tile_indices_offset);
params.kv_tile_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_p, plan_info_p.kv_tile_indices_offset);
params.o_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_p, plan_info_p.o_indptr_offset);
params.kv_chunk_size_ptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_p, plan_info_p.kv_chunk_size_ptr_offset);
if (plan_info_p.split_kv) {
params.merge_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_p, plan_info_p.merge_indptr_offset);
tmp_v_p = GetPtrFromBaseOffset<DTypeO>(float_buffer_ptr_p, plan_info_p.v_offset);
tmp_s_p = GetPtrFromBaseOffset<float>(float_buffer_ptr_p, plan_info_p.s_offset);
if (plan_info_p.enable_cuda_graph) {
params.block_valid_mask =
GetPtrFromBaseOffset<bool>(int_buffer_ptr_p, plan_info_p.block_valid_mask_offset);
}
}
params.padded_batch_size = plan_info_p.padded_batch_size;
params.max_total_num_rows = plan_info_p.total_num_rows;
if (plan_info_p.enable_cuda_graph) {
params.total_num_rows =
GetPtrFromBaseOffset<uint32_t>(int_buffer_ptr_p, plan_info_p.total_num_rows_offset);
}
}

DecodeParams decode_params;
DTypeO* tmp_v_d = nullptr;
float* tmp_s_d = nullptr;
{
DecodeParams& params = decode_params;
params.q = static_cast<DTypeQ*>(q_d.data_ptr());
paged_kv_t<DTypeKV, IdType> paged_kv(
num_kv_heads_d, page_size_d, HEAD_DIM_VO, batch_size_d, kv_layout_d,
static_cast<DTypeKV*>(paged_k_cache_d.data_ptr()),
static_cast<DTypeKV*>(paged_v_cache_d.data_ptr()), kv_cache_strides_d,
static_cast<IdType*>(paged_kv_indices_d.data_ptr()),
static_cast<IdType*>(paged_kv_indptr_d.data_ptr()),
static_cast<IdType*>(paged_kv_last_page_len_d.data_ptr()));
params.paged_kv = paged_kv;
params.q_indptr = static_cast<IdType*>(qo_indptr_d.data_ptr());
params.o = static_cast<DTypeO*>(o_d.data_ptr());

params.lse = maybe_lse_d.has_value() ? static_cast<float*>(maybe_lse_d.value().data_ptr())
: nullptr;
params.num_qo_heads = num_qo_heads;
params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads);
params.q_stride_n = q_stride_n_d;
params.q_stride_h = q_stride_h_d;
params.window_left = window_left_d;

params.request_indices = nullptr;
params.qo_tile_indices = nullptr;
params.kv_tile_indices = nullptr;
params.merge_indptr = nullptr;
params.o_indptr = nullptr;
params.kv_chunk_size_ptr = nullptr;
params.block_valid_mask = nullptr;
params.total_num_rows = nullptr;
params.max_total_num_rows = 0;
params.padded_batch_size = 0;
params.partition_kv = false;

params.maybe_mask_indptr =
maybe_mask_indptr_d.has_value()
? static_cast<int32_t*>(maybe_mask_indptr_d.value().data_ptr())
: nullptr;
params.maybe_alibi_slopes =
maybe_alibi_slopes_d.has_value()
? static_cast<float*>(maybe_alibi_slopes_d.value().data_ptr())
: nullptr;
params.logits_soft_cap = logits_soft_cap_d;
params.sm_scale = sm_scale_d;
params.rope_rcp_scale = rope_rcp_scale_d;
params.rope_rcp_theta = rope_rcp_theta_d;

params.request_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_d, plan_info_d.request_indices_offset);
params.qo_tile_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_d, plan_info_d.qo_tile_indices_offset);
params.kv_tile_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_d, plan_info_d.kv_tile_indices_offset);
params.o_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_d, plan_info_d.o_indptr_offset);
params.kv_chunk_size_ptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_d, plan_info_d.kv_chunk_size_ptr_offset);
if (plan_info_d.split_kv) {
params.merge_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr_d, plan_info_d.merge_indptr_offset);
tmp_v_d = GetPtrFromBaseOffset<DTypeO>(float_buffer_ptr_d, plan_info_d.v_offset);
tmp_s_d = GetPtrFromBaseOffset<float>(float_buffer_ptr_d, plan_info_d.s_offset);
if (plan_info_d.enable_cuda_graph) {
params.block_valid_mask =
GetPtrFromBaseOffset<bool>(int_buffer_ptr_d, plan_info_d.block_valid_mask_offset);
}
}
params.padded_batch_size = plan_info_d.padded_batch_size;
params.max_total_num_rows = plan_info_d.total_num_rows;
if (plan_info_d.enable_cuda_graph) {
params.total_num_rows =
GetPtrFromBaseOffset<uint32_t>(int_buffer_ptr_d, plan_info_d.total_num_rows_offset);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a significant amount of duplicated code for setting up prefill_params and decode_params. The logic within the two blocks is nearly identical. To improve maintainability and reduce redundancy, this should be refactored into a single helper function. Since PrefillParams and DecodeParams are both typedefs for BatchPrefillPagedParams, a single function can handle both setup processes.

namespace flashinfer {
constexpr auto use_custom_mask_p = {{ mask_mode_p }} == MaskMode::kCustom;
constexpr auto use_custom_mask_d = {{ mask_mode_d }} == MaskMode::kCustom;
// Not sure about the below declaration
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment indicates uncertainty. It's best to confirm the logic and remove such comments to improve code clarity. If POS_ENCODING_MODE is intentionally hardcoded, a brief explanation would be more helpful for future maintainers.

int linear_bid;
// SM-aware CTA scheduler
if (threadIdx.x == 0) {
// TODO_AK: If num_threads dont match, use virtual sub-CTAs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This file contains several TODO comments that point to limitations or areas needing fixes. For instance:

  • Line 62: // TODO_AK: If num_threads dont match, use virtual sub-CTAs.
  • Lines 223, 251: // TODO(Zihao): fix the following computation

These should be addressed to improve the robustness and correctness of the implementation.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
benchmarks/bench_mixed_attention.py (1)

25-44: Fix block/last-page calculations for general page sizes.

The new Batch POD path works only because page_block_size is hard-coded to 1. As soon as we exercise a larger page size (which BatchPOD kernels should support), the block counts and last-page lengths become wrong—the integer division truncates and the modulo is applied to “number of blocks” instead of the real token lengths. That mis-specifies kv_indptr/last_page_len, producing incorrect plans and silently skewing the benchmark once we move past the 1-token page case. Please compute block counts with a ceil and derive last-page lengths from the original sequence lengths.

-    p_seq_lens_blocks = (
-        torch.tensor(p_kv_lens, dtype=torch.int32) / page_block_size
-    ).int()
-    d_seq_lens_blocks = (
-        torch.tensor(d_kv_lens, dtype=torch.int32) / page_block_size
-    ).int()
+    p_seq_lens_blocks = torch.ceil(
+        torch.tensor(p_kv_lens, dtype=torch.float32) / page_block_size
+    ).to(torch.int32)
+    d_seq_lens_blocks = torch.ceil(
+        torch.tensor(d_kv_lens, dtype=torch.float32) / page_block_size
+    ).to(torch.int32)
@@
-    last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1
-    last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1
+    last_page_len_d = (
+        torch.tensor(d_kv_lens, dtype=torch.int32) - 1
+    ) % page_block_size + 1
+    last_page_len_p = (
+        torch.tensor(p_kv_lens, dtype=torch.int32) - 1
+    ) % page_block_size + 1
csrc/batch_pod.cu (1)

345-353: tbAssign must be per-device (and error-checked)

tbAssign is a static pointer allocated on whichever GPU first hits this path. On a multi-GPU process, subsequent launches on a different device reuse the same pointer, so cudaMemset and the kernel see an address from the wrong device and fail. We should key the scratch allocation by device (and wrap cudaMalloc/cudaMemset with the usual error checks).

-  static int* tbAssign = nullptr;
-  if (tbAssign == nullptr) cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2));
-  cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2));
+  static std::unordered_map<int, int*> tb_assign_per_device;
+  auto& tbAssign = tb_assign_per_device[dev_id];
+  if (tbAssign == nullptr) {
+    FLASHINFER_CUDA_CALL(cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2)));
+  }
+  FLASHINFER_CUDA_CALL(cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2)));
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 11177e8 and 42424b8.

📒 Files selected for processing (12)
  • benchmarks/bench_mixed_attention.py (6 hunks)
  • csrc/batch_pod.cu (1 hunks)
  • csrc/batch_pod_customize_config.jinja (1 hunks)
  • csrc/batch_pod_jit_binding.cu (1 hunks)
  • csrc/batch_pod_kernel_inst.jinja (1 hunks)
  • csrc/pod_jit_binding.cu (1 hunks)
  • flashinfer/__init__.py (1 hunks)
  • flashinfer/jit/__init__.py (1 hunks)
  • flashinfer/jit/attention/__init__.py (1 hunks)
  • flashinfer/jit/attention/modules.py (3 hunks)
  • flashinfer/pod.py (3 hunks)
  • include/flashinfer/attention/batch_pod.cuh (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (8)
flashinfer/jit/attention/__init__.py (1)
flashinfer/jit/attention/modules.py (1)
  • gen_batch_pod_module (633-695)
flashinfer/jit/attention/modules.py (3)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
build_backend.py (1)
  • write_if_different (78-82)
flashinfer/jit/attention/utils.py (1)
  • generate_additional_params (20-81)
csrc/batch_pod_jit_binding.cu (2)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • enable_pdl (220-220)
csrc/batch_pod.cu (2)
  • batch_pod_with_kv_cache_tensor (41-336)
  • batch_pod_with_kv_cache_tensor (41-60)
flashinfer/jit/__init__.py (1)
flashinfer/jit/attention/modules.py (1)
  • gen_batch_pod_module (633-695)
flashinfer/pod.py (5)
flashinfer/jit/attention/modules.py (1)
  • gen_batch_pod_module (633-695)
csrc/batch_pod.cu (2)
  • batch_pod_with_kv_cache_tensor (41-336)
  • batch_pod_with_kv_cache_tensor (41-60)
csrc/batch_pod_jit_binding.cu (1)
  • batch_pod_with_kv_cache_tensor (22-41)
flashinfer/utils.py (9)
  • _check_kv_layout (152-154)
  • canonicalize_torch_dtype (240-248)
  • PosEncodingMode (31-34)
  • device_support_pdl (569-573)
  • _unpack_paged_kv_cache (168-188)
  • _check_cached_qkv_data_type (258-268)
  • MaskMode (37-41)
  • TensorLayout (44-46)
  • _get_cache_alibi_slopes_buf (229-237)
flashinfer/page.py (1)
  • get_seq_lens (212-235)
benchmarks/bench_mixed_attention.py (2)
flashinfer/pod.py (5)
  • BatchPODWithPagedKVCacheWrapper (621-1157)
  • plan (262-430)
  • plan (751-953)
  • run (434-614)
  • run (957-1153)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (972-1033)
flashinfer/__init__.py (1)
flashinfer/pod.py (1)
  • BatchPODWithPagedKVCacheWrapper (621-1157)
csrc/batch_pod.cu (1)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
🪛 Ruff (0.14.4)
flashinfer/pod.py

765-765: Unused method argument: causal_p

(ARG002)


986-986: Unused method argument: args

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
benchmarks/bench_mixed_attention.py (1)

113-114: Critical: Incorrect last_page_len calculation still not fixed.

This issue was already flagged in a previous review but remains unfixed. The calculation uses the number of blocks (d_seq_lens_blocks, p_seq_lens_blocks) instead of the actual sequence lengths (d_kv_lens, p_kv_lens). This only works in the benchmark because page_block_size = 1, but will fail for any other page size.

Apply this diff:

-    last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1
-    last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1
+    last_page_len_d = (torch.tensor(d_kv_lens, device=device) - 1) % page_block_size + 1
+    last_page_len_p = (torch.tensor(p_kv_lens, device=device) - 1) % page_block_size + 1
🧹 Nitpick comments (1)
benchmarks/bench_mixed_attention.py (1)

286-295: Consider testing with page_block_size > 1.

The expanded configurations provide good coverage of different prefill/decode scenarios. However, since page_block_size is hardcoded to 1 (lines 21, 297), the benchmark doesn't catch bugs that only manifest with larger page sizes, such as the last_page_len calculation issue.

Consider adding at least one benchmark configuration with page_block_size = 16 or page_block_size = 32 to validate correctness with realistic page sizes.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 42424b8 and 74fefdd.

📒 Files selected for processing (1)
  • benchmarks/bench_mixed_attention.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_mixed_attention.py (1)
flashinfer/pod.py (5)
  • BatchPODWithPagedKVCacheWrapper (621-1157)
  • plan (262-430)
  • plan (751-953)
  • run (434-614)
  • run (957-1153)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
benchmarks/bench_mixed_attention.py (3)

26-31: LGTM! Block calculation correctly uses torch.ceil.

The calculation now properly rounds up when computing the number of blocks needed, addressing the previous review feedback.


38-43: LGTM! Indptr construction follows standard pattern.

The prefill indptr arrays are constructed correctly using cumulative sum with an initial zero, consistent with the decode indptr construction.


105-161: Overall structure is correct, but depends on fixing the last_page_len bug.

The batched POD attention implementation follows a logical structure:

  • Correctly splits inputs into decode and prefill portions
  • Properly concatenates outputs in the right order (decode, then prefill)
  • Includes verification against baseline with appropriate tolerances

However, the critical bug in lines 113-114 (already flagged) must be fixed for this code to work correctly with page_block_size > 1.

Aditya K Kamath and others added 5 commits November 12, 2025 00:42
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
benchmarks/bench_mixed_attention.py (1)

113-114: Compute last_page_len_* from true sequence lengths

last_page_len_d / last_page_len_p are still derived from the block counts. For any page_block_size > 1, a length like 17 with page size 16 produces (2 - 1) % 16 + 1 == 2 even though only one element lives in the last page. That mis-sizes the tail page and can drive the kernel to read stale or out-of-bounds data. Please base these tensors on the original *_kv_lens, as done for the combined path above.

-    last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1
-    last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1
+    last_page_len_d = (torch.tensor(d_kv_lens, device=device) - 1) % page_block_size + 1
+    last_page_len_p = (torch.tensor(p_kv_lens, device=device) - 1) % page_block_size + 1
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 74fefdd and 3a396f3.

📒 Files selected for processing (2)
  • benchmarks/bench_mixed_attention.py (5 hunks)
  • flashinfer/pod.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
benchmarks/bench_mixed_attention.py (2)
flashinfer/pod.py (5)
  • BatchPODWithPagedKVCacheWrapper (621-1142)
  • plan (262-430)
  • plan (751-953)
  • run (434-614)
  • run (957-1138)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (972-1033)
flashinfer/pod.py (7)
flashinfer/jit/attention/modules.py (1)
  • gen_batch_pod_module (633-695)
flashinfer/jit/core.py (1)
  • build_and_load (300-312)
csrc/batch_pod_jit_binding.cu (1)
  • batch_pod_with_kv_cache_tensor (22-41)
csrc/batch_pod.cu (2)
  • batch_pod_with_kv_cache_tensor (41-336)
  • batch_pod_with_kv_cache_tensor (41-60)
flashinfer/utils.py (8)
  • canonicalize_torch_dtype (240-248)
  • PosEncodingMode (31-34)
  • device_support_pdl (569-573)
  • _unpack_paged_kv_cache (168-188)
  • _check_cached_qkv_data_type (258-268)
  • MaskMode (37-41)
  • TensorLayout (44-46)
  • _get_cache_alibi_slopes_buf (229-237)
flashinfer/page.py (1)
  • get_seq_lens (212-235)
flashinfer/quantization.py (1)
  • packbits (45-76)
🪛 Ruff (0.14.4)
flashinfer/pod.py

765-765: Unused method argument: causal_p

(ARG002)


975-975: Unused method argument: args

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

…to Python instead. Also remove q_scale, k_scale from prefill path.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3a396f3 and 5694da7.

📒 Files selected for processing (5)
  • csrc/batch_pod.cu (1 hunks)
  • csrc/batch_pod_jit_binding.cu (1 hunks)
  • csrc/batch_pod_kernel_inst.jinja (1 hunks)
  • flashinfer/pod.py (3 hunks)
  • include/flashinfer/attention/batch_pod.cuh (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/batch_pod.cu
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/pod.py (7)
flashinfer/jit/attention/modules.py (1)
  • gen_batch_pod_module (633-695)
flashinfer/jit/core.py (1)
  • build_and_load (300-312)
csrc/batch_pod.cu (2)
  • batch_pod_with_kv_cache_tensor (41-348)
  • batch_pod_with_kv_cache_tensor (41-60)
csrc/batch_pod_jit_binding.cu (1)
  • batch_pod_with_kv_cache_tensor (22-41)
flashinfer/utils.py (9)
  • _check_kv_layout (152-154)
  • canonicalize_torch_dtype (240-248)
  • PosEncodingMode (31-34)
  • _unpack_paged_kv_cache (168-188)
  • _check_cached_qkv_data_type (258-268)
  • _check_pos_encoding_mode (147-149)
  • MaskMode (37-41)
  • TensorLayout (44-46)
  • _get_cache_alibi_slopes_buf (229-237)
flashinfer/page.py (1)
  • get_seq_lens (212-235)
flashinfer/quantization.py (1)
  • packbits (45-76)
csrc/batch_pod_jit_binding.cu (1)
csrc/batch_pod.cu (2)
  • batch_pod_with_kv_cache_tensor (41-348)
  • batch_pod_with_kv_cache_tensor (41-60)
🪛 GitHub Actions: pre-commit
include/flashinfer/attention/batch_pod.cuh

[error] 1-1: clang-format check failed. Files were modified by this hook. Re-run pre-commit locally to apply formatting fixes.

🪛 Ruff (0.14.4)
flashinfer/pod.py

771-771: Unused method argument: causal_p

(ARG002)


982-982: Unused method argument: args

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
include/flashinfer/attention/batch_pod.cuh (2)

222-224: Address TODO: Fix shared memory computation.

The shared memory calculation has a TODO indicating it needs to be fixed. Please verify the correctness of this computation or implement the proper fix.


243-247: Address TODO: Fix shared memory computation for decode.

Similar to the prefill path, the decode shared memory calculation has a TODO. Please verify or fix this computation.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (4)
csrc/batch_pod.cu (1)

161-315: Refactor duplicated parameter setup logic.

The prefill and decode parameter setup blocks (lines 164-237 and 242-315) contain nearly identical logic with only naming differences (_p vs _d suffixes). This duplication makes maintenance error-prone.

Consider extracting a template helper function:

template<typename Params, typename DTypeQ, typename DTypeKV, typename DTypeO, typename IdType>
void SetupBatchParams(
    Params& params,
    const PrefillPlanInfo& plan_info,
    TensorView q, TensorView paged_k_cache, TensorView paged_v_cache,
    /* ... other common parameters ... */
    void* float_buffer_ptr, void* int_buffer_ptr,
    DTypeO*& tmp_v, float*& tmp_s) {
  // Common setup logic here
}

Then call it twice with the appropriate parameters.

include/flashinfer/attention/batch_pod.cuh (3)

62-65: Address the TODO or add runtime validation.

The TODO indicates that the current implementation assumes matching thread counts between prefill and decode operations. If KTraits_P::NUM_THREADS != KTraits_D::NUM_THREADS, the kernel may behave incorrectly since blk_factor_p and blk_factor_d are both hardcoded to 1.

Consider adding a runtime assertion to validate this assumption:

     // TODO_AK: If num_threads dont match, use virtual sub-CTAs.
     // Requires changing block-level sync in main prefill/decode kernels.
+    static_assert(KTraits_P::NUM_THREADS == KTraits_D::NUM_THREADS,
+                  "Current implementation requires matching thread counts for prefill and decode");
     constexpr int blk_factor_p = 1;
     constexpr int blk_factor_d = 1;

221-224: Remove dead code.

These lines compute num_ctas_per_sm and max_smem_per_threadblock but the variables are never used. The same computation is performed later at lines 251-253 with the _p suffix, and those values are actually used.

Apply this diff to remove the dead code:

-  // we expect each sm execute two threadblocks
-  // TODO(Zihao): fix the following computation
-  const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_D) * 16) ? 2 : 1;
-  const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;
-

243-253: Address TODO comments regarding shared memory computations.

Multiple TODO comments (lines 243, 250) indicate that the shared memory and occupancy calculations need to be fixed. These computations are critical for kernel correctness and performance, as incorrect shared memory sizing can lead to launch failures or poor occupancy.

Would you like me to help verify the current calculations against the kernel's actual shared memory requirements, or open an issue to track this work?

🧹 Nitpick comments (1)
include/flashinfer/attention/batch_pod.cuh (1)

274-275: Consider refactoring nested dispatch (optional).

The nested DISPATCH_NUM_MMA_KV is necessary because prefill and decode have independent configuration spaces, but it does increase compile time and code complexity. Consider extracting the inner dispatch into a separate helper function to improve readability.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5694da7 and 53233ef.

📒 Files selected for processing (2)
  • csrc/batch_pod.cu (1 hunks)
  • include/flashinfer/attention/batch_pod.cuh (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/batch_pod.cu (1)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (5)
csrc/batch_pod.cu (3)

113-127: Good validation checks.

The assertions ensure that query and KV head counts match between prefill and decode operations, which is a requirement for the current kernel implementation. The error messages clearly explain the constraints.


334-336: Good buffer validation.

The assertion correctly validates that the sm_aware_sched tensor has the required shape (num_sm + 2,) and is of type int32, matching the kernel's requirements for SM-aware scheduling.


337-338: Verify hardcoded decode CTA_TILE_Q value.

The prefill path uses a dynamic cta_tile_q from plan_info_p, but the decode path hardcodes CTA_TILE_Q_D = 16. This asymmetry suggests either:

  1. Decode operations always use tile size 16 (should be documented), or
  2. The decode plan_info_d.cta_tile_q is being ignored (potential bug).

Please clarify if this is intentional. If decode should support dynamic tile sizes like prefill, the code should be:

-          constexpr size_t CTA_TILE_Q_D = 16;
+          DISPATCH_CTA_TILE_Q(plan_info_d.cta_tile_q, CTA_TILE_Q_D, {
           cudaError_t status = flashinfer::BatchPODWithKVCacheTensorDispatched<
               HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, CTA_TILE_Q_P,
               MASK_MODE_P, CTA_TILE_Q_D, MASK_MODE_D, PrefillAttentionVariant,
               DecodeAttentionVariant>(prefill_params, tmp_v_p, tmp_s_p, decode_params, tmp_v_d,
                                       tmp_s_d, enable_pdl, stream,
                                       static_cast<int*>(sm_aware_sched.data_ptr()));
+          });
include/flashinfer/attention/batch_pod.cuh (2)

335-339: No action required—buffer size validation is properly enforced.

The caller in csrc/batch_pod.cu already validates that sm_aware_sched has exactly num_sm + 2 entries (line 335-336) before dispatching the kernel. The assertion ensures the buffer meets size requirements.


69-73: I need to examine the actual file and its context more closely to determine if the review is valid.

Based on my research, I've gathered sufficient information to assess the review comment:

Key findings:

  • FlashInfer supports NVIDIA SM architectures 75 and higher, meaning the library is not limited to A100/H100
  • The %nsmid register requires PTX ISA 20 and SM_35+, so it's not exclusive to A100/H100 either
  • The existing codebase patterns show architecture checks via cudaDeviceGetAttribute for compute capability, but the inline assembly in batch_pod.cuh uses hardware-specific PTX registers rather than CUDA APIs

Assessment:
The review comment raises a valid portability concern—the inline assembly code is indeed architecture-specific and only tested on A100/H100. However, the codebase already demonstrates patterns for architecture validation elsewhere. The concern about "silent failure" is reasonable for untested architectures, though %nsmid/%smid are standard PTX registers available on SM_35+ (not just A100/H100).

The suggestion to add runtime GPU architecture validation is constructive and aligns with existing codebase practices. However, without confirming whether this code path is gated by architecture checks at the call site, it's impossible to determine if the concern is critical or if protections already exist.


Add runtime GPU architecture validation before using %nsmid register.

The inline assembly using %nsmid is only tested on A100/H100 per the warning comment. While %nsmid is available on SM_35+ architectures, behavior on untested GPU models is unverified. Consider adding a compute capability check (similar to patterns elsewhere in the codebase) or guarding this code path with architecture validation to catch incompatibilities at runtime rather than allowing silent failures or incorrect SM scheduling on untested GPUs.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 53233ef and f275715.

📒 Files selected for processing (2)
  • csrc/batch_pod_kernel_inst.jinja (1 hunks)
  • flashinfer/pod.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/pod.py (4)
flashinfer/jit/attention/modules.py (1)
  • gen_batch_pod_module (633-695)
csrc/batch_pod_jit_binding.cu (1)
  • batch_pod_with_kv_cache_tensor (22-41)
flashinfer/utils.py (10)
  • _check_kv_layout (152-154)
  • canonicalize_torch_dtype (240-248)
  • PosEncodingMode (31-34)
  • device_support_pdl (569-573)
  • _unpack_paged_kv_cache (168-188)
  • _check_cached_qkv_data_type (258-268)
  • _check_pos_encoding_mode (147-149)
  • MaskMode (37-41)
  • TensorLayout (44-46)
  • _get_cache_alibi_slopes_buf (229-237)
flashinfer/page.py (1)
  • get_seq_lens (212-235)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (7)
flashinfer/pod.py (4)

50-54: LGTM! Clean module accessor implementation.

The new get_batch_pod_module function follows the same pattern as the existing get_pod_module, using functools.cache for memoization and returning a SimpleNamespace with the appropriate run_tensor method.


687-748: LGTM! Constructor properly initializes separate buffers.

The constructor correctly:

  • Splits the float workspace buffer between prefill and decode paths (50/50 split is a reasonable default)
  • Allocates separate int workspace and pinned memory buffers for each path
  • Creates an SM-aware scheduling buffer sized appropriately for the device

754-963: LGTM! Plan method correctly handles batch prefill and decode.

The plan() method properly:

  • Accepts separate parameter sets for prefill and decode phases
  • Prepares and caches buffers and metadata for both paths
  • Calls the underlying batch prefill module's plan() for both prefill and decode with appropriate parameters
  • Stores all necessary state (_pos_encoding_mode, _window_left, etc.) for subsequent run() calls

967-1145: LGTM! Run method correctly orchestrates batch POD execution.

The run() method properly:

  • Unpacks KV caches for both prefill and decode paths
  • Retrieves cached parameters from the prior plan() call (encoding mode, window, scales, etc.)
  • Computes default values for optional parameters (sm_scale, rope_scale, etc.)
  • Handles custom mask packing when needed
  • Invokes the batch POD module with all parameters in the correct order matching the C++ FFI signature
  • Applies post-processing (v_scale) and returns the appropriate tuple based on return_lse
csrc/batch_pod_kernel_inst.jinja (3)

1-14: LGTM! Includes and namespace setup are correct.

The file properly includes all necessary FlashInfer headers, the generated config, and sets up the namespace.


21-30: LGTM! Template instantiation loop is well-structured.

The for-loop correctly instantiates BatchPODWithKVCacheTensorDispatched for three CTA_TILE_Q configurations (16, 64, 128), parameterized appropriately with head dimensions, mask modes, variants, and buffer types. The function signature matches the expected parameters including workspace buffers, PDL flag, stream, and SM-aware scheduling buffer.


31-31: LGTM! Namespace is properly closed.

The flashinfer namespace is correctly closed with } (the syntax error }; from an earlier version has been fixed).

Aditya K Kamath and others added 2 commits November 12, 2025 03:22
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (6)
flashinfer/prefill.py (2)

1713-1720: Fix undefined qo_indptr_host/total_num_rows when max_token_per_sequence is provided.

qo_indptr_host and total_num_rows are only defined in the else-branch. Both are used later unconditionally in plan args, causing runtime errors if max_token_per_sequence is not None. Define them before the branch and compute _max_q_len accordingly.

Apply this diff:

-        # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
-        if max_token_per_sequence is not None:
-            self._max_q_len = max_token_per_sequence
-        else:
-            qo_indptr_host = qo_indptr.to("cpu")
-            self._max_q_len = max(qo_indptr_host[1:] - qo_indptr_host[:-1]).item()
-            total_num_rows = int(qo_indptr_host[-1])
+        # NOTE(Zihao): always materialize qo_indptr_host; total_num_rows is used in plan args
+        qo_indptr_host = qo_indptr.to("cpu")
+        total_num_rows = int(qo_indptr_host[-1])
+        if max_token_per_sequence is not None:
+            self._max_q_len = max_token_per_sequence
+        else:
+            self._max_q_len = max(qo_indptr_host[1:] - qo_indptr_host[:-1]).item()

Also applies to: 1883-1906


673-714: Align fake-op paged_run signature with the real op (add sinks).

The fake-op lacks the sinks argument that the real custom op accepts, risking tracing/compile-time breakage.

Apply this diff:

 def _fake_paged_run(
@@
-        cum_seq_lens_kv: Optional[torch.Tensor] = None,
+        cum_seq_lens_kv: Optional[torch.Tensor] = None,
+        sinks: Optional[torch.Tensor] = None,
 ) -> None:
     pass
flashinfer/sparse.py (2)

319-324: Harden block index validation (fix off-by-one and divisibility).

Current check misses equality and doesn’t validate N % C == 0. This can permit OOB access.

Apply this diff:

-        if indices.max().item() * C > N:
-            raise ValueError("indices out of bound")
+        if N % C != 0:
+            raise ValueError("N must be divisible by C")
+        if indices.max().item() >= (N // C):
+            raise ValueError("indices out of bound")

473-479: FA2 plan call missing num_colocated_ctas trailing arg.

Binding expects the new trailing parameter; omit leads to arity mismatch at runtime.

Apply this diff:

             if self._backend == "fa2":
                 args.append(-1)  # fixed_split_size
                 args.append(False)  # disable_split_kv
+                args.append(0)  # num_colocated_ctas
             self._plan_info = self._cached_module.plan(
                 *args,
             )
flashinfer/decode.py (1)

1045-1064: Fix stale fast_decode_plan comment and add missing num_colocated_ctas to sparse FA2.

decode.py line 1044 is correct. However, line 2748 comment incorrectly states "exactly 16 arguments for tensor core version" when the code has 19 arguments (including num_colocated_ctas). Additionally, sparse.py line 476 FA2 is missing the num_colocated_ctas argument that all other FA2 tensor-core paths include.

  • decode.py line 2748: Update comment from "exactly 16 arguments" to reflect 19 arguments
  • sparse.py line 476: Add args.append(0) # num_colocated_ctas after the disable_split_kv append to match pod/prefill FA2 consistency
include/flashinfer/attention/scheduler.cuh (1)

718-723: Clamp grid budget after subtracting colocated CTAs

Subtracting num_colocated_ctas directly from num_blocks_per_sm * num_sm can drive max_grid_size negative whenever decode already consumes at least the whole device. A negative int silently wraps when later assigned to the uint32_t-typed variables (e.g. max_batch_size_if_split), turning the grid budget into a huge positive number, so the planner ends up believing it has practically infinite CTAs. That causes catastrophic overscheduling and broken chunking in precisely the situations this knob is meant to handle.

Please keep the arithmetic in 64-bit and clamp at zero before reusing the value. One way:

-  int max_grid_size = num_blocks_per_sm * num_sm - num_colocated_ctas;
+  int64_t available_ctas =
+      static_cast<int64_t>(num_blocks_per_sm) * num_sm - num_colocated_ctas;
+  int max_grid_size = static_cast<int>(std::max<int64_t>(0, available_ctas));

This keeps the grid budget sane even when decode takes the entire device.

♻️ Duplicate comments (1)
benchmarks/bench_mixed_attention.py (1)

113-115: Use true sequence lengths when deriving last page lengths

last_page_len_d/p are still derived from *_seq_lens_blocks. That happens to work only while page_block_size == 1, but once we run with bigger pages the value is wrong again (e.g. a 17-token sequence at page size 16 now reports a last-page length of 2). This issue was already flagged in an earlier review, and the fix needs to operate on the underlying kv lengths:

-    last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1
-    last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1
+    last_page_len_d = (torch.tensor(d_kv_lens, device=device, dtype=torch.int32) - 1) % page_block_size + 1
+    last_page_len_p = (torch.tensor(p_kv_lens, device=device, dtype=torch.int32) - 1) % page_block_size + 1

That keeps the benchmark (and the kernel inputs it prepares) correct for arbitrary page sizes.

🧹 Nitpick comments (3)
flashinfer/prefill.py (1)

1903-1906: FA2 plan: num_colocated_ctas wiring looks correct.

Passing a trailing 0 keeps behavior unchanged and aligns with the extended binding. Optionally expose this as a tunable knob later.

flashinfer/sparse.py (1)

990-992: Avoid unconditional device-wide synchronize in plan.

This serialize-all can hurt perf. Prefer stream semantics or remove if not strictly required for correctness.

flashinfer/decode.py (1)

2748-2769: Update stale comment: argument count changed.

Comment claims “exactly 16 arguments” but an extra num_colocated_ctas is now passed.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f275715 and 77d20fd.

📒 Files selected for processing (9)
  • benchmarks/bench_mixed_attention.py (7 hunks)
  • csrc/batch_prefill.cu (2 hunks)
  • csrc/batch_prefill_jit_binding.cu (1 hunks)
  • flashinfer/decode.py (3 hunks)
  • flashinfer/pod.py (4 hunks)
  • flashinfer/prefill.py (2 hunks)
  • flashinfer/sparse.py (1 hunks)
  • include/flashinfer/attention/batch_pod.cuh (1 hunks)
  • include/flashinfer/attention/scheduler.cuh (2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.563Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.563Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/attention/batch_pod.cuh
🧬 Code graph analysis (2)
flashinfer/pod.py (5)
flashinfer/jit/attention/modules.py (1)
  • gen_batch_pod_module (633-695)
csrc/batch_pod_jit_binding.cu (1)
  • batch_pod_with_kv_cache_tensor (22-41)
csrc/batch_pod.cu (2)
  • batch_pod_with_kv_cache_tensor (41-350)
  • batch_pod_with_kv_cache_tensor (41-60)
flashinfer/utils.py (9)
  • _check_kv_layout (152-154)
  • canonicalize_torch_dtype (240-248)
  • PosEncodingMode (31-34)
  • device_support_pdl (569-573)
  • _unpack_paged_kv_cache (168-188)
  • _check_cached_qkv_data_type (258-268)
  • MaskMode (37-41)
  • TensorLayout (44-46)
  • _get_cache_alibi_slopes_buf (229-237)
flashinfer/page.py (1)
  • get_seq_lens (212-235)
benchmarks/bench_mixed_attention.py (2)
flashinfer/pod.py (5)
  • BatchPODWithPagedKVCacheWrapper (622-1155)
  • plan (262-431)
  • plan (755-969)
  • run (435-615)
  • run (973-1151)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (972-1033)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
csrc/batch_prefill_jit_binding.cu (1)

22-29: Binding signature update looks good.

New num_colocated_ctas is threaded; aligns with Python call sites.

Compile CI should validate ABI. If failures arise, search for any remaining plan call sites missing the new arg.

@yzh119
Copy link
Collaborator

yzh119 commented Nov 13, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !134 has been created, and the CI pipeline #38392237 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
flashinfer/pod.py (1)

751-755: Consider documenting the workspace buffer splitting strategy.

The workspace buffer is split 50/50 between prefill and decode. While this is a reasonable default, it may not be optimal for all workload distributions (e.g., workloads with very small prefill and large decode, or vice versa).

Consider adding a comment explaining this choice, or noting it in the docstring as a potential future enhancement.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6e4fedb and 970c480.

📒 Files selected for processing (1)
  • flashinfer/pod.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/pod.py (8)
flashinfer/jit/attention/modules.py (1)
  • gen_batch_pod_module (633-695)
flashinfer/jit/core.py (1)
  • build_and_load (300-312)
csrc/batch_pod.cu (2)
  • batch_pod_with_kv_cache_tensor (41-350)
  • batch_pod_with_kv_cache_tensor (41-60)
csrc/batch_pod_jit_binding.cu (1)
  • batch_pod_with_kv_cache_tensor (22-41)
flashinfer/utils.py (9)
  • _check_kv_layout (152-154)
  • canonicalize_torch_dtype (240-248)
  • PosEncodingMode (31-34)
  • device_support_pdl (569-573)
  • _unpack_paged_kv_cache (168-188)
  • _check_cached_qkv_data_type (258-268)
  • MaskMode (37-41)
  • TensorLayout (44-46)
  • _get_cache_alibi_slopes_buf (229-237)
flashinfer/prefill.py (9)
  • plan (1523-1919)
  • plan (2496-2785)
  • get_batch_prefill_module (365-724)
  • run (1950-1962)
  • run (1965-1977)
  • run (1979-2213)
  • run (2815-2825)
  • run (2828-2838)
  • run (2840-2986)
flashinfer/page.py (1)
  • get_seq_lens (212-235)
flashinfer/quantization.py (1)
  • packbits (45-76)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (6)
flashinfer/pod.py (6)

24-24: LGTM!

The import addition for gen_batch_pod_module follows the established pattern and is correctly placed alongside the existing gen_pod_module import.


50-53: LGTM!

The get_batch_pod_module function correctly mirrors the pattern of get_pod_module and appropriately caches the expensive module build operation.


422-422: LGTM!

Adding num_colocated_ctas=0 parameter maintains backward compatibility while enabling the new batch processing optimization in the parallel code path.


1057-1059: LGTM!

The prefill softmax scale correctly omits quantization factors, which is the proper behavior for mixed-precision POD where decode uses quantization but prefill remains in FP16/BF16.


1136-1185: LGTM!

The run_tensor call correctly threads all prefill and decode parameters, including the packed_custom_mask_p which is properly passed instead of None.


1187-1190: LGTM!

The return logic correctly handles both LSE and non-LSE cases, and the v_scale is appropriately applied only to the decode output.

... [torch.tensor([0]), torch.cumsum(p_seq_lens_blocks, 0)], dim=0
... ).int()
>>> kv_indices_p = torch.arange(0, p_kv_indptr[-1], device=device, dtype=torch.int32)
>>> last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix the last page length calculation formula in the example.

The formula (p_seq_lens_blocks - 1) % page_block_size + 1 is incorrect. While it happens to work for page_block_size = 1 (used in this example), it fails for other page sizes.

For example, with kv_lens = [100] and page_block_size = 16:

  • p_seq_lens_blocks = [7] (ceil(100/16))
  • Current formula: (7 - 1) % 16 + 1 = 7
  • Correct value: 100 - 6*16 = 4

Apply this diff to fix the example:

-    >>> last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1
+    >>> last_page_len_p = torch.tensor(
+    ...     [(kv_len - 1) % page_block_size + 1 for kv_len in p_kv_lens],
+    ...     dtype=torch.int32
+    ... )

The same issue exists on line 673 for the decode path.

🤖 Prompt for AI Agents
In flashinfer/pod.py around line 661 (and likewise update line 673 for the
decode path), the example uses an incorrect formula for last_page_len_p: replace
the modulo-based expression with a calculation that subtracts the full-block
bytes consumed by prior pages from the total sequence length (i.e.,
last_page_len = total_len - (num_blocks - 1) * page_block_size). Update both
occurrences so the last page length is computed as total sequence length minus
(blocks-1)*page_block_size instead of using (blocks - 1) % page_block_size + 1.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
flashinfer/pod.py (2)

455-524: Decode-side run() options are still silently ignored in PODWithPagedKVCacheWrapper

PODWithPagedKVCacheWrapper.run() accepts several decode options (pos_encoding_mode_d, sm_scale_d, window_left_d, rope_scale_d, rope_theta_d etc.), but they are immediately overwritten from the cached plan state (self._pos_encoding_mode, self._window_left, self._logits_soft_cap, self._sm_scale, self._rope_scale, self._rope_theta at lines 519‑524). This makes the decode arguments effectively dead and can confuse callers who expect per‑run overrides.

Consider either:

  • honoring the decode arguments (falling back to the cached values when None), or
  • removing these decode options from the run() signature and documenting that decode configuration is plan‑time only.

646-673: Fix last_page_len computation in BatchPOD example

In the example, last_page_len_p / last_page_len_d are computed from the number of blocks:

last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1
...
last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1

This is incorrect for page_block_size > 1 (it ignores the true sequence length). It can mislead users copying the example.

Suggest computing last-page lengths directly from p_kv_lens / d_kv_lens:

-    >>> last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1
+    >>> last_page_len_p = torch.tensor(
+    ...     [(kv_len - 1) % page_block_size + 1 for kv_len in p_kv_lens],
+    ...     dtype=torch.int32,
+    ... )

...
-    >>> last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1
+    >>> last_page_len_d = torch.tensor(
+    ...     [(kv_len - 1) % page_block_size + 1 for kv_len in d_kv_lens],
+    ...     dtype=torch.int32,
+    ... )
🧹 Nitpick comments (3)
flashinfer/pod.py (3)

751-781: Clarify workspace split and use explicit int32 for SM scheduling buffer

The workspace split and SM scheduling buffer wiring look correct overall, but two small improvements would make the behavior clearer and more robust:

  • Explicitly require or assert that float_workspace_buffer is 1‑D and “large enough” before torch.chunk(float_workspace_buffer, 2, dim=0), so mis‑shapen buffers fail fast instead of being silently split along an unexpected dimension.
  • Use torch.int32 instead of torch.int for _sm_aware_sched to match the C++ int* expectation and avoid potential platform‑dependent surprises:
-        self._sm_aware_sched = torch.empty(
-            (dev_prop.multi_processor_count + 2), dtype=torch.int, device=self.device
-        )
+        self._sm_aware_sched = torch.empty(
+            (dev_prop.multi_processor_count + 2),
+            dtype=torch.int32,
+            device=self.device,
+        )

889-1001: num_colocated_ctas heuristic and magic threshold

The new planner logic:

  • derives num_colocated_ctas from the decode plan (self._plan_info_d[0]), and
  • disables prefill colocation when total_num_rows_p > 1536.

This matches the intent to avoid oversplitting large prefill workloads and reuses the decode planner’s choice where appropriate. To make this more maintainable:

  • Extract 1536 into a named constant with a brief comment (e.g. “empirically chosen token threshold beyond which prefill becomes memory‑bound when colocated”).
  • Optionally log or expose num_colocated_ctas in debug builds to help tune this heuristic with future benchmarks.

1012-1190: BatchPOD run wiring matches C++ binding; decode supports quantization but not custom masks

The BatchPODWithPagedKVCacheWrapper.run() path looks sound:

  • Prefill/decode q/kv dtypes are checked against the plan cache via _check_cached_qkv_data_type.
  • JIT specialization via get_batch_pod_module matches gen_batch_pod_module’s signature (shared q/kv/o dtypes, head_dim, pos‑enc flags, idx dtype, decode flags).
  • The run_tensor call matches the C++ batch_pod_with_kv_cache_tensor signature, including separate prefill/decode workspace buffers, plan infos, indptr/indices/last_page_len, optional LSE, and SM‑aware scheduling buffer.

Behavioral notes (likely intentional but worth documenting):

  • Decode always uses MaskMode.NON_CAUSAL and does not accept a decode custom_mask; callers needing decode‑side custom masks can’t express them through this wrapper today.
  • Quantization scales (q_scale, k_scale, v_scale) are applied only on the decode path (via sm_scale_d / v_scale), leaving prefill in FP16/BF16 space, which aligns with the single‑request POD wrapper’s design.

Overall the batch path looks correct and consistent with the underlying kernel.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 970c480 and 64721eb.

📒 Files selected for processing (1)
  • flashinfer/pod.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/pod.py (7)
flashinfer/jit/attention/modules.py (1)
  • gen_batch_pod_module (633-695)
flashinfer/jit/core.py (1)
  • build_and_load (300-312)
csrc/batch_pod.cu (2)
  • batch_pod_with_kv_cache_tensor (41-350)
  • batch_pod_with_kv_cache_tensor (41-60)
csrc/batch_pod_jit_binding.cu (1)
  • batch_pod_with_kv_cache_tensor (22-41)
flashinfer/utils.py (8)
  • canonicalize_torch_dtype (240-248)
  • PosEncodingMode (31-34)
  • device_support_pdl (569-573)
  • _unpack_paged_kv_cache (168-188)
  • _check_cached_qkv_data_type (258-268)
  • MaskMode (37-41)
  • TensorLayout (44-46)
  • _get_cache_alibi_slopes_buf (229-237)
flashinfer/page.py (1)
  • get_seq_lens (212-235)
flashinfer/quantization.py (1)
  • packbits (45-76)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/pod.py (2)

24-53: Cached batch POD JIT helper looks consistent with gen_batch_pod_module

get_batch_pod_module mirrors get_pod_module and matches the gen_batch_pod_module signature (q/kv/o dtypes, head_dim, pos-enc + sliding window flags, idx dtype, decode flags). The SimpleNamespace(run_tensor=module.batch_pod_with_kv_cache_tensor) shape aligns with the C++ binding and keeps the public surface minimal.


422-423: Explicitly disabling CTA colocation keeps legacy POD semantics stable

Passing num_colocated_ctas=0 into the existing PODWithPagedKVCacheWrapper.plan() path preserves previous behavior while accommodating the new planner signature. This is a reasonable conservative default for the non‑batched wrapper; tuning CTA colocation there can be done in a follow‑up if desired.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #38392237: 13/17 passed

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119 yzh119 merged commit 636a3ab into flashinfer-ai:main Nov 14, 2025
4 checks passed
qsang-nv pushed a commit to qsang-nv/flashinfer that referenced this pull request Nov 18, 2025
<!-- .github/pull_request_template.md -->

Co-authored-by: @Edenzzzz 

## 📌 Description
Fixes flashinfer-ai#1022. Unlike
flashinfer-ai#1231, this splits the
inputs into separate prefill and decode inputs. It probably should be
possible to automatically handle this splitting in Python so you can
simply just provide a single batch of requests?

To run the benchmark for this run: `python
benchmarks/bench_mixed_attention.py`

Performance:
===== Benchmark 1: (kv_len, qo_len) set =====
Prefill = 2 requests, 2048 Q len, 2048 KV len
Decode = 128 requests, 2048 KV len
Elapsed time (Batched Prefill): 0.65 ms
Elapsed time (Batched POD Attention): 0.46 ms
Elapsed time (Persistent BatchAttention): 0.56 ms
**Batch POD speedup over Persistent BatchAttention: 1.22x**

===== Benchmark 2: (kv_len, qo_len) set =====
Prefill = 1 request, 2048 Q len, 2048 KV len
Decode = 128 requests, 2048 KV len
Elapsed time (Batched Prefill): 0.55 ms
Elapsed time (Batched POD Attention): 0.41 ms
Elapsed time (POD Attention): 0.41 ms
Elapsed time (Sequential two kernels): 0.51 ms
Elapsed time (Persistent BatchAttention): 0.45 ms
**Batch POD speedup over Persistent BatchAttention: 1.11x**

===== Benchmark 3: (kv_len, qo_len) set =====
Prefill = 1 request, 4096 Q len, 4096 KV len
Decode = 128 requests, 4096 KV len
Elapsed time (Batched Prefill): 1.27 ms
Elapsed time (Batched POD Attention): 0.86 ms
Elapsed time (POD Attention): 0.82 ms
Elapsed time (Sequential two kernels): 1.15 ms
Elapsed time (Persistent BatchAttention): 1.08 ms
Batch POD speedup over Persistent BatchAttention: 1.26x

===== Benchmark 4: (kv_len, qo_len) set =====
Prefill = 1 request, 4096 Q len, 4096 KV len
Decode = 128 requests, 8192 KV len
Elapsed time (Batched Prefill): 2.15 ms
Elapsed time (Batched POD Attention): 1.52 ms
Elapsed time (POD Attention): 1.54 ms
Elapsed time (Sequential two kernels): 1.82 ms
Elapsed time (Persistent BatchAttention): 1.76 ms
**Batch POD speedup over Persistent BatchAttention: 1.16x**

===== Benchmark 5: (kv_len, qo_len) set =====
Prefill = 1 request, 6000 Q len, 7000 KV len
Decode = 128 requests, 8192 KV len
Elapsed time (Batched Prefill): 2.86 ms
Elapsed time (Batched POD Attention): 2.03 ms
Elapsed time (POD Attention): 1.95 ms
Elapsed time (Sequential two kernels): 2.52 ms
Elapsed time (Persistent BatchAttention): 2.45 ms
**Batch POD speedup over Persistent BatchAttention: 1.20x**


## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Added a batched prefill+decode attention path with a public
batch-oriented POD wrapper and JIT module export.

* **Performance**
* Benchmarks extended to include batched-path timings, memory bandwidth,
elapsed-time and comparative speedup metrics across expanded
prefill/decode scenarios.

* **API**
* Runtime binding for batched KV‑cache execution added; planning APIs
now accept an optional colocated-CTA parameter that influences
scheduling.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Aditya K Kamath <[email protected]>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Low performance of POD Attention compared to BatchPrefillWithPagedKVCache

4 participants