Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
e913534
fix(fmha): global_load_lds flat addressing for >4GB KV cache (page_si…
Jeff-Huang Mar 30, 2026
f96c1d1
fix(fmha): V flat 64-bit load for >4GB KV cache (page_size < kN0)
Jeff-Huang Apr 10, 2026
88508a7
fix(fmha): double-buffer v_physical_pages for flat load pipeline sync
Jeff-Huang Apr 14, 2026
78203c6
feat(fmha): template dispatch for >4GB KV cache in batch prefill
Jeff-Huang Apr 14, 2026
b87084c
refactor(fmha): unify tile_scatter_gather to two-mode design (SRD/Glo…
Jeff-Huang Apr 16, 2026
b0a6bd6
feat(fmha): add CDNA3+ arch guards for global_load_lds in batch prefill
Jeff-Huang Apr 17, 2026
83d6de7
fix(fmha): limit SRD num_records to page_stride after rebase (gfx950 …
Jeff-Huang Apr 18, 2026
bcead4b
cleanup: remove unused wave_reduce_min from utility.hpp
Jeff-Huang Apr 19, 2026
62b15aa
fix(buffer): use dependent assertion for unsupported num_dwords in as…
Jeff-Huang Apr 20, 2026
e1f80b1
refactor(fmha): tighten batch prefill SRD types and document 32-bit v…
Jeff-Huang Apr 20, 2026
b475690
docs(fmha): correct >2GB threshold wording across batch prefill
Jeff-Huang Apr 20, 2026
0fcb74e
refactor(fmha): move use_64bit_load decision into auto-gen API dispat…
Jeff-Huang Apr 20, 2026
66efbd5
refactor(fmha): unify kUse64BitLoad/kUseFlatLoad → kUseGlobalLoad
Jeff-Huang Apr 21, 2026
9f098e2
fix(fmha): use __builtin_assume + restore batch prefill review polish
Jeff-Huang Apr 21, 2026
72a5da7
refactor(fmha): batch prefill review polish — assert helper + setter …
Jeff-Huang Apr 21, 2026
bda6142
refactor(fmha): apply PR #6653 review feedback (Tasks #70-#74)
Jeff-Huang Apr 23, 2026
a2692f8
refactor(fmha): make tile_scatter_gather page fields conditional on k…
Jeff-Huang Apr 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,16 @@
QSCALE_CHECK_MAP,
QSCALE_MAP,
)
from codegen.arch import ArchTrait
from codegen.utils import update_file

# Architecture trait for kernels requiring global_load_lds (CDNA3+).
# Only used for GLOBAL_LOAD_LDS variants; all other kernels are arch-agnostic.
CDNA3_PLUS_ARCH = ArchTrait(
"cdna3_plus",
preprocessor_check="defined(__gfx94__) || defined(__gfx950__)",
)

DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
Expand All @@ -34,6 +42,10 @@
"bf8": 8,
}

# Element size in bytes per dtype, used by the auto-generated dispatcher to
# decide kv_load_mode per-arm (total KV cache bytes vs INT32_MAX).
DTYPE_BYTES = {k: v // 8 for k, v in DTYPE_BITS.items()}

K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}

SUPPORTED_PAGE_SIZE = [1, 16, 1024]
Expand All @@ -47,6 +59,10 @@
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
}
KV_LOAD_MODE_ENUM_MAP = {
False: "ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD",
True: "ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS",
}


FMHA_BATCH_PREFILL_PIPELINE_MAP = {
Expand All @@ -61,6 +77,8 @@
"""

FMHA_FWD_KERNEL_BODY = """
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check})

using fmha_dtype_{F_idx} = {F_dtype};

using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
Expand All @@ -87,7 +105,8 @@
{F_sink},
{F_page_size},
{F_kv_memory_layout},
{F_kv_lookup_table}>;
{F_kv_lookup_table},
{F_kv_load_mode}>;

using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;

Expand Down Expand Up @@ -125,7 +144,7 @@
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;

using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>;

#include <iostream>

Expand All @@ -140,10 +159,13 @@
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}

#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check})
"""

FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp"
FMHA_FWD_API = """
#include <cstdint>
#include <cstdio>

namespace {{
Expand Down Expand Up @@ -194,6 +216,7 @@
"""

FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
constexpr int kElementBytes = {F_element_bytes};
{F_hdim_case}
}}
"""
Expand All @@ -203,8 +226,8 @@
"""

FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && (fmha_batch_prefill_select_kv_load_mode(a.page_block_size, {F_bn0}, a.num_total_pages, a.batch_stride_k, kElementBytes) == {F_kv_load_mode})) {{
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
Expand Down Expand Up @@ -253,12 +276,14 @@ class FmhaFwdApiTrait:
kv_memory_layout: str
kv_lookup_table: str
page_size: int = 1 # page block size
use_global_load: bool = False # use global_load_lds_* for >2GB KV cache

@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
+ ("-gload" if self.use_global_load else "-bload")
)

@property
Expand Down Expand Up @@ -481,14 +506,18 @@ def api(self) -> str:
],
F_page_size=trait.page_size,
F_sink=BOOL_MAP[trait.sink],
F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[trait.use_global_load],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
F_if=if_i,
F_dtype=dtype,
F_element_bytes=DTYPE_BYTES[dtype],
F_hdim_case=per_hdim_case,
)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
Expand Down Expand Up @@ -539,6 +568,7 @@ class FmhaFwdKernel:
F_pipeline: FmhaFwdPipeline
mask_impl: str
F_page_size: int = 1 # page block size
F_use_global_load: bool = False # use global_load_lds_* for >2GB KV cache

@property
def template(self) -> str:
Expand Down Expand Up @@ -588,13 +618,18 @@ def template(self) -> str:
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
F_page_size=self.F_page_size,
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[self.F_use_global_load],
F_arch_check=CDNA3_PLUS_ARCH.preprocessor_check
if self.F_use_global_load
else "true",
)

@property
def name(self) -> str:
# TODO: we don't encode idx here
return (
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
+ ("gload_" if self.F_use_global_load else "bload_")
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
Expand Down Expand Up @@ -632,6 +667,7 @@ def api_trait(self) -> FmhaFwdApiTrait:
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
page_size=self.F_page_size,
use_global_load=self.F_use_global_load,
)


Expand Down Expand Up @@ -714,8 +750,11 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:


def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl,
targets: Optional[List[str]] = None
kernel_filter: Optional[str],
receipt,
optdim_list,
mask_impl,
targets: Optional[List[str]] = None,
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing
# (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with
Expand Down Expand Up @@ -837,6 +876,25 @@ def get_fwd_blobs(
api_pool.register_traits(k.api_trait())
gen.append(k)

# For page_size < kN0 (tile.F_bn0), also generate a GLOBAL_LOAD_LDS
# variant for >2GB KV cache support. The default (BUFFER_LOAD) uses SRD
# buffer_load (fast, <2GB). GLOBAL_LOAD_LDS uses global_load_lds_*
# (slower, handles >2GB).
if page_size < tile.F_bn0:
k_global_load = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
F_page_size=page_size,
F_use_global_load=True,
)
api_pool.register_traits(k_global_load.api_trait())
gen.append(k_global_load)

return (api_pool, gen)


Expand All @@ -856,7 +914,9 @@ def write_blobs(
optdim_list,
mask_impl,
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
api_pool, kernels = get_fwd_blobs(
kernel_filter, receipt, optdim_list, mask_impl, targets
)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
Expand All @@ -871,7 +931,9 @@ def list_blobs(
mask_impl,
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
_, kernels = get_fwd_blobs(
kernel_filter, receipt, optdim_list, mask_impl, targets
)
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
32 changes: 31 additions & 1 deletion projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,33 @@ struct fmha_batch_prefill_args
ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension
};

// Selects the KV-cache load mode for a batch-prefill dispatch arm.
// GLOBAL_LOAD_LDS: required when (a) the page is smaller than one K/V tile
// so per-page SRD is impossible, AND (b) the total KV-pool byte size
// exceeds INT32_MAX so SRD's 32-bit byte offset cannot address it.
// BUFFER_LOAD: every other case — the SGPR-resident SRD path is fastest.
// Inputs are taken as plain integers so the helper has no template parameter
// and can be called from each codegen-emitted dispatcher arm with the arm's
// compile-time kN0 / element_bytes substituted as constants.
inline ck_tile::BlockAttentionKVCacheLoadModeEnum
fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size,
ck_tile::index_t kN0,
ck_tile::index_t num_total_pages,
ck_tile::index_t batch_stride_k,
ck_tile::index_t element_bytes)
{
// Promote every operand to long_index_t so overflow is impossible regardless
// of multiplication order. A bare `static_cast<long_index_t>(num_total_pages)
// * batch_stride_k * element_bytes` only works because of left-to-right
// associativity — a future reorder of the operands would silently truncate.
const auto kv_pool_bytes = static_cast<ck_tile::long_index_t>(num_total_pages) *
static_cast<ck_tile::long_index_t>(batch_stride_k) *
static_cast<ck_tile::long_index_t>(element_bytes);
return (page_block_size < kN0 && kv_pool_bytes > INT32_MAX)
? ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS
: ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD;
}

template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
{
Expand Down Expand Up @@ -1457,7 +1484,9 @@ template <ck_tile::index_t HDim_,
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
ck_tile::BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D,
ck_tile::BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ =
ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD>
struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
DataType_,
kIsGroupMode_,
Expand Down Expand Up @@ -1486,6 +1515,7 @@ struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
static constexpr auto kKVLookupTable = kKVLookupTable_;
static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_;
static constexpr auto kKVLoadMode = kKVLoadMode_;
static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout");
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,87 @@ CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}

// Flat async load from global memory to LDS using 64-bit global addressing.
// Bypasses the SRD's 32-bit offset limit; required when the KV cache exceeds
// INT32_MAX (2GB) byte offset on the SRD voffset path.
//
// !!! M0 PRECONDITION — IMPLICIT INPUT NOT VISIBLE IN OPERAND LIST !!!
//
// The LDS destination address is taken from M0 (per AMD CDNA3 ISA §10.3:
// `LDS_ADDR = LDSbase + LDSoffset(M0[17:2] * 4) + INST.OFFSET + ThreadID*4`).
// M0 does NOT appear as an operand of these instructions or of the inline
// asm below — the compiler cannot see the dependency. Caller must:
//
// 1. Initialize M0 once before the load loop:
// `m0_set_with_memory(amd_wave_read_first_lane(lds_byte_offset));`
// M0 is SALU-only — `m0_set_with_memory` uses an "s" constraint to
// enforce this. Direct VALU writes to M0 are illegal.
//
// 2. Advance M0 between successive issues:
// `m0_inc_with_memory(size_per_issue);`
// `size_per_issue` MUST be a multiple of 4 — GLOBAL/FLAT LDS path
// only honors M0[17:2]*4 (dword-aligned), so low 2 bits are silently
// dropped (NOTE: this differs from MUBUF buffer_load_lds which uses
// M0[15:0] as a raw byte offset).
//
// 3. Never bundle `m0_inc_with_memory` and the next call to this
// function into a single inline asm. The compiler auto-inserts a
// hazard NOP between an SALU write to M0 and the consuming
// `global_load_lds_*`; bundling bypasses that and may read stale M0.
//
// The "memory" clobber on this asm is load-bearing: it prevents the
// compiler from reordering this load across other M0-touching helpers
// (`m0_set_with_memory` / `m0_inc_with_memory`, also "memory"-clobbered).
//
// Verified instruction emission (HIP 6.4 / clang 19, gfx942 + gfx950):
// `global_load_lds_dwordx4` is a single instruction (encoding 0xDDF48000
// 0x007F0000), NOT software-expanded into 4× dword. Same encoding on both
// arches. The opcode is undocumented in CDNA3 ISA spec §13.6.2 but
// supported by the LLVM AMDGPU backend.
//
// Available on gfx940+ (CDNA3: MI300, MI355, MI350 series).
template <unsigned num_dwords, bool pre_nop = false>
CK_TILE_DEVICE void
async_global_load_lds_dwordxn(void* smem, const void* global_addr, bool_constant<pre_nop> = {})
{
#if !defined(__gfx94__) && !defined(__gfx950__)
static_assert(always_false_v<integral_constant<unsigned, num_dwords>>,
"global_load_lds requires CDNA3+ (gfx940/gfx950). "
"Ensure kKVLoadMode is BUFFER_LOAD on this architecture.");
#endif

static_assert(num_dwords == 1 || num_dwords == 4,
"global_load_lds supports num_dwords == 1 or 4 only "
"(2 dwords does not exist on any supported arch; "
"3 dwords only on CDNA4 and unused in FMHA pipeline)");

// Inline asm: only the global address is an explicit operand. The LDS
// destination is implicit via M0 (see contract above). `"=r"(smem)` is a
// SSA scheduling anchor only — `smem` is NOT written by this asm; the
// load goes to LDS at `M0[17:2]*4 + offset:0 + ThreadID*4`.
#define CK_TILE_GLOBAL_LOAD_LDS_INSTR(instr) \
if constexpr(pre_nop) \
asm volatile("s_nop 4\n" instr " %1, off offset:0" \
: "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \
: "v"(global_addr) \
: "memory" /*prevents reorder across m0_{set,inc}*/); \
else \
asm volatile(instr " %1, off offset:0" \
: "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \
: "v"(global_addr) \
: "memory" /*prevents reorder across m0_{set,inc}*/);

if constexpr(num_dwords == 1)
{
CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dword");
}
else if constexpr(num_dwords == 4)
{
CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dwordx4");
}
#undef CK_TILE_GLOBAL_LOAD_LDS_INSTR
}

template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE thread_buffer<int8_t, N>
Expand Down
Loading
Loading