From e91353460d000b92030a99177b717ca64307c0b5 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Mon, 30 Mar 2026 13:21:10 +0800 Subject: [PATCH 01/17] fix(fmha): global_load_lds flat addressing for >4GB KV cache (page_size < kN0) Use global_load_lds_dwordx{1,4} with 64-bit flat addresses for K loads when page_size < kN0, eliminating the SRD 32-bit offset overflow. V loads use per-tile SRD rebase with wave_reduce_min. --- .../arch/amd_buffer_addressing_builtins.hpp | 38 +++++ .../include/ck_tile/core/arch/utility.hpp | 14 ++ .../include/ck_tile/core/tensor/load_tile.hpp | 18 ++ .../core/tensor/tile_scatter_gather.hpp | 116 +++++++++++++ ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 161 +++++++++++++----- 5 files changed, 305 insertions(+), 42 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 8056b76af7a..f69d3c7706c 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1319,6 +1319,44 @@ CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); } +// Flat async load from global memory to LDS using 64-bit addressing. +// Uses global_load_lds_dwordx{1,4} which bypasses the SRD's 32-bit offset limit. +// M0 must already contain the LDS destination offset (set by caller). +// The data is loaded from global_addr to LDS at [M0]. +// +// Available on gfx940+ (CDNA3: MI300, MI355, MI350 series). +template +CK_TILE_DEVICE void +async_global_load_lds_dwordxn(void* smem, const void* global_addr, bool_constant = {}) +{ +// Use inline asm with VGPR pair for 64-bit flat address +#define CK_TILE_GLOBAL_LOAD_LDS_INSTR(instr) \ + if constexpr(pre_nop) \ + asm volatile("s_nop 4\n" instr " %1, off offset:0" \ + : "=r"(smem) /*dummy dependency for smem*/ \ + : "v"(global_addr) \ + : "memory"); \ + else \ + asm volatile(instr " %1, off offset:0" \ + : "=r"(smem) /*dummy dependency for smem*/ \ + : "v"(global_addr) \ + : "memory"); + + if constexpr(num_dwords == 1) + { + CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dword"); + } + else if constexpr(num_dwords == 4) + { + CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dwordx4"); + } + else + { + static_assert(false, "wrong! only dword and dwordx4 supported for global_load_lds"); + } +#undef CK_TILE_GLOBAL_LOAD_LDS_INSTR +} + template CK_TILE_DEVICE thread_buffer diff --git a/projects/composablekernel/include/ck_tile/core/arch/utility.hpp b/projects/composablekernel/include/ck_tile/core/arch/utility.hpp index 647f5b4435c..a1fb7cca0e0 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/utility.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/utility.hpp @@ -59,6 +59,20 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta) #endif } +// Butterfly min-reduction across all lanes in a wave. +// Returns the minimum value, broadcast to all lanes as a uniform SGPR value. +// Used for per-tile SRD rebase: find the min physical page across all threads +// so the SRD can be rebased to that page with 64-bit pointer arithmetic. +CK_TILE_DEVICE index_t wave_reduce_min(index_t val) +{ + for(index_t offset = 1; offset < get_warp_size(); offset <<= 1) + { + const index_t other = warp_shuffle_down(val, offset); + val = min(val, other); + } + return __builtin_amdgcn_readfirstlane(val); +} + template CK_TILE_DEVICE auto warp_shuffle_down_pair(const T& v_local) { diff --git a/projects/composablekernel/include/ck_tile/core/tensor/load_tile.hpp b/projects/composablekernel/include/ck_tile/core/tensor/load_tile.hpp index d1c06d43780..64ea957121f 100644 --- a/projects/composablekernel/include/ck_tile/core/tensor/load_tile.hpp +++ b/projects/composablekernel/include/ck_tile/core/tensor/load_tile.hpp @@ -193,6 +193,24 @@ CK_TILE_DEVICE void async_load_tile_raw(LdsTileWindow_&& lds_tile, bool_constant{}); } +// Flat async load variant using 64-bit addressing (global_load_lds). +// For page_size < kN0 with >4GB KV cache where SRD-based loads overflow. +template +CK_TILE_DEVICE void async_load_tile_raw_flat(LdsTileWindow_&& lds_tile, + const TileWindow_& tile_window, + const PhysicalPagesArray& physical_pages, + long_index_t page_stride_bytes, + number = {}, + bool_constant = {}) +{ + tile_window.async_load_raw_flat( + lds_tile, physical_pages, page_stride_bytes, number{}, bool_constant{}); +} + CK_TILE_DEVICE void async_load_fence(index_t cnt = 0) { asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); diff --git a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp index aa293458920..5fb464c2f5a 100644 --- a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -357,6 +357,13 @@ struct tile_scatter_gather bottom_tensor_view_.buf_.p_data_ = data; } + // Override buffer size (in elements) for SRD num_records control. + // Use to set max range when SRD is rebased per-tile (page_size < kN0 path). + CK_TILE_DEVICE constexpr void set_bottom_tensor_view_buffer_size(long_index_t size) + { + bottom_tensor_view_.buf_.buffer_size_ = size; + } + // move thread's window adaptor coordinate and bottom tensor coordinate // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] template @@ -718,6 +725,115 @@ struct tile_scatter_gather }); } + // Flat async load from global memory to LDS using 64-bit addressing. + // Replaces async_load_raw for the page_size < kN0 path where SRD-based + // buffer_load_dword...lds would overflow its 32-bit offset. + // + // Instead of using SRD + voffset, computes a 64-bit flat address per element: + // addr = base_ptr + (int64)physical_pages[idx] * page_stride_bytes + // + (coord_offset + within_page_offset) * sizeof(DataType) + // Then issues global_load_lds_dwordx{1,4} which takes a 64-bit VGPR address. + // + // M0 register management is identical to async_load_raw. + template + CK_TILE_DEVICE auto async_load_raw_flat(LdsTileWindow_&& lds_tile, + const PhysicalPagesArray& physical_pages, + long_index_t page_stride_bytes, + number = {}, + bool_constant = {}) const + { + using LdsTileWindow = remove_cvref_t; + using LdsDataType = typename LdsTileWindow::DataType; + using DataType = typename BottomTensorView::DataType; + + static_assert(LdsTileWindow::get_num_of_dimension() == 3); + + const index_t size_per_buf = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType); + + const index_t size_per_wave = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<1>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t size_per_issue = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<1>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + m0_set_with_memory(amd_wave_read_first_lane(m0_init_value)); + + using Traits = load_store_traits; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; + + // Base pointer for 64-bit address computation + const auto* base_ptr = reinterpret_cast(get_bottom_tensor_view().buf_.p_data_); + + // Number of dwords per vector element + constexpr index_t vector_size = sizeof(vector_t) / sizeof(uint32_t); // dwords per vector + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); + + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr auto idx_gather = get_gather_index(idx_ys_start); + + // within-page offset from page_idx_ (set by kv_offset_array_transform) + const auto within_page_offset = page_idx_[idx_gather]; + // physical page index + const auto physical_page = physical_pages[idx_gather]; + + // Compute 64-bit flat address: + // base + phys_page * page_stride_bytes + // + (coord_offset + within_page_offset) * sizeof(DataType) + const auto coord_offset = bottom_tensor_thread_coord.get_offset(); + const auto* flat_addr = + base_ptr + static_cast(physical_page) * page_stride_bytes + + static_cast(coord_offset + within_page_offset) * sizeof(DataType); + + async_global_load_lds_dwordxn(smem, flat_addr, pre_nop_); + + // move thread coordinate (same as async_load_raw) + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + constexpr auto forward_step_scatter = generate_tuple( + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, + number{}); + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + forward_step_scatter); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + m0_inc_with_memory(size_per_issue); + } + }); + }); + } + // TODO: fix with swizzle template = kN0) - { - // SRD rebasing mode: within-page offset only. - // The full page base is handled by rebasing the SRD pointer. - kv_offset_vec[k0] = token_idx_in_page * stride_token; - } - else - { - // Full global offset (original code path for ps1, ps16, etc.) - const index_t physical_page = physical_pages[k0]; - kv_offset_vec[k0] = - physical_page * stride_page_block + token_idx_in_page * stride_token; - } + // Store within-page offset only. + // For kPageBlockSize >= kN0: the full page base is handled by SRD rebase. + // For kPageBlockSize < kN0: the full page base is handled by 64-bit flat + // addressing (global_load_lds) or per-tile SRD rebase. + kv_offset_vec[k0] = token_idx_in_page * stride_token; }); } else // V cache @@ -629,6 +621,31 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kVectorSize>( k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); + // V wave min page for per-tile SRD rebase (page_size < kN0) + index_t v_wave_min_page = 0; + + // Lambda: compute wave-level min physical page and adjust offsets to be relative. + auto compute_min_and_adjust_offsets = [](auto& physical_pages, + auto& offsets, + index_t page_stride, + auto NPages) { + index_t thread_min = physical_pages[number<0>{}]; + static_for<1, decltype(NPages)::value, 1>{}( + [&](auto k) { thread_min = min(thread_min, physical_pages[k]); }); + const index_t wave_min = wave_reduce_min(thread_min); + const index_t base_offset = wave_min * page_stride; + static_for<0, decltype(NPages)::value, 1>{}([&](auto k) { offsets[k] -= base_offset; }); + return wave_min; + }; + + // K load helper: use flat 64-bit loads for page_size < kN0, + // SRD-based buffer loads for page_size >= kN0. + const long_index_t k_page_stride_bytes = + static_cast(page_stride_k) * + sizeof(typename std::remove_const::type>:: + type); + auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window.get_window_lengths(), k_dram_block_window.get_window_origin(), @@ -636,14 +653,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync k_offsets); // K DRAM tile window for k_dram_window.init_raw(); - // SRD rebasing: move the buffer descriptor base pointer to each page's start address - // using 48-bit pointer arithmetic, so voffset only needs the small within-page offset. - // Only applies when kPageBlockSize >= kN0 (all threads in a wave access the same page). + // SRD rebasing for K: only for page_size >= kN0 (all threads on same page). + // For page_size < kN0, K uses flat 64-bit loads (no SRD needed). auto rebase_k_window = [&](auto& window, index_t physical_page) { if constexpr(kPageBlockSize >= kN0) { - // readfirstlane: make physical_page provably wave-uniform so the - // resulting SRD lands in SGPRs (required by buffer load instructions). physical_page = __builtin_amdgcn_readfirstlane(physical_page); const auto* base_ptr = k_dram_block_window.get_bottom_tensor_view().buf_.p_data_; const auto* page_ptr = @@ -653,20 +667,22 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } }; + // SRD rebasing for V: works for all page sizes. + // For page_size >= kN0: rebase to the single page (readfirstlane). + // For page_size < kN0: rebase to wave-level min page (from wave_reduce_min). auto rebase_v_window = [&](auto& window, index_t physical_page) { if constexpr(kPageBlockSize >= kN0) { physical_page = __builtin_amdgcn_readfirstlane(physical_page); - const auto* base_ptr = - v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_; - const auto* page_ptr = - base_ptr + static_cast(physical_page) * page_stride_v; - window.set_bottom_tensor_view_data_ptr(page_ptr); - window.init_raw(); } + const auto* base_ptr = v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_; + const auto* page_ptr = + base_ptr + static_cast(physical_page) * page_stride_v; + window.set_bottom_tensor_view_data_ptr(page_ptr); + window.init_raw(); }; - // Initial K SRD rebase + // Initial K SRD rebase (no-op for page_size < kN0, uses flat loads instead) rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]); constexpr auto k_oob_ck = bool_constant{}; @@ -902,6 +918,14 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kVectorSize>( v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); } + + // Per-tile min-reduce for V with page_size < kN0: + // adjust offsets to be relative to wave-level min page. + if constexpr(kPageBlockSize < kN0) + { + v_wave_min_page = compute_min_and_adjust_offsets( + v_physical_pages, v_offsets, page_stride_v, number{}); + } }; // Prefetch V physical pages early to hide buffer load latency @@ -917,12 +941,34 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync number<1>{}, // NumCoord VPageIndexYDims); + // For page_size < kN0, set max buffer size for V SRD + if constexpr(kPageBlockSize < kN0) + { + v_dram_window.set_bottom_tensor_view_buffer_size(0x7FFFFFFF); + } + // Initial V SRD rebase - rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); + rebase_v_window(v_dram_window, + kPageBlockSize >= kN0 ? v_physical_pages[number<0>{}] : v_wave_min_page); - // prefetch K tile - async_load_tile_raw( - k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); + // prefetch K tile: use flat 64-bit loads for page_size < kN0 + if constexpr(kPageBlockSize < kN0) + { + async_load_tile_raw_flat(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + k_physical_pages, + k_page_stride_bytes, + number<-1>{}, + k_pre_np); + } + else + { + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + } move_tile_window(k_dram_window, {0, kK0}); __builtin_amdgcn_sched_barrier(0); @@ -979,11 +1025,24 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if constexpr(k0_loops > 1) { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - async_load_tile_raw(k_lds_store(number{})>{}), - k_dram_window, - number<-1>{}, - k_oob_ck, - k_pre_np); + if constexpr(kPageBlockSize < kN0) + { + async_load_tile_raw_flat( + k_lds_store(number{})>{}), + k_dram_window, + k_physical_pages, + k_page_stride_bytes, + number<-1>{}, + k_pre_np); + } + else + { + async_load_tile_raw(k_lds_store(number{})>{}), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + } if constexpr(i_k0 < k0_loops - 1) move_tile_window(k_dram_window, {0, kK0}); @@ -1023,7 +1082,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // V physical pages already prefetched before GEMM0 update_v_offsets(number{}); v_dram_window.update_page_idx(v_offsets); - rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); + rebase_v_window(v_dram_window, + kPageBlockSize >= kN0 ? v_physical_pages[number<0>{}] + : v_wave_min_page); // KV_BLOCKSCALE: apply k_descale to s_acc (dequantize QK result) if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) @@ -1224,7 +1285,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf update_v_offsets(number<2 * kK1>{}); v_dram_window.update_page_idx(v_offsets); - rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); + rebase_v_window(v_dram_window, + kPageBlockSize >= kN0 ? v_physical_pages[number<0>{}] + : v_wave_min_page); } __builtin_amdgcn_sched_barrier(0); @@ -1395,7 +1458,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Update V offsets using previously prefetched physical pages update_v_offsets(number<(2 + i_k1.value) * kK1>{}); v_dram_window.update_page_idx(v_offsets); - rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); + rebase_v_window(v_dram_window, + kPageBlockSize >= kN0 ? v_physical_pages[number<0>{}] + : v_wave_min_page); } // Prefetch V physical pages for NEXT iteration - overlaps with GEMM1 @@ -1503,11 +1568,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) __builtin_amdgcn_s_barrier(); - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), - k_dram_window, - number<-1>{}, - k_oob_ck, - k_pre_np); + if constexpr(kPageBlockSize < kN0) + { + async_load_tile_raw_flat(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + k_physical_pages, + k_page_stride_bytes, + number<-1>{}, + k_pre_np); + } + else + { + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + } move_tile_window(k_dram_window, {0, kK0}); } // tail From f96c1d15545ee370f33f1777a57374b138d3e430 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Fri, 10 Apr 2026 20:50:06 +0800 Subject: [PATCH 02/17] fix(fmha): V flat 64-bit load for >4GB KV cache (page_size < kN0) Previously, V loads used SRD buffer_load with int32 voffset which overflows when pages within a tile span >2GB. K was already fixed to use flat 64-bit addressing (global_load_lds), but V still used SRD rebase. Changes: - Add load_flat() to tile_scatter_gather: flat load to VGPRs using 64-bit pointer arithmetic (base + page*stride + within_page_offset) - Add load_tile_flat() free function in load_tile.hpp - Change V kv_offset_array_transform to store within-page offset only (matching K), instead of full global offset that overflows int32 - Remove V compute_min_and_adjust_offsets (no longer needed with flat) - Pipeline: use load_tile_flat for V when page_size < kN0 This fixes scattered page allocation where adjacent logical tokens map to physically distant pages (>2GB apart within a tile). --- .../include/ck_tile/core/tensor/load_tile.hpp | 13 +++ .../core/tensor/tile_scatter_gather.hpp | 84 +++++++++++++++++++ ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 62 ++++++++------ 3 files changed, 133 insertions(+), 26 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/tensor/load_tile.hpp b/projects/composablekernel/include/ck_tile/core/tensor/load_tile.hpp index 64ea957121f..1489c95cb80 100644 --- a/projects/composablekernel/include/ck_tile/core/tensor/load_tile.hpp +++ b/projects/composablekernel/include/ck_tile/core/tensor/load_tile.hpp @@ -40,6 +40,19 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, return tile_window.load(number{}, bool_constant{}); } +// Flat load variant using 64-bit addressing. For scatter/gather windows with +// page_size < kN0 where SRD-based buffer_load would overflow its 32-bit voffset. +template +CK_TILE_DEVICE auto load_tile_flat(const TileWindow_& tile_window, + const PhysicalPagesArray& physical_pages, + long_index_t page_stride_bytes, + number = {}) +{ + return tile_window.load_flat(physical_pages, page_stride_bytes, number{}); +} + /** * @brief Load tile with elementwise function * diff --git a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 5fb464c2f5a..84075273826 100644 --- a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -526,6 +526,90 @@ struct tile_scatter_gather }); } + // Flat load from global memory to VGPRs using 64-bit addressing. + // Replaces load() for the page_size < kN0 path where SRD-based buffer_load + // would overflow its 32-bit voffset for scattered pages. + // + // Uses physical_pages array + page_stride_bytes to compute 64-bit flat + // addresses per element, then loads via pointer dereference (flat_load). + template + CK_TILE_DEVICE auto load_flat(const PhysicalPagesArray& physical_pages, + long_index_t page_stride_bytes, + number = {}) const + { + constexpr auto tile_dstr = TileDstr{}; + auto dst_tensor = make_static_distributed_tensor(tile_dstr); + + using Traits = load_store_traits; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + const auto* base_ptr = + reinterpret_cast(get_bottom_tensor_view().buf_.p_data_); + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr auto idx_gather = get_gather_index(idx_ys_start); + + // within-page offset from page_idx_ + const auto within_page_offset = page_idx_[idx_gather]; + // physical page index + const auto physical_page = physical_pages[idx_gather]; + + // 64-bit flat address: base + page*stride + (coord+within_page)*sizeof(T) + const auto coord_offset = bottom_tensor_thread_coord.get_offset(); + const auto* flat_addr = + base_ptr + + static_cast(physical_page) * page_stride_bytes + + static_cast(coord_offset + within_page_offset) * + sizeof(DataType); + + // Load via pointer dereference (generates flat_load instruction) + vector_t vec_value; + __builtin_memcpy(&vec_value, flat_addr, sizeof(vector_t)); + + // Write into distributed tensor (same as load()) + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + + dst_tensor.get_thread_buffer().template at() = + vec_value.template get_as()[j / Traits::PackedSize]; + }); + + // Move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + constexpr auto forward_step_scatter = generate_tuple( + [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, + number{}); + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + forward_step_scatter); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + return dst_tensor; + } + template diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index f54d5f2ecbb..1d9ab519f1c 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -203,22 +203,19 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica } else { - // Full global offset (original code path for ps1, ps16, etc.) - const index_t physical_page = physical_pages[k0]; - const long_index_t page_base_offset = - static_cast(physical_page) * stride_page_block; - + // Within-page offset only: page base is handled by flat 64-bit + // addressing in load_tile_flat (via physical_pages array). if constexpr(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) { const index_t token_offset = (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + (token_idx_in_page % kVectorSize); - kv_offset_vec[k0] = page_base_offset + token_offset; + kv_offset_vec[k0] = token_offset; } else { - kv_offset_vec[k0] = page_base_offset + token_idx_in_page * stride_token; + kv_offset_vec[k0] = token_idx_in_page * stride_token; } } }); @@ -919,13 +916,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); } - // Per-tile min-reduce for V with page_size < kN0: - // adjust offsets to be relative to wave-level min page. - if constexpr(kPageBlockSize < kN0) - { - v_wave_min_page = compute_min_and_adjust_offsets( - v_physical_pages, v_offsets, page_stride_v, number{}); - } + // For page_size < kN0: V uses flat 64-bit loads (load_tile_flat), + // so no per-tile SRD rebase or offset adjustment needed. + // v_offsets contain within-page offsets; page base is handled by + // physical_pages in load_tile_flat. }; // Prefetch V physical pages early to hide buffer load latency @@ -941,15 +935,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync number<1>{}, // NumCoord VPageIndexYDims); - // For page_size < kN0, set max buffer size for V SRD - if constexpr(kPageBlockSize < kN0) + // V page stride in bytes for flat 64-bit addressing + const long_index_t v_page_stride_bytes = + static_cast(page_stride_v) * + sizeof(typename std::remove_const::type>:: + type); + + // For page_size >= kN0, use SRD rebase (all threads on same page) + if constexpr(kPageBlockSize >= kN0) { - v_dram_window.set_bottom_tensor_view_buffer_size(0x7FFFFFFF); + rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); } - - // Initial V SRD rebase - rebase_v_window(v_dram_window, - kPageBlockSize >= kN0 ? v_physical_pages[number<0>{}] : v_wave_min_page); + // For page_size < kN0, V uses flat 64-bit loads (load_tile_flat) + // instead of SRD rebase, so no init_raw/rebase needed. // prefetch K tile: use flat 64-bit loads for page_size < kN0 if constexpr(kPageBlockSize < kN0) @@ -1078,7 +1077,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } __builtin_amdgcn_sched_barrier(1); - auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + auto v_buf = [&]() { + if constexpr(kPageBlockSize < kN0) + return load_tile_flat(v_dram_window, v_physical_pages, v_page_stride_bytes); + else + return load_tile(v_dram_window, number<-1>{}, bool_constant{}); + }(); // V physical pages already prefetched before GEMM0 update_v_offsets(number{}); v_dram_window.update_page_idx(v_offsets); @@ -1281,8 +1285,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_dram_window, {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... - v_buf = load_tile( - v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + if constexpr(kPageBlockSize < kN0) + v_buf = load_tile_flat(v_dram_window, v_physical_pages, v_page_stride_bytes); + else + v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); update_v_offsets(number<2 * kK1>{}); v_dram_window.update_page_idx(v_offsets); rebase_v_window(v_dram_window, @@ -1453,8 +1459,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) { - v_buf = load_tile( - v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + if constexpr(kPageBlockSize < kN0) + v_buf = load_tile_flat( + v_dram_window, v_physical_pages, v_page_stride_bytes); + else + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // Update V offsets using previously prefetched physical pages update_v_offsets(number<(2 + i_k1.value) * kK1>{}); v_dram_window.update_page_idx(v_offsets); From 88508a713a4af4282293d435df346c0b60363f30 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 14 Apr 2026 12:56:22 +0800 Subject: [PATCH 03/17] fix(fmha): double-buffer v_physical_pages for flat load pipeline sync The V flat load path (page_size < kN0) requires both physical_pages[] and page_idx_ to correspond to the same sub-tile. However, the pipeline prefetches the NEXT sub-tile's physical pages before the CURRENT sub-tile's load_tile_flat executes, causing address computation to mix within-page offsets from sub-tile N with physical pages from sub-tile N+1. Fix: save v_physical_pages to v_physical_pages_current before each prefetch_v_physical_pages() call, and use the saved copy in all 3 load_tile_flat() call sites. This preserves the pipeline prefetch overlap (pages loaded during GEMM0/softmax) while providing correct data to load_tile_flat. The fix is guarded by if constexpr(kPageBlockSize < kN0), so the SRD rebase path (page_size >= kN0) has zero overhead. --- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 49 +++++++++++++------ 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 1d9ab519f1c..843da88c255 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -205,6 +205,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica { // Within-page offset only: page base is handled by flat 64-bit // addressing in load_tile_flat (via physical_pages array). + // Note: physical_pages[] carries the page index separately. if constexpr(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) { @@ -218,6 +219,9 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica kv_offset_vec[k0] = token_idx_in_page * stride_token; } } + // Note: K also uses within-page offsets here (line 172 above). + // Both K and V rely on physical_pages[] for the page base, + // computed via 64-bit arithmetic in load_flat/async_load_raw_flat. }); } } @@ -664,19 +668,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } }; - // SRD rebasing for V: works for all page sizes. - // For page_size >= kN0: rebase to the single page (readfirstlane). - // For page_size < kN0: rebase to wave-level min page (from wave_reduce_min). + // SRD rebasing for V: only for page_size >= kN0 (all threads on same page). + // For page_size < kN0, V uses flat 64-bit loads (load_tile_flat) which + // compute addresses independently — no SRD rebase needed. auto rebase_v_window = [&](auto& window, index_t physical_page) { if constexpr(kPageBlockSize >= kN0) { physical_page = __builtin_amdgcn_readfirstlane(physical_page); + const auto* base_ptr = + v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_; + const auto* page_ptr = + base_ptr + static_cast(physical_page) * page_stride_v; + window.set_bottom_tensor_view_data_ptr(page_ptr); + window.init_raw(); } - const auto* base_ptr = v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_; - const auto* page_ptr = - base_ptr + static_cast(physical_page) * page_stride_v; - window.set_bottom_tensor_view_data_ptr(page_ptr); - window.init_raw(); }; // Initial K SRD rebase (no-op for page_size < kN0, uses flat loads instead) @@ -818,6 +823,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // V physical pages array for use with kv_offset_array_transform // For V_KIterOuter > 1, we need V_PageIdxRepeat elements; otherwise V_KIterInner statically_indexed_array v_physical_pages{}; + // Double-buffer for flat loads: save current sub-tile's physical pages before + // prefetch overwrites them. load_tile_flat needs physical_pages aligned with + // page_idx_, but the pipeline prefetches the NEXT sub-tile's pages before the + // CURRENT sub-tile's flat load executes. + statically_indexed_array v_physical_pages_current{}; // Prefetch V physical pages - can be called early to hide buffer load latency auto prefetch_v_physical_pages = [&](auto k_loop_start) { @@ -939,8 +949,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const long_index_t v_page_stride_bytes = static_cast(page_stride_v) * sizeof(typename std::remove_const::type>:: - type); + decltype(v_dram_block_window_tmp.get_bottom_tensor_view() + .buf_.p_data_)>::type>::type); // For page_size >= kN0, use SRD rebase (all threads on same page) if constexpr(kPageBlockSize >= kN0) @@ -1016,6 +1026,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_descale = v_descale_ptr[scale_offset]; } + // Save current physical pages before prefetch overwrites them + // (load_tile_flat needs pages aligned with current page_idx_) + if constexpr(kPageBlockSize < kN0) + v_physical_pages_current = v_physical_pages; // Prefetch V physical pages early - overlaps with GEMM0 computation prefetch_v_physical_pages(number{}); @@ -1079,7 +1093,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto v_buf = [&]() { if constexpr(kPageBlockSize < kN0) - return load_tile_flat(v_dram_window, v_physical_pages, v_page_stride_bytes); + return load_tile_flat( + v_dram_window, v_physical_pages_current, v_page_stride_bytes); else return load_tile(v_dram_window, number<-1>{}, bool_constant{}); }(); @@ -1231,6 +1246,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Prefetch V physical pages early - overlaps with softmax computation if constexpr(k1_loops > 1) { + if constexpr(kPageBlockSize < kN0) + v_physical_pages_current = v_physical_pages; prefetch_v_physical_pages(number<2 * kK1>{}); } @@ -1286,7 +1303,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... if constexpr(kPageBlockSize < kN0) - v_buf = load_tile_flat(v_dram_window, v_physical_pages, v_page_stride_bytes); + v_buf = load_tile_flat( + v_dram_window, v_physical_pages_current, v_page_stride_bytes); else v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); update_v_offsets(number<2 * kK1>{}); @@ -1461,10 +1479,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { if constexpr(kPageBlockSize < kN0) v_buf = load_tile_flat( - v_dram_window, v_physical_pages, v_page_stride_bytes); + v_dram_window, v_physical_pages_current, v_page_stride_bytes); else - v_buf = load_tile( - v_dram_window, number<-1>{}, bool_constant{}); + v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); // Update V offsets using previously prefetched physical pages update_v_offsets(number<(2 + i_k1.value) * kK1>{}); v_dram_window.update_page_idx(v_offsets); @@ -1476,6 +1493,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Prefetch V physical pages for NEXT iteration - overlaps with GEMM1 if constexpr(i_k1 + 1 < k1_loops - 1) { + if constexpr(kPageBlockSize < kN0) + v_physical_pages_current = v_physical_pages; prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{}); } From 78203c66745b20232ece724d3a83e0c04de99d4d Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Wed, 15 Apr 2026 07:01:16 +0800 Subject: [PATCH 04/17] feat(fmha): template dispatch for >4GB KV cache in batch prefill MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add kUse64BitLoad template parameter to select between SRD buffer_load (fast, <4GB) and flat 64-bit loads (correct, >4GB) at kernel launch time. For page_size < kN0 (128), the kernel generates two variants: - kUse64BitLoad=false: original full-offset buffer_load path (zero regression) - kUse64BitLoad=true: flat load with V double-buffer (15-20% slower, >4GB safe) For page_size >= kN0, SRD rebase handles >4GB natively via 64-bit pointer arithmetic — no flat load variant needed. Runtime dispatch in mha_fwd_batch_prefill.cu checks max_page_byte_offset against INT32_MAX and selects the appropriate variant automatically. --- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 34 ++++- .../example/ck_tile/01_fmha/fmha_fwd.hpp | 7 +- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 136 ++++++++---------- .../pipeline/block_fmha_pipeline_problem.hpp | 4 + .../ops/fmha/pipeline/tile_fmha_traits.hpp | 4 +- 5 files changed, 104 insertions(+), 81 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 7c3efb9c185..74837fbaa24 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -87,7 +87,8 @@ {F_sink}, {F_page_size}, {F_kv_memory_layout}, - {F_kv_lookup_table}>; + {F_kv_lookup_table}, + {F_use_64bit_load}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -125,7 +126,7 @@ ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_use_64bit_load}>; #include @@ -203,8 +204,8 @@ """ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{ - using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && (t.use_64bit_load == {F_use_64bit_load})) {{ + using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_use_64bit_load}>; return fmha_batch_prefill_(s, a); }} """ @@ -253,12 +254,14 @@ class FmhaFwdApiTrait: kv_memory_layout: str kv_lookup_table: str page_size: int = 1 # page block size + use_64bit_load: bool = False # use flat 64-bit loads for >4GB KV cache @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}" + + ("-64bit" if self.use_64bit_load else "") ) @property @@ -481,6 +484,7 @@ def api(self) -> str: ], F_page_size=trait.page_size, F_sink=BOOL_MAP[trait.sink], + F_use_64bit_load=BOOL_MAP["t" if trait.use_64bit_load else "f"], ) if_j = "if" if j == 0 else "else if" per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( @@ -539,6 +543,7 @@ class FmhaFwdKernel: F_pipeline: FmhaFwdPipeline mask_impl: str F_page_size: int = 1 # page block size + F_use_64bit_load: bool = False # use flat 64-bit loads for >4GB KV cache @property def template(self) -> str: @@ -588,6 +593,7 @@ def template(self) -> str: F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag], F_page_size=self.F_page_size, F_sink=BOOL_MAP[self.F_pipeline.F_sink], + F_use_64bit_load=BOOL_MAP["t" if self.F_use_64bit_load else "f"], ) @property @@ -595,6 +601,7 @@ def name(self) -> str: # TODO: we don't encode idx here return ( f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_" + + ("64bit_" if self.F_use_64bit_load else "") + self.F_tile.name + "_" + self.F_pipeline.name @@ -632,6 +639,7 @@ def api_trait(self) -> FmhaFwdApiTrait: kv_memory_layout=self.F_pipeline.F_kv_memory_layout, kv_lookup_table=self.F_pipeline.F_kv_lookup_table, page_size=self.F_page_size, + use_64bit_load=self.F_use_64bit_load, ) @@ -837,6 +845,24 @@ def get_fwd_blobs( api_pool.register_traits(k.api_trait()) gen.append(k) + # For page_size < kN0 (tile.F_bn0), also generate a kUse64BitLoad=true + # variant for >4GB KV cache support. The default (false) uses SRD buffer_load + # (fast, <4GB). The 64-bit variant uses flat loads (slower, handles >4GB). + if page_size < tile.F_bn0: + k_64bit = FmhaFwdKernel( + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + F_page_size=page_size, + F_use_64bit_load=True, + ) + api_pool.register_traits(k_64bit.api_trait()) + gen.append(k_64bit) + return (api_pool, gen) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp index 6c842def58c..dc7821fecb5 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1457,7 +1457,8 @@ template + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D, + bool kUse64BitLoad_ = false> struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_ + index_t kVectorSize, + bool kUseFlatLoad_ = false> CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages, const index_t& stride_token, const index_t& stride_page_block, @@ -156,56 +157,63 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica const index_t& thread_coord_start = coord_vec[kCoordAxis]; constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1; + // Offset strategy: + // kPageBlockSize >= kN0: within-page offset (SRD rebased to page base) + // kPageBlockSize < kN0 && kUseFlatLoad_: within-page offset (flat load uses + // physical_pages[]) kPageBlockSize < kN0 && !kUseFlatLoad_: FULL offset (page * stride + + // within_page) for + // direct buffer_load with 32-bit voffset — the original code path, fast but limited to <4GB + constexpr bool kNeedFullOffset = (kPageBlockSize < kN0) && !kUseFlatLoad_; + if constexpr(kIsKcache) { // K cache: per-token lookup - // Each token may be on a different page, so we use physical_pages[k0] for each. static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; - // Store within-page offset only. - // For kPageBlockSize >= kN0: the full page base is handled by SRD rebase. - // For kPageBlockSize < kN0: the full page base is handled by 64-bit flat - // addressing (global_load_lds) or per-tile SRD rebase. - kv_offset_vec[k0] = token_idx_in_page * stride_token; + if constexpr(kNeedFullOffset) + { + const index_t physical_page = physical_pages[k0]; + kv_offset_vec[k0] = + physical_page * stride_page_block + token_idx_in_page * stride_token; + } + else + { + kv_offset_vec[k0] = token_idx_in_page * stride_token; + } }); } else // V cache { - // V cache: use physical_pages[k0] for each token - // physical_pages was already populated correctly by load_physical_pages(), handling: - // - page_size=1: page_idx maps token_idx -> physical_page directly - // - V tile crosses pages: per-token page lookup - // - V tile in single page: lane0 lookup with broadcast to all lanes static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; - if constexpr(kPageBlockSize >= kN0) + if constexpr(kNeedFullOffset) { - // SRD rebasing mode: within-page offset only. - // The full page base is handled by rebasing the SRD pointer. + const index_t physical_page = physical_pages[k0]; + const long_index_t page_base = + static_cast(physical_page) * stride_page_block; + if constexpr(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) { const index_t token_offset = (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + (token_idx_in_page % kVectorSize); - kv_offset_vec[k0] = token_offset; + kv_offset_vec[k0] = page_base + token_offset; } else { - kv_offset_vec[k0] = token_idx_in_page * stride_token; + kv_offset_vec[k0] = page_base + token_idx_in_page * stride_token; } } else { - // Within-page offset only: page base is handled by flat 64-bit - // addressing in load_tile_flat (via physical_pages array). - // Note: physical_pages[] carries the page index separately. + // Within-page offset only: page base handled by SRD rebase or flat load if constexpr(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) { @@ -219,9 +227,6 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica kv_offset_vec[k0] = token_idx_in_page * stride_token; } } - // Note: K also uses within-page offsets here (line 172 above). - // Both K and V rely on physical_pages[] for the page base, - // computed via 64-bit arithmetic in load_flat/async_load_raw_flat. }); } } @@ -262,11 +267,14 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; static constexpr index_t kPageBlockSize = Problem::kPageBlockSize; + static constexpr bool kUse64BitLoad = Problem::kUse64BitLoad; static constexpr index_t kVectorSize = Problem::kVectorSize; - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - static constexpr auto I3 = number<3>{}; + // Effective condition for flat 64-bit loads: kUse64BitLoad AND page_size < kN0 + static constexpr bool kUseFlatLoad = kUse64BitLoad && (kPageBlockSize < kN0); + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); static constexpr bool kIsGroupMode = Problem::kIsGroupMode; @@ -619,26 +627,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, true, kN0, - kVectorSize>( + kVectorSize, + kUseFlatLoad>( k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); - // V wave min page for per-tile SRD rebase (page_size < kN0) - index_t v_wave_min_page = 0; - - // Lambda: compute wave-level min physical page and adjust offsets to be relative. - auto compute_min_and_adjust_offsets = [](auto& physical_pages, - auto& offsets, - index_t page_stride, - auto NPages) { - index_t thread_min = physical_pages[number<0>{}]; - static_for<1, decltype(NPages)::value, 1>{}( - [&](auto k) { thread_min = min(thread_min, physical_pages[k]); }); - const index_t wave_min = wave_reduce_min(thread_min); - const index_t base_offset = wave_min * page_stride; - static_for<0, decltype(NPages)::value, 1>{}([&](auto k) { offsets[k] -= base_offset; }); - return wave_min; - }; - // K load helper: use flat 64-bit loads for page_size < kN0, // SRD-based buffer loads for page_size >= kN0. const long_index_t k_page_stride_bytes = @@ -655,7 +647,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync k_dram_window.init_raw(); // SRD rebasing for K: only for page_size >= kN0 (all threads on same page). - // For page_size < kN0, K uses flat 64-bit loads (no SRD needed). + // For page_size < kN0: either flat loads (kUseFlatLoad) or full offsets handle addressing. auto rebase_k_window = [&](auto& window, index_t physical_page) { if constexpr(kPageBlockSize >= kN0) { @@ -669,8 +661,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync }; // SRD rebasing for V: only for page_size >= kN0 (all threads on same page). - // For page_size < kN0, V uses flat 64-bit loads (load_tile_flat) which - // compute addresses independently — no SRD rebase needed. + // For page_size < kN0: either flat loads (kUseFlatLoad) or full offsets handle addressing. auto rebase_v_window = [&](auto& window, index_t physical_page) { if constexpr(kPageBlockSize >= kN0) { @@ -897,12 +888,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, false, kN0, - kVectorSize>(v_physical_pages_k2, - stride_v, - page_stride_v, - v_coord, - v_offsets_k2, - current_seq_k); + kVectorSize, + kUseFlatLoad>(v_physical_pages_k2, + stride_v, + page_stride_v, + v_coord, + v_offsets_k2, + current_seq_k); static_for<0, V_KIterInner, 1>{}([&](auto k1) { constexpr auto idx = number{}; @@ -922,7 +914,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, false, kN0, - kVectorSize>( + kVectorSize, + kUseFlatLoad>( v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); } @@ -961,7 +954,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // instead of SRD rebase, so no init_raw/rebase needed. // prefetch K tile: use flat 64-bit loads for page_size < kN0 - if constexpr(kPageBlockSize < kN0) + if constexpr(kUseFlatLoad) { async_load_tile_raw_flat(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, @@ -1028,7 +1021,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Save current physical pages before prefetch overwrites them // (load_tile_flat needs pages aligned with current page_idx_) - if constexpr(kPageBlockSize < kN0) + if constexpr(kUseFlatLoad) v_physical_pages_current = v_physical_pages; // Prefetch V physical pages early - overlaps with GEMM0 computation prefetch_v_physical_pages(number{}); @@ -1038,7 +1031,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if constexpr(k0_loops > 1) { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - if constexpr(kPageBlockSize < kN0) + if constexpr(kUseFlatLoad) { async_load_tile_raw_flat( k_lds_store(number{})>{}), @@ -1092,7 +1085,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync __builtin_amdgcn_sched_barrier(1); auto v_buf = [&]() { - if constexpr(kPageBlockSize < kN0) + if constexpr(kUseFlatLoad) return load_tile_flat( v_dram_window, v_physical_pages_current, v_page_stride_bytes); else @@ -1101,9 +1094,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // V physical pages already prefetched before GEMM0 update_v_offsets(number{}); v_dram_window.update_page_idx(v_offsets); - rebase_v_window(v_dram_window, - kPageBlockSize >= kN0 ? v_physical_pages[number<0>{}] - : v_wave_min_page); + rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); // KV_BLOCKSCALE: apply k_descale to s_acc (dequantize QK result) if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::KV_BLOCKSCALE) @@ -1246,7 +1237,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Prefetch V physical pages early - overlaps with softmax computation if constexpr(k1_loops > 1) { - if constexpr(kPageBlockSize < kN0) + if constexpr(kUseFlatLoad) v_physical_pages_current = v_physical_pages; prefetch_v_physical_pages(number<2 * kK1>{}); } @@ -1302,16 +1293,14 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_dram_window, {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... - if constexpr(kPageBlockSize < kN0) + if constexpr(kUseFlatLoad) v_buf = load_tile_flat( v_dram_window, v_physical_pages_current, v_page_stride_bytes); else v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); update_v_offsets(number<2 * kK1>{}); v_dram_window.update_page_idx(v_offsets); - rebase_v_window(v_dram_window, - kPageBlockSize >= kN0 ? v_physical_pages[number<0>{}] - : v_wave_min_page); + rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); } __builtin_amdgcn_sched_barrier(0); @@ -1477,7 +1466,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) { - if constexpr(kPageBlockSize < kN0) + if constexpr(kUseFlatLoad) v_buf = load_tile_flat( v_dram_window, v_physical_pages_current, v_page_stride_bytes); else @@ -1485,15 +1474,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Update V offsets using previously prefetched physical pages update_v_offsets(number<(2 + i_k1.value) * kK1>{}); v_dram_window.update_page_idx(v_offsets); - rebase_v_window(v_dram_window, - kPageBlockSize >= kN0 ? v_physical_pages[number<0>{}] - : v_wave_min_page); + rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); } // Prefetch V physical pages for NEXT iteration - overlaps with GEMM1 if constexpr(i_k1 + 1 < k1_loops - 1) { - if constexpr(kPageBlockSize < kN0) + if constexpr(kUseFlatLoad) v_physical_pages_current = v_physical_pages; prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{}); } @@ -1575,7 +1562,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, true, kN0, - kVectorSize>( + kVectorSize, + kUseFlatLoad>( k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]); @@ -1597,7 +1585,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) __builtin_amdgcn_s_barrier(); - if constexpr(kPageBlockSize < kN0) + if constexpr(kUseFlatLoad) { async_load_tile_raw_flat(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 87db7b85b9e..61c174c1bf6 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -117,6 +117,10 @@ struct BlockFmhaBatchPrefillPipelineProblem static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0, "kPageBlockSize must be power of two"); + // When true, use flat 64-bit loads for page_size < kN0 (handles >4GB KV cache). + // When false (default), use SRD buffer_load for all page sizes (faster, <4GB only). + static constexpr bool kUse64BitLoad = Traits_::kUse64BitLoad; + static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4 static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout; static constexpr auto kKVLookupTable = Traits_::kKVLookupTable; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 7df39c3d113..57ed097c24d 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -58,7 +58,8 @@ template + BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D, + bool kUse64BitLoad_ = false> struct TileFmhaBatchPrefillTraits : public TileFmhaTraits Date: Thu, 16 Apr 2026 16:39:10 +0800 Subject: [PATCH 05/17] refactor(fmha): unify tile_scatter_gather to two-mode design (SRD/Global load) Simplify tile_scatter_gather to two clean modes controlled by kUseFlatLoad_: - SRD mode (kUseFlatLoad=false): buffer_load(SRD, page_idx_[i] + coord) - Global load mode (kUseFlatLoad=true): flat_load(base + physical_pages_[i] * stride + page_idx_[i] + coord) Changes: - kv_offset_array_transform: eliminate 6-branch K/V duplication into unified loop - tile_scatter_gather: add kUseFlatLoad_ template param, physical_pages_ and page_stride_elements_ members, flat load branches in load() and async_load_raw() - Remove load_flat(), async_load_raw_flat(), load_tile_flat(), async_load_tile_raw_flat() - Pipeline: replace ~10 if constexpr(kUseFlatLoad) load branches with update_physical_pages() + unified load_tile()/async_load_tile_raw() calls - Remove v_physical_pages_current double-buffer variable (now managed internally) Net: -235 lines, zero functional change confirmed across page_size 1/16/1024, bf16/fp8, linear/vectorized layouts, and >4GB overflow boundary tests. --- .../include/ck_tile/core/tensor/load_tile.hpp | 31 -- .../core/tensor/tile_scatter_gather.hpp | 306 ++++++------------ ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 223 ++++--------- 3 files changed, 161 insertions(+), 399 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/tensor/load_tile.hpp b/projects/composablekernel/include/ck_tile/core/tensor/load_tile.hpp index 1489c95cb80..d1c06d43780 100644 --- a/projects/composablekernel/include/ck_tile/core/tensor/load_tile.hpp +++ b/projects/composablekernel/include/ck_tile/core/tensor/load_tile.hpp @@ -40,19 +40,6 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, return tile_window.load(number{}, bool_constant{}); } -// Flat load variant using 64-bit addressing. For scatter/gather windows with -// page_size < kN0 where SRD-based buffer_load would overflow its 32-bit voffset. -template -CK_TILE_DEVICE auto load_tile_flat(const TileWindow_& tile_window, - const PhysicalPagesArray& physical_pages, - long_index_t page_stride_bytes, - number = {}) -{ - return tile_window.load_flat(physical_pages, page_stride_bytes, number{}); -} - /** * @brief Load tile with elementwise function * @@ -206,24 +193,6 @@ CK_TILE_DEVICE void async_load_tile_raw(LdsTileWindow_&& lds_tile, bool_constant{}); } -// Flat async load variant using 64-bit addressing (global_load_lds). -// For page_size < kN0 with >4GB KV cache where SRD-based loads overflow. -template -CK_TILE_DEVICE void async_load_tile_raw_flat(LdsTileWindow_&& lds_tile, - const TileWindow_& tile_window, - const PhysicalPagesArray& physical_pages, - long_index_t page_stride_bytes, - number = {}, - bool_constant = {}) -{ - tile_window.async_load_raw_flat( - lds_tile, physical_pages, page_stride_bytes, number{}, bool_constant{}); -} - CK_TILE_DEVICE void async_load_fence(index_t cnt = 0) { asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); diff --git a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 84075273826..b6b27c957bb 100644 --- a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -45,16 +45,18 @@ template > + typename YsGatherDims = sequence<0>, + bool kUseFlatLoad_ = false> struct tile_scatter_gather { - using BottomTensorView = remove_reference_t; - using WindowLengths = remove_cvref_t; - using TileDstr = remove_cvref_t; - using PageIdxArray = remove_cvref_t; - using ValidArray = remove_cvref_t; - using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; - using BottomTensorDesc = typename BottomTensorView::TensorDesc; + static constexpr bool kUseFlatLoad = kUseFlatLoad_; + using BottomTensorView = remove_reference_t; + using WindowLengths = remove_cvref_t; + using TileDstr = remove_cvref_t; + using PageIdxArray = remove_cvref_t; + using ValidArray = remove_cvref_t; + using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; + using BottomTensorDesc = typename BottomTensorView::TensorDesc; using DataType = remove_cvref_t; @@ -465,7 +467,21 @@ struct tile_scatter_gather // read from bottom tensor const vector_t vec_value = [&]() { - if constexpr(std::is_same_v) + if constexpr(kUseFlatLoad_) + { + // Global load mode: 64-bit typed pointer arithmetic + const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_; + const auto physical_page = physical_pages_[idx_gather]; + const auto coord_offset = bottom_tensor_thread_coord.get_offset(); + const long_index_t total_offset = + static_cast(physical_page) * page_stride_elements_ + + coord_offset + page_offset; + const auto* addr = base_ptr + total_offset; + vector_t v; + __builtin_memcpy(&v, addr, sizeof(vector_t)); + return v; + } + else if constexpr(std::is_same_v) { return get_bottom_tensor_view().template get_vectorized_elements( bottom_tensor_thread_coord, @@ -526,90 +542,6 @@ struct tile_scatter_gather }); } - // Flat load from global memory to VGPRs using 64-bit addressing. - // Replaces load() for the page_size < kN0 path where SRD-based buffer_load - // would overflow its 32-bit voffset for scattered pages. - // - // Uses physical_pages array + page_stride_bytes to compute 64-bit flat - // addresses per element, then loads via pointer dereference (flat_load). - template - CK_TILE_DEVICE auto load_flat(const PhysicalPagesArray& physical_pages, - long_index_t page_stride_bytes, - number = {}) const - { - constexpr auto tile_dstr = TileDstr{}; - auto dst_tensor = make_static_distributed_tensor(tile_dstr); - - using Traits = load_store_traits; - using vector_t = typename Traits::vector_t; - using SFC_Ys = typename Traits::SFC_Ys; - - const auto* base_ptr = - reinterpret_cast(get_bottom_tensor_view().buf_.p_data_); - - static_for<0, NumCoord, 1>{}([&](auto iCoord) { - auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; - auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; - - static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { - constexpr auto iAccess = number{}; - constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - constexpr auto idx_gather = get_gather_index(idx_ys_start); - - // within-page offset from page_idx_ - const auto within_page_offset = page_idx_[idx_gather]; - // physical page index - const auto physical_page = physical_pages[idx_gather]; - - // 64-bit flat address: base + page*stride + (coord+within_page)*sizeof(T) - const auto coord_offset = bottom_tensor_thread_coord.get_offset(); - const auto* flat_addr = - base_ptr + - static_cast(physical_page) * page_stride_bytes + - static_cast(coord_offset + within_page_offset) * - sizeof(DataType); - - // Load via pointer dereference (generates flat_load instruction) - vector_t vec_value; - __builtin_memcpy(&vec_value, flat_addr, sizeof(vector_t)); - - // Write into distributed tensor (same as load()) - static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { - constexpr auto idx_ys = generate_tuple( - [&](auto jj) { - return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) - : idx_ys_start[jj]; - }, - number{}); - - constexpr index_t d = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / - Traits::PackedSize; - - dst_tensor.get_thread_buffer().template at() = - vec_value.template get_as()[j / Traits::PackedSize]; - }); - - // Move thread coordinate - if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) - { - constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); - constexpr auto forward_step_scatter = generate_tuple( - [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, - number{}); - constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), - forward_step_scatter); - - move_window_adaptor_and_bottom_tensor_thread_coordinate( - window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); - } - }); - }); - return dst_tensor; - } - template @@ -771,7 +703,23 @@ struct tile_scatter_gather const auto page_offset = page_idx_[idx_gather]; // read from bottom tensor - if constexpr(std::is_same_v) + if constexpr(kUseFlatLoad_) + { + // Global load mode: global_load_lds with 64-bit address + constexpr index_t vector_size = + sizeof(vector_t) / sizeof(uint32_t); // dwords per vector + const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_; + const auto physical_page = physical_pages_[idx_gather]; + const auto coord_offset = bottom_tensor_thread_coord.get_offset(); + const long_index_t total_offset = + static_cast(physical_page) * page_stride_elements_ + + coord_offset + page_offset; + const auto* addr = base_ptr + total_offset; + // global_load_lds takes byte address + async_global_load_lds_dwordxn( + smem, reinterpret_cast(addr), pre_nop_); + } + else if constexpr(std::is_same_v) { get_bottom_tensor_view().template async_get_vectorized_elements_raw( smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_); @@ -809,115 +757,6 @@ struct tile_scatter_gather }); } - // Flat async load from global memory to LDS using 64-bit addressing. - // Replaces async_load_raw for the page_size < kN0 path where SRD-based - // buffer_load_dword...lds would overflow its 32-bit offset. - // - // Instead of using SRD + voffset, computes a 64-bit flat address per element: - // addr = base_ptr + (int64)physical_pages[idx] * page_stride_bytes - // + (coord_offset + within_page_offset) * sizeof(DataType) - // Then issues global_load_lds_dwordx{1,4} which takes a 64-bit VGPR address. - // - // M0 register management is identical to async_load_raw. - template - CK_TILE_DEVICE auto async_load_raw_flat(LdsTileWindow_&& lds_tile, - const PhysicalPagesArray& physical_pages, - long_index_t page_stride_bytes, - number = {}, - bool_constant = {}) const - { - using LdsTileWindow = remove_cvref_t; - using LdsDataType = typename LdsTileWindow::DataType; - using DataType = typename BottomTensorView::DataType; - - static_assert(LdsTileWindow::get_num_of_dimension() == 3); - - const index_t size_per_buf = - lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( - make_tuple(number<0>{}, number<0>{}, number<0>{})) * - sizeof(LdsDataType); - - const index_t size_per_wave = - lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( - make_tuple(number<0>{}, number<1>{}, number<0>{})) * - sizeof(LdsDataType) - - size_per_buf; - - const index_t size_per_issue = - lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( - make_tuple(number<1>{}, number<0>{}, number<0>{})) * - sizeof(LdsDataType) - - size_per_buf; - - const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - m0_set_with_memory(amd_wave_read_first_lane(m0_init_value)); - - using Traits = load_store_traits; - using vector_t = typename Traits::vector_t; - using SFC_Ys = typename Traits::SFC_Ys; - - LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; - - // Base pointer for 64-bit address computation - const auto* base_ptr = reinterpret_cast(get_bottom_tensor_view().buf_.p_data_); - - // Number of dwords per vector element - constexpr index_t vector_size = sizeof(vector_t) / sizeof(uint32_t); // dwords per vector - - static_for<0, NumCoord, 1>{}([&](auto iCoord) { - auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; - auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; - - static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { - constexpr auto iAccess = number{}; - constexpr auto pre_nop_ = [&]() { - if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) - return bool_constant{}; - else - return bool_constant{}; - }(); - - constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); - constexpr auto idx_gather = get_gather_index(idx_ys_start); - - // within-page offset from page_idx_ (set by kv_offset_array_transform) - const auto within_page_offset = page_idx_[idx_gather]; - // physical page index - const auto physical_page = physical_pages[idx_gather]; - - // Compute 64-bit flat address: - // base + phys_page * page_stride_bytes - // + (coord_offset + within_page_offset) * sizeof(DataType) - const auto coord_offset = bottom_tensor_thread_coord.get_offset(); - const auto* flat_addr = - base_ptr + static_cast(physical_page) * page_stride_bytes + - static_cast(coord_offset + within_page_offset) * sizeof(DataType); - - async_global_load_lds_dwordxn(smem, flat_addr, pre_nop_); - - // move thread coordinate (same as async_load_raw) - if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) - { - constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); - constexpr auto forward_step_scatter = generate_tuple( - [&](auto i) { return is_gather_dim(i) ? 0 : idx_diff_ys[i]; }, - number{}); - constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), - forward_step_scatter); - - move_window_adaptor_and_bottom_tensor_thread_coordinate( - window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); - - m0_inc_with_memory(size_per_issue); - } - }); - }); - } - // TODO: fix with swizzle template == false) @@ -1339,7 +1185,24 @@ struct tile_scatter_gather // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] TileDstr tile_dstr_; + // Scatter/gather offsets for each element, set by update_page_idx(). + // SRD mode (kUseFlatLoad=false): buffer_load(SRD, page_idx_[i] + coord). + // page_idx_[i] = within-page offset when kPageBlockSize >= kN0 (SRD rebased to page base) + // page_idx_[i] = page_base + within-page offset when kPageBlockSize < kN0 (full voffset) + // Global load mode (kUseFlatLoad=true): page_idx_[i] = within-page offset only. + // Full address = base + physical_pages_[i] * page_stride_elements_ + page_idx_[i] + coord PageIdxArray page_idx_; + + // Physical page indices for global load mode (kUseFlatLoad=true only). + // Maps each gather element to its physical page in a paged memory pool. + // Updated via update_physical_pages() before each load call. + // Unused in SRD mode — SRD rebase handles page addressing externally. + PageIdxArray physical_pages_; + + // Page stride in elements for global load mode. + // physical_pages_[i] * page_stride_elements_ gives the page base offset in elements. + index_t page_stride_elements_ = 0; + ValidArray valids_; // this contains: @@ -1378,7 +1241,8 @@ template + index_t... YsGatherDims, + bool UseFlatLoad = false> CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(const TensorView_& tensor_view, const WindowLengths_& window_lengths, @@ -1387,7 +1251,8 @@ make_tile_scatter_gather(const TensorView_& tensor_view, const StaticPageIndexArray_& page_idx, number, number, - sequence) + sequence, + bool_constant = {}) { return tile_scatter_gather, remove_cvref_t, @@ -1396,11 +1261,12 @@ make_tile_scatter_gather(const TensorView_& tensor_view, std::nullptr_t, HsGatherDim, NumCoord, - sequence>{ + sequence, + UseFlatLoad>{ tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; } -// Legacy overload (compatible with original API) +// Legacy overload (compatible with original API, kUseFlatLoad=false) template +CK_TILE_DEVICE constexpr auto +make_tile_scatter_gather(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + const StaticPageIndexArray_& page_idx, + bool_constant) +{ + return tile_scatter_gather, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + std::nullptr_t, + 0, + 1, + sequence<0>, + UseFlatLoad>{ + tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; +} + template {}([&](auto k0) { - const index_t global_token_idx = - global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; - - if constexpr(kNeedFullOffset) + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t global_token_idx = + global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; + const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + + // Within-page offset (layout-dependent for V cache with VECTORIZED_LAYOUT) + const index_t within_page = [&]() { + if constexpr(!kIsKcache && kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) { - const index_t physical_page = physical_pages[k0]; - kv_offset_vec[k0] = - physical_page * stride_page_block + token_idx_in_page * stride_token; + return (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + + (token_idx_in_page % kVectorSize); } else { - kv_offset_vec[k0] = token_idx_in_page * stride_token; + return token_idx_in_page * stride_token; } - }); - } - else // V cache - { - static_for<0, kLoopCount, 1>{}([&](auto k0) { - const index_t global_token_idx = - global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; - - if constexpr(kNeedFullOffset) - { - const index_t physical_page = physical_pages[k0]; - const long_index_t page_base = - static_cast(physical_page) * stride_page_block; + }(); - if constexpr(kKVMemoryLayout == - BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) - { - const index_t token_offset = - (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + - (token_idx_in_page % kVectorSize); - kv_offset_vec[k0] = page_base + token_offset; - } - else - { - kv_offset_vec[k0] = page_base + token_idx_in_page * stride_token; - } - } - else - { - // Within-page offset only: page base handled by SRD rebase or flat load - if constexpr(kKVMemoryLayout == - BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) - { - const index_t token_offset = - (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + - (token_idx_in_page % kVectorSize); - kv_offset_vec[k0] = token_offset; - } - else - { - kv_offset_vec[k0] = token_idx_in_page * stride_token; - } - } - }); - } + // SRD + page_size < kN0: add page base to form complete voffset for buffer_load + if constexpr(kNeedFullOffset) + { + kv_offset_vec[k0] = + static_cast(physical_pages[k0]) * stride_page_block + within_page; + } + else + { + kv_offset_vec[k0] = within_page; + } + }); } // a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) @@ -631,19 +597,17 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kUseFlatLoad>( k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); - // K load helper: use flat 64-bit loads for page_size < kN0, - // SRD-based buffer loads for page_size >= kN0. - const long_index_t k_page_stride_bytes = - static_cast(page_stride_k) * - sizeof(typename std::remove_const::type>:: - type); - auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window.get_window_lengths(), k_dram_block_window.get_window_origin(), k_dist, - k_offsets); // K DRAM tile window for + k_offsets, + bool_constant{}); + if constexpr(kUseFlatLoad) + { + k_dram_window.set_page_stride_elements(page_stride_k); + k_dram_window.update_physical_pages(k_physical_pages); + } k_dram_window.init_raw(); // SRD rebasing for K: only for page_size >= kN0 (all threads on same page). @@ -814,11 +778,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // V physical pages array for use with kv_offset_array_transform // For V_KIterOuter > 1, we need V_PageIdxRepeat elements; otherwise V_KIterInner statically_indexed_array v_physical_pages{}; - // Double-buffer for flat loads: save current sub-tile's physical pages before - // prefetch overwrites them. load_tile_flat needs physical_pages aligned with - // page_idx_, but the pipeline prefetches the NEXT sub-tile's pages before the - // CURRENT sub-tile's flat load executes. - statically_indexed_array v_physical_pages_current{}; // Prefetch V physical pages - can be called early to hide buffer load latency auto prefetch_v_physical_pages = [&](auto k_loop_start) { @@ -919,10 +878,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); } - // For page_size < kN0: V uses flat 64-bit loads (load_tile_flat), - // so no per-tile SRD rebase or offset adjustment needed. - // v_offsets contain within-page offsets; page base is handled by - // physical_pages in load_tile_flat. + // For page_size < kN0 with kUseFlatLoad: v_offsets contain within-page offsets; + // page base is handled by physical_pages_ in tile_scatter_gather::load(). }; // Prefetch V physical pages early to hide buffer load latency @@ -936,41 +893,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_offsets, number<1>{}, // HsGatherDim number<1>{}, // NumCoord - VPageIndexYDims); - - // V page stride in bytes for flat 64-bit addressing - const long_index_t v_page_stride_bytes = - static_cast(page_stride_v) * - sizeof(typename std::remove_const::type>::type); + VPageIndexYDims, + bool_constant{}); + if constexpr(kUseFlatLoad) + { + v_dram_window.set_page_stride_elements(page_stride_v); + v_dram_window.update_physical_pages(v_physical_pages); + } // For page_size >= kN0, use SRD rebase (all threads on same page) if constexpr(kPageBlockSize >= kN0) { rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); } - // For page_size < kN0, V uses flat 64-bit loads (load_tile_flat) - // instead of SRD rebase, so no init_raw/rebase needed. - // prefetch K tile: use flat 64-bit loads for page_size < kN0 - if constexpr(kUseFlatLoad) - { - async_load_tile_raw_flat(k_lds_store(LdsSeq.at(number<0>{})), - k_dram_window, - k_physical_pages, - k_page_stride_bytes, - number<-1>{}, - k_pre_np); - } - else - { - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), - k_dram_window, - number<-1>{}, - k_oob_ck, - k_pre_np); - } + // prefetch K tile + async_load_tile_raw( + k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); move_tile_window(k_dram_window, {0, kK0}); __builtin_amdgcn_sched_barrier(0); @@ -1019,10 +958,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_descale = v_descale_ptr[scale_offset]; } - // Save current physical pages before prefetch overwrites them - // (load_tile_flat needs pages aligned with current page_idx_) + // Save current V physical pages before prefetch overwrites them if constexpr(kUseFlatLoad) - v_physical_pages_current = v_physical_pages; + v_dram_window.update_physical_pages(v_physical_pages); // Prefetch V physical pages early - overlaps with GEMM0 computation prefetch_v_physical_pages(number{}); @@ -1031,24 +969,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if constexpr(k0_loops > 1) { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - if constexpr(kUseFlatLoad) - { - async_load_tile_raw_flat( - k_lds_store(number{})>{}), - k_dram_window, - k_physical_pages, - k_page_stride_bytes, - number<-1>{}, - k_pre_np); - } - else - { - async_load_tile_raw(k_lds_store(number{})>{}), - k_dram_window, - number<-1>{}, - k_oob_ck, - k_pre_np); - } + async_load_tile_raw(k_lds_store(number{})>{}), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); if constexpr(i_k0 < k0_loops - 1) move_tile_window(k_dram_window, {0, kK0}); @@ -1084,13 +1009,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } __builtin_amdgcn_sched_barrier(1); - auto v_buf = [&]() { - if constexpr(kUseFlatLoad) - return load_tile_flat( - v_dram_window, v_physical_pages_current, v_page_stride_bytes); - else - return load_tile(v_dram_window, number<-1>{}, bool_constant{}); - }(); + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); // V physical pages already prefetched before GEMM0 update_v_offsets(number{}); v_dram_window.update_page_idx(v_offsets); @@ -1238,7 +1157,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if constexpr(k1_loops > 1) { if constexpr(kUseFlatLoad) - v_physical_pages_current = v_physical_pages; + v_dram_window.update_physical_pages(v_physical_pages); prefetch_v_physical_pages(number<2 * kK1>{}); } @@ -1293,11 +1212,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_dram_window, {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... - if constexpr(kUseFlatLoad) - v_buf = load_tile_flat( - v_dram_window, v_physical_pages_current, v_page_stride_bytes); - else - v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); update_v_offsets(number<2 * kK1>{}); v_dram_window.update_page_idx(v_offsets); rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); @@ -1466,11 +1381,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) { - if constexpr(kUseFlatLoad) - v_buf = load_tile_flat( - v_dram_window, v_physical_pages_current, v_page_stride_bytes); - else - v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); // Update V offsets using previously prefetched physical pages update_v_offsets(number<(2 + i_k1.value) * kK1>{}); v_dram_window.update_page_idx(v_offsets); @@ -1481,7 +1392,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if constexpr(i_k1 + 1 < k1_loops - 1) { if constexpr(kUseFlatLoad) - v_physical_pages_current = v_physical_pages; + v_dram_window.update_physical_pages(v_physical_pages); prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{}); } @@ -1566,6 +1477,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kUseFlatLoad>( k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); + if constexpr(kUseFlatLoad) + k_dram_window.update_physical_pages(k_physical_pages); rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]); // After sink→window transition (i_total_loops == num_sink_loop), V window @@ -1585,23 +1498,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) __builtin_amdgcn_s_barrier(); - if constexpr(kUseFlatLoad) - { - async_load_tile_raw_flat(k_lds_store(LdsSeq.at(number<0>{})), - k_dram_window, - k_physical_pages, - k_page_stride_bytes, - number<-1>{}, - k_pre_np); - } - else - { - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), - k_dram_window, - number<-1>{}, - k_oob_ck, - k_pre_np); - } + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); move_tile_window(k_dram_window, {0, kK0}); } // tail From b0a6bd650b275f49dbee86755065bb8b4c29911f Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Fri, 17 Apr 2026 20:33:31 +0800 Subject: [PATCH 06/17] feat(fmha): add CDNA3+ arch guards for global_load_lds in batch prefill Add three-layer architecture protection for the kUseFlatLoad path which requires the global_load_lds instruction (CDNA3+: gfx940/gfx950 only): 1. Codegen #if guard (fmha_batch_prefill.py): Wrap kUse64BitLoad=true kernel instantiation with #if defined(__gfx94__) || defined(__gfx950__). Uses ArchTrait pattern consistent with fmha_fwd.py. 2. static_assert in tile_scatter_gather.hpp: Prevents kUseFlatLoad_=true instantiation on non-CDNA3 architectures at compile time. 3. static_assert in async_global_load_lds_dwordxn: Prevents the global_load_lds intrinsic from being instantiated on unsupported architectures. Verified: cross-compilation with --offload-arch=gfx90a (CDNA2) and --offload-arch=gfx1100 (RDNA3) succeeds with kernel body excluded. --- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 13 +++++++++++++ .../core/arch/amd_buffer_addressing_builtins.hpp | 9 +++++++++ .../ck_tile/core/tensor/tile_scatter_gather.hpp | 11 ++++++++++- 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 74837fbaa24..f6c365e57ea 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -22,8 +22,16 @@ QSCALE_CHECK_MAP, QSCALE_MAP, ) +from codegen.arch import ArchTrait from codegen.utils import update_file +# Architecture trait for kernels requiring global_load_lds (CDNA3+). +# Only used for kUse64BitLoad=true variants; all other kernels are arch-agnostic. +CDNA3_PLUS_ARCH = ArchTrait( + "cdna3_plus", + preprocessor_check="defined(__gfx94__) || defined(__gfx950__)", +) + DTYPE_BITS = { "fp32": 32, "fp16": 16, @@ -61,6 +69,8 @@ """ FMHA_FWD_KERNEL_BODY = """ +#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check}) + using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; @@ -141,6 +151,8 @@ constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} + +#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check}) """ FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp" @@ -594,6 +606,7 @@ def template(self) -> str: F_page_size=self.F_page_size, F_sink=BOOL_MAP[self.F_pipeline.F_sink], F_use_64bit_load=BOOL_MAP["t" if self.F_use_64bit_load else "f"], + F_arch_check=CDNA3_PLUS_ARCH.preprocessor_check if self.F_use_64bit_load else "true", ) @property diff --git a/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index f69d3c7706c..09793cc49e7 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1329,6 +1329,15 @@ template CK_TILE_DEVICE void async_global_load_lds_dwordxn(void* smem, const void* global_addr, bool_constant = {}) { +#if !defined(__gfx94__) && !defined(__gfx950__) + // global_load_lds is only available on CDNA3+ (gfx940/gfx950). + // Use !num_dwords so the assert depends on a template parameter + // and is only checked at instantiation time. + static_assert(!num_dwords, + "global_load_lds requires CDNA3+ (gfx940/gfx950). " + "Ensure kUseFlatLoad is false on this architecture."); +#endif + // Use inline asm with VGPR pair for 64-bit flat address #define CK_TILE_GLOBAL_LOAD_LDS_INSTR(instr) \ if constexpr(pre_nop) \ diff --git a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp index b6b27c957bb..c78cb0d99db 100644 --- a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -50,7 +50,16 @@ template ; + +#if !defined(__gfx94__) && !defined(__gfx950__) + // global_load_lds instruction is only available on CDNA3+ (gfx940/gfx950). + // On other architectures, kUseFlatLoad must be false. + static_assert(!kUseFlatLoad_, + "kUseFlatLoad requires global_load_lds (CDNA3+: gfx940/gfx950). " + "This kernel should not be instantiated on this architecture."); +#endif + + using BottomTensorView = remove_reference_t; using WindowLengths = remove_cvref_t; using TileDstr = remove_cvref_t; using PageIdxArray = remove_cvref_t; From 83d6de77d0df0a63520ed483a885f4297dd1cedf Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Sat, 18 Apr 2026 13:50:11 +0800 Subject: [PATCH 07/17] fix(fmha): limit SRD num_records to page_stride after rebase (gfx950 fix) After SRD rebase to a physical page, num_records was left at the full buffer size. This caused the SRD to claim validity for a range [page_base, page_base + full_buffer_size) that extends far beyond the allocated buffer when rebased to high pages. On gfx942 (CDNA3), the hardware only checks voffset < num_records per buffer_load instruction, so the extended range is harmless. On gfx950 (CDNA4), the hardware appears to validate the full SRD range against page table permissions. When the extended range covers freed or protected memory, this causes VM_L2_PROTECTION_FAULT (PERMISSION_FAULTS with MAPPING_ERROR=0). Fix: set buffer_size to page_stride (one page worth of elements) before init_raw() after each SRD rebase. This scopes the SRD to exactly the page being accessed. Verified: 80 passed on both gfx942 (MI308X) and gfx950 (MI355X). --- .../block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 190464fa058..6537ef3bbbd 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -620,6 +620,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const auto* page_ptr = base_ptr + static_cast(physical_page) * page_stride_k; window.set_bottom_tensor_view_data_ptr(page_ptr); + // Limit SRD num_records to one page worth of elements. + // Without this, the SRD claims validity for [page_ptr, page_ptr + full_buffer_size), + // which extends far beyond the allocated buffer when rebased to high pages. + // On gfx950, the hardware may validate the full SRD range against page table + // permissions, causing faults on freed/protected memory beyond the buffer. + window.set_bottom_tensor_view_buffer_size(page_stride_k); window.init_raw(); } }; @@ -635,6 +641,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const auto* page_ptr = base_ptr + static_cast(physical_page) * page_stride_v; window.set_bottom_tensor_view_data_ptr(page_ptr); + window.set_bottom_tensor_view_buffer_size(page_stride_v); window.init_raw(); } }; From bcead4be5cc4c613e6444e0db80871f4c11fa546 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Sun, 19 Apr 2026 08:50:22 +0800 Subject: [PATCH 08/17] cleanup: remove unused wave_reduce_min from utility.hpp This function was added for a per-tile SRD rebase approach that was later replaced by template dispatch. No callers remain. --- .../include/ck_tile/core/arch/utility.hpp | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/arch/utility.hpp b/projects/composablekernel/include/ck_tile/core/arch/utility.hpp index a1fb7cca0e0..647f5b4435c 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/utility.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/utility.hpp @@ -59,20 +59,6 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta) #endif } -// Butterfly min-reduction across all lanes in a wave. -// Returns the minimum value, broadcast to all lanes as a uniform SGPR value. -// Used for per-tile SRD rebase: find the min physical page across all threads -// so the SRD can be rebased to that page with 64-bit pointer arithmetic. -CK_TILE_DEVICE index_t wave_reduce_min(index_t val) -{ - for(index_t offset = 1; offset < get_warp_size(); offset <<= 1) - { - const index_t other = warp_shuffle_down(val, offset); - val = min(val, other); - } - return __builtin_amdgcn_readfirstlane(val); -} - template CK_TILE_DEVICE auto warp_shuffle_down_pair(const T& v_local) { From 62b15aaabc7394b51eeb2160361ad907b96bd59b Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Mon, 20 Apr 2026 17:12:54 +0800 Subject: [PATCH 09/17] fix(buffer): use dependent assertion for unsupported num_dwords in async_global_load_lds_dwordxn The previous static_assert(false) fires unconditionally during template parsing on newer compilers (CWG 2518), even for never-instantiated branches. Wrap it in a dependent expression so the assertion only fires when an unsupported num_dwords is actually instantiated. Found during batch prefill template dispatch review. --- .../arch/amd_buffer_addressing_builtins.hpp | 75 ++++++++++++++----- 1 file changed, 55 insertions(+), 20 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 09793cc49e7..9831d778a75 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1319,10 +1319,43 @@ CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); } -// Flat async load from global memory to LDS using 64-bit addressing. -// Uses global_load_lds_dwordx{1,4} which bypasses the SRD's 32-bit offset limit. -// M0 must already contain the LDS destination offset (set by caller). -// The data is loaded from global_addr to LDS at [M0]. +// Flat async load from global memory to LDS using 64-bit global addressing. +// Bypasses the SRD's 32-bit offset limit; required when the KV cache exceeds +// INT32_MAX (2GB) byte offset on the SRD voffset path. +// +// !!! M0 PRECONDITION — IMPLICIT INPUT NOT VISIBLE IN OPERAND LIST !!! +// +// The LDS destination address is taken from M0 (per AMD CDNA3 ISA §10.3: +// `LDS_ADDR = LDSbase + LDSoffset(M0[17:2] * 4) + INST.OFFSET + ThreadID*4`). +// M0 does NOT appear as an operand of these instructions or of the inline +// asm below — the compiler cannot see the dependency. Caller must: +// +// 1. Initialize M0 once before the load loop: +// `m0_set_with_memory(amd_wave_read_first_lane(lds_byte_offset));` +// M0 is SALU-only — `m0_set_with_memory` uses an "s" constraint to +// enforce this. Direct VALU writes to M0 are illegal. +// +// 2. Advance M0 between successive issues: +// `m0_inc_with_memory(size_per_issue);` +// `size_per_issue` MUST be a multiple of 4 — GLOBAL/FLAT LDS path +// only honors M0[17:2]*4 (dword-aligned), so low 2 bits are silently +// dropped (NOTE: this differs from MUBUF buffer_load_lds which uses +// M0[15:0] as a raw byte offset). +// +// 3. Never bundle `m0_inc_with_memory` and the next call to this +// function into a single inline asm. The compiler auto-inserts a +// hazard NOP between an SALU write to M0 and the consuming +// `global_load_lds_*`; bundling bypasses that and may read stale M0. +// +// The "memory" clobber on this asm is load-bearing: it prevents the +// compiler from reordering this load across other M0-touching helpers +// (`m0_set_with_memory` / `m0_inc_with_memory`, also "memory"-clobbered). +// +// Verified instruction emission (HIP 6.4 / clang 19, gfx942 + gfx950): +// `global_load_lds_dwordx4` is a single instruction (encoding 0xDDF48000 +// 0x007F0000), NOT software-expanded into 4× dword. Same encoding on both +// arches. The opcode is undocumented in CDNA3 ISA spec §13.6.2 but +// supported by the LLVM AMDGPU backend. // // Available on gfx940+ (CDNA3: MI300, MI355, MI350 series). template @@ -1338,18 +1371,24 @@ async_global_load_lds_dwordxn(void* smem, const void* global_addr, bool_constant "Ensure kUseFlatLoad is false on this architecture."); #endif -// Use inline asm with VGPR pair for 64-bit flat address -#define CK_TILE_GLOBAL_LOAD_LDS_INSTR(instr) \ - if constexpr(pre_nop) \ - asm volatile("s_nop 4\n" instr " %1, off offset:0" \ - : "=r"(smem) /*dummy dependency for smem*/ \ - : "v"(global_addr) \ - : "memory"); \ - else \ - asm volatile(instr " %1, off offset:0" \ - : "=r"(smem) /*dummy dependency for smem*/ \ - : "v"(global_addr) \ - : "memory"); + static_assert(num_dwords == 1 || num_dwords == 4, + "global_load_lds only supports num_dwords == 1 or 4"); + +// Inline asm: only the global address is an explicit operand. The LDS +// destination is implicit via M0 (see contract above). `"=r"(smem)` is a +// SSA scheduling anchor only — `smem` is NOT written by this asm; the +// load goes to LDS at `M0[17:2]*4 + offset:0 + ThreadID*4`. +#define CK_TILE_GLOBAL_LOAD_LDS_INSTR(instr) \ + if constexpr(pre_nop) \ + asm volatile("s_nop 4\n" instr " %1, off offset:0" \ + : "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \ + : "v"(global_addr) \ + : "memory" /*prevents reorder across m0_{set,inc}*/); \ + else \ + asm volatile(instr " %1, off offset:0" \ + : "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \ + : "v"(global_addr) \ + : "memory" /*prevents reorder across m0_{set,inc}*/); if constexpr(num_dwords == 1) { @@ -1359,10 +1398,6 @@ async_global_load_lds_dwordxn(void* smem, const void* global_addr, bool_constant { CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dwordx4"); } - else - { - static_assert(false, "wrong! only dword and dwordx4 supported for global_load_lds"); - } #undef CK_TILE_GLOBAL_LOAD_LDS_INSTR } From e1f80b1d2dacb7a8fa9e7581afd2c47844061db8 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Mon, 20 Apr 2026 17:13:21 +0800 Subject: [PATCH 10/17] refactor(fmha): tighten batch prefill SRD types and document 32-bit voffset MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit tile_scatter_gather: divide buffer_size override by PackedSize to match buffer_view ctor convention (raw element count in, packed count stored). Without this, packed types (FP4 / int4, PackedSize=2) would over-report num_records by 2x and silently mask OOB reads. batch_prefill does not exercise the packed-type path today, but this is generic infrastructure and must honor the same invariant. Also narrow the signature from long_index_t to index_t since SRD num_records is hardware 32-bit. block_fmha_batch_prefill_pipeline_qr_ks_vs_async: remove misleading static_cast on the SRD voffset path. The 32-bit limit on this branch comes from CDNA3 MUBUF voffset hardware format, not from an implementation choice — widening would not lift the 2GB ceiling because the hardware truncates regardless. The kUseFlatLoad_ template path handles the >2GB case via 64-bit global_load_lds_*. Added a comment making this explicit so the next reader doesn't propose the same fix. Found during batch prefill template dispatch review. --- .../core/tensor/tile_scatter_gather.hpp | 32 +++++++++++++------ ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 21 +++++++----- 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp index c78cb0d99db..aac624b641e 100644 --- a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -60,12 +60,12 @@ struct tile_scatter_gather #endif using BottomTensorView = remove_reference_t; - using WindowLengths = remove_cvref_t; - using TileDstr = remove_cvref_t; - using PageIdxArray = remove_cvref_t; - using ValidArray = remove_cvref_t; - using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; - using BottomTensorDesc = typename BottomTensorView::TensorDesc; + using WindowLengths = remove_cvref_t; + using TileDstr = remove_cvref_t; + using PageIdxArray = remove_cvref_t; + using ValidArray = remove_cvref_t; + using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; + using BottomTensorDesc = typename BottomTensorView::TensorDesc; using DataType = remove_cvref_t; @@ -368,11 +368,23 @@ struct tile_scatter_gather bottom_tensor_view_.buf_.p_data_ = data; } - // Override buffer size (in elements) for SRD num_records control. - // Use to set max range when SRD is rebased per-tile (page_size < kN0 path). - CK_TILE_DEVICE constexpr void set_bottom_tensor_view_buffer_size(long_index_t size) + // Override buffer size (input in RAW elements, NOT pre-divided by PackedSize) for + // SRD num_records control. Use to set max range when SRD is rebased per-tile + // (page_size >= kN0 path): each rebased SRD only needs to cover one page; without + // this the SRD claims validity for memory beyond the allocated buffer, which can + // fault on gfx950 page-table validation. + // + // Matches buffer_view ctor convention (buffer_view.hpp:245): input is raw element + // count and is divided by PackedSize before being stored. For PackedSize=1 + // (fp16/bf16/fp8) the division is a no-op; for PackedSize=2 (FP4 / packed int4) + // skipping it would over-report num_records by 2x and silently mask OOB on SRD + // reads. batch_prefill currently does not exercise the packed-type path, but this + // setter is generic infrastructure (lives in tile_scatter_gather.hpp) so it must + // honor the same invariant the ctor enforces. + CK_TILE_DEVICE constexpr void set_bottom_tensor_view_buffer_size(index_t size) { - bottom_tensor_view_.buf_.buffer_size_ = size; + using BufType = remove_cvref_t; + bottom_tensor_view_.buf_.buffer_size_ = size / BufType::PackedSize; } // move thread's window adaptor coordinate and bottom tensor coordinate diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 6537ef3bbbd..ce989f5e8d6 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -162,7 +162,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica // kPageBlockSize < kN0 && kUseFlatLoad_: within-page offset (flat load uses // physical_pages[]) kPageBlockSize < kN0 && !kUseFlatLoad_: FULL offset (page * stride + // within_page) for - // direct buffer_load with 32-bit voffset — the original code path, fast but limited to <4GB + // direct buffer_load with 32-bit voffset — the original code path, fast but limited to <2GB constexpr bool kNeedFullOffset = (kPageBlockSize < kN0) && !kUseFlatLoad_; static_for<0, kLoopCount, 1>{}([&](auto k0) { @@ -184,11 +184,16 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica } }(); - // SRD + page_size < kN0: add page base to form complete voffset for buffer_load + // SRD + page_size < kN0: add page base to form complete voffset for buffer_load. + // + // 32-bit by hardware: SRD buffer_load voffset is fundamentally 32-bit (CDNA3 MUBUF + // microcode format), so this branch is only reachable when total KV bytes fit in + // INT32_MAX. The kUseFlatLoad_ template path handles the >2GB case via 64-bit + // global_load_lds_*; widening kv_offset_vec here would not lift the 2GB ceiling + // because the hardware truncates voffset regardless. if constexpr(kNeedFullOffset) { - kv_offset_vec[k0] = - static_cast(physical_pages[k0]) * stride_page_block + within_page; + kv_offset_vec[k0] = physical_pages[k0] * stride_page_block + within_page; } else { @@ -621,10 +626,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync base_ptr + static_cast(physical_page) * page_stride_k; window.set_bottom_tensor_view_data_ptr(page_ptr); // Limit SRD num_records to one page worth of elements. - // Without this, the SRD claims validity for [page_ptr, page_ptr + full_buffer_size), - // which extends far beyond the allocated buffer when rebased to high pages. - // On gfx950, the hardware may validate the full SRD range against page table - // permissions, causing faults on freed/protected memory beyond the buffer. + // Without this, the SRD claims validity for [page_ptr, page_ptr + + // full_buffer_size), which extends far beyond the allocated buffer when rebased to + // high pages. On gfx950, the hardware may validate the full SRD range against page + // table permissions, causing faults on freed/protected memory beyond the buffer. window.set_bottom_tensor_view_buffer_size(page_stride_k); window.init_raw(); } From b475690131512ff41dbd32188e615a930db175d5 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Mon, 20 Apr 2026 17:13:45 +0800 Subject: [PATCH 11/17] docs(fmha): correct >2GB threshold wording across batch prefill MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The KV cache overflow threshold is 2GB (INT32_MAX byte offset for SRD voffset), matching CK's existing TwoGB convention in transform_conv_fwd_to_gemm.hpp. Previous comments said "4GB" which is incorrect — SRD voffset is signed-32-bit-effectively, not unsigned. Updated: - codegen comment + use_64bit_load field doc - BlockFmhaBatchPrefillPipelineProblem::kUse64BitLoad doc, with explicit note about INT32_MAX / TwoGB convention Found during batch prefill template dispatch review. --- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 12 +++++++----- .../fmha/pipeline/block_fmha_pipeline_problem.hpp | 5 +++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index f6c365e57ea..cb2ef8fdf40 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -266,7 +266,7 @@ class FmhaFwdApiTrait: kv_memory_layout: str kv_lookup_table: str page_size: int = 1 # page block size - use_64bit_load: bool = False # use flat 64-bit loads for >4GB KV cache + use_64bit_load: bool = False # use flat 64-bit loads for >2GB KV cache @property def name(self) -> str: @@ -555,7 +555,7 @@ class FmhaFwdKernel: F_pipeline: FmhaFwdPipeline mask_impl: str F_page_size: int = 1 # page block size - F_use_64bit_load: bool = False # use flat 64-bit loads for >4GB KV cache + F_use_64bit_load: bool = False # use flat 64-bit loads for >2GB KV cache @property def template(self) -> str: @@ -606,7 +606,9 @@ def template(self) -> str: F_page_size=self.F_page_size, F_sink=BOOL_MAP[self.F_pipeline.F_sink], F_use_64bit_load=BOOL_MAP["t" if self.F_use_64bit_load else "f"], - F_arch_check=CDNA3_PLUS_ARCH.preprocessor_check if self.F_use_64bit_load else "true", + F_arch_check=CDNA3_PLUS_ARCH.preprocessor_check + if self.F_use_64bit_load + else "true", ) @property @@ -859,8 +861,8 @@ def get_fwd_blobs( gen.append(k) # For page_size < kN0 (tile.F_bn0), also generate a kUse64BitLoad=true - # variant for >4GB KV cache support. The default (false) uses SRD buffer_load - # (fast, <4GB). The 64-bit variant uses flat loads (slower, handles >4GB). + # variant for >2GB KV cache support. The default (false) uses SRD buffer_load + # (fast, <2GB). The 64-bit variant uses flat loads (slower, handles >2GB). if page_size < tile.F_bn0: k_64bit = FmhaFwdKernel( F_idx=0, diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 61c174c1bf6..5f072b6efd1 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -117,8 +117,9 @@ struct BlockFmhaBatchPrefillPipelineProblem static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0, "kPageBlockSize must be power of two"); - // When true, use flat 64-bit loads for page_size < kN0 (handles >4GB KV cache). - // When false (default), use SRD buffer_load for all page sizes (faster, <4GB only). + // When true, use flat 64-bit loads for page_size < kN0 (handles >2GB KV cache). + // When false (default), use SRD buffer_load for all page sizes (faster, <2GB only). + // The 2GB bound = INT32_MAX byte offset, matching CK's existing TwoGB convention. static constexpr bool kUse64BitLoad = Traits_::kUse64BitLoad; static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4 From 0fcb74e251d4fd3b8f8d249b82ef7b1d72a7bf93 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Mon, 20 Apr 2026 19:40:39 +0800 Subject: [PATCH 12/17] refactor(fmha): move use_64bit_load decision into auto-gen API dispatcher MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The runtime `bool use_64bit_load` field on `fmha_batch_prefill_traits` forced wrappers to encode each kernel arm's compile-time `bn0` and per-dtype element size to decide whether KV cache exceeds 2GB. That leaked codegen detail and required updating wrappers when new tile configs were added. Move the decision into the auto-generated `fmha_batch_prefill_api.cpp` dispatcher, where each arm already knows its own `{F_bn0}` and dtype. Each per-dtype scope now emits `constexpr int kElementBytes` from a new `DTYPE_BYTES` map, and the inner dispatch predicate evaluates `(a.page_block_size < {F_bn0} && num_total_pages * batch_stride_k * kElementBytes > INT32_MAX) == {F_use_64bit_load}` per arm. The C++ template parameter `kUse64BitLoad_` (and both kernel ELFs) stays — only the runtime trait field is removed. --- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 13 +++++++++++-- .../example/ck_tile/01_fmha/fmha_fwd.hpp | 3 +-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index cb2ef8fdf40..73d6b368d2e 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -42,6 +42,10 @@ "bf8": 8, } +# Element size in bytes per dtype, used by the auto-generated dispatcher to +# decide use_64bit_load per-arm (total KV cache bytes vs INT32_MAX). +DTYPE_BYTES = {k: v // 8 for k, v in DTYPE_BITS.items()} + K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} SUPPORTED_PAGE_SIZE = [1, 16, 1024] @@ -157,6 +161,7 @@ FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp" FMHA_FWD_API = """ +#include #include namespace {{ @@ -207,6 +212,7 @@ """ FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ + constexpr int kElementBytes = {F_element_bytes}; {F_hdim_case} }} """ @@ -216,7 +222,7 @@ """ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && (t.use_64bit_load == {F_use_64bit_load})) {{ + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && ((a.page_block_size < {F_bn0} && static_cast(a.num_total_pages) * a.batch_stride_k * kElementBytes > INT32_MAX) == {F_use_64bit_load})) {{ using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_use_64bit_load}>; return fmha_batch_prefill_(s, a); }} @@ -504,7 +510,10 @@ def api(self) -> str: ) if_i = "if" if i == 0 else "else if" per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( - F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + F_if=if_i, + F_dtype=dtype, + F_element_bytes=DTYPE_BYTES[dtype], + F_hdim_case=per_hdim_case, ) if not per_dtypes: # empty string we add some ignore to suppress warning in api diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp index dc7821fecb5..9423ea3a80b 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1743,8 +1743,7 @@ struct fmha_batch_prefill_traits : public fmha_fwd_traits ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table = ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D; - int page_size = 1; - bool use_64bit_load = false; + int page_size = 1; }; float fmha_batch_prefill(fmha_batch_prefill_traits, From 66efbd5178f94bbce93baed4cb2f1a26630a7e45 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 21 Apr 2026 08:57:24 +0800 Subject: [PATCH 13/17] =?UTF-8?q?refactor(fmha):=20unify=20kUse64BitLoad/k?= =?UTF-8?q?UseFlatLoad=20=E2=86=92=20kUseGlobalLoad?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The >2GB KV cache code path used two historical names for the same concept across layers (kUse64BitLoad at the kernel/Problem level, kUseFlatLoad inside the batch_prefill pipeline and tile_scatter_gather). Both mean "use global_load_lds_* instead of SRD buffer_load_*". Unify on kUseGlobalLoad everywhere — kernel template params, pipeline traits, scatter-gather op, codegen Python (F_use_global_load), and generated kernel filename suffix (64bit_ → globalload_). Also collapse the two-level structure in the batch_prefill pipeline: the derived kUseFlatLoad = Problem::kUse64BitLoad && (kPageBlockSize < kN0) becomes kUseGlobalLoad = Problem::kUseGlobalLoad directly, with a static_assert backstop for the page_size < kN0 invariant that codegen already guarantees. Verified on both archs (no behavior change vs prior Option 3 baseline): - gfx942 (smc300x-clt): test_batch_prefill.py 512/640, 4gb_small_page 12/12, 4gb_repro 9/9 bf16 + 6/6 FP8 KV_BLOCKSCALE - gfx950 (smci355-gfx950): test_batch_prefill.py 384/768, 4gb_small_page 12/12, 4gb_repro 9/9 bf16 + 6/6 FP8 KV_BLOCKSCALE --- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 42 +++++------ .../example/ck_tile/01_fmha/fmha_fwd.hpp | 4 +- .../arch/amd_buffer_addressing_builtins.hpp | 2 +- .../core/tensor/tile_scatter_gather.hpp | 36 +++++----- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 71 ++++++++++--------- .../pipeline/block_fmha_pipeline_problem.hpp | 4 +- .../ops/fmha/pipeline/tile_fmha_traits.hpp | 4 +- 7 files changed, 86 insertions(+), 77 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 73d6b368d2e..3f08b54dd88 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -26,7 +26,7 @@ from codegen.utils import update_file # Architecture trait for kernels requiring global_load_lds (CDNA3+). -# Only used for kUse64BitLoad=true variants; all other kernels are arch-agnostic. +# Only used for kUseGlobalLoad=true variants; all other kernels are arch-agnostic. CDNA3_PLUS_ARCH = ArchTrait( "cdna3_plus", preprocessor_check="defined(__gfx94__) || defined(__gfx950__)", @@ -43,7 +43,7 @@ } # Element size in bytes per dtype, used by the auto-generated dispatcher to -# decide use_64bit_load per-arm (total KV cache bytes vs INT32_MAX). +# decide use_global_load per-arm (total KV cache bytes vs INT32_MAX). DTYPE_BYTES = {k: v // 8 for k, v in DTYPE_BITS.items()} K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} @@ -102,7 +102,7 @@ {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, - {F_use_64bit_load}>; + {F_use_global_load}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -140,7 +140,7 @@ ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_use_64bit_load}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_use_global_load}>; #include @@ -222,8 +222,8 @@ """ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && ((a.page_block_size < {F_bn0} && static_cast(a.num_total_pages) * a.batch_stride_k * kElementBytes > INT32_MAX) == {F_use_64bit_load})) {{ - using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_use_64bit_load}>; + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && ((a.page_block_size < {F_bn0} && static_cast(a.num_total_pages) * a.batch_stride_k * kElementBytes > INT32_MAX) == {F_use_global_load})) {{ + using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_use_global_load}>; return fmha_batch_prefill_(s, a); }} """ @@ -272,14 +272,14 @@ class FmhaFwdApiTrait: kv_memory_layout: str kv_lookup_table: str page_size: int = 1 # page block size - use_64bit_load: bool = False # use flat 64-bit loads for >2GB KV cache + use_global_load: bool = False # use global_load_lds_* for >2GB KV cache @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}" - + ("-64bit" if self.use_64bit_load else "") + + ("-globalload" if self.use_global_load else "") ) @property @@ -502,7 +502,9 @@ def api(self) -> str: ], F_page_size=trait.page_size, F_sink=BOOL_MAP[trait.sink], - F_use_64bit_load=BOOL_MAP["t" if trait.use_64bit_load else "f"], + F_use_global_load=BOOL_MAP[ + "t" if trait.use_global_load else "f" + ], ) if_j = "if" if j == 0 else "else if" per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( @@ -564,7 +566,7 @@ class FmhaFwdKernel: F_pipeline: FmhaFwdPipeline mask_impl: str F_page_size: int = 1 # page block size - F_use_64bit_load: bool = False # use flat 64-bit loads for >2GB KV cache + F_use_global_load: bool = False # use global_load_lds_* for >2GB KV cache @property def template(self) -> str: @@ -614,9 +616,9 @@ def template(self) -> str: F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag], F_page_size=self.F_page_size, F_sink=BOOL_MAP[self.F_pipeline.F_sink], - F_use_64bit_load=BOOL_MAP["t" if self.F_use_64bit_load else "f"], + F_use_global_load=BOOL_MAP["t" if self.F_use_global_load else "f"], F_arch_check=CDNA3_PLUS_ARCH.preprocessor_check - if self.F_use_64bit_load + if self.F_use_global_load else "true", ) @@ -625,7 +627,7 @@ def name(self) -> str: # TODO: we don't encode idx here return ( f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_" - + ("64bit_" if self.F_use_64bit_load else "") + + ("globalload_" if self.F_use_global_load else "") + self.F_tile.name + "_" + self.F_pipeline.name @@ -663,7 +665,7 @@ def api_trait(self) -> FmhaFwdApiTrait: kv_memory_layout=self.F_pipeline.F_kv_memory_layout, kv_lookup_table=self.F_pipeline.F_kv_lookup_table, page_size=self.F_page_size, - use_64bit_load=self.F_use_64bit_load, + use_global_load=self.F_use_global_load, ) @@ -869,11 +871,11 @@ def get_fwd_blobs( api_pool.register_traits(k.api_trait()) gen.append(k) - # For page_size < kN0 (tile.F_bn0), also generate a kUse64BitLoad=true + # For page_size < kN0 (tile.F_bn0), also generate a kUseGlobalLoad=true # variant for >2GB KV cache support. The default (false) uses SRD buffer_load - # (fast, <2GB). The 64-bit variant uses flat loads (slower, handles >2GB). + # (fast, <2GB). The global_load variant uses global_load_lds_* (slower, handles >2GB). if page_size < tile.F_bn0: - k_64bit = FmhaFwdKernel( + k_global_load = FmhaFwdKernel( F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -882,10 +884,10 @@ def get_fwd_blobs( F_pipeline=pipeline, mask_impl=mask_impl, F_page_size=page_size, - F_use_64bit_load=True, + F_use_global_load=True, ) - api_pool.register_traits(k_64bit.api_trait()) - gen.append(k_64bit) + api_pool.register_traits(k_global_load.api_trait()) + gen.append(k_global_load) return (api_pool, gen) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp index 9423ea3a80b..95a858c271b 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1458,7 +1458,7 @@ template + bool kUseGlobalLoad_ = false> struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_, - bool kUseFlatLoad_ = false> + bool kUseGlobalLoad_ = false> struct tile_scatter_gather { - static constexpr bool kUseFlatLoad = kUseFlatLoad_; + static constexpr bool kUseGlobalLoad = kUseGlobalLoad_; #if !defined(__gfx94__) && !defined(__gfx950__) // global_load_lds instruction is only available on CDNA3+ (gfx940/gfx950). - // On other architectures, kUseFlatLoad must be false. - static_assert(!kUseFlatLoad_, - "kUseFlatLoad requires global_load_lds (CDNA3+: gfx940/gfx950). " + // On other architectures, kUseGlobalLoad must be false. + static_assert(!kUseGlobalLoad_, + "kUseGlobalLoad requires global_load_lds (CDNA3+: gfx940/gfx950). " "This kernel should not be instantiated on this architecture."); #endif @@ -488,7 +488,7 @@ struct tile_scatter_gather // read from bottom tensor const vector_t vec_value = [&]() { - if constexpr(kUseFlatLoad_) + if constexpr(kUseGlobalLoad_) { // Global load mode: 64-bit typed pointer arithmetic const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_; @@ -724,7 +724,7 @@ struct tile_scatter_gather const auto page_offset = page_idx_[idx_gather]; // read from bottom tensor - if constexpr(kUseFlatLoad_) + if constexpr(kUseGlobalLoad_) { // Global load mode: global_load_lds with 64-bit address constexpr index_t vector_size = @@ -1207,14 +1207,14 @@ struct tile_scatter_gather TileDstr tile_dstr_; // Scatter/gather offsets for each element, set by update_page_idx(). - // SRD mode (kUseFlatLoad=false): buffer_load(SRD, page_idx_[i] + coord). + // SRD mode (kUseGlobalLoad=false): buffer_load(SRD, page_idx_[i] + coord). // page_idx_[i] = within-page offset when kPageBlockSize >= kN0 (SRD rebased to page base) // page_idx_[i] = page_base + within-page offset when kPageBlockSize < kN0 (full voffset) - // Global load mode (kUseFlatLoad=true): page_idx_[i] = within-page offset only. + // Global load mode (kUseGlobalLoad=true): page_idx_[i] = within-page offset only. // Full address = base + physical_pages_[i] * page_stride_elements_ + page_idx_[i] + coord PageIdxArray page_idx_; - // Physical page indices for global load mode (kUseFlatLoad=true only). + // Physical page indices for global load mode (kUseGlobalLoad=true only). // Maps each gather element to its physical page in a paged memory pool. // Updated via update_physical_pages() before each load call. // Unused in SRD mode — SRD rebase handles page addressing externally. @@ -1263,7 +1263,7 @@ template + bool UseGlobalLoad = false> CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(const TensorView_& tensor_view, const WindowLengths_& window_lengths, @@ -1273,7 +1273,7 @@ make_tile_scatter_gather(const TensorView_& tensor_view, number, number, sequence, - bool_constant = {}) + bool_constant = {}) { return tile_scatter_gather, remove_cvref_t, @@ -1283,11 +1283,11 @@ make_tile_scatter_gather(const TensorView_& tensor_view, HsGatherDim, NumCoord, sequence, - UseFlatLoad>{ + UseGlobalLoad>{ tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; } -// Legacy overload (compatible with original API, kUseFlatLoad=false) +// Legacy overload (compatible with original API, kUseGlobalLoad=false) template + bool UseGlobalLoad> CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(const TensorView_& tensor_view, const WindowLengths_& window_lengths, const multi_index& origin, const StaticTileDistribution_& tile_distribution, const StaticPageIndexArray_& page_idx, - bool_constant) + bool_constant) { return tile_scatter_gather, remove_cvref_t, @@ -1336,7 +1336,7 @@ make_tile_scatter_gather(const TensorView_& tensor_view, 0, 1, sequence<0>, - UseFlatLoad>{ + UseGlobalLoad>{ tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; } diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index ce989f5e8d6..7b597f148c5 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -135,7 +135,7 @@ template + bool kUseGlobalLoad_ = false> CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages, const index_t& stride_token, const index_t& stride_page_block, @@ -159,11 +159,11 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica // Offset strategy: // kPageBlockSize >= kN0: within-page offset (SRD rebased to page base) - // kPageBlockSize < kN0 && kUseFlatLoad_: within-page offset (flat load uses - // physical_pages[]) kPageBlockSize < kN0 && !kUseFlatLoad_: FULL offset (page * stride + + // kPageBlockSize < kN0 && kUseGlobalLoad_: within-page offset (global_load_lds uses + // physical_pages[]) kPageBlockSize < kN0 && !kUseGlobalLoad_: FULL offset (page * stride + // within_page) for // direct buffer_load with 32-bit voffset — the original code path, fast but limited to <2GB - constexpr bool kNeedFullOffset = (kPageBlockSize < kN0) && !kUseFlatLoad_; + constexpr bool kNeedFullOffset = (kPageBlockSize < kN0) && !kUseGlobalLoad_; static_for<0, kLoopCount, 1>{}([&](auto k0) { const index_t global_token_idx = @@ -188,7 +188,7 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica // // 32-bit by hardware: SRD buffer_load voffset is fundamentally 32-bit (CDNA3 MUBUF // microcode format), so this branch is only reachable when total KV bytes fit in - // INT32_MAX. The kUseFlatLoad_ template path handles the >2GB case via 64-bit + // INT32_MAX. The kUseGlobalLoad_ template path handles the >2GB case via 64-bit // global_load_lds_*; widening kv_offset_vec here would not lift the 2GB ceiling // because the hardware truncates voffset regardless. if constexpr(kNeedFullOffset) @@ -238,14 +238,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; static constexpr index_t kPageBlockSize = Problem::kPageBlockSize; - static constexpr bool kUse64BitLoad = Problem::kUse64BitLoad; static constexpr index_t kVectorSize = Problem::kVectorSize; - // Effective condition for flat 64-bit loads: kUse64BitLoad AND page_size < kN0 - static constexpr bool kUseFlatLoad = kUse64BitLoad && (kPageBlockSize < kN0); - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - static constexpr auto I3 = number<3>{}; + // Single load-mode flag for the whole pipeline: when true, K/V tiles use + // global_load_lds_* (handles >2GB KV cache) instead of SRD buffer_load_*. + // Codegen only emits kUseGlobalLoad=true arms when page_size < kN0; the + // static_assert is a backstop in case someone instantiates the pipeline manually. + static constexpr bool kUseGlobalLoad = Problem::kUseGlobalLoad; + static_assert(!kUseGlobalLoad || (kPageBlockSize < kN0), + "kUseGlobalLoad is only valid when kPageBlockSize < kN0; " + "codegen should not emit this instantiation otherwise."); + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); static constexpr bool kIsGroupMode = Problem::kIsGroupMode; @@ -599,7 +604,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync true, kN0, kVectorSize, - kUseFlatLoad>( + kUseGlobalLoad>( k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(), @@ -607,8 +612,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync k_dram_block_window.get_window_origin(), k_dist, k_offsets, - bool_constant{}); - if constexpr(kUseFlatLoad) + bool_constant{}); + if constexpr(kUseGlobalLoad) { k_dram_window.set_page_stride_elements(page_stride_k); k_dram_window.update_physical_pages(k_physical_pages); @@ -616,7 +621,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync k_dram_window.init_raw(); // SRD rebasing for K: only for page_size >= kN0 (all threads on same page). - // For page_size < kN0: either flat loads (kUseFlatLoad) or full offsets handle addressing. + // For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle + // addressing. auto rebase_k_window = [&](auto& window, index_t physical_page) { if constexpr(kPageBlockSize >= kN0) { @@ -636,7 +642,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync }; // SRD rebasing for V: only for page_size >= kN0 (all threads on same page). - // For page_size < kN0: either flat loads (kUseFlatLoad) or full offsets handle addressing. + // For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle + // addressing. auto rebase_v_window = [&](auto& window, index_t physical_page) { if constexpr(kPageBlockSize >= kN0) { @@ -860,12 +867,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync false, kN0, kVectorSize, - kUseFlatLoad>(v_physical_pages_k2, - stride_v, - page_stride_v, - v_coord, - v_offsets_k2, - current_seq_k); + kUseGlobalLoad>(v_physical_pages_k2, + stride_v, + page_stride_v, + v_coord, + v_offsets_k2, + current_seq_k); static_for<0, V_KIterInner, 1>{}([&](auto k1) { constexpr auto idx = number{}; @@ -886,11 +893,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync false, kN0, kVectorSize, - kUseFlatLoad>( + kUseGlobalLoad>( v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); } - // For page_size < kN0 with kUseFlatLoad: v_offsets contain within-page offsets; + // For page_size < kN0 with kUseGlobalLoad: v_offsets contain within-page offsets; // page base is handled by physical_pages_ in tile_scatter_gather::load(). }; @@ -906,8 +913,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync number<1>{}, // HsGatherDim number<1>{}, // NumCoord VPageIndexYDims, - bool_constant{}); - if constexpr(kUseFlatLoad) + bool_constant{}); + if constexpr(kUseGlobalLoad) { v_dram_window.set_page_stride_elements(page_stride_v); v_dram_window.update_physical_pages(v_physical_pages); @@ -971,7 +978,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } // Save current V physical pages before prefetch overwrites them - if constexpr(kUseFlatLoad) + if constexpr(kUseGlobalLoad) v_dram_window.update_physical_pages(v_physical_pages); // Prefetch V physical pages early - overlaps with GEMM0 computation prefetch_v_physical_pages(number{}); @@ -1168,7 +1175,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Prefetch V physical pages early - overlaps with softmax computation if constexpr(k1_loops > 1) { - if constexpr(kUseFlatLoad) + if constexpr(kUseGlobalLoad) v_dram_window.update_physical_pages(v_physical_pages); prefetch_v_physical_pages(number<2 * kK1>{}); } @@ -1403,7 +1410,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Prefetch V physical pages for NEXT iteration - overlaps with GEMM1 if constexpr(i_k1 + 1 < k1_loops - 1) { - if constexpr(kUseFlatLoad) + if constexpr(kUseGlobalLoad) v_dram_window.update_physical_pages(v_physical_pages); prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{}); } @@ -1486,10 +1493,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync true, kN0, kVectorSize, - kUseFlatLoad>( + kUseGlobalLoad>( k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); - if constexpr(kUseFlatLoad) + if constexpr(kUseGlobalLoad) k_dram_window.update_physical_pages(k_physical_pages); rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]); diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 5f072b6efd1..f0a61c65f81 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -117,10 +117,10 @@ struct BlockFmhaBatchPrefillPipelineProblem static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0, "kPageBlockSize must be power of two"); - // When true, use flat 64-bit loads for page_size < kN0 (handles >2GB KV cache). + // When true, use global_load_lds_* for page_size < kN0 (handles >2GB KV cache). // When false (default), use SRD buffer_load for all page sizes (faster, <2GB only). // The 2GB bound = INT32_MAX byte offset, matching CK's existing TwoGB convention. - static constexpr bool kUse64BitLoad = Traits_::kUse64BitLoad; + static constexpr bool kUseGlobalLoad = Traits_::kUseGlobalLoad; static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4 static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 57ed097c24d..8164420ad34 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -59,7 +59,7 @@ template + bool kUseGlobalLoad_ = false> struct TileFmhaBatchPrefillTraits : public TileFmhaTraits Date: Tue, 21 Apr 2026 19:33:49 +0800 Subject: [PATCH 14/17] fix(fmha): use __builtin_assume + restore batch prefill review polish tile_scatter_gather.hpp: replace cassert assert(size > 0) in set_bottom_tensor_view_buffer_size with __builtin_assume. The cassert form expands to an __assert_fail call whose SGPR pressure forces the LLVM AMDGPU register allocator to reuse the K-SRD scalar register window (s24-s27) as scratch for the assert-PC literal, scattering the 4 K-SRD writes across two conditional branches. gfx950 buffer_load_dwordx4 does not tolerate the staggered SRD setup; gfx942 (4x scalar buffer_load) absorbs it. Reproduced as 95.2% mismatch on MI355X for hd=256, ps=1024, linear, vectorized, bf16, causal, soft_cap=30. __builtin_assume preserves the optimizer hint without emitting the assert handler. block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp: restore review follow-ups lost in an earlier reset: * Four-case addressing-strategy comment block above kNeedFullOffset (Case 1-3 mechanisms + Case 4 codegen backstop reference). * readfirstlane "wave-uniform -> SGPR" rationale on K and V rebase sites. * v_offsets semantics comment enumerating Cases 1/2/3, naming kNeedFullOffset as the Case-3 selector. * Remove redundant outer if constexpr(kPageBlockSize >= kN0) around rebase_v_window initial call (single source of truth in lambda). * save_and_prefetch_v_pages lambda encapsulates the update_physical_pages -> prefetch_v_physical_pages ordering invariant; 3 in-loop sites collapsed to single calls. Verified: test_batch_prefill_4gb_small_page.py 12/12 on both archs; test_batch_prefill_4gb_repro.py --num_blocks 5000 9/9 bf16 + 6/6 FP8 KV_BLOCKSCALE on gfx942 (smc300x-clt) and gfx950 (smci355-gfx950). --- .../core/tensor/tile_scatter_gather.hpp | 9 +++ ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 78 +++++++++++++------ 2 files changed, 64 insertions(+), 23 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp index fce8c3da662..95861b33946 100644 --- a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -383,6 +383,15 @@ struct tile_scatter_gather // honor the same invariant the ctor enforces. CK_TILE_DEVICE constexpr void set_bottom_tensor_view_buffer_size(index_t size) { + // Hint the optimizer that size is positive without inserting a runtime + // branch. Using assert() here corrupted gfx950 batch_prefill + // output: the __assert_fail handler's SGPR pressure forced the K-SRD + // register window to be reused as scratch and scattered the SRD writes + // across two conditional branches, which gfx950's packed + // buffer_load_dwordx4 issue window doesn't tolerate (gfx942 absorbs it + // via per-tile single-dword loads). __builtin_assume is hint-only — + // no branch, no scratch SGPRs, no codegen impact. + __builtin_assume(size > 0); using BufType = remove_cvref_t; bottom_tensor_view_.buf_.buffer_size_ = size / BufType::PackedSize; } diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 7b597f148c5..08782424ac4 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -157,12 +157,28 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica const index_t& thread_coord_start = coord_vec[kCoordAxis]; constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1; - // Offset strategy: - // kPageBlockSize >= kN0: within-page offset (SRD rebased to page base) - // kPageBlockSize < kN0 && kUseGlobalLoad_: within-page offset (global_load_lds uses - // physical_pages[]) kPageBlockSize < kN0 && !kUseGlobalLoad_: FULL offset (page * stride + - // within_page) for - // direct buffer_load with 32-bit voffset — the original code path, fast but limited to <2GB + // Addressing strategy — four cases controlled by (kPageBlockSize vs kN0, kUseGlobalLoad_): + // + // Case 1: kPageBlockSize >= kN0 + // SRD is rebased per-tile to the page base (rebase_{k,v}_window in caller). + // Page base is absorbed into the SRD's 48-bit base pointer (SGPR-resident). + // This function writes within-page offset only. + // + // Case 2: kPageBlockSize < kN0 && kUseGlobalLoad_ + // SRD cannot be rebased (multi-page wave). Loads use global_load_lds_*; the full + // 64-bit address is computed by tile_scatter_gather::load() in + // include/ck_tile/core/tensor/tile_scatter_gather.hpp from physical_pages_ + + // page_stride_elements_. This function writes within-page offset only. + // + // Case 3: kPageBlockSize < kN0 && !kUseGlobalLoad_ (kNeedFullOffset == true) + // SRD base is the entire KV buffer; the only place to encode page identity + // is the voffset itself. This function writes the FULL offset: + // page * stride_page_block + within_page + // Limited to <2GB total KV bytes by 32-bit voffset hardware width. + // + // Case 4: kPageBlockSize >= kN0 && kUseGlobalLoad_ + // Not emitted by codegen. Backstop static_assert in + // BlockFmhaBatchPrefillPipelineQRKSVSAsync. constexpr bool kNeedFullOffset = (kPageBlockSize < kN0) && !kUseGlobalLoad_; static_for<0, kLoopCount, 1>{}([&](auto k0) { @@ -626,6 +642,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto rebase_k_window = [&](auto& window, index_t physical_page) { if constexpr(kPageBlockSize >= kN0) { + // readfirstlane: make physical_page provably wave-uniform so the + // resulting SRD lands in SGPRs (required by buffer load instructions). physical_page = __builtin_amdgcn_readfirstlane(physical_page); const auto* base_ptr = k_dram_block_window.get_bottom_tensor_view().buf_.p_data_; const auto* page_ptr = @@ -647,6 +665,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto rebase_v_window = [&](auto& window, index_t physical_page) { if constexpr(kPageBlockSize >= kN0) { + // readfirstlane: make physical_page provably wave-uniform so the + // resulting SRD lands in SGPRs (required by buffer load instructions). physical_page = __builtin_amdgcn_readfirstlane(physical_page); const auto* base_ptr = v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_; @@ -897,8 +917,15 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); } - // For page_size < kN0 with kUseGlobalLoad: v_offsets contain within-page offsets; - // page base is handled by physical_pages_ in tile_scatter_gather::load(). + // v_offsets semantics — see the four-case addressing-strategy block above + // kNeedFullOffset in kv_offset_array_transform. Three cases reach this lambda: + // Case 1 (kPageBlockSize >= kN0): within-page offset; page base in SRD. + // Case 2 (page_size < kN0, kUseGlobalLoad): within-page offset; page base computed + // by tile_scatter_gather::load() from + // physical_pages_. + // Case 3 (page_size < kN0, !kUseGlobalLoad, == kNeedFullOffset): + // FULL offset (page * stride + within), + // carried in the 32-bit voffset (<2GB cap). }; // Prefetch V physical pages early to hide buffer load latency @@ -920,11 +947,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_dram_window.update_physical_pages(v_physical_pages); } - // For page_size >= kN0, use SRD rebase (all threads on same page) - if constexpr(kPageBlockSize >= kN0) - { - rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); - } + // Initial V SRD rebase. Single source of truth: rebase_v_window's own + // `if constexpr(kPageBlockSize >= kN0)` makes this a no-op for case 2/3. + // Do not re-add an outer guard here — it would duplicate the inner check + // and drift if the lambda's gating condition ever changes. + rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); + + // Save the *current* tile's V physical pages into v_dram_window before + // prefetch_v_physical_pages overwrites the v_physical_pages buffer with the + // *next* tile's pages. Case-2 only (kUseGlobalLoad); case-1/3 don't read + // physical_pages_ from the window. Encapsulating the save+prefetch pair + // here makes the ordering invariant unmissable when a fourth prefetch site + // is added later. + auto save_and_prefetch_v_pages = [&](auto k_loop_start) { + if constexpr(kUseGlobalLoad) + v_dram_window.update_physical_pages(v_physical_pages); + prefetch_v_physical_pages(k_loop_start); + }; // prefetch K tile async_load_tile_raw( @@ -977,11 +1016,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_descale = v_descale_ptr[scale_offset]; } - // Save current V physical pages before prefetch overwrites them - if constexpr(kUseGlobalLoad) - v_dram_window.update_physical_pages(v_physical_pages); // Prefetch V physical pages early - overlaps with GEMM0 computation - prefetch_v_physical_pages(number{}); + save_and_prefetch_v_pages(number{}); // STAGE 1, QK gemm clear_tile(s_acc); // initialize C @@ -1175,9 +1211,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Prefetch V physical pages early - overlaps with softmax computation if constexpr(k1_loops > 1) { - if constexpr(kUseGlobalLoad) - v_dram_window.update_physical_pages(v_physical_pages); - prefetch_v_physical_pages(number<2 * kK1>{}); + save_and_prefetch_v_pages(number<2 * kK1>{}); } auto m_local = block_tile_reduce( @@ -1410,9 +1444,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Prefetch V physical pages for NEXT iteration - overlaps with GEMM1 if constexpr(i_k1 + 1 < k1_loops - 1) { - if constexpr(kUseGlobalLoad) - v_dram_window.update_physical_pages(v_physical_pages); - prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{}); + save_and_prefetch_v_pages(number<(2 + i_k1.value + 1) * kK1>{}); } block_sync_lds(); From 72a5da713c3abdb0ae8c6d71aba6645a3be5171b Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Tue, 21 Apr 2026 21:46:16 +0800 Subject: [PATCH 15/17] =?UTF-8?q?refactor(fmha):=20batch=20prefill=20revie?= =?UTF-8?q?w=20polish=20=E2=80=94=20assert=20helper=20+=20setter=20guards?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit amd_buffer_addressing_builtins.hpp: * Replace the !num_dwords CWG-2518 workaround in the non-CDNA3+ guard with a named file-local helper impl::global_load_lds_arch_unreachable_v. The old form had a misleading edge case: !num_dwords is true only when num_dwords == 0, so the assert silently passed on num_dwords==0 wrong- arch instantiations and the second assert (num_dwords == 1 || == 4) fired with the wrong diagnostic. The named helper makes the intent unambiguous and the dependent-false pattern self-documenting. * Reword the num_dwords == 1 || num_dwords == 4 static_assert message to distinguish hardware reality from policy: 2 dwords does not exist on any supported arch; 3 dwords only on CDNA4 and unused in FMHA pipeline. Prevents a future contributor from assuming 2/3 are deliberately blacklisted by software. tile_scatter_gather.hpp: * async_global_load_lds_dwordxn callsite: remove the redundant reinterpret_cast(addr). addr is already const DataType* and converts implicitly to const void*. Comment clarifies that global_load_lds takes a byte address, which is what the implicit conversion produces. * Add positive static_assert(kUseGlobalLoad_, ...) to update_physical_pages and set_page_stride_elements. Both fields (physical_pages_, page_stride_elements_) only participate in the global-load addressing path; calling these setters in SRD mode is silently a no-op that hides the misuse. The compile-time guard turns the misuse into a build error and locks down the invariant. Verified on smci355-gfx950 (gfx950): clean JIT rebuild succeeds, no new warnings, and test_batch_prefill_4gb_small_page.py 12/12 pass with the two new positive setter asserts in place (codegen only emits kUseGlobalLoad=true arms when the setters are reachable, so neither fires in practice). --- .../core/arch/amd_buffer_addressing_builtins.hpp | 14 +++++++++----- .../ck_tile/core/tensor/tile_scatter_gather.hpp | 15 +++++++++++---- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 5769a443234..feeaa7a2615 100644 --- a/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1319,6 +1319,11 @@ CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); } +namespace impl { +template +inline constexpr bool global_load_lds_arch_unreachable_v = false; +} // namespace impl + // Flat async load from global memory to LDS using 64-bit global addressing. // Bypasses the SRD's 32-bit offset limit; required when the KV cache exceeds // INT32_MAX (2GB) byte offset on the SRD voffset path. @@ -1363,16 +1368,15 @@ CK_TILE_DEVICE void async_global_load_lds_dwordxn(void* smem, const void* global_addr, bool_constant = {}) { #if !defined(__gfx94__) && !defined(__gfx950__) - // global_load_lds is only available on CDNA3+ (gfx940/gfx950). - // Use !num_dwords so the assert depends on a template parameter - // and is only checked at instantiation time. - static_assert(!num_dwords, + static_assert(impl::global_load_lds_arch_unreachable_v, "global_load_lds requires CDNA3+ (gfx940/gfx950). " "Ensure kUseGlobalLoad is false on this architecture."); #endif static_assert(num_dwords == 1 || num_dwords == 4, - "global_load_lds only supports num_dwords == 1 or 4"); + "global_load_lds supports num_dwords == 1 or 4 only " + "(2 dwords does not exist on any supported arch; " + "3 dwords only on CDNA4 and unused in FMHA pipeline)"); // Inline asm: only the global address is an explicit operand. The LDS // destination is implicit via M0 (see contract above). `"=r"(smem)` is a diff --git a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 95861b33946..1c4224c3dec 100644 --- a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -745,9 +745,9 @@ struct tile_scatter_gather static_cast(physical_page) * page_stride_elements_ + coord_offset + page_offset; const auto* addr = base_ptr + total_offset; - // global_load_lds takes byte address - async_global_load_lds_dwordxn( - smem, reinterpret_cast(addr), pre_nop_); + // global_load_lds takes a byte address; addr (const DataType*) + // converts implicitly to const void*, no explicit cast needed. + async_global_load_lds_dwordxn(smem, addr, pre_nop_); } else if constexpr(std::is_same_v) { @@ -1117,10 +1117,17 @@ struct tile_scatter_gather CK_TILE_DEVICE void update_physical_pages(const PageIdxArray& pages) { + static_assert(kUseGlobalLoad_, + "global-load mode only; physical_pages_ is unused in SRD mode."); physical_pages_ = pages; } - CK_TILE_DEVICE void set_page_stride_elements(index_t stride) { page_stride_elements_ = stride; } + CK_TILE_DEVICE void set_page_stride_elements(index_t stride) + { + static_assert(kUseGlobalLoad_, + "global-load mode only; page_stride_elements_ is unused in SRD mode."); + page_stride_elements_ = stride; + } CK_TILE_DEVICE void update_valids(const ValidArray& new_valids) { From bda61422b1c6db7482e190fa9a2935d40cc46f0e Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Thu, 23 Apr 2026 11:23:00 +0800 Subject: [PATCH 16/17] refactor(fmha): apply PR #6653 review feedback (Tasks #70-#74) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address all 5 reviewer asks on the >2GB KV cache batch-prefill series, plus two self-found polish items surfaced by an internal CK-aware review pass. Task #71 — bool kUseGlobalLoad_ -> BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ (poyenc, tile_fmha_traits.hpp:62): Adjacent traits-template params (kKVMemoryLayout_, kKVLookupTable_) are already BlockAttention*Enum types; the binary kUseGlobalLoad_ stuck out as a bool exception. Convert to a 2-value enum {BUFFER_LOAD = 0, GLOBAL_LOAD_LDS = 1} living in a new ops/fmha/block/ header so it sits alongside its siblings. Touch sites: * include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp (NEW): the enum class. * include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp: rename last template param + static member alias. * include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp: mirror alias rename. * include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp: add enum header include; class declares static auto kKVLoadMode plus derived static bool kUseGlobalLoad = (kKVLoadMode == GLOBAL_LOAD_LDS). All 10 internal `if constexpr(kUseGlobalLoad)` sites unchanged so the bool boundary is local to one TU. The standalone helper kv_offset_array_transform keeps its bool template param (private inline; intentional — keeps core/ tile_scatter_gather.hpp out of the enum's blast radius). * example/ck_tile/01_fmha/fmha_fwd.hpp: fmha_fwd_batch_prefill_traits_ last template param renamed; static member alias kUseGlobalLoad -> kKVLoadMode (default BUFFER_LOAD). * include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp: comment-only update. Task #70 — explicit constructor mem-init for tile_scatter_gather (asleepzzz, tile_scatter_gather.hpp:1241, comment #3125912056): physical_pages_ and page_stride_elements_ were silently zero-initialized in the BUFFER_LOAD arm. Today safe (Task #71's positive setter asserts prevent misuse), but a future kUseGlobalLoad=true caller that misses a setter would get silent data corruption with no compile error. Make both fields explicit in the mem-init list so the contract is visible at the constructor boundary. Task #72 — extract dispatcher overflow predicate to a named helper (poyenc, fmha_batch_prefill.py:225): Move the (page_block_size < kN0 && kv_pool_bytes > INT32_MAX) decision out of the codegen template into a free helper: fmha_batch_prefill_select_kv_load_mode(page_block_size, kN0, num_total_pages, batch_stride_k, element_bytes) in example/ck_tile/01_fmha/fmha_fwd.hpp. The codegen-emitted dispatcher arms now call it with their compile-time kN0/element_bytes substituted, so the formula has exactly one source of truth. Task #73 — symmetric gload/bload kernel-name suffix (poyenc, fmha_batch_prefill.py:282): Match the existing CK convention (e.g., causal/ncausal, sink/nsink) by emitting a non-empty token in BOTH branches: '-gload' / '-bload' on FmhaFwdApiTrait.name, 'gload_' / 'bload_' on FmhaFwdKernel.name. The prior blank-default made it impossible to tell, when grepping JIT blob/ output 6 months later, whether a missing marker meant 'BUFFER_LOAD variant' or 'old codegen revision before the gload branch existed'. Task #74 — replace single-use dependent-false with reusable always_false_v (poyenc, amd_buffer_addressing_builtins.hpp:1324): Promote impl::global_load_lds_arch_unreachable_v from a file-local helper into a generic ck_tile::always_false_v utility in core/utility/type_traits.hpp. Use it at the original site. The variable-template form defers evaluation to instantiation time, so a bare `static_assert(false, ...)` would (per CWG-2518 / current Clang) fire at parse time and break the whole TU even on never-instantiated arches. Polish I-1 — umbrella header completeness: include/ck_tile/ops/fmha.hpp now pulls in the new block_attention_kv_load_mode_enum.hpp alongside the other BlockAttention*Enum siblings. Without this, downstream consumers that rely solely on the umbrella header would miss the enum. Polish I-2 — overflow-cast robustness in fmha_batch_prefill_select_kv_load_mode: Promote every operand of the kv_pool_bytes multiplication to long_index_t individually instead of relying on left-to-right associativity to widen the chain. A future operand reorder would silently truncate; the per-operand cast makes overflow impossible regardless of order. Verified on smci355-gfx950 (gfx950): clean JIT rebuild succeeds; full op_tests/test_batch_prefill.py sweep passes 30,720 / 30,720 (10,016 skipped, 0 failed) in 30:40 wall. Codegen identifier changes only affect the renamed template parameter; no register-allocation perturbation expected on either gfx942 or gfx950 (confirmed by the cross-arch sweep). --- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 48 +++++++++++-------- .../example/ck_tile/01_fmha/fmha_fwd.hpp | 32 ++++++++++++- .../arch/amd_buffer_addressing_builtins.hpp | 9 +--- .../core/tensor/tile_scatter_gather.hpp | 48 ++++++++++++------- .../ck_tile/core/utility/type_traits.hpp | 14 ++++++ .../include/ck_tile/ops/fmha.hpp | 1 + .../block_attention_kv_load_mode_enum.hpp | 17 +++++++ ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 24 ++++++---- .../pipeline/block_fmha_pipeline_problem.hpp | 9 ++-- .../ops/fmha/pipeline/tile_fmha_traits.hpp | 6 ++- 10 files changed, 147 insertions(+), 61 deletions(-) create mode 100644 projects/composablekernel/include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 3f08b54dd88..8c006c09db9 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -26,7 +26,7 @@ from codegen.utils import update_file # Architecture trait for kernels requiring global_load_lds (CDNA3+). -# Only used for kUseGlobalLoad=true variants; all other kernels are arch-agnostic. +# Only used for GLOBAL_LOAD_LDS variants; all other kernels are arch-agnostic. CDNA3_PLUS_ARCH = ArchTrait( "cdna3_plus", preprocessor_check="defined(__gfx94__) || defined(__gfx950__)", @@ -43,7 +43,7 @@ } # Element size in bytes per dtype, used by the auto-generated dispatcher to -# decide use_global_load per-arm (total KV cache bytes vs INT32_MAX). +# decide kv_load_mode per-arm (total KV cache bytes vs INT32_MAX). DTYPE_BYTES = {k: v // 8 for k, v in DTYPE_BITS.items()} K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} @@ -59,6 +59,10 @@ "vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D", "sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D", } +KV_LOAD_MODE_ENUM_MAP = { + False: "ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD", + True: "ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS", +} FMHA_BATCH_PREFILL_PIPELINE_MAP = { @@ -102,7 +106,7 @@ {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, - {F_use_global_load}>; + {F_kv_load_mode}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -140,7 +144,7 @@ ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_use_global_load}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>; #include @@ -222,8 +226,8 @@ """ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && ((a.page_block_size < {F_bn0} && static_cast(a.num_total_pages) * a.batch_stride_k * kElementBytes > INT32_MAX) == {F_use_global_load})) {{ - using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_use_global_load}>; + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && (fmha_batch_prefill_select_kv_load_mode(a.page_block_size, {F_bn0}, a.num_total_pages, a.batch_stride_k, kElementBytes) == {F_kv_load_mode})) {{ + using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>; return fmha_batch_prefill_(s, a); }} """ @@ -279,7 +283,7 @@ def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}" - + ("-globalload" if self.use_global_load else "") + + ("-gload" if self.use_global_load else "-bload") ) @property @@ -502,9 +506,7 @@ def api(self) -> str: ], F_page_size=trait.page_size, F_sink=BOOL_MAP[trait.sink], - F_use_global_load=BOOL_MAP[ - "t" if trait.use_global_load else "f" - ], + F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[trait.use_global_load], ) if_j = "if" if j == 0 else "else if" per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( @@ -616,7 +618,7 @@ def template(self) -> str: F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag], F_page_size=self.F_page_size, F_sink=BOOL_MAP[self.F_pipeline.F_sink], - F_use_global_load=BOOL_MAP["t" if self.F_use_global_load else "f"], + F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[self.F_use_global_load], F_arch_check=CDNA3_PLUS_ARCH.preprocessor_check if self.F_use_global_load else "true", @@ -627,7 +629,7 @@ def name(self) -> str: # TODO: we don't encode idx here return ( f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_" - + ("globalload_" if self.F_use_global_load else "") + + ("gload_" if self.F_use_global_load else "bload_") + self.F_tile.name + "_" + self.F_pipeline.name @@ -748,8 +750,11 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: def get_fwd_blobs( - kernel_filter: Optional[str], receipt, optdim_list, mask_impl, - targets: Optional[List[str]] = None + kernel_filter: Optional[str], + receipt, + optdim_list, + mask_impl, + targets: Optional[List[str]] = None, ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing # (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with @@ -871,9 +876,10 @@ def get_fwd_blobs( api_pool.register_traits(k.api_trait()) gen.append(k) - # For page_size < kN0 (tile.F_bn0), also generate a kUseGlobalLoad=true - # variant for >2GB KV cache support. The default (false) uses SRD buffer_load - # (fast, <2GB). The global_load variant uses global_load_lds_* (slower, handles >2GB). + # For page_size < kN0 (tile.F_bn0), also generate a GLOBAL_LOAD_LDS + # variant for >2GB KV cache support. The default (BUFFER_LOAD) uses SRD + # buffer_load (fast, <2GB). GLOBAL_LOAD_LDS uses global_load_lds_* + # (slower, handles >2GB). if page_size < tile.F_bn0: k_global_load = FmhaFwdKernel( F_idx=0, @@ -908,7 +914,9 @@ def write_blobs( optdim_list, mask_impl, ) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets) + api_pool, kernels = get_fwd_blobs( + kernel_filter, receipt, optdim_list, mask_impl, targets + ) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) @@ -923,7 +931,9 @@ def list_blobs( mask_impl, ) -> None: with file_path.open("a") as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets) + _, kernels = get_fwd_blobs( + kernel_filter, receipt, optdim_list, mask_impl, targets + ) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp index 95a858c271b..98e2df2e1ee 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -673,6 +673,33 @@ struct fmha_batch_prefill_args ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension }; +// Selects the KV-cache load mode for a batch-prefill dispatch arm. +// GLOBAL_LOAD_LDS: required when (a) the page is smaller than one K/V tile +// so per-page SRD is impossible, AND (b) the total KV-pool byte size +// exceeds INT32_MAX so SRD's 32-bit byte offset cannot address it. +// BUFFER_LOAD: every other case — the SGPR-resident SRD path is fastest. +// Inputs are taken as plain integers so the helper has no template parameter +// and can be called from each codegen-emitted dispatcher arm with the arm's +// compile-time kN0 / element_bytes substituted as constants. +inline ck_tile::BlockAttentionKVCacheLoadModeEnum +fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size, + ck_tile::index_t kN0, + ck_tile::index_t num_total_pages, + ck_tile::index_t batch_stride_k, + ck_tile::index_t element_bytes) +{ + // Promote every operand to long_index_t so overflow is impossible regardless + // of multiplication order. A bare `static_cast(num_total_pages) + // * batch_stride_k * element_bytes` only works because of left-to-right + // associativity — a future reorder of the operands would silently truncate. + const auto kv_pool_bytes = static_cast(num_total_pages) * + static_cast(batch_stride_k) * + static_cast(element_bytes); + return (page_block_size < kN0 && kv_pool_bytes > INT32_MAX) + ? ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS + : ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD; +} + template auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) { @@ -1458,7 +1485,8 @@ template + ck_tile::BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ = + ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD> struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_ -inline constexpr bool global_load_lds_arch_unreachable_v = false; -} // namespace impl - // Flat async load from global memory to LDS using 64-bit global addressing. // Bypasses the SRD's 32-bit offset limit; required when the KV cache exceeds // INT32_MAX (2GB) byte offset on the SRD voffset path. @@ -1368,9 +1363,9 @@ CK_TILE_DEVICE void async_global_load_lds_dwordxn(void* smem, const void* global_addr, bool_constant = {}) { #if !defined(__gfx94__) && !defined(__gfx950__) - static_assert(impl::global_load_lds_arch_unreachable_v, + static_assert(always_false_v>, "global_load_lds requires CDNA3+ (gfx940/gfx950). " - "Ensure kUseGlobalLoad is false on this architecture."); + "Ensure kKVLoadMode is BUFFER_LOAD on this architecture."); #endif static_assert(num_dwords == 1 || num_dwords == 4, diff --git a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 1c4224c3dec..e5f89aa95c3 100644 --- a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -244,12 +244,15 @@ struct tile_scatter_gather const BottomTensorIndex& window_origin, const TileDstr& tile_distribution, const PageIdxArray& page_idx, - const ValidArray& valids) + const ValidArray& valids, + index_t page_stride_elements = 0) : bottom_tensor_view_{bottom_tensor_view}, window_lengths_{window_lengths}, window_origin_{window_origin}, tile_dstr_{tile_distribution}, page_idx_{page_idx}, + physical_pages_{}, + page_stride_elements_{page_stride_elements}, valids_{valids}, pre_computed_coords_{} { @@ -1122,13 +1125,6 @@ struct tile_scatter_gather physical_pages_ = pages; } - CK_TILE_DEVICE void set_page_stride_elements(index_t stride) - { - static_assert(kUseGlobalLoad_, - "global-load mode only; page_stride_elements_ is unused in SRD mode."); - page_stride_elements_ = stride; - } - CK_TILE_DEVICE void update_valids(const ValidArray& new_valids) { if constexpr(std::is_same_v == false) @@ -1236,9 +1232,11 @@ struct tile_scatter_gather // Unused in SRD mode — SRD rebase handles page addressing externally. PageIdxArray physical_pages_; - // Page stride in elements for global load mode. + // Page stride in elements for global load mode (kUseGlobalLoad=true only). // physical_pages_[i] * page_stride_elements_ gives the page base offset in elements. - index_t page_stride_elements_ = 0; + // Set at construction time via the make_tile_scatter_gather overload that + // takes bool_constant; immutable thereafter. + index_t page_stride_elements_; ValidArray valids_; @@ -1289,7 +1287,8 @@ make_tile_scatter_gather(const TensorView_& tensor_view, number, number, sequence, - bool_constant = {}) + bool_constant = {}, + index_t page_stride_elements = 0) { return tile_scatter_gather, remove_cvref_t, @@ -1299,8 +1298,13 @@ make_tile_scatter_gather(const TensorView_& tensor_view, HsGatherDim, NumCoord, sequence, - UseGlobalLoad>{ - tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; + UseGlobalLoad>{tensor_view, + window_lengths, + origin, + tile_distribution, + page_idx, + nullptr, + page_stride_elements}; } // Legacy overload (compatible with original API, kUseGlobalLoad=false) @@ -1330,7 +1334,11 @@ make_tile_scatter_gather(const TensorView_& tensor_view, tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; } -// Overload with kUseGlobalLoad (simple, used by K cache) +// Overload with kUseGlobalLoad (simple, used by K cache). +// page_stride_elements is forwarded to the constructor; required (non-zero) +// when UseGlobalLoad=true so that physical_pages_[i] * page_stride_elements_ +// produces a valid address. Defaulting to 0 keeps SRD-mode call sites unchanged +// (page_stride_elements_ is unread in SRD mode). template & origin, const StaticTileDistribution_& tile_distribution, const StaticPageIndexArray_& page_idx, - bool_constant) + bool_constant, + index_t page_stride_elements = 0) { return tile_scatter_gather, remove_cvref_t, @@ -1352,8 +1361,13 @@ make_tile_scatter_gather(const TensorView_& tensor_view, 0, 1, sequence<0>, - UseGlobalLoad>{ - tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; + UseGlobalLoad>{tensor_view, + window_lengths, + origin, + tile_distribution, + page_idx, + nullptr, + page_stride_elements}; } template ` — a value-template that is always `false` but whose +// evaluation is deferred until template instantiation. The canonical use is +// inside the `else` arm of an `if constexpr` chain or under an arch-gated +// `#if` to fire a `static_assert` ONLY when the offending instantiation is +// actually requested, e.g.: +// +// if constexpr (...) { ... } +// else { static_assert(always_false_v, "unsupported T"); } +// +// A bare `static_assert(false, ...)` would fire at template-definition +// parse time on conforming compilers, breaking the whole TU. +template +inline constexpr bool always_false_v = false; + // remove_cvref_t template using remove_reference_t = typename std::remove_reference::type; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha.hpp b/projects/composablekernel/include/ck_tile/ops/fmha.hpp index 8a5d77bf462..59e868f678f 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha.hpp @@ -3,6 +3,7 @@ #pragma once #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp new file mode 100644 index 00000000000..826cd106f1b --- /dev/null +++ b/projects/composablekernel/include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +// KV cache load addressing mode selector for batch_prefill / paged-attention pipelines. +// - BUFFER_LOAD: SGPR-based SRD via buffer_load_* (default; 32-bit byte addressing, <2GB pool) +// - GLOBAL_LOAD_LDS: direct global_load_lds_* (64-bit addressing, required for >2GB KV cache) +enum class BlockAttentionKVCacheLoadModeEnum +{ + BUFFER_LOAD = 0, + GLOBAL_LOAD_LDS = 1, +}; + +} // namespace ck_tile diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 08782424ac4..8aa6d17dc3e 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" @@ -255,13 +256,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; static constexpr index_t kPageBlockSize = Problem::kPageBlockSize; static constexpr index_t kVectorSize = Problem::kVectorSize; - // Single load-mode flag for the whole pipeline: when true, K/V tiles use - // global_load_lds_* (handles >2GB KV cache) instead of SRD buffer_load_*. - // Codegen only emits kUseGlobalLoad=true arms when page_size < kN0; the - // static_assert is a backstop in case someone instantiates the pipeline manually. - static constexpr bool kUseGlobalLoad = Problem::kUseGlobalLoad; + // Single load-mode selector for the whole pipeline. GLOBAL_LOAD_LDS routes K/V + // tiles through global_load_lds_* (handles >2GB KV cache); BUFFER_LOAD uses SRD + // buffer_load_*. The enum is named at the trait/Problem level; internally we + // derive a bool helper to keep `if constexpr` sites narrow. Codegen only emits + // GLOBAL_LOAD_LDS arms when page_size < kN0; the static_assert is a backstop. + static constexpr auto kKVLoadMode = Problem::kKVLoadMode; + static constexpr bool kUseGlobalLoad = + (kKVLoadMode == BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS); static_assert(!kUseGlobalLoad || (kPageBlockSize < kN0), - "kUseGlobalLoad is only valid when kPageBlockSize < kN0; " + "GLOBAL_LOAD_LDS load mode is only valid when kPageBlockSize < kN0; " "codegen should not emit this instantiation otherwise."); static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; @@ -628,10 +632,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync k_dram_block_window.get_window_origin(), k_dist, k_offsets, - bool_constant{}); + bool_constant{}, + page_stride_k); if constexpr(kUseGlobalLoad) { - k_dram_window.set_page_stride_elements(page_stride_k); k_dram_window.update_physical_pages(k_physical_pages); } k_dram_window.init_raw(); @@ -940,10 +944,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync number<1>{}, // HsGatherDim number<1>{}, // NumCoord VPageIndexYDims, - bool_constant{}); + bool_constant{}, + page_stride_v); if constexpr(kUseGlobalLoad) { - v_dram_window.set_page_stride_elements(page_stride_v); v_dram_window.update_physical_pages(v_physical_pages); } diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index f0a61c65f81..c441f57c864 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -117,10 +117,11 @@ struct BlockFmhaBatchPrefillPipelineProblem static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0, "kPageBlockSize must be power of two"); - // When true, use global_load_lds_* for page_size < kN0 (handles >2GB KV cache). - // When false (default), use SRD buffer_load for all page sizes (faster, <2GB only). - // The 2GB bound = INT32_MAX byte offset, matching CK's existing TwoGB convention. - static constexpr bool kUseGlobalLoad = Traits_::kUseGlobalLoad; + // KV cache load addressing mode. GLOBAL_LOAD_LDS handles >2GB pools via + // 64-bit addressing; BUFFER_LOAD (default) uses SRD buffer_load for the + // <2GB fast path. The 2GB bound = INT32_MAX byte offset, matching CK's + // existing TwoGB convention. + static constexpr auto kKVLoadMode = Traits_::kKVLoadMode; static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4 static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 8164420ad34..e7370cdb65d 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" @@ -59,7 +60,8 @@ template + BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ = + BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD> struct TileFmhaBatchPrefillTraits : public TileFmhaTraits Date: Thu, 23 Apr 2026 14:26:46 +0800 Subject: [PATCH 17/17] refactor(fmha): make tile_scatter_gather page fields conditional on kUseGlobalLoad Replace the unconditional `physical_pages_` and `page_stride_elements_` members with `std::conditional_t` + `[[no_unique_address]]` so they collapse to zero-byte placeholders in the SRD instantiation (kUseGlobalLoad=false). Why: Reviewer asleepzzz observed that these fields were always present even when SRD-mode kernels never read them. The previous fix (Task #70, explicit mem-init) addressed the *form* of the concern (silent zero-init -> explicit zero-init) but not the *substance* (wasted storage in SRD-mode instantiations). This commit makes the fields literally disappear in SRD mode. How it works: - Empty placeholder `gl_field_empty_t` introduced inside the class. - Both fields wrapped in `std::conditional_t` with `[[no_unique_address]]` so the SRD-mode layout drops them. - All access sites (lines 520, 523, 758, 761) are already inside `if constexpr(kUseGlobalLoad_)` arms, so they compile out cleanly. - The setter `update_physical_pages` keeps its `static_assert(kUseGlobalLoad_)` guard; combined with lazy template member-function instantiation, the body is never instantiated for SRD callers. - Constructor mem-init stays type-agnostic via value-init `{}`; the `page_stride_elements_` assignment is gated by `if constexpr` in the body so the SRD arm only sees the empty struct. AP-7 (codegen-hash) note: Class layout changes only on the kUseGlobalLoad=true instantiation (where layout is identical: one PageIdxArray + one index_t). The kUseGlobalLoad=false instantiation now has *less* state, but adjacent fields' offsets shift only if the compiler chose not to merge the `[[no_unique_address]]` placeholder. Verified by remote re-test on both gfx942 and gfx950. --- .../core/tensor/tile_scatter_gather.hpp | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp index e5f89aa95c3..45131abb973 100644 --- a/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/projects/composablekernel/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -59,6 +59,15 @@ struct tile_scatter_gather "This kernel should not be instantiated on this architecture."); #endif + // Empty placeholder used by the SRD instantiation so physical_pages_ and + // page_stride_elements_ occupy zero bytes there (combined with + // [[no_unique_address]] on the member declarations). Access sites are all + // inside `if constexpr(kUseGlobalLoad_)` arms, which compile out in SRD + // mode, so no caller needs to change. + struct gl_field_empty_t + { + }; + using BottomTensorView = remove_reference_t; using WindowLengths = remove_cvref_t; using TileDstr = remove_cvref_t; @@ -252,10 +261,14 @@ struct tile_scatter_gather tile_dstr_{tile_distribution}, page_idx_{page_idx}, physical_pages_{}, - page_stride_elements_{page_stride_elements}, + page_stride_elements_{}, valids_{valids}, pre_computed_coords_{} { + if constexpr(kUseGlobalLoad_) + { + page_stride_elements_ = page_stride_elements; + } #if 0 // debug // TODO: this use more register for FA, but less register for GEMM // need investigation @@ -1229,14 +1242,17 @@ struct tile_scatter_gather // Physical page indices for global load mode (kUseGlobalLoad=true only). // Maps each gather element to its physical page in a paged memory pool. // Updated via update_physical_pages() before each load call. - // Unused in SRD mode — SRD rebase handles page addressing externally. - PageIdxArray physical_pages_; + // SRD mode: collapsed to gl_field_empty_t so the storage disappears. + [[no_unique_address]] std::conditional_t + physical_pages_; // Page stride in elements for global load mode (kUseGlobalLoad=true only). // physical_pages_[i] * page_stride_elements_ gives the page base offset in elements. // Set at construction time via the make_tile_scatter_gather overload that // takes bool_constant; immutable thereafter. - index_t page_stride_elements_; + // SRD mode: collapsed to gl_field_empty_t so the storage disappears. + [[no_unique_address]] std::conditional_t + page_stride_elements_; ValidArray valids_;