Skip to content

[CK_TILE] fix(fmha): support >2GB KV cache in batch prefill via template dispatch#6653

Merged
Jeff-Huang merged 17 commits into
developfrom
users/jeff-huang/ck/batch-prefill-fix-4bg-overflow
Apr 23, 2026
Merged

[CK_TILE] fix(fmha): support >2GB KV cache in batch prefill via template dispatch#6653
Jeff-Huang merged 17 commits into
developfrom
users/jeff-huang/ck/batch-prefill-fix-4bg-overflow

Conversation

@Jeff-Huang
Copy link
Copy Markdown
Contributor

@Jeff-Huang Jeff-Huang commented Apr 22, 2026

Motivation

The CK batch prefill kernel previously failed (silent overflow + page faults) when the KV cache exceeded 2 GB, blocking long-context inference workloads (e.g., 128K+ token contexts with paged KV).

Two distinct failure modes were addressed:

  1. >4GB SRD overflow (page_size < kN0): The SRD buffer_load_dwordx4 path uses a 32-bit voffset register; for small page sizes the rebased SRD spans the full KV pool and the offset wraps past 2 GB, corrupting K/V loads.
  2. gfx950 page-table fault (page_size >= kN0): On CDNA4 the hardware validates the full SRD num_records range against page-table permissions (CDNA3 only checks per-instruction voffset). After per-tile SRD rebase, an un-trimmed num_records field extends past the live page and faults on freed/protected memory.

Technical Details

Two-mode tile_scatter_gather selected by the kUseGlobalLoad template parameter:

Case page_size KV cache size Mode Load path Addressing
1 >= kN0 (large pages) any SRD (kUseGlobalLoad=false) buffer_load_dwordx4 32-bit voffset, bounded by per-page rebase
2 < kN0 (small pages) <= 2 GB SRD (kUseGlobalLoad=false) buffer_load_dwordx4 32-bit voffset, fits in INT32 byte range
3 < kN0 (small pages) > 2 GB Global-load (kUseGlobalLoad=true) async_load_tile_raw_flat (K) + load_tile_flat (V) 64-bit

Dispatch: the auto-gen API layer (fmha_batch_prefill.py) selects the kernel instantiation at launch from (page_block_size, num_total_pages * batch_stride_k * kElementBytes), so the small-page penalty is paid only when correctness requires it.

gfx950 SRD num_records trimming: in the K and V rebase lambdas of block_fmha_batch_prefill_pipeline_qr_ks_vs_async, set_bottom_tensor_view_buffer_size(page_stride_k/v) is called after each rebase to constrain num_records to the live page. Required for CDNA4 page-table validation; harmless on CDNA3.

Pipeline sync for the global-load path:

  • V uses synchronous load_tile_flat; K uses async_load_tile_raw_flat.
  • v_physical_pages_current is double-buffered so the V flat load doesn't race against the next iteration's K rebase computation.

Arch guards: global_load_lds intrinsics are gated to __gfx94__ / __gfx950__ (CDNA3+). Other architectures hit a dependent_false static_assert with a descriptive message.

Device-side assertion convention: SRD setters use __builtin_assume(cond) (hint-only) rather than <cassert>'s assert(). The latter introduces an __assert_fail call whose register pressure scatters the K-SRD scalar register window across conditional branches, corrupting buffer_load_dwordx4 on gfx950.

Test Plan

Tested on both MI308 (gfx942) and MI355 (gfx950) via the aiter wrapper test suite. All coverage lives in op_tests/test_batch_prefill.py:

  • Functional matrix (96 cases)test_batch_prefill: page_size ∈ {1, 16, 1024} × kv_layout ∈ {linear, vectorized} × dtype ∈ {bf16, fp8 quant variants} × causal × soft_cap × LSE × batch_size ∈ {1, 4} (parametrized to exercise per-sequence SRD rebase across batch boundaries).
  • >2 GB coveragetest_batch_prefill_large_kvcache: extended to allocate a 5 GB+ KV cache pool and exercise both kUseGlobalLoad=true (small-page) and kUseGlobalLoad=false (large-page rebase) paths. Includes both single-batch and multi-batch (batch_size=4) cases to exercise per-sequence SRD rebase across the >2 GB pool.
  • Numerical reference: PyTorch SDPA, per-batch loop with atol / rtol from the existing batch prefill test harness.

Test Result

Arch test_batch_prefill test_batch_prefill_large_kvcache (>2 GB)
MI308 (gfx942) All passed Passed
MI355 (gfx950) All passed Passed

Performance impact (gfx950, hot SRD path):

  • +2.67% kernel-time on seqlen=1024 / page_sz=1024 / bf16 / sglang / causal / soft_cap=30, attributable in full to the two set_bottom_tensor_view_buffer_size calls in the K/V rebase lambdas (5-run median, signal/noise ≈ 9×).
  • This cost is mandatory for gfx950 correctness on >2 GB workloads — removing the setters re-introduces page-faults.
  • gfx942: 0 regressions in the same range (all configs ≤ +0.97%).

Submission Checklist

…ze < kN0)

Use global_load_lds_dwordx{1,4} with 64-bit flat addresses for K loads
when page_size < kN0, eliminating the SRD 32-bit offset overflow.
V loads use per-tile SRD rebase with wave_reduce_min.
Previously, V loads used SRD buffer_load with int32 voffset which
overflows when pages within a tile span >2GB. K was already fixed
to use flat 64-bit addressing (global_load_lds), but V still used
SRD rebase.

Changes:
- Add load_flat() to tile_scatter_gather: flat load to VGPRs using
  64-bit pointer arithmetic (base + page*stride + within_page_offset)
- Add load_tile_flat() free function in load_tile.hpp
- Change V kv_offset_array_transform to store within-page offset only
  (matching K), instead of full global offset that overflows int32
- Remove V compute_min_and_adjust_offsets (no longer needed with flat)
- Pipeline: use load_tile_flat for V when page_size < kN0

This fixes scattered page allocation where adjacent logical tokens
map to physically distant pages (>2GB apart within a tile).
The V flat load path (page_size < kN0) requires both physical_pages[]
and page_idx_ to correspond to the same sub-tile. However, the pipeline
prefetches the NEXT sub-tile's physical pages before the CURRENT
sub-tile's load_tile_flat executes, causing address computation to mix
within-page offsets from sub-tile N with physical pages from sub-tile
N+1.

Fix: save v_physical_pages to v_physical_pages_current before each
prefetch_v_physical_pages() call, and use the saved copy in all 3
load_tile_flat() call sites. This preserves the pipeline prefetch
overlap (pages loaded during GEMM0/softmax) while providing correct
data to load_tile_flat.

The fix is guarded by if constexpr(kPageBlockSize < kN0), so the SRD
rebase path (page_size >= kN0) has zero overhead.
Add kUse64BitLoad template parameter to select between SRD buffer_load
(fast, <4GB) and flat 64-bit loads (correct, >4GB) at kernel launch time.

For page_size < kN0 (128), the kernel generates two variants:
- kUse64BitLoad=false: original full-offset buffer_load path (zero regression)
- kUse64BitLoad=true: flat load with V double-buffer (15-20% slower, >4GB safe)

For page_size >= kN0, SRD rebase handles >4GB natively via 64-bit pointer
arithmetic — no flat load variant needed.

Runtime dispatch in mha_fwd_batch_prefill.cu checks max_page_byte_offset
against INT32_MAX and selects the appropriate variant automatically.
…bal load)

Simplify tile_scatter_gather to two clean modes controlled by kUseFlatLoad_:
- SRD mode (kUseFlatLoad=false): buffer_load(SRD, page_idx_[i] + coord)
- Global load mode (kUseFlatLoad=true): flat_load(base + physical_pages_[i] * stride + page_idx_[i] + coord)

Changes:
- kv_offset_array_transform: eliminate 6-branch K/V duplication into unified loop
- tile_scatter_gather: add kUseFlatLoad_ template param, physical_pages_ and
  page_stride_elements_ members, flat load branches in load() and async_load_raw()
- Remove load_flat(), async_load_raw_flat(), load_tile_flat(), async_load_tile_raw_flat()
- Pipeline: replace ~10 if constexpr(kUseFlatLoad) load branches with
  update_physical_pages() + unified load_tile()/async_load_tile_raw() calls
- Remove v_physical_pages_current double-buffer variable (now managed internally)

Net: -235 lines, zero functional change confirmed across page_size 1/16/1024,
bf16/fp8, linear/vectorized layouts, and >4GB overflow boundary tests.
Add three-layer architecture protection for the kUseFlatLoad path which
requires the global_load_lds instruction (CDNA3+: gfx940/gfx950 only):

1. Codegen #if guard (fmha_batch_prefill.py): Wrap kUse64BitLoad=true
   kernel instantiation with #if defined(__gfx94__) || defined(__gfx950__).
   Uses ArchTrait pattern consistent with fmha_fwd.py.

2. static_assert in tile_scatter_gather.hpp: Prevents kUseFlatLoad_=true
   instantiation on non-CDNA3 architectures at compile time.

3. static_assert in async_global_load_lds_dwordxn: Prevents the
   global_load_lds intrinsic from being instantiated on unsupported
   architectures.

Verified: cross-compilation with --offload-arch=gfx90a (CDNA2) and
--offload-arch=gfx1100 (RDNA3) succeeds with kernel body excluded.
…fix)

After SRD rebase to a physical page, num_records was left at the full
buffer size. This caused the SRD to claim validity for a range
[page_base, page_base + full_buffer_size) that extends far beyond the
allocated buffer when rebased to high pages.

On gfx942 (CDNA3), the hardware only checks voffset < num_records per
buffer_load instruction, so the extended range is harmless.

On gfx950 (CDNA4), the hardware appears to validate the full SRD range
against page table permissions. When the extended range covers freed or
protected memory, this causes VM_L2_PROTECTION_FAULT (PERMISSION_FAULTS
with MAPPING_ERROR=0).

Fix: set buffer_size to page_stride (one page worth of elements) before
init_raw() after each SRD rebase. This scopes the SRD to exactly the
page being accessed.

Verified: 80 passed on both gfx942 (MI308X) and gfx950 (MI355X).
This function was added for a per-tile SRD rebase approach that was
later replaced by template dispatch. No callers remain.
…ync_global_load_lds_dwordxn

The previous static_assert(false) fires unconditionally during template
parsing on newer compilers (CWG 2518), even for never-instantiated branches.
Wrap it in a dependent expression so the assertion only fires when an
unsupported num_dwords is actually instantiated.

Found during batch prefill template dispatch review.
…offset

tile_scatter_gather: divide buffer_size override by PackedSize to match
buffer_view ctor convention (raw element count in, packed count stored).
Without this, packed types (FP4 / int4, PackedSize=2) would over-report
num_records by 2x and silently mask OOB reads. batch_prefill does not
exercise the packed-type path today, but this is generic infrastructure
and must honor the same invariant. Also narrow the signature from
long_index_t to index_t since SRD num_records is hardware 32-bit.

block_fmha_batch_prefill_pipeline_qr_ks_vs_async: remove misleading
static_cast<long_index_t> on the SRD voffset path. The 32-bit limit on
this branch comes from CDNA3 MUBUF voffset hardware format, not from an
implementation choice — widening would not lift the 2GB ceiling because
the hardware truncates regardless. The kUseFlatLoad_ template path
handles the >2GB case via 64-bit global_load_lds_*. Added a comment
making this explicit so the next reader doesn't propose the same fix.

Found during batch prefill template dispatch review.
The KV cache overflow threshold is 2GB (INT32_MAX byte offset for SRD
voffset), matching CK's existing TwoGB convention in
transform_conv_fwd_to_gemm.hpp. Previous comments said "4GB" which is
incorrect — SRD voffset is signed-32-bit-effectively, not unsigned.

Updated:
- codegen comment + use_64bit_load field doc
- BlockFmhaBatchPrefillPipelineProblem::kUse64BitLoad doc, with explicit
  note about INT32_MAX / TwoGB convention

Found during batch prefill template dispatch review.
…cher

The runtime `bool use_64bit_load` field on `fmha_batch_prefill_traits`
forced wrappers to encode each kernel arm's compile-time `bn0` and
per-dtype element size to decide whether KV cache exceeds 2GB. That
leaked codegen detail and required updating wrappers when new tile
configs were added.

Move the decision into the auto-generated `fmha_batch_prefill_api.cpp`
dispatcher, where each arm already knows its own `{F_bn0}` and dtype.
Each per-dtype scope now emits `constexpr int kElementBytes` from a new
`DTYPE_BYTES` map, and the inner dispatch predicate evaluates
`(a.page_block_size < {F_bn0} && num_total_pages * batch_stride_k *
kElementBytes > INT32_MAX) == {F_use_64bit_load}` per arm. The C++
template parameter `kUse64BitLoad_` (and both kernel ELFs) stays — only
the runtime trait field is removed.
The >2GB KV cache code path used two historical names for the same
concept across layers (kUse64BitLoad at the kernel/Problem level,
kUseFlatLoad inside the batch_prefill pipeline and tile_scatter_gather).
Both mean "use global_load_lds_* instead of SRD buffer_load_*". Unify
on kUseGlobalLoad everywhere — kernel template params, pipeline
traits, scatter-gather op, codegen Python (F_use_global_load), and
generated kernel filename suffix (64bit_ → globalload_).

Also collapse the two-level structure in the batch_prefill pipeline:
the derived kUseFlatLoad = Problem::kUse64BitLoad && (kPageBlockSize
< kN0) becomes kUseGlobalLoad = Problem::kUseGlobalLoad directly,
with a static_assert backstop for the page_size < kN0 invariant that
codegen already guarantees.

Verified on both archs (no behavior change vs prior Option 3 baseline):
- gfx942 (smc300x-clt): test_batch_prefill.py 512/640, 4gb_small_page
  12/12, 4gb_repro 9/9 bf16 + 6/6 FP8 KV_BLOCKSCALE
- gfx950 (smci355-gfx950): test_batch_prefill.py 384/768, 4gb_small_page
  12/12, 4gb_repro 9/9 bf16 + 6/6 FP8 KV_BLOCKSCALE
tile_scatter_gather.hpp: replace cassert assert(size > 0) in
set_bottom_tensor_view_buffer_size with __builtin_assume. The cassert
form expands to an __assert_fail call whose SGPR pressure forces the
LLVM AMDGPU register allocator to reuse the K-SRD scalar register
window (s24-s27) as scratch for the assert-PC literal, scattering the
4 K-SRD writes across two conditional branches. gfx950 buffer_load_dwordx4
does not tolerate the staggered SRD setup; gfx942 (4x scalar buffer_load)
absorbs it. Reproduced as 95.2% mismatch on MI355X for hd=256, ps=1024,
linear, vectorized, bf16, causal, soft_cap=30. __builtin_assume preserves
the optimizer hint without emitting the assert handler.

block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp: restore review
follow-ups lost in an earlier reset:
* Four-case addressing-strategy comment block above kNeedFullOffset
  (Case 1-3 mechanisms + Case 4 codegen backstop reference).
* readfirstlane "wave-uniform -> SGPR" rationale on K and V rebase sites.
* v_offsets semantics comment enumerating Cases 1/2/3, naming
  kNeedFullOffset as the Case-3 selector.
* Remove redundant outer if constexpr(kPageBlockSize >= kN0) around
  rebase_v_window initial call (single source of truth in lambda).
* save_and_prefetch_v_pages lambda encapsulates the
  update_physical_pages -> prefetch_v_physical_pages ordering invariant;
  3 in-loop sites collapsed to single calls.

Verified: test_batch_prefill_4gb_small_page.py 12/12 on both archs;
test_batch_prefill_4gb_repro.py --num_blocks 5000 9/9 bf16 + 6/6 FP8
KV_BLOCKSCALE on gfx942 (smc300x-clt) and gfx950 (smci355-gfx950).
…guards

amd_buffer_addressing_builtins.hpp:
* Replace the !num_dwords CWG-2518 workaround in the non-CDNA3+ guard
  with a named file-local helper impl::global_load_lds_arch_unreachable_v.
  The old form had a misleading edge case: !num_dwords is true only when
  num_dwords == 0, so the assert silently passed on num_dwords==0 wrong-
  arch instantiations and the second assert (num_dwords == 1 || == 4)
  fired with the wrong diagnostic. The named helper makes the intent
  unambiguous and the dependent-false pattern self-documenting.
* Reword the num_dwords == 1 || num_dwords == 4 static_assert message
  to distinguish hardware reality from policy: 2 dwords does not exist
  on any supported arch; 3 dwords only on CDNA4 and unused in FMHA
  pipeline. Prevents a future contributor from assuming 2/3 are
  deliberately blacklisted by software.

tile_scatter_gather.hpp:
* async_global_load_lds_dwordxn callsite: remove the redundant
  reinterpret_cast<const void*>(addr). addr is already const DataType*
  and converts implicitly to const void*. Comment clarifies that
  global_load_lds takes a byte address, which is what the implicit
  conversion produces.
* Add positive static_assert(kUseGlobalLoad_, ...) to update_physical_pages
  and set_page_stride_elements. Both fields (physical_pages_,
  page_stride_elements_) only participate in the global-load addressing
  path; calling these setters in SRD mode is silently a no-op that hides
  the misuse. The compile-time guard turns the misuse into a build error
  and locks down the invariant.

Verified on smci355-gfx950 (gfx950): clean JIT rebuild succeeds, no new
warnings, and test_batch_prefill_4gb_small_page.py 12/12 pass with the
two new positive setter asserts in place (codegen only emits
kUseGlobalLoad=true arms when the setters are reachable, so neither
fires in practice).
@Jeff-Huang Jeff-Huang requested a review from a team as a code owner April 22, 2026 05:00
@Jeff-Huang Jeff-Huang changed the title [CK_TILE] fix(fmha): support >4GB KV cache in batch prefill via template dispatch [CK_TILE] fix(fmha): support >2GB KV cache in batch prefill via template dispatch Apr 22, 2026
@asleepzzz
Copy link
Copy Markdown
Contributor

physical_pages_ and page_stride_elements_ are always present in tile_scatter_gather, even when kUseGlobalLoad=false

Copy link
Copy Markdown
Contributor

@poyenc poyenc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace_bottom_tensor_view (line 1478) does not deduce the new YsGatherDims or kUseGlobalLoad_ template parameters — passing a kUseGlobalLoad=true window is a compile error. Not reachable today, but if this needs fixing later, the function also needs to forward physical_pages_ and page_stride_elements_ (not just page_idx_ and valids_).

Comment thread projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp Outdated
Comment thread projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Outdated
Address all 5 reviewer asks on the >2GB KV cache batch-prefill series, plus
two self-found polish items surfaced by an internal CK-aware review pass.

Task #71 — bool kUseGlobalLoad_ -> BlockAttentionKVCacheLoadModeEnum kKVLoadMode_
(poyenc, tile_fmha_traits.hpp:62):
  Adjacent traits-template params (kKVMemoryLayout_, kKVLookupTable_) are
  already BlockAttention*Enum types; the binary kUseGlobalLoad_ stuck out
  as a bool exception. Convert to a 2-value enum {BUFFER_LOAD = 0,
  GLOBAL_LOAD_LDS = 1} living in a new ops/fmha/block/ header so it sits
  alongside its siblings. Touch sites:
  * include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp
    (NEW): the enum class.
  * include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp: rename last
    template param + static member alias.
  * include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp:
    mirror alias rename.
  * include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:
    add enum header include; class declares static auto kKVLoadMode plus
    derived static bool kUseGlobalLoad = (kKVLoadMode == GLOBAL_LOAD_LDS).
    All 10 internal `if constexpr(kUseGlobalLoad)` sites unchanged so the
    bool boundary is local to one TU. The standalone helper
    kv_offset_array_transform keeps its bool template param (private
    inline; intentional — keeps core/ tile_scatter_gather.hpp out of the
    enum's blast radius).
  * example/ck_tile/01_fmha/fmha_fwd.hpp: fmha_fwd_batch_prefill_traits_
    last template param renamed; static member alias kUseGlobalLoad ->
    kKVLoadMode (default BUFFER_LOAD).
  * include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp:
    comment-only update.

Task #70 — explicit constructor mem-init for tile_scatter_gather
(asleepzzz, tile_scatter_gather.hpp:1241, comment #3125912056):
  physical_pages_ and page_stride_elements_ were silently zero-initialized
  in the BUFFER_LOAD arm. Today safe (Task #71's positive setter asserts
  prevent misuse), but a future kUseGlobalLoad=true caller that misses a
  setter would get silent data corruption with no compile error. Make
  both fields explicit in the mem-init list so the contract is visible
  at the constructor boundary.

Task #72 — extract dispatcher overflow predicate to a named helper
(poyenc, fmha_batch_prefill.py:225):
  Move the (page_block_size < kN0 && kv_pool_bytes > INT32_MAX) decision
  out of the codegen template into a free helper:
    fmha_batch_prefill_select_kv_load_mode(page_block_size, kN0,
                                           num_total_pages, batch_stride_k,
                                           element_bytes)
  in example/ck_tile/01_fmha/fmha_fwd.hpp. The codegen-emitted dispatcher
  arms now call it with their compile-time kN0/element_bytes substituted,
  so the formula has exactly one source of truth.

Task #73 — symmetric gload/bload kernel-name suffix
(poyenc, fmha_batch_prefill.py:282):
  Match the existing CK convention (e.g., causal/ncausal, sink/nsink) by
  emitting a non-empty token in BOTH branches: '-gload' / '-bload' on
  FmhaFwdApiTrait.name, 'gload_' / 'bload_' on FmhaFwdKernel.name. The
  prior blank-default made it impossible to tell, when grepping JIT
  blob/ output 6 months later, whether a missing marker meant
  'BUFFER_LOAD variant' or 'old codegen revision before the gload branch
  existed'.

Task #74 — replace single-use dependent-false with reusable always_false_v
(poyenc, amd_buffer_addressing_builtins.hpp:1324):
  Promote impl::global_load_lds_arch_unreachable_v from a file-local
  helper into a generic ck_tile::always_false_v utility in
  core/utility/type_traits.hpp. Use it at the original site. The
  variable-template form defers evaluation to instantiation time, so a
  bare `static_assert(false, ...)` would (per CWG-2518 / current Clang)
  fire at parse time and break the whole TU even on never-instantiated
  arches.

Polish I-1 — umbrella header completeness:
  include/ck_tile/ops/fmha.hpp now pulls in the new
  block_attention_kv_load_mode_enum.hpp alongside the other
  BlockAttention*Enum siblings. Without this, downstream consumers that
  rely solely on the umbrella header would miss the enum.

Polish I-2 — overflow-cast robustness in fmha_batch_prefill_select_kv_load_mode:
  Promote every operand of the kv_pool_bytes multiplication to
  long_index_t individually instead of relying on left-to-right
  associativity to widen the chain. A future operand reorder would
  silently truncate; the per-operand cast makes overflow impossible
  regardless of order.

Verified on smci355-gfx950 (gfx950): clean JIT rebuild succeeds; full
op_tests/test_batch_prefill.py sweep passes 30,720 / 30,720 (10,016
skipped, 0 failed) in 30:40 wall. Codegen identifier changes only affect
the renamed template parameter; no register-allocation perturbation
expected on either gfx942 or gfx950 (confirmed by the cross-arch sweep).
…UseGlobalLoad Replace the unconditional `physical_pages_` and `page_stride_elements_` members with `std::conditional_t` + `[[no_unique_address]]` so they collapse to zero-byte placeholders in the SRD instantiation (kUseGlobalLoad=false).

Why:
  Reviewer asleepzzz observed that these fields were always present even
  when SRD-mode kernels never read them. The previous fix (Task #70,
  explicit mem-init) addressed the *form* of the concern (silent
  zero-init -> explicit zero-init) but not the *substance* (wasted
  storage in SRD-mode instantiations). This commit makes the fields
  literally disappear in SRD mode.

How it works:
  - Empty placeholder `gl_field_empty_t` introduced inside the class.
  - Both fields wrapped in `std::conditional_t<kUseGlobalLoad_, T, gl_field_empty_t>`
    with `[[no_unique_address]]` so the SRD-mode layout drops them.
  - All access sites (lines 520, 523, 758, 761) are already inside
    `if constexpr(kUseGlobalLoad_)` arms, so they compile out cleanly.
  - The setter `update_physical_pages` keeps its `static_assert(kUseGlobalLoad_)`
    guard; combined with lazy template member-function instantiation,
    the body is never instantiated for SRD callers.
  - Constructor mem-init stays type-agnostic via value-init `{}`; the
    `page_stride_elements_` assignment is gated by `if constexpr` in the
    body so the SRD arm only sees the empty struct.

AP-7 (codegen-hash) note:
  Class layout changes only on the kUseGlobalLoad=true instantiation
  (where layout is identical: one PageIdxArray + one index_t). The
  kUseGlobalLoad=false instantiation now has *less* state, but adjacent
  fields' offsets shift only if the compiler chose not to merge the
  `[[no_unique_address]]` placeholder. Verified by remote re-test on
  both gfx942 and gfx950.
@Jeff-Huang Jeff-Huang force-pushed the users/jeff-huang/ck/batch-prefill-fix-4bg-overflow branch from 819731a to a2692f8 Compare April 23, 2026 06:30
@Jeff-Huang
Copy link
Copy Markdown
Contributor Author

physical_pages_ and page_stride_elements_ are always present in tile_scatter_gather, even when kUseGlobalLoad=false

@asleepzzz Addressed in a2692f8a3d. The two members are now wrapped in std::conditional_t + [[no_unique_address]] so they collapse to zero bytes in the SRD-mode (kUseGlobalLoad=false) instantiation:

struct gl_field_empty_t {};  // empty placeholder for SRD-mode

[[no_unique_address]] std::conditional_t<kUseGlobalLoad_, PageIdxArray, gl_field_empty_t>
    physical_pages_;

[[no_unique_address]] std::conditional_t<kUseGlobalLoad_, index_t, gl_field_empty_t>
    page_stride_elements_;

All access sites were already inside if constexpr(kUseGlobalLoad_) arms, so no caller changes were required. The update_physical_pages setter keeps its static_assert(kUseGlobalLoad_) guard — combined with lazy template member-function instantiation, the body is never instantiated for SRD callers, so the empty-struct arm never sees an assignment.

Verified on three servers — test_batch_prefill_large_kvcache reports 160 passed / 32 skipped / 0 failed on each (smci355-gfx950 / hjbog-srdc-39 / gbt350-gfx950). The 32 skips are all page_size=1 + vectorized config-validity skips, not regressions.

@Jeff-Huang Jeff-Huang merged commit 1df887e into develop Apr 23, 2026
32 checks passed
@Jeff-Huang Jeff-Huang deleted the users/jeff-huang/ck/batch-prefill-fix-4bg-overflow branch April 23, 2026 23:08
assistant-librarian Bot pushed a commit to ROCm/composable_kernel that referenced this pull request Apr 23, 2026
[CK_TILE] fix(fmha): support >2GB KV cache in batch prefill
 via template dispatch (#6653)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

The CK batch prefill kernel previously failed (silent overflow + page
faults) when the KV cache exceeded 2 GB, blocking long-context inference
workloads (e.g., 128K+ token contexts with paged KV).

Two distinct failure modes were addressed:

1. **>4GB SRD overflow (`page_size < kN0`):** The SRD
`buffer_load_dwordx4` path uses a 32-bit `voffset` register; for small
page sizes the rebased SRD spans the full KV pool and the offset wraps
past 2 GB, corrupting K/V loads.
2. **gfx950 page-table fault (`page_size >= kN0`):** On CDNA4 the
hardware validates the **full SRD `num_records` range** against
page-table permissions (CDNA3 only checks per-instruction `voffset`).
After per-tile SRD rebase, an un-trimmed `num_records` field extends
past the live page and faults on freed/protected memory.

## Technical Details

**Two-mode `tile_scatter_gather` selected by the `kUseGlobalLoad`
template parameter:**

| Case | `page_size` | KV cache size | Mode | Load path | Addressing |
|---|---|---|---|---|---|
| 1 | `>= kN0` (large pages) | any | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, bounded by per-page rebase |
| 2 | `< kN0` (small pages) | `<= 2 GB` | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, fits in INT32 byte range |
| 3 | `< kN0` (small pages) | `> 2 GB` | Global-load
(`kUseGlobalLoad=true`) | `async_load_tile_raw_flat` (K) +
`load_tile_flat` (V) | 64-bit |

**Dispatch:** the auto-gen API layer (`fmha_batch_prefill.py`) selects
the kernel instantiation at launch from `(page_block_size,
num_total_pages * batch_stride_k * kElementBytes)`, so the small-page
penalty is paid only when correctness requires it.

**gfx950 SRD `num_records` trimming:** in the K and V rebase lambdas of
`block_fmha_batch_prefill_pipeline_qr_ks_vs_async`,
`set_bottom_tensor_view_buffer_size(page_stride_k/v)` is called after
each rebase to constrain `num_records` to the live page. Required for
CDNA4 page-table validation; harmless on CDNA3.

**Pipeline sync for the global-load path:**
- V uses synchronous `load_tile_flat`; K uses
`async_load_tile_raw_flat`.
- `v_physical_pages_current` is double-buffered so the V flat load
doesn't race against the next iteration's K rebase computation.

**Arch guards:** `global_load_lds` intrinsics are gated to `__gfx94__` /
`__gfx950__` (CDNA3+). Other architectures hit a `dependent_false`
static_assert with a descriptive message.

**Device-side assertion convention:** SRD setters use
`__builtin_assume(cond)` (hint-only) rather than `<cassert>`'s
`assert()`. The latter introduces an `__assert_fail` call whose register
pressure scatters the K-SRD scalar register window across conditional
branches, corrupting `buffer_load_dwordx4` on gfx950.

## Test Plan

Tested on both MI308 (gfx942) and MI355 (gfx950) via the aiter wrapper
test suite. All coverage lives in **`op_tests/test_batch_prefill.py`**:

- **Functional matrix (96 cases)** — `test_batch_prefill`: `page_size ∈
{1, 16, 1024}` × `kv_layout ∈ {linear, vectorized}` × `dtype ∈ {bf16,
fp8 quant variants}` × `causal` × `soft_cap` × `LSE` × `batch_size ∈ {1,
4}` (parametrized to exercise per-sequence SRD rebase across batch
boundaries).
- **>2 GB coverage** — `test_batch_prefill_large_kvcache`: extended to
allocate a 5 GB+ KV cache pool and exercise both `kUseGlobalLoad=true`
(small-page) and `kUseGlobalLoad=false` (large-page rebase) paths.
Includes both single-batch and multi-batch (`batch_size=4`) cases to
exercise per-sequence SRD rebase across the >2 GB pool.
- Numerical reference: PyTorch SDPA, per-batch loop with `atol` / `rtol`
from the existing batch prefill test harness.

## Test Result

| Arch | `test_batch_prefill` | `test_batch_prefill_large_kvcache` (>2
GB) |
|------|----------------------|---------------------|
| MI308 (gfx942) | All passed | Passed |
| MI355 (gfx950) | All passed | Passed |

**Performance impact (gfx950, hot SRD path):**
- +2.67% kernel-time on `seqlen=1024 / page_sz=1024 / bf16 / sglang /
causal / soft_cap=30`, attributable in full to the two
`set_bottom_tensor_view_buffer_size` calls in the K/V rebase lambdas
(5-run median, signal/noise ≈ 9×).
- This cost is **mandatory for gfx950 correctness** on >2 GB workloads —
removing the setters re-introduces page-faults.
- gfx942: 0 regressions in the same range (all configs ≤ +0.97%).

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
aledudek pushed a commit that referenced this pull request May 20, 2026
…ate dispatch (#6653)

## Motivation

The CK batch prefill kernel previously failed (silent overflow + page
faults) when the KV cache exceeded 2 GB, blocking long-context inference
workloads (e.g., 128K+ token contexts with paged KV).

Two distinct failure modes were addressed:

1. **>4GB SRD overflow (`page_size < kN0`):** The SRD
`buffer_load_dwordx4` path uses a 32-bit `voffset` register; for small
page sizes the rebased SRD spans the full KV pool and the offset wraps
past 2 GB, corrupting K/V loads.
2. **gfx950 page-table fault (`page_size >= kN0`):** On CDNA4 the
hardware validates the **full SRD `num_records` range** against
page-table permissions (CDNA3 only checks per-instruction `voffset`).
After per-tile SRD rebase, an un-trimmed `num_records` field extends
past the live page and faults on freed/protected memory.

## Technical Details

**Two-mode `tile_scatter_gather` selected by the `kUseGlobalLoad`
template parameter:**

| Case | `page_size` | KV cache size | Mode | Load path | Addressing |
|---|---|---|---|---|---|
| 1 | `>= kN0` (large pages) | any | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, bounded by per-page rebase |
| 2 | `< kN0` (small pages) | `<= 2 GB` | SRD (`kUseGlobalLoad=false`) |
`buffer_load_dwordx4` | 32-bit `voffset`, fits in INT32 byte range |
| 3 | `< kN0` (small pages) | `> 2 GB` | Global-load
(`kUseGlobalLoad=true`) | `async_load_tile_raw_flat` (K) +
`load_tile_flat` (V) | 64-bit |

**Dispatch:** the auto-gen API layer (`fmha_batch_prefill.py`) selects
the kernel instantiation at launch from `(page_block_size,
num_total_pages * batch_stride_k * kElementBytes)`, so the small-page
penalty is paid only when correctness requires it.

**gfx950 SRD `num_records` trimming:** in the K and V rebase lambdas of
`block_fmha_batch_prefill_pipeline_qr_ks_vs_async`,
`set_bottom_tensor_view_buffer_size(page_stride_k/v)` is called after
each rebase to constrain `num_records` to the live page. Required for
CDNA4 page-table validation; harmless on CDNA3.

**Pipeline sync for the global-load path:**
- V uses synchronous `load_tile_flat`; K uses
`async_load_tile_raw_flat`.
- `v_physical_pages_current` is double-buffered so the V flat load
doesn't race against the next iteration's K rebase computation.

**Arch guards:** `global_load_lds` intrinsics are gated to `__gfx94__` /
`__gfx950__` (CDNA3+). Other architectures hit a `dependent_false`
static_assert with a descriptive message.

**Device-side assertion convention:** SRD setters use
`__builtin_assume(cond)` (hint-only) rather than `<cassert>`'s
`assert()`. The latter introduces an `__assert_fail` call whose register
pressure scatters the K-SRD scalar register window across conditional
branches, corrupting `buffer_load_dwordx4` on gfx950.


## Test Plan

Tested on both MI308 (gfx942) and MI355 (gfx950) via the aiter wrapper
test suite. All coverage lives in **`op_tests/test_batch_prefill.py`**:

- **Functional matrix (96 cases)** — `test_batch_prefill`: `page_size ∈
{1, 16, 1024}` × `kv_layout ∈ {linear, vectorized}` × `dtype ∈ {bf16,
fp8 quant variants}` × `causal` × `soft_cap` × `LSE` × `batch_size ∈ {1,
4}` (parametrized to exercise per-sequence SRD rebase across batch
boundaries).
- **>2 GB coverage** — `test_batch_prefill_large_kvcache`: extended to
allocate a 5 GB+ KV cache pool and exercise both `kUseGlobalLoad=true`
(small-page) and `kUseGlobalLoad=false` (large-page rebase) paths.
Includes both single-batch and multi-batch (`batch_size=4`) cases to
exercise per-sequence SRD rebase across the >2 GB pool.
- Numerical reference: PyTorch SDPA, per-batch loop with `atol` / `rtol`
from the existing batch prefill test harness.

## Test Result

| Arch | `test_batch_prefill` | `test_batch_prefill_large_kvcache` (>2
GB) |
|------|----------------------|---------------------|
| MI308 (gfx942) | All passed | Passed |
| MI355 (gfx950) | All passed | Passed |

**Performance impact (gfx950, hot SRD path):**
- +2.67% kernel-time on `seqlen=1024 / page_sz=1024 / bf16 / sglang /
causal / soft_cap=30`, attributable in full to the two
`set_bottom_tensor_view_buffer_size` calls in the K/V rebase lambdas
(5-run median, signal/noise ≈ 9×).
- This cost is **mandatory for gfx950 correctness** on >2 GB workloads —
removing the setters re-introduces page-faults.
- gfx942: 0 regressions in the same range (all configs ≤ +0.97%).

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants