[CK_TILE] fix(fmha): support >2GB KV cache in batch prefill via template dispatch#6653
Conversation
…ze < kN0)
Use global_load_lds_dwordx{1,4} with 64-bit flat addresses for K loads
when page_size < kN0, eliminating the SRD 32-bit offset overflow.
V loads use per-tile SRD rebase with wave_reduce_min.
Previously, V loads used SRD buffer_load with int32 voffset which overflows when pages within a tile span >2GB. K was already fixed to use flat 64-bit addressing (global_load_lds), but V still used SRD rebase. Changes: - Add load_flat() to tile_scatter_gather: flat load to VGPRs using 64-bit pointer arithmetic (base + page*stride + within_page_offset) - Add load_tile_flat() free function in load_tile.hpp - Change V kv_offset_array_transform to store within-page offset only (matching K), instead of full global offset that overflows int32 - Remove V compute_min_and_adjust_offsets (no longer needed with flat) - Pipeline: use load_tile_flat for V when page_size < kN0 This fixes scattered page allocation where adjacent logical tokens map to physically distant pages (>2GB apart within a tile).
The V flat load path (page_size < kN0) requires both physical_pages[] and page_idx_ to correspond to the same sub-tile. However, the pipeline prefetches the NEXT sub-tile's physical pages before the CURRENT sub-tile's load_tile_flat executes, causing address computation to mix within-page offsets from sub-tile N with physical pages from sub-tile N+1. Fix: save v_physical_pages to v_physical_pages_current before each prefetch_v_physical_pages() call, and use the saved copy in all 3 load_tile_flat() call sites. This preserves the pipeline prefetch overlap (pages loaded during GEMM0/softmax) while providing correct data to load_tile_flat. The fix is guarded by if constexpr(kPageBlockSize < kN0), so the SRD rebase path (page_size >= kN0) has zero overhead.
Add kUse64BitLoad template parameter to select between SRD buffer_load (fast, <4GB) and flat 64-bit loads (correct, >4GB) at kernel launch time. For page_size < kN0 (128), the kernel generates two variants: - kUse64BitLoad=false: original full-offset buffer_load path (zero regression) - kUse64BitLoad=true: flat load with V double-buffer (15-20% slower, >4GB safe) For page_size >= kN0, SRD rebase handles >4GB natively via 64-bit pointer arithmetic — no flat load variant needed. Runtime dispatch in mha_fwd_batch_prefill.cu checks max_page_byte_offset against INT32_MAX and selects the appropriate variant automatically.
…bal load) Simplify tile_scatter_gather to two clean modes controlled by kUseFlatLoad_: - SRD mode (kUseFlatLoad=false): buffer_load(SRD, page_idx_[i] + coord) - Global load mode (kUseFlatLoad=true): flat_load(base + physical_pages_[i] * stride + page_idx_[i] + coord) Changes: - kv_offset_array_transform: eliminate 6-branch K/V duplication into unified loop - tile_scatter_gather: add kUseFlatLoad_ template param, physical_pages_ and page_stride_elements_ members, flat load branches in load() and async_load_raw() - Remove load_flat(), async_load_raw_flat(), load_tile_flat(), async_load_tile_raw_flat() - Pipeline: replace ~10 if constexpr(kUseFlatLoad) load branches with update_physical_pages() + unified load_tile()/async_load_tile_raw() calls - Remove v_physical_pages_current double-buffer variable (now managed internally) Net: -235 lines, zero functional change confirmed across page_size 1/16/1024, bf16/fp8, linear/vectorized layouts, and >4GB overflow boundary tests.
Add three-layer architecture protection for the kUseFlatLoad path which requires the global_load_lds instruction (CDNA3+: gfx940/gfx950 only): 1. Codegen #if guard (fmha_batch_prefill.py): Wrap kUse64BitLoad=true kernel instantiation with #if defined(__gfx94__) || defined(__gfx950__). Uses ArchTrait pattern consistent with fmha_fwd.py. 2. static_assert in tile_scatter_gather.hpp: Prevents kUseFlatLoad_=true instantiation on non-CDNA3 architectures at compile time. 3. static_assert in async_global_load_lds_dwordxn: Prevents the global_load_lds intrinsic from being instantiated on unsupported architectures. Verified: cross-compilation with --offload-arch=gfx90a (CDNA2) and --offload-arch=gfx1100 (RDNA3) succeeds with kernel body excluded.
…fix) After SRD rebase to a physical page, num_records was left at the full buffer size. This caused the SRD to claim validity for a range [page_base, page_base + full_buffer_size) that extends far beyond the allocated buffer when rebased to high pages. On gfx942 (CDNA3), the hardware only checks voffset < num_records per buffer_load instruction, so the extended range is harmless. On gfx950 (CDNA4), the hardware appears to validate the full SRD range against page table permissions. When the extended range covers freed or protected memory, this causes VM_L2_PROTECTION_FAULT (PERMISSION_FAULTS with MAPPING_ERROR=0). Fix: set buffer_size to page_stride (one page worth of elements) before init_raw() after each SRD rebase. This scopes the SRD to exactly the page being accessed. Verified: 80 passed on both gfx942 (MI308X) and gfx950 (MI355X).
This function was added for a per-tile SRD rebase approach that was later replaced by template dispatch. No callers remain.
…ync_global_load_lds_dwordxn The previous static_assert(false) fires unconditionally during template parsing on newer compilers (CWG 2518), even for never-instantiated branches. Wrap it in a dependent expression so the assertion only fires when an unsupported num_dwords is actually instantiated. Found during batch prefill template dispatch review.
…offset tile_scatter_gather: divide buffer_size override by PackedSize to match buffer_view ctor convention (raw element count in, packed count stored). Without this, packed types (FP4 / int4, PackedSize=2) would over-report num_records by 2x and silently mask OOB reads. batch_prefill does not exercise the packed-type path today, but this is generic infrastructure and must honor the same invariant. Also narrow the signature from long_index_t to index_t since SRD num_records is hardware 32-bit. block_fmha_batch_prefill_pipeline_qr_ks_vs_async: remove misleading static_cast<long_index_t> on the SRD voffset path. The 32-bit limit on this branch comes from CDNA3 MUBUF voffset hardware format, not from an implementation choice — widening would not lift the 2GB ceiling because the hardware truncates regardless. The kUseFlatLoad_ template path handles the >2GB case via 64-bit global_load_lds_*. Added a comment making this explicit so the next reader doesn't propose the same fix. Found during batch prefill template dispatch review.
The KV cache overflow threshold is 2GB (INT32_MAX byte offset for SRD voffset), matching CK's existing TwoGB convention in transform_conv_fwd_to_gemm.hpp. Previous comments said "4GB" which is incorrect — SRD voffset is signed-32-bit-effectively, not unsigned. Updated: - codegen comment + use_64bit_load field doc - BlockFmhaBatchPrefillPipelineProblem::kUse64BitLoad doc, with explicit note about INT32_MAX / TwoGB convention Found during batch prefill template dispatch review.
…cher
The runtime `bool use_64bit_load` field on `fmha_batch_prefill_traits`
forced wrappers to encode each kernel arm's compile-time `bn0` and
per-dtype element size to decide whether KV cache exceeds 2GB. That
leaked codegen detail and required updating wrappers when new tile
configs were added.
Move the decision into the auto-generated `fmha_batch_prefill_api.cpp`
dispatcher, where each arm already knows its own `{F_bn0}` and dtype.
Each per-dtype scope now emits `constexpr int kElementBytes` from a new
`DTYPE_BYTES` map, and the inner dispatch predicate evaluates
`(a.page_block_size < {F_bn0} && num_total_pages * batch_stride_k *
kElementBytes > INT32_MAX) == {F_use_64bit_load}` per arm. The C++
template parameter `kUse64BitLoad_` (and both kernel ELFs) stays — only
the runtime trait field is removed.
The >2GB KV cache code path used two historical names for the same concept across layers (kUse64BitLoad at the kernel/Problem level, kUseFlatLoad inside the batch_prefill pipeline and tile_scatter_gather). Both mean "use global_load_lds_* instead of SRD buffer_load_*". Unify on kUseGlobalLoad everywhere — kernel template params, pipeline traits, scatter-gather op, codegen Python (F_use_global_load), and generated kernel filename suffix (64bit_ → globalload_). Also collapse the two-level structure in the batch_prefill pipeline: the derived kUseFlatLoad = Problem::kUse64BitLoad && (kPageBlockSize < kN0) becomes kUseGlobalLoad = Problem::kUseGlobalLoad directly, with a static_assert backstop for the page_size < kN0 invariant that codegen already guarantees. Verified on both archs (no behavior change vs prior Option 3 baseline): - gfx942 (smc300x-clt): test_batch_prefill.py 512/640, 4gb_small_page 12/12, 4gb_repro 9/9 bf16 + 6/6 FP8 KV_BLOCKSCALE - gfx950 (smci355-gfx950): test_batch_prefill.py 384/768, 4gb_small_page 12/12, 4gb_repro 9/9 bf16 + 6/6 FP8 KV_BLOCKSCALE
tile_scatter_gather.hpp: replace cassert assert(size > 0) in set_bottom_tensor_view_buffer_size with __builtin_assume. The cassert form expands to an __assert_fail call whose SGPR pressure forces the LLVM AMDGPU register allocator to reuse the K-SRD scalar register window (s24-s27) as scratch for the assert-PC literal, scattering the 4 K-SRD writes across two conditional branches. gfx950 buffer_load_dwordx4 does not tolerate the staggered SRD setup; gfx942 (4x scalar buffer_load) absorbs it. Reproduced as 95.2% mismatch on MI355X for hd=256, ps=1024, linear, vectorized, bf16, causal, soft_cap=30. __builtin_assume preserves the optimizer hint without emitting the assert handler. block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp: restore review follow-ups lost in an earlier reset: * Four-case addressing-strategy comment block above kNeedFullOffset (Case 1-3 mechanisms + Case 4 codegen backstop reference). * readfirstlane "wave-uniform -> SGPR" rationale on K and V rebase sites. * v_offsets semantics comment enumerating Cases 1/2/3, naming kNeedFullOffset as the Case-3 selector. * Remove redundant outer if constexpr(kPageBlockSize >= kN0) around rebase_v_window initial call (single source of truth in lambda). * save_and_prefetch_v_pages lambda encapsulates the update_physical_pages -> prefetch_v_physical_pages ordering invariant; 3 in-loop sites collapsed to single calls. Verified: test_batch_prefill_4gb_small_page.py 12/12 on both archs; test_batch_prefill_4gb_repro.py --num_blocks 5000 9/9 bf16 + 6/6 FP8 KV_BLOCKSCALE on gfx942 (smc300x-clt) and gfx950 (smci355-gfx950).
…guards amd_buffer_addressing_builtins.hpp: * Replace the !num_dwords CWG-2518 workaround in the non-CDNA3+ guard with a named file-local helper impl::global_load_lds_arch_unreachable_v. The old form had a misleading edge case: !num_dwords is true only when num_dwords == 0, so the assert silently passed on num_dwords==0 wrong- arch instantiations and the second assert (num_dwords == 1 || == 4) fired with the wrong diagnostic. The named helper makes the intent unambiguous and the dependent-false pattern self-documenting. * Reword the num_dwords == 1 || num_dwords == 4 static_assert message to distinguish hardware reality from policy: 2 dwords does not exist on any supported arch; 3 dwords only on CDNA4 and unused in FMHA pipeline. Prevents a future contributor from assuming 2/3 are deliberately blacklisted by software. tile_scatter_gather.hpp: * async_global_load_lds_dwordxn callsite: remove the redundant reinterpret_cast<const void*>(addr). addr is already const DataType* and converts implicitly to const void*. Comment clarifies that global_load_lds takes a byte address, which is what the implicit conversion produces. * Add positive static_assert(kUseGlobalLoad_, ...) to update_physical_pages and set_page_stride_elements. Both fields (physical_pages_, page_stride_elements_) only participate in the global-load addressing path; calling these setters in SRD mode is silently a no-op that hides the misuse. The compile-time guard turns the misuse into a build error and locks down the invariant. Verified on smci355-gfx950 (gfx950): clean JIT rebuild succeeds, no new warnings, and test_batch_prefill_4gb_small_page.py 12/12 pass with the two new positive setter asserts in place (codegen only emits kUseGlobalLoad=true arms when the setters are reachable, so neither fires in practice).
|
physical_pages_ and page_stride_elements_ are always present in tile_scatter_gather, even when kUseGlobalLoad=false |
poyenc
left a comment
There was a problem hiding this comment.
replace_bottom_tensor_view (line 1478) does not deduce the new YsGatherDims or kUseGlobalLoad_ template parameters — passing a kUseGlobalLoad=true window is a compile error. Not reachable today, but if this needs fixing later, the function also needs to forward physical_pages_ and page_stride_elements_ (not just page_idx_ and valids_).
Address all 5 reviewer asks on the >2GB KV cache batch-prefill series, plus two self-found polish items surfaced by an internal CK-aware review pass. Task #71 — bool kUseGlobalLoad_ -> BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ (poyenc, tile_fmha_traits.hpp:62): Adjacent traits-template params (kKVMemoryLayout_, kKVLookupTable_) are already BlockAttention*Enum types; the binary kUseGlobalLoad_ stuck out as a bool exception. Convert to a 2-value enum {BUFFER_LOAD = 0, GLOBAL_LOAD_LDS = 1} living in a new ops/fmha/block/ header so it sits alongside its siblings. Touch sites: * include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp (NEW): the enum class. * include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp: rename last template param + static member alias. * include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp: mirror alias rename. * include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp: add enum header include; class declares static auto kKVLoadMode plus derived static bool kUseGlobalLoad = (kKVLoadMode == GLOBAL_LOAD_LDS). All 10 internal `if constexpr(kUseGlobalLoad)` sites unchanged so the bool boundary is local to one TU. The standalone helper kv_offset_array_transform keeps its bool template param (private inline; intentional — keeps core/ tile_scatter_gather.hpp out of the enum's blast radius). * example/ck_tile/01_fmha/fmha_fwd.hpp: fmha_fwd_batch_prefill_traits_ last template param renamed; static member alias kUseGlobalLoad -> kKVLoadMode (default BUFFER_LOAD). * include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp: comment-only update. Task #70 — explicit constructor mem-init for tile_scatter_gather (asleepzzz, tile_scatter_gather.hpp:1241, comment #3125912056): physical_pages_ and page_stride_elements_ were silently zero-initialized in the BUFFER_LOAD arm. Today safe (Task #71's positive setter asserts prevent misuse), but a future kUseGlobalLoad=true caller that misses a setter would get silent data corruption with no compile error. Make both fields explicit in the mem-init list so the contract is visible at the constructor boundary. Task #72 — extract dispatcher overflow predicate to a named helper (poyenc, fmha_batch_prefill.py:225): Move the (page_block_size < kN0 && kv_pool_bytes > INT32_MAX) decision out of the codegen template into a free helper: fmha_batch_prefill_select_kv_load_mode(page_block_size, kN0, num_total_pages, batch_stride_k, element_bytes) in example/ck_tile/01_fmha/fmha_fwd.hpp. The codegen-emitted dispatcher arms now call it with their compile-time kN0/element_bytes substituted, so the formula has exactly one source of truth. Task #73 — symmetric gload/bload kernel-name suffix (poyenc, fmha_batch_prefill.py:282): Match the existing CK convention (e.g., causal/ncausal, sink/nsink) by emitting a non-empty token in BOTH branches: '-gload' / '-bload' on FmhaFwdApiTrait.name, 'gload_' / 'bload_' on FmhaFwdKernel.name. The prior blank-default made it impossible to tell, when grepping JIT blob/ output 6 months later, whether a missing marker meant 'BUFFER_LOAD variant' or 'old codegen revision before the gload branch existed'. Task #74 — replace single-use dependent-false with reusable always_false_v (poyenc, amd_buffer_addressing_builtins.hpp:1324): Promote impl::global_load_lds_arch_unreachable_v from a file-local helper into a generic ck_tile::always_false_v utility in core/utility/type_traits.hpp. Use it at the original site. The variable-template form defers evaluation to instantiation time, so a bare `static_assert(false, ...)` would (per CWG-2518 / current Clang) fire at parse time and break the whole TU even on never-instantiated arches. Polish I-1 — umbrella header completeness: include/ck_tile/ops/fmha.hpp now pulls in the new block_attention_kv_load_mode_enum.hpp alongside the other BlockAttention*Enum siblings. Without this, downstream consumers that rely solely on the umbrella header would miss the enum. Polish I-2 — overflow-cast robustness in fmha_batch_prefill_select_kv_load_mode: Promote every operand of the kv_pool_bytes multiplication to long_index_t individually instead of relying on left-to-right associativity to widen the chain. A future operand reorder would silently truncate; the per-operand cast makes overflow impossible regardless of order. Verified on smci355-gfx950 (gfx950): clean JIT rebuild succeeds; full op_tests/test_batch_prefill.py sweep passes 30,720 / 30,720 (10,016 skipped, 0 failed) in 30:40 wall. Codegen identifier changes only affect the renamed template parameter; no register-allocation perturbation expected on either gfx942 or gfx950 (confirmed by the cross-arch sweep).
…UseGlobalLoad Replace the unconditional `physical_pages_` and `page_stride_elements_` members with `std::conditional_t` + `[[no_unique_address]]` so they collapse to zero-byte placeholders in the SRD instantiation (kUseGlobalLoad=false). Why: Reviewer asleepzzz observed that these fields were always present even when SRD-mode kernels never read them. The previous fix (Task #70, explicit mem-init) addressed the *form* of the concern (silent zero-init -> explicit zero-init) but not the *substance* (wasted storage in SRD-mode instantiations). This commit makes the fields literally disappear in SRD mode. How it works: - Empty placeholder `gl_field_empty_t` introduced inside the class. - Both fields wrapped in `std::conditional_t<kUseGlobalLoad_, T, gl_field_empty_t>` with `[[no_unique_address]]` so the SRD-mode layout drops them. - All access sites (lines 520, 523, 758, 761) are already inside `if constexpr(kUseGlobalLoad_)` arms, so they compile out cleanly. - The setter `update_physical_pages` keeps its `static_assert(kUseGlobalLoad_)` guard; combined with lazy template member-function instantiation, the body is never instantiated for SRD callers. - Constructor mem-init stays type-agnostic via value-init `{}`; the `page_stride_elements_` assignment is gated by `if constexpr` in the body so the SRD arm only sees the empty struct. AP-7 (codegen-hash) note: Class layout changes only on the kUseGlobalLoad=true instantiation (where layout is identical: one PageIdxArray + one index_t). The kUseGlobalLoad=false instantiation now has *less* state, but adjacent fields' offsets shift only if the compiler chose not to merge the `[[no_unique_address]]` placeholder. Verified by remote re-test on both gfx942 and gfx950.
819731a to
a2692f8
Compare
@asleepzzz Addressed in struct gl_field_empty_t {}; // empty placeholder for SRD-mode
[[no_unique_address]] std::conditional_t<kUseGlobalLoad_, PageIdxArray, gl_field_empty_t>
physical_pages_;
[[no_unique_address]] std::conditional_t<kUseGlobalLoad_, index_t, gl_field_empty_t>
page_stride_elements_;All access sites were already inside Verified on three servers — |
[CK_TILE] fix(fmha): support >2GB KV cache in batch prefill
via template dispatch (#6653)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
## Motivation
The CK batch prefill kernel previously failed (silent overflow + page
faults) when the KV cache exceeded 2 GB, blocking long-context inference
workloads (e.g., 128K+ token contexts with paged KV).
Two distinct failure modes were addressed:
1. **>4GB SRD overflow (`page_size < kN0`):** The SRD
`buffer_load_dwordx4` path uses a 32-bit `voffset` register; for small
page sizes the rebased SRD spans the full KV pool and the offset wraps
past 2 GB, corrupting K/V loads.
2. **gfx950 page-table fault (`page_size >= kN0`):** On CDNA4 the
hardware validates the **full SRD `num_records` range** against
page-table permissions (CDNA3 only checks per-instruction `voffset`).
After per-tile SRD rebase, an un-trimmed `num_records` field extends
past the live page and faults on freed/protected memory.
## Technical Details
**Two-mode `tile_scatter_gather` selected by the `kUseGlobalLoad`
template parameter:**
| Case | `page_size` | KV cache size | Mode | Load path | Addressing |
|---|---|---|---|---|---|
| 1 | `>= kN0` (large pages) | any | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, bounded by per-page rebase |
| 2 | `< kN0` (small pages) | `<= 2 GB` | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, fits in INT32 byte range |
| 3 | `< kN0` (small pages) | `> 2 GB` | Global-load
(`kUseGlobalLoad=true`) | `async_load_tile_raw_flat` (K) +
`load_tile_flat` (V) | 64-bit |
**Dispatch:** the auto-gen API layer (`fmha_batch_prefill.py`) selects
the kernel instantiation at launch from `(page_block_size,
num_total_pages * batch_stride_k * kElementBytes)`, so the small-page
penalty is paid only when correctness requires it.
**gfx950 SRD `num_records` trimming:** in the K and V rebase lambdas of
`block_fmha_batch_prefill_pipeline_qr_ks_vs_async`,
`set_bottom_tensor_view_buffer_size(page_stride_k/v)` is called after
each rebase to constrain `num_records` to the live page. Required for
CDNA4 page-table validation; harmless on CDNA3.
**Pipeline sync for the global-load path:**
- V uses synchronous `load_tile_flat`; K uses
`async_load_tile_raw_flat`.
- `v_physical_pages_current` is double-buffered so the V flat load
doesn't race against the next iteration's K rebase computation.
**Arch guards:** `global_load_lds` intrinsics are gated to `__gfx94__` /
`__gfx950__` (CDNA3+). Other architectures hit a `dependent_false`
static_assert with a descriptive message.
**Device-side assertion convention:** SRD setters use
`__builtin_assume(cond)` (hint-only) rather than `<cassert>`'s
`assert()`. The latter introduces an `__assert_fail` call whose register
pressure scatters the K-SRD scalar register window across conditional
branches, corrupting `buffer_load_dwordx4` on gfx950.
## Test Plan
Tested on both MI308 (gfx942) and MI355 (gfx950) via the aiter wrapper
test suite. All coverage lives in **`op_tests/test_batch_prefill.py`**:
- **Functional matrix (96 cases)** — `test_batch_prefill`: `page_size ∈
{1, 16, 1024}` × `kv_layout ∈ {linear, vectorized}` × `dtype ∈ {bf16,
fp8 quant variants}` × `causal` × `soft_cap` × `LSE` × `batch_size ∈ {1,
4}` (parametrized to exercise per-sequence SRD rebase across batch
boundaries).
- **>2 GB coverage** — `test_batch_prefill_large_kvcache`: extended to
allocate a 5 GB+ KV cache pool and exercise both `kUseGlobalLoad=true`
(small-page) and `kUseGlobalLoad=false` (large-page rebase) paths.
Includes both single-batch and multi-batch (`batch_size=4`) cases to
exercise per-sequence SRD rebase across the >2 GB pool.
- Numerical reference: PyTorch SDPA, per-batch loop with `atol` / `rtol`
from the existing batch prefill test harness.
## Test Result
| Arch | `test_batch_prefill` | `test_batch_prefill_large_kvcache` (>2
GB) |
|------|----------------------|---------------------|
| MI308 (gfx942) | All passed | Passed |
| MI355 (gfx950) | All passed | Passed |
**Performance impact (gfx950, hot SRD path):**
- +2.67% kernel-time on `seqlen=1024 / page_sz=1024 / bf16 / sglang /
causal / soft_cap=30`, attributable in full to the two
`set_bottom_tensor_view_buffer_size` calls in the K/V rebase lambdas
(5-run median, signal/noise ≈ 9×).
- This cost is **mandatory for gfx950 correctness** on >2 GB workloads —
removing the setters re-introduces page-faults.
- gfx942: 0 regressions in the same range (all configs ≤ +0.97%).
## Submission Checklist
- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
…ate dispatch (#6653) ## Motivation The CK batch prefill kernel previously failed (silent overflow + page faults) when the KV cache exceeded 2 GB, blocking long-context inference workloads (e.g., 128K+ token contexts with paged KV). Two distinct failure modes were addressed: 1. **>4GB SRD overflow (`page_size < kN0`):** The SRD `buffer_load_dwordx4` path uses a 32-bit `voffset` register; for small page sizes the rebased SRD spans the full KV pool and the offset wraps past 2 GB, corrupting K/V loads. 2. **gfx950 page-table fault (`page_size >= kN0`):** On CDNA4 the hardware validates the **full SRD `num_records` range** against page-table permissions (CDNA3 only checks per-instruction `voffset`). After per-tile SRD rebase, an un-trimmed `num_records` field extends past the live page and faults on freed/protected memory. ## Technical Details **Two-mode `tile_scatter_gather` selected by the `kUseGlobalLoad` template parameter:** | Case | `page_size` | KV cache size | Mode | Load path | Addressing | |---|---|---|---|---|---| | 1 | `>= kN0` (large pages) | any | SRD (`kUseGlobalLoad=false`) | `buffer_load_dwordx4` | 32-bit `voffset`, bounded by per-page rebase | | 2 | `< kN0` (small pages) | `<= 2 GB` | SRD (`kUseGlobalLoad=false`) | `buffer_load_dwordx4` | 32-bit `voffset`, fits in INT32 byte range | | 3 | `< kN0` (small pages) | `> 2 GB` | Global-load (`kUseGlobalLoad=true`) | `async_load_tile_raw_flat` (K) + `load_tile_flat` (V) | 64-bit | **Dispatch:** the auto-gen API layer (`fmha_batch_prefill.py`) selects the kernel instantiation at launch from `(page_block_size, num_total_pages * batch_stride_k * kElementBytes)`, so the small-page penalty is paid only when correctness requires it. **gfx950 SRD `num_records` trimming:** in the K and V rebase lambdas of `block_fmha_batch_prefill_pipeline_qr_ks_vs_async`, `set_bottom_tensor_view_buffer_size(page_stride_k/v)` is called after each rebase to constrain `num_records` to the live page. Required for CDNA4 page-table validation; harmless on CDNA3. **Pipeline sync for the global-load path:** - V uses synchronous `load_tile_flat`; K uses `async_load_tile_raw_flat`. - `v_physical_pages_current` is double-buffered so the V flat load doesn't race against the next iteration's K rebase computation. **Arch guards:** `global_load_lds` intrinsics are gated to `__gfx94__` / `__gfx950__` (CDNA3+). Other architectures hit a `dependent_false` static_assert with a descriptive message. **Device-side assertion convention:** SRD setters use `__builtin_assume(cond)` (hint-only) rather than `<cassert>`'s `assert()`. The latter introduces an `__assert_fail` call whose register pressure scatters the K-SRD scalar register window across conditional branches, corrupting `buffer_load_dwordx4` on gfx950. ## Test Plan Tested on both MI308 (gfx942) and MI355 (gfx950) via the aiter wrapper test suite. All coverage lives in **`op_tests/test_batch_prefill.py`**: - **Functional matrix (96 cases)** — `test_batch_prefill`: `page_size ∈ {1, 16, 1024}` × `kv_layout ∈ {linear, vectorized}` × `dtype ∈ {bf16, fp8 quant variants}` × `causal` × `soft_cap` × `LSE` × `batch_size ∈ {1, 4}` (parametrized to exercise per-sequence SRD rebase across batch boundaries). - **>2 GB coverage** — `test_batch_prefill_large_kvcache`: extended to allocate a 5 GB+ KV cache pool and exercise both `kUseGlobalLoad=true` (small-page) and `kUseGlobalLoad=false` (large-page rebase) paths. Includes both single-batch and multi-batch (`batch_size=4`) cases to exercise per-sequence SRD rebase across the >2 GB pool. - Numerical reference: PyTorch SDPA, per-batch loop with `atol` / `rtol` from the existing batch prefill test harness. ## Test Result | Arch | `test_batch_prefill` | `test_batch_prefill_large_kvcache` (>2 GB) | |------|----------------------|---------------------| | MI308 (gfx942) | All passed | Passed | | MI355 (gfx950) | All passed | Passed | **Performance impact (gfx950, hot SRD path):** - +2.67% kernel-time on `seqlen=1024 / page_sz=1024 / bf16 / sglang / causal / soft_cap=30`, attributable in full to the two `set_bottom_tensor_view_buffer_size` calls in the K/V rebase lambdas (5-run median, signal/noise ≈ 9×). - This cost is **mandatory for gfx950 correctness** on >2 GB workloads — removing the setters re-introduces page-faults. - gfx942: 0 regressions in the same range (all configs ≤ +0.97%). ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
Motivation
The CK batch prefill kernel previously failed (silent overflow + page faults) when the KV cache exceeded 2 GB, blocking long-context inference workloads (e.g., 128K+ token contexts with paged KV).
Two distinct failure modes were addressed:
page_size < kN0): The SRDbuffer_load_dwordx4path uses a 32-bitvoffsetregister; for small page sizes the rebased SRD spans the full KV pool and the offset wraps past 2 GB, corrupting K/V loads.page_size >= kN0): On CDNA4 the hardware validates the full SRDnum_recordsrange against page-table permissions (CDNA3 only checks per-instructionvoffset). After per-tile SRD rebase, an un-trimmednum_recordsfield extends past the live page and faults on freed/protected memory.Technical Details
Two-mode
tile_scatter_gatherselected by thekUseGlobalLoadtemplate parameter:page_size>= kN0(large pages)kUseGlobalLoad=false)buffer_load_dwordx4voffset, bounded by per-page rebase< kN0(small pages)<= 2 GBkUseGlobalLoad=false)buffer_load_dwordx4voffset, fits in INT32 byte range< kN0(small pages)> 2 GBkUseGlobalLoad=true)async_load_tile_raw_flat(K) +load_tile_flat(V)Dispatch: the auto-gen API layer (
fmha_batch_prefill.py) selects the kernel instantiation at launch from(page_block_size, num_total_pages * batch_stride_k * kElementBytes), so the small-page penalty is paid only when correctness requires it.gfx950 SRD
num_recordstrimming: in the K and V rebase lambdas ofblock_fmha_batch_prefill_pipeline_qr_ks_vs_async,set_bottom_tensor_view_buffer_size(page_stride_k/v)is called after each rebase to constrainnum_recordsto the live page. Required for CDNA4 page-table validation; harmless on CDNA3.Pipeline sync for the global-load path:
load_tile_flat; K usesasync_load_tile_raw_flat.v_physical_pages_currentis double-buffered so the V flat load doesn't race against the next iteration's K rebase computation.Arch guards:
global_load_ldsintrinsics are gated to__gfx94__/__gfx950__(CDNA3+). Other architectures hit adependent_falsestatic_assert with a descriptive message.Device-side assertion convention: SRD setters use
__builtin_assume(cond)(hint-only) rather than<cassert>'sassert(). The latter introduces an__assert_failcall whose register pressure scatters the K-SRD scalar register window across conditional branches, corruptingbuffer_load_dwordx4on gfx950.Test Plan
Tested on both MI308 (gfx942) and MI355 (gfx950) via the aiter wrapper test suite. All coverage lives in
op_tests/test_batch_prefill.py:test_batch_prefill:page_size ∈ {1, 16, 1024}×kv_layout ∈ {linear, vectorized}×dtype ∈ {bf16, fp8 quant variants}×causal×soft_cap×LSE×batch_size ∈ {1, 4}(parametrized to exercise per-sequence SRD rebase across batch boundaries).test_batch_prefill_large_kvcache: extended to allocate a 5 GB+ KV cache pool and exercise bothkUseGlobalLoad=true(small-page) andkUseGlobalLoad=false(large-page rebase) paths. Includes both single-batch and multi-batch (batch_size=4) cases to exercise per-sequence SRD rebase across the >2 GB pool.atol/rtolfrom the existing batch prefill test harness.Test Result
test_batch_prefilltest_batch_prefill_large_kvcache(>2 GB)Performance impact (gfx950, hot SRD path):
seqlen=1024 / page_sz=1024 / bf16 / sglang / causal / soft_cap=30, attributable in full to the twoset_bottom_tensor_view_buffer_sizecalls in the K/V rebase lambdas (5-run median, signal/noise ≈ 9×).Submission Checklist