From fbf59e48daf9189b716b0ea5b75050df276638b8 Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Sun, 1 Mar 2026 14:30:58 +0000 Subject: [PATCH 01/12] [CK_TILE] Add LLC-aware FMHA head grouping and head-major scheduling --- .../example/ck_tile/01_fmha/fmha_fwd.hpp | 10 +- .../01_fmha/fmha_fwd_head_grouping.hpp | 414 ++++++++++++++++++ .../ck_tile/01_fmha/fmha_fwd_runner.hpp | 90 +++- .../composablekernel/include/ck_tile/core.hpp | 1 + .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 125 ++++-- 5 files changed, 611 insertions(+), 29 deletions(-) create mode 100644 projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp index 3123e2bd596f..9990bf088c8f 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -242,6 +242,8 @@ struct fmha_fwd_args ck_tile::index_t hdim_v; ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; + ck_tile::index_t num_head_q_total = 0; + ck_tile::index_t head_start = 0; float scale_s; float logits_soft_cap; @@ -669,7 +671,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, - args.sink_ptr); + args.sink_ptr, + args.num_head_q_total, + args.head_start); } else { // create batch mode kernel arguments @@ -728,7 +732,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.block_scale_size_kv, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, - args.sink_ptr); + args.sink_ptr, + args.num_head_q_total, + args.head_start); } }(); diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp new file mode 100644 index 000000000000..a22f80080cb6 --- /dev/null +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp @@ -0,0 +1,414 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/host.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef CK_TILE_FMHA_ENABLE_HEAD_GROUPING +#define CK_TILE_FMHA_ENABLE_HEAD_GROUPING 1 +#endif + +#if CK_TILE_FMHA_ENABLE_HEAD_GROUPING +namespace fmha_fwd_head_grouping { + +inline bool log_enabled() +{ + const char* env = std::getenv("CK_TILE_FMHA_HEAD_GROUP_LOG"); + return env != nullptr && std::atoi(env) == 1; +} + +inline bool disabled_by_env() +{ + const char* env_disable = std::getenv("CK_TILE_FMHA_DISABLE_HEAD_GROUPING"); + if(env_disable != nullptr && std::atoi(env_disable) == 1) + return true; + return false; +} + +inline bool is_decimal_string(const std::string& s) +{ + if(s.empty()) + return false; + return std::all_of(s.begin(), s.end(), [](unsigned char c) { return std::isdigit(c) != 0; }); +} + +inline std::optional read_property_value(const std::string& filepath, + const std::string& key) +{ + std::ifstream fs(filepath); + if(!fs.is_open()) + return std::nullopt; + + std::string k, v; + while(fs >> k >> v) + { + if(k == key) + { + try + { + return std::stoll(v, nullptr, 0); + } + catch(...) + { + return std::nullopt; + } + } + std::string rest; + std::getline(fs, rest); + } + return std::nullopt; +} + +struct kfd_device_location +{ + int domain = 0; + int location_id = 0; +}; + +inline std::optional get_current_kfd_location() +{ + int device = 0; + if(hipGetDevice(&device) != hipSuccess) + return std::nullopt; + + char bdf[64] = {}; + if(hipDeviceGetPCIBusId(bdf, sizeof(bdf), device) == hipSuccess) + { + unsigned int domain = 0, bus = 0, dev = 0, fn = 0; + if(std::sscanf(bdf, "%x:%x:%x.%x", &domain, &bus, &dev, &fn) == 4) + { + return kfd_device_location{ + static_cast(domain), + static_cast(((bus & 0xff) << 8) | ((dev & 0x1f) << 3) | (fn & 0x7))}; + } + } + + hipDeviceProp_t props{}; + if(hipGetDeviceProperties(&props, device) != hipSuccess) + return std::nullopt; + + return kfd_device_location{props.pciDomainID, + ((props.pciBusID & 0xff) << 8) | ((props.pciDeviceID & 0x1f) << 3)}; +} + +inline std::optional find_matching_kfd_node(const kfd_device_location& loc) +{ + constexpr const char* kKfdNodesDir = "/sys/class/kfd/kfd/topology/nodes"; + DIR* dir = opendir(kKfdNodesDir); + if(dir == nullptr) + return std::nullopt; + + std::optional matched; + while(auto* ent = readdir(dir)) + { + const std::string node_name(ent->d_name); + if(!is_decimal_string(node_name)) + continue; + + const std::string prop_path = std::string(kKfdNodesDir) + "/" + node_name + "/properties"; + const auto location_val = read_property_value(prop_path, "location_id"); + if(!location_val.has_value() || static_cast(*location_val) != loc.location_id) + continue; + + const auto domain_val = read_property_value(prop_path, "domain"); + if(domain_val.has_value() && static_cast(*domain_val) != loc.domain) + continue; + + matched = node_name; + break; + } + + closedir(dir); + return matched; +} + +inline size_t read_kfd_node_l3_bytes(const std::string& node_name) +{ + const std::string caches_dir = "/sys/class/kfd/kfd/topology/nodes/" + node_name + "/caches"; + DIR* dir = opendir(caches_dir.c_str()); + if(dir == nullptr) + return 0; + + size_t l3_kb = 0; + while(auto* ent = readdir(dir)) + { + const std::string cache_name(ent->d_name); + if(!is_decimal_string(cache_name)) + continue; + + const std::string prop_path = caches_dir + "/" + cache_name + "/properties"; + const auto level_val = read_property_value(prop_path, "level"); + if(!level_val.has_value() || *level_val != 3) + continue; + + const auto size_val = read_property_value(prop_path, "size"); + if(!size_val.has_value() || *size_val <= 0) + continue; + + l3_kb = std::max(l3_kb, static_cast(*size_val)); + } + + closedir(dir); + return l3_kb * 1024ull; +} + +inline size_t get_kfd_sysfs_llc_cache_bytes() +{ + const auto loc = get_current_kfd_location(); + if(!loc.has_value()) + return 0; + + const auto node = find_matching_kfd_node(*loc); + if(!node.has_value()) + return 0; + + return read_kfd_node_l3_bytes(*node); +} + +inline size_t get_default_llc_cache_bytes_for_arch(const std::string& arch); + +inline size_t resolve_llc_cache_bytes_uncached(const std::string& arch) +{ + // If parsed LLC looks invalidly tiny, ignore it and fallback. + constexpr size_t kMinValidKfdLlcBytes = 32ull * 1024ull; + + const size_t kfd_llc_bytes = get_kfd_sysfs_llc_cache_bytes(); + if(kfd_llc_bytes >= kMinValidKfdLlcBytes) + return kfd_llc_bytes; + + const size_t default_cache_bytes = get_default_llc_cache_bytes_for_arch(arch); + if(default_cache_bytes > 0) + return default_cache_bytes; + + // No default configured -> no grouping. + return 0; +} + +inline bool ck_tile_is_rdna_arch(const std::string& arch) +{ + return arch.rfind("gfx11", 0) == 0 || arch.rfind("gfx12", 0) == 0; +} + +inline size_t get_default_llc_cache_bytes_for_arch(const std::string& arch) +{ + if(arch == "gfx1100") + return 96ull * 1024ull * 1024ull; + if(arch == "gfx1101") + return 64ull * 1024ull * 1024ull; + if(arch == "gfx1102") + return 32ull * 1024ull * 1024ull; + if(arch == "gfx1151") + return 32ull * 1024ull * 1024ull; + if(arch == "gfx1200") + return 32ull * 1024ull * 1024ull; + if(arch == "gfx1201") + return 64ull * 1024ull * 1024ull; + return 0; +} + +inline size_t get_llc_cache_bytes(const std::string& arch) +{ + // resolve once and reuse. + static const size_t resolved_llc_bytes = [&]() -> size_t { + const char* env_llc_mb = std::getenv("CK_TILE_FMHA_LLC_CACHE_MB"); + if(env_llc_mb != nullptr) + { + const int mb = std::atoi(env_llc_mb); + if(mb > 0) + return static_cast(mb) * 1024ull * 1024ull; + } + return resolve_llc_cache_bytes_uncached(arch); + }(); + + return resolved_llc_bytes; +} + +inline std::optional get_head_group_size(ck_tile::index_t nhead_q, + ck_tile::index_t nhead_k, + ck_tile::index_t batch, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + size_t elem_bytes_k, + size_t elem_bytes_v) +{ + if(disabled_by_env()) + return std::nullopt; + + const std::string arch = ck_tile::get_device_name(); + if(arch.empty() || !ck_tile_is_rdna_arch(arch)) + return std::nullopt; + + const size_t llc_bytes = get_llc_cache_bytes(arch); + if(llc_bytes == 0) + return std::nullopt; + + if(nhead_k <= 0 || nhead_q <= 0 || (nhead_q % nhead_k) != 0) + return std::nullopt; + if(seqlen_k <= 0 || hdim_q <= 0 || hdim_v <= 0 || batch <= 0) + return std::nullopt; + static_cast(batch); + + const size_t kv_bytes_per_head = + static_cast(seqlen_k) * + (static_cast(hdim_q) * elem_bytes_k + static_cast(hdim_v) * elem_bytes_v); + if(kv_bytes_per_head == 0) + return std::nullopt; + + // large LLC GPUs (>= 64MB): slightly more cache-resident grouping + constexpr size_t kLargeLlcThresholdBytes = 64ull * 1024ull * 1024ull; + const bool is_large_llc = llc_bytes >= kLargeLlcThresholdBytes; + const long double llc_utilization = is_large_llc ? 0.85L : 1.0L; + const long double threshold_ratio = is_large_llc ? 1.3L : 1.5L; + const size_t target_llc_bytes = + static_cast(static_cast(llc_bytes) * llc_utilization); + if(target_llc_bytes == 0) + return std::nullopt; + + const size_t total_kv_bytes = static_cast(nhead_q) * kv_bytes_per_head; + if(static_cast(total_kv_bytes) < + static_cast(target_llc_bytes) * threshold_ratio) + return std::nullopt; + + ck_tile::index_t group = static_cast(target_llc_bytes / kv_bytes_per_head); + if(group < 1) + group = 1; + + const ck_tile::index_t min_group_size = std::max(1, nhead_q / 16); + if(group < min_group_size) + group = min_group_size; + + // Cap the number of groups to avoid excessive launch overhead. + constexpr ck_tile::index_t kMaxGroups = 8; + const ck_tile::index_t min_group_for_max_groups = + ck_tile::integer_divide_ceil(nhead_q, kMaxGroups); + if(group < min_group_for_max_groups) + group = min_group_for_max_groups; + + const ck_tile::index_t gqa_ratio = nhead_q / nhead_k; + if(gqa_ratio > 1) + { + group = ((group + gqa_ratio - 1) / gqa_ratio) * gqa_ratio; + } + + group = std::min(group, nhead_q); + if(group >= nhead_q) + return std::nullopt; + + return group; +} + +template +inline const void* ptr_offset(const void* base, ck_tile::index_t offset_elems) +{ + if(base == nullptr) + return nullptr; + return static_cast(reinterpret_cast(base) + offset_elems); +} + +template +inline void* ptr_offset(void* base, ck_tile::index_t offset_elems) +{ + if(base == nullptr) + return nullptr; + return static_cast(reinterpret_cast(base) + offset_elems); +} + +template +float run_fwd_head_grouped(const ck_tile::stream_config& sc, + const FmhaFwdTraits& fmha_traits, + const FmhaFwdArgs& base_args_in, + ck_tile::index_t nhead, + ck_tile::index_t nhead_k, + ck_tile::index_t group_size_q, + bool use_blockscale_qscale, + RunKernelFn&& run_kernel_fn) +{ + auto base_args = base_args_in; + base_args.num_head_q_total = nhead; + const ck_tile::index_t gqa_ratio = (nhead_k > 0 ? (nhead / nhead_k) : 1); + const ck_tile::index_t group_sz = std::min(group_size_q, nhead); + const ck_tile::index_t n_groups = ck_tile::integer_divide_ceil(nhead, group_sz); + + float total_time = 0.0f; + for(ck_tile::index_t head_start = 0; head_start < nhead; head_start += group_sz) + { + const ck_tile::index_t q_heads = std::min(group_sz, nhead - head_start); + const ck_tile::index_t k_head_start = (gqa_ratio > 0 ? head_start / gqa_ratio : head_start); + const ck_tile::index_t k_heads = (gqa_ratio > 0 ? q_heads / gqa_ratio : q_heads); + + auto args = base_args; + args.nhead_q = q_heads; + args.nhead_k = k_heads; + args.head_start = head_start; + + args.q_ptr = ptr_offset(base_args.q_ptr, head_start * base_args.nhead_stride_q); + args.k_ptr = + ptr_offset(base_args.k_ptr, k_head_start * base_args.nhead_stride_k); + args.v_ptr = + ptr_offset(base_args.v_ptr, k_head_start * base_args.nhead_stride_v); + args.o_ptr = ptr_offset(base_args.o_ptr, head_start * base_args.nhead_stride_o); + + args.bias_ptr = + ptr_offset(base_args.bias_ptr, head_start * base_args.nhead_stride_bias); + args.lse_ptr = + ptr_offset(base_args.lse_ptr, head_start * base_args.nhead_stride_lse); + args.rand_val_ptr = ptr_offset( + base_args.rand_val_ptr, head_start * base_args.nhead_stride_randval); + + if(use_blockscale_qscale) + { + args.q_descale_ptr = ptr_offset(base_args.q_descale_ptr, + head_start * base_args.nhead_stride_q_descale); + args.k_descale_ptr = ptr_offset(base_args.k_descale_ptr, + k_head_start * base_args.nhead_stride_k_descale); + args.v_descale_ptr = ptr_offset(base_args.v_descale_ptr, + k_head_start * base_args.nhead_stride_v_descale); + } + else + { + args.q_descale_ptr = base_args.q_descale_ptr; + args.k_descale_ptr = base_args.k_descale_ptr; + args.v_descale_ptr = base_args.v_descale_ptr; + } + + args.sink_ptr = ptr_offset(base_args.sink_ptr, head_start); + + if(log_enabled()) + { + const ck_tile::index_t head_end = head_start + q_heads; + std::cout << "[LLC Head Grouping] group " << (head_start / group_sz) << "/" << n_groups + << " heads_q=[" << head_start << ", " << head_end << ") heads_k=[" + << k_head_start << ", " << (k_head_start + k_heads) << ")" << std::endl; + } + + const float t = run_kernel_fn(fmha_traits, args, sc); + if(t < 0.0f) + return t; + total_time += t; + } + return total_time; +} + +} // namespace fmha_fwd_head_grouping +#endif // CK_TILE_FMHA_ENABLE_HEAD_GROUPING diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 1227724d404d..a616da9d2b5e 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -6,14 +6,17 @@ #include "ck_tile/host.hpp" #include "ck_tile/ref/naive_attention.hpp" #include "fmha_fwd.hpp" +#include "fmha_fwd_head_grouping.hpp" #include "utils.hpp" #include "ck_tile/utility/json_dump.hpp" #include +#include #include #include #include #include +#include #include #include #include @@ -1075,6 +1078,11 @@ fwd_result fmha_fwd_run(mode_enum mode, args.hdim_v = hdim_v; args.nhead_q = nhead; args.nhead_k = nhead_k; + if constexpr(std::is_same_v>) + { + args.num_head_q_total = nhead; + args.head_start = 0; + } args.stride_q = stride_q; args.stride_k = stride_k; @@ -1365,7 +1373,87 @@ fwd_result fmha_fwd_run(mode_enum mode, return fmha_fwd(fmha_traits, fmha_args, sc); }; - const float fwd_ave_time = run_fwd(stream_config); + + float fwd_ave_time = -1.0f; +#if CK_TILE_FMHA_ENABLE_HEAD_GROUPING + const bool allow_head_grouping = !i_perm && !use_kvcache && (num_splits <= 1) && + !need_append_kvcache && + (mode == mode_enum::batch || mode == mode_enum::group); + + if(allow_head_grouping) + { + if(fmha_fwd_head_grouping::disabled_by_env()) + { + if(fmha_fwd_head_grouping::log_enabled()) + std::cout << "[LLC Head Grouping] disabled by env" << std::endl; + } + else + { + const auto group_size_opt = + fmha_fwd_head_grouping::get_head_group_size(nhead, + nhead_k, + batch, + max_seqlen_k, + hdim_q, + hdim_v, + sizeof(KDataType), + sizeof(VDataType)); + + if(group_size_opt.has_value() && group_size_opt.value() < nhead) + { + if(fmha_fwd_head_grouping::log_enabled()) + { + const std::string arch = ck_tile::get_device_name(); + const size_t llc_bytes = fmha_fwd_head_grouping::get_llc_cache_bytes(arch); + const ck_tile::index_t gqa_ratio = (nhead_k > 0 ? (nhead / nhead_k) : 1); + const ck_tile::index_t group_sz = group_size_opt.value(); + const ck_tile::index_t n_groups = ck_tile::integer_divide_ceil(nhead, group_sz); + std::cout << "[LLC Head Grouping] enabled" << std::endl; + std::cout << "[LLC Head Grouping] arch=" << (arch.empty() ? "unknown" : arch) + << " llc_mb=" << (llc_bytes / (1024ull * 1024ull)) + << " nhead_q=" << nhead << " nhead_k=" << nhead_k + << " gqa_ratio=" << gqa_ratio << " group_size=" << group_sz + << " groups=" << n_groups << std::endl; + } + fmha_fwd_traits fmha_traits; + init_traits(fmha_traits); + + fmha_fwd_args fmha_args; + init_args(fmha_args); + + fwd_ave_time = fmha_fwd_head_grouping::run_fwd_head_grouped( + stream_config, + fmha_traits, + fmha_args, + nhead, + nhead_k, + group_size_opt.value(), + qscale.type == quant_scale_enum::blockscale, + [&](const auto& traits, auto& args, const auto& sc) { + return fmha_fwd(traits, args, sc); + }); + } + else if(fmha_fwd_head_grouping::log_enabled()) + { + std::cout << "[LLC Head Grouping] skipped (group_size not set or >= nhead)" + << std::endl; + } + } + } + else if(fmha_fwd_head_grouping::log_enabled()) + { + std::cout << "[LLC Head Grouping] disabled by conditions/layout" << std::endl; + } +#endif + + if(fwd_ave_time < 0.0f) + fwd_ave_time = run_fwd(stream_config); if(fwd_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; diff --git a/projects/composablekernel/include/ck_tile/core.hpp b/projects/composablekernel/include/ck_tile/core.hpp index 2d4964f86a4b..f42526ddf75c 100644 --- a/projects/composablekernel/include/ck_tile/core.hpp +++ b/projects/composablekernel/include/ck_tile/core.hpp @@ -23,6 +23,7 @@ #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/arch/mma/mma_transforms.hpp" +#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" #include "ck_tile/core/arch/mma/wmma/wmma.hpp" #include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp" #include "ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 0039c57cfce9..e50e0584cf6e 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -111,6 +111,10 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_o; + + // Optional global head count and head offset (for grouped launches & RNG correctness) + ck_tile::index_t num_head_q_total = 0; + ck_tile::index_t head_start = 0; }; struct FmhaFwdLogitsSoftCapKargs @@ -380,9 +384,11 @@ struct FmhaFwdKernel drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { Kargs kargs{{q_ptr, k_ptr, @@ -418,6 +424,8 @@ struct FmhaFwdKernel batch_stride_k, batch_stride_v, batch_stride_o}; + kargs.num_head_q_total = num_head_q_total; + kargs.head_start = head_start; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -554,9 +562,11 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { return MakeKargsImpl( q_ptr, @@ -614,7 +624,9 @@ struct FmhaFwdKernel block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, - sink_ptr); + sink_ptr, + num_head_q_total, + head_start); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -673,9 +685,11 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { return MakeKargsImpl( q_ptr, @@ -733,7 +747,9 @@ struct FmhaFwdKernel block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, - sink_ptr); + sink_ptr, + num_head_q_total, + head_start); } template @@ -787,9 +803,11 @@ struct FmhaFwdKernel drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { Kargs kargs{{q_ptr, k_ptr, @@ -826,6 +844,8 @@ struct FmhaFwdKernel reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_q_ptr), reinterpret_cast(seqlen_k_ptr)}; + kargs.num_head_q_total = num_head_q_total; + kargs.head_start = head_start; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -959,9 +979,11 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { return MakeKargsImpl( q_ptr, @@ -1014,7 +1036,9 @@ struct FmhaFwdKernel block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, - sink_ptr); + sink_ptr, + num_head_q_total, + head_start); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -1068,9 +1092,11 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { return MakeKargsImpl( q_ptr, @@ -1123,7 +1149,9 @@ struct FmhaFwdKernel block_scale_size_kv, cu_seqlen_q_ptr, cu_seqlen_k_ptr, - sink_ptr); + sink_ptr, + num_head_q_total, + head_start); } CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, @@ -1158,6 +1186,46 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); +#if !defined(CK_TILE_FMHA_FORCE_HEAD_MAJOR) +#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx11__) || defined(__gfx12__)) +#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 1 +#else +#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 0 +#endif +#endif + +#if CK_TILE_FMHA_FORCE_HEAD_MAJOR + // bhsd should satisfy stride_q == hdim_q and nhead_stride_q > hdim_q. + // The extra nhead_stride_q guard prevents bshd false-positive when nhead == 1. + const bool is_bhsd_layout = + (kargs.stride_q == kargs.hdim_q) && (kargs.nhead_stride_q > kargs.hdim_q); + if(is_bhsd_layout) + { + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + const index_t num_tile_total = has_padded_seqlen_k ? gridDim.z : gridDim.y; + const index_t num_head = gridDim.x; + const index_t blocks_per_batch = num_head * num_tile_total; + const index_t linear_id = + blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z); + + const index_t i_batch = linear_id / blocks_per_batch; + const index_t rem0 = linear_id - i_batch * blocks_per_batch; + const index_t i_nhead = rem0 / num_tile_total; + const index_t i_block = rem0 - i_nhead * num_tile_total; + + index_t i_tile_m = i_block / num_tile_n1; + index_t i_tile_n = i_block - i_tile_m * num_tile_n1; + + if constexpr(kHasMask) + { + const index_t num_tile_m = num_tile_total / num_tile_n1; + i_tile_m = num_tile_m - 1 - i_tile_m; + } + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } +#endif + if(has_padded_seqlen_k) { // const index_t num_tile_m0 = seqlen_q / kM0; @@ -1179,7 +1247,8 @@ struct FmhaFwdKernel if constexpr(kHasMask) { // assume that num_tile_n1 is always 1 - return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + return ck_tile::make_tuple( + static_cast(gridDim.z) - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); } else { @@ -1207,7 +1276,8 @@ struct FmhaFwdKernel if constexpr(kHasMask) { // assume that num_tile_n1 is always 1 - return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + return ck_tile::make_tuple( + static_cast(gridDim.y) - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); } else { @@ -1575,9 +1645,12 @@ struct FmhaFwdKernel auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { if constexpr(kHasDropout) { + const auto num_head_q_total = + (kargs.num_head_q_total > 0 ? kargs.num_head_q_total : kargs.num_head_q); + const auto i_head_global = kargs.head_start + i_nhead_; return BlockDropout{i_batch_, - i_nhead_, - kargs.num_head_q, + i_head_global, + num_head_q_total, kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val : *kargs.drop_seed.ptr, kargs.is_drop_seed_offset_from_host From 9ab672005b0f9a08095ed37affbc0b7051109cd2 Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Tue, 3 Mar 2026 18:39:21 +0000 Subject: [PATCH 02/12] macro placement and env API updates --- .../01_fmha/fmha_fwd_head_grouping.hpp | 33 ++++++++++--------- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 17 +++++----- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp index a22f80080cb6..7517fdd33029 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp @@ -8,10 +8,10 @@ #include #include #include -#include #include #include #include +#include #include #include @@ -20,20 +20,20 @@ #endif #if CK_TILE_FMHA_ENABLE_HEAD_GROUPING +CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_HEAD_GROUP_LOG) +CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_FMHA_DISABLE_HEAD_GROUPING) +CK_TILE_DECLARE_ENV_VAR_UINT64(CK_TILE_FMHA_LLC_CACHE_MB) + namespace fmha_fwd_head_grouping { inline bool log_enabled() { - const char* env = std::getenv("CK_TILE_FMHA_HEAD_GROUP_LOG"); - return env != nullptr && std::atoi(env) == 1; + return ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_HEAD_GROUP_LOG)); } inline bool disabled_by_env() { - const char* env_disable = std::getenv("CK_TILE_FMHA_DISABLE_HEAD_GROUPING"); - if(env_disable != nullptr && std::atoi(env_disable) == 1) - return true; - return false; + return ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_FMHA_DISABLE_HEAD_GROUPING)); } inline bool is_decimal_string(const std::string& s) @@ -221,13 +221,17 @@ inline size_t get_llc_cache_bytes(const std::string& arch) { // resolve once and reuse. static const size_t resolved_llc_bytes = [&]() -> size_t { - const char* env_llc_mb = std::getenv("CK_TILE_FMHA_LLC_CACHE_MB"); - if(env_llc_mb != nullptr) + const uint64_t llc_mb = ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_FMHA_LLC_CACHE_MB)); + if(llc_mb > 0) { - const int mb = std::atoi(env_llc_mb); - if(mb > 0) - return static_cast(mb) * 1024ull * 1024ull; + constexpr uint64_t kBytesPerMb = 1024ull * 1024ull; + const uint64_t max_mb_for_size_t = static_cast( + std::numeric_limits::max() / static_cast(kBytesPerMb)); + + if(llc_mb <= max_mb_for_size_t) + return static_cast(llc_mb * kBytesPerMb); } + return resolve_llc_cache_bytes_uncached(arch); }(); @@ -258,7 +262,6 @@ inline std::optional get_head_group_size(ck_tile::index_t nhea return std::nullopt; if(seqlen_k <= 0 || hdim_q <= 0 || hdim_v <= 0 || batch <= 0) return std::nullopt; - static_cast(batch); const size_t kv_bytes_per_head = static_cast(seqlen_k) * @@ -354,8 +357,8 @@ float run_fwd_head_grouped(const ck_tile::stream_config& sc, for(ck_tile::index_t head_start = 0; head_start < nhead; head_start += group_sz) { const ck_tile::index_t q_heads = std::min(group_sz, nhead - head_start); - const ck_tile::index_t k_head_start = (gqa_ratio > 0 ? head_start / gqa_ratio : head_start); - const ck_tile::index_t k_heads = (gqa_ratio > 0 ? q_heads / gqa_ratio : q_heads); + const ck_tile::index_t k_head_start = (gqa_ratio >= 1 ? head_start / gqa_ratio : head_start); + const ck_tile::index_t k_heads = (gqa_ratio >= 1 ? q_heads / gqa_ratio : q_heads); auto args = base_args; args.nhead_q = q_heads; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index e50e0584cf6e..959c3ee0c0cf 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -15,6 +15,15 @@ #include #define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0 + +#if !defined(CK_TILE_FMHA_FORCE_HEAD_MAJOR) +#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx11__) || defined(__gfx12__)) +#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 1 +#else +#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 0 +#endif +#endif + // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] @@ -1186,14 +1195,6 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); -#if !defined(CK_TILE_FMHA_FORCE_HEAD_MAJOR) -#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx11__) || defined(__gfx12__)) -#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 1 -#else -#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 0 -#endif -#endif - #if CK_TILE_FMHA_FORCE_HEAD_MAJOR // bhsd should satisfy stride_q == hdim_q and nhead_stride_q > hdim_q. // The extra nhead_stride_q guard prevents bshd false-positive when nhead == 1. From 49322ecafe1a794f4df7a4b665599fefad413abb Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Wed, 4 Mar 2026 05:09:02 +0000 Subject: [PATCH 03/12] formatting --- .../example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp index 7517fdd33029..9cd1fb9cdcb7 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp @@ -356,9 +356,10 @@ float run_fwd_head_grouped(const ck_tile::stream_config& sc, float total_time = 0.0f; for(ck_tile::index_t head_start = 0; head_start < nhead; head_start += group_sz) { - const ck_tile::index_t q_heads = std::min(group_sz, nhead - head_start); - const ck_tile::index_t k_head_start = (gqa_ratio >= 1 ? head_start / gqa_ratio : head_start); - const ck_tile::index_t k_heads = (gqa_ratio >= 1 ? q_heads / gqa_ratio : q_heads); + const ck_tile::index_t q_heads = std::min(group_sz, nhead - head_start); + const ck_tile::index_t k_head_start = + (gqa_ratio >= 1 ? head_start / gqa_ratio : head_start); + const ck_tile::index_t k_heads = (gqa_ratio >= 1 ? q_heads / gqa_ratio : q_heads); auto args = base_args; args.nhead_q = q_heads; From b5ba2b08db1625311e4c3039f97e270a228b717f Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Fri, 6 Mar 2026 15:30:49 +0000 Subject: [PATCH 04/12] switch bshd head grouping to single-launch remap --- .../example/ck_tile/01_fmha/fmha_fwd.hpp | 86 ++++++++- .../01_fmha/fmha_fwd_head_grouping.hpp | 102 ----------- .../ck_tile/01_fmha/fmha_fwd_runner.hpp | 87 +-------- .../composablekernel/include/ck_tile/core.hpp | 8 + .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 170 +++++++++++------- .../ck_tile/ops/grouped_convolution.hpp | 2 +- 6 files changed, 199 insertions(+), 256 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp index 9990bf088c8f..ef91554477e7 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -10,10 +10,12 @@ #include "ck_tile/ops/fmha.hpp" #include "bias.hpp" +#include "fmha_fwd_head_grouping.hpp" #include "mask.hpp" #include "quant.hpp" #include "rotary.hpp" +#include #include #include #include @@ -242,8 +244,6 @@ struct fmha_fwd_args ck_tile::index_t hdim_v; ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; - ck_tile::index_t num_head_q_total = 0; - ck_tile::index_t head_start = 0; float scale_s; float logits_soft_cap; @@ -289,6 +289,10 @@ struct fmha_fwd_args ck_tile::index_t block_scale_size_q; ck_tile::index_t block_scale_size_kv; + + // Optional override for implicit single-launch head grouping. + // 0 means "auto decide in CK using LLC-aware policy". + ck_tile::index_t head_group_size_q = 0; }; struct fmha_fwd_pagedkv_args @@ -613,11 +617,81 @@ struct fmha_batch_prefill_args ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension }; +template +CK_TILE_HOST ck_tile::index_t fmha_fwd_resolve_head_group_size_q(const fmha_fwd_args& args) +{ + if(args.nhead_q <= 1) + return 0; + + if(args.head_group_size_q > 0) + { + const ck_tile::index_t explicit_group = std::min(args.head_group_size_q, args.nhead_q); + if(explicit_group < args.nhead_q) + return explicit_group; + return 0; + } + +#if CK_TILE_FMHA_ENABLE_HEAD_GROUPING + if(fmha_fwd_head_grouping::disabled_by_env()) + return 0; + + if(args.nhead_k <= 0 || (args.nhead_q % args.nhead_k) != 0) + return 0; + + // Apply implicit single-launch grouping only for bshd. + const bool is_bshd_layout = + (args.nhead_stride_q == args.hdim_q) && (args.stride_q > args.hdim_q); + if(!is_bshd_layout) + return 0; + + ck_tile::index_t seqlen_k_for_policy = args.seqlen_k; + if(args.batch > 0 && args.seqstart_k_ptr != nullptr && args.seqlen_k_ptr == nullptr) + { + // group-mode without explicit per-batch seqlen: use per-batch average as policy input. + seqlen_k_for_policy = ck_tile::integer_divide_ceil(args.seqlen_k, args.batch); + } + + const auto group_size_opt = + fmha_fwd_head_grouping::get_head_group_size(args.nhead_q, + args.nhead_k, + args.batch, + seqlen_k_for_policy, + args.hdim_q, + args.hdim_v, + sizeof(typename FmhaKernel::KDataType), + sizeof(typename FmhaKernel::VDataType)); + if(!group_size_opt.has_value()) + return 0; + + const ck_tile::index_t group_size = group_size_opt.value(); + if(group_size <= 0 || group_size >= args.nhead_q) + return 0; + + if(fmha_fwd_head_grouping::log_enabled()) + { + const std::string arch = ck_tile::get_device_name(); + const size_t llc_bytes = fmha_fwd_head_grouping::get_llc_cache_bytes(arch); + const ck_tile::index_t gqa_ratio = (args.nhead_k > 0 ? (args.nhead_q / args.nhead_k) : 1); + const ck_tile::index_t n_groups = ck_tile::integer_divide_ceil(args.nhead_q, group_size); + std::cout << "[LLC Head Grouping] enabled (fmha_fwd auto)" + << " arch=" << (arch.empty() ? "unknown" : arch) + << " llc_mb=" << (llc_bytes / (1024ull * 1024ull)) << " nhead_q=" << args.nhead_q + << " nhead_k=" << args.nhead_k << " gqa_ratio=" << gqa_ratio + << " group_size=" << group_size << " groups=" << n_groups << std::endl; + } + + return group_size; +#else + return 0; +#endif +} + template auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) { assert(args.nhead_q % args.nhead_k == 0); - auto kargs = [&] { + const ck_tile::index_t head_group_size_q = fmha_fwd_resolve_head_group_size_q(args); + auto kargs = [&] { // create group mode kernel arguments if constexpr(FmhaKernel::kIsGroupMode) { @@ -672,8 +746,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.sink_ptr, - args.num_head_q_total, - args.head_start); + head_group_size_q); } else { // create batch mode kernel arguments @@ -733,8 +806,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.sink_ptr, - args.num_head_q_total, - args.head_start); + head_group_size_q); } }(); diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp index 9cd1fb9cdcb7..bc4c0e50a1c6 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp @@ -312,107 +312,5 @@ inline std::optional get_head_group_size(ck_tile::index_t nhea return group; } -template -inline const void* ptr_offset(const void* base, ck_tile::index_t offset_elems) -{ - if(base == nullptr) - return nullptr; - return static_cast(reinterpret_cast(base) + offset_elems); -} - -template -inline void* ptr_offset(void* base, ck_tile::index_t offset_elems) -{ - if(base == nullptr) - return nullptr; - return static_cast(reinterpret_cast(base) + offset_elems); -} - -template -float run_fwd_head_grouped(const ck_tile::stream_config& sc, - const FmhaFwdTraits& fmha_traits, - const FmhaFwdArgs& base_args_in, - ck_tile::index_t nhead, - ck_tile::index_t nhead_k, - ck_tile::index_t group_size_q, - bool use_blockscale_qscale, - RunKernelFn&& run_kernel_fn) -{ - auto base_args = base_args_in; - base_args.num_head_q_total = nhead; - const ck_tile::index_t gqa_ratio = (nhead_k > 0 ? (nhead / nhead_k) : 1); - const ck_tile::index_t group_sz = std::min(group_size_q, nhead); - const ck_tile::index_t n_groups = ck_tile::integer_divide_ceil(nhead, group_sz); - - float total_time = 0.0f; - for(ck_tile::index_t head_start = 0; head_start < nhead; head_start += group_sz) - { - const ck_tile::index_t q_heads = std::min(group_sz, nhead - head_start); - const ck_tile::index_t k_head_start = - (gqa_ratio >= 1 ? head_start / gqa_ratio : head_start); - const ck_tile::index_t k_heads = (gqa_ratio >= 1 ? q_heads / gqa_ratio : q_heads); - - auto args = base_args; - args.nhead_q = q_heads; - args.nhead_k = k_heads; - args.head_start = head_start; - - args.q_ptr = ptr_offset(base_args.q_ptr, head_start * base_args.nhead_stride_q); - args.k_ptr = - ptr_offset(base_args.k_ptr, k_head_start * base_args.nhead_stride_k); - args.v_ptr = - ptr_offset(base_args.v_ptr, k_head_start * base_args.nhead_stride_v); - args.o_ptr = ptr_offset(base_args.o_ptr, head_start * base_args.nhead_stride_o); - - args.bias_ptr = - ptr_offset(base_args.bias_ptr, head_start * base_args.nhead_stride_bias); - args.lse_ptr = - ptr_offset(base_args.lse_ptr, head_start * base_args.nhead_stride_lse); - args.rand_val_ptr = ptr_offset( - base_args.rand_val_ptr, head_start * base_args.nhead_stride_randval); - - if(use_blockscale_qscale) - { - args.q_descale_ptr = ptr_offset(base_args.q_descale_ptr, - head_start * base_args.nhead_stride_q_descale); - args.k_descale_ptr = ptr_offset(base_args.k_descale_ptr, - k_head_start * base_args.nhead_stride_k_descale); - args.v_descale_ptr = ptr_offset(base_args.v_descale_ptr, - k_head_start * base_args.nhead_stride_v_descale); - } - else - { - args.q_descale_ptr = base_args.q_descale_ptr; - args.k_descale_ptr = base_args.k_descale_ptr; - args.v_descale_ptr = base_args.v_descale_ptr; - } - - args.sink_ptr = ptr_offset(base_args.sink_ptr, head_start); - - if(log_enabled()) - { - const ck_tile::index_t head_end = head_start + q_heads; - std::cout << "[LLC Head Grouping] group " << (head_start / group_sz) << "/" << n_groups - << " heads_q=[" << head_start << ", " << head_end << ") heads_k=[" - << k_head_start << ", " << (k_head_start + k_heads) << ")" << std::endl; - } - - const float t = run_kernel_fn(fmha_traits, args, sc); - if(t < 0.0f) - return t; - total_time += t; - } - return total_time; -} - } // namespace fmha_fwd_head_grouping #endif // CK_TILE_FMHA_ENABLE_HEAD_GROUPING diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index a616da9d2b5e..f6ca276c32e3 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -6,7 +6,6 @@ #include "ck_tile/host.hpp" #include "ck_tile/ref/naive_attention.hpp" #include "fmha_fwd.hpp" -#include "fmha_fwd_head_grouping.hpp" #include "utils.hpp" #include "ck_tile/utility/json_dump.hpp" @@ -1078,11 +1077,6 @@ fwd_result fmha_fwd_run(mode_enum mode, args.hdim_v = hdim_v; args.nhead_q = nhead; args.nhead_k = nhead_k; - if constexpr(std::is_same_v>) - { - args.num_head_q_total = nhead; - args.head_start = 0; - } args.stride_q = stride_q; args.stride_k = stride_k; @@ -1374,86 +1368,7 @@ fwd_result fmha_fwd_run(mode_enum mode, return fmha_fwd(fmha_traits, fmha_args, sc); }; - float fwd_ave_time = -1.0f; -#if CK_TILE_FMHA_ENABLE_HEAD_GROUPING - const bool allow_head_grouping = !i_perm && !use_kvcache && (num_splits <= 1) && - !need_append_kvcache && - (mode == mode_enum::batch || mode == mode_enum::group); - - if(allow_head_grouping) - { - if(fmha_fwd_head_grouping::disabled_by_env()) - { - if(fmha_fwd_head_grouping::log_enabled()) - std::cout << "[LLC Head Grouping] disabled by env" << std::endl; - } - else - { - const auto group_size_opt = - fmha_fwd_head_grouping::get_head_group_size(nhead, - nhead_k, - batch, - max_seqlen_k, - hdim_q, - hdim_v, - sizeof(KDataType), - sizeof(VDataType)); - - if(group_size_opt.has_value() && group_size_opt.value() < nhead) - { - if(fmha_fwd_head_grouping::log_enabled()) - { - const std::string arch = ck_tile::get_device_name(); - const size_t llc_bytes = fmha_fwd_head_grouping::get_llc_cache_bytes(arch); - const ck_tile::index_t gqa_ratio = (nhead_k > 0 ? (nhead / nhead_k) : 1); - const ck_tile::index_t group_sz = group_size_opt.value(); - const ck_tile::index_t n_groups = ck_tile::integer_divide_ceil(nhead, group_sz); - std::cout << "[LLC Head Grouping] enabled" << std::endl; - std::cout << "[LLC Head Grouping] arch=" << (arch.empty() ? "unknown" : arch) - << " llc_mb=" << (llc_bytes / (1024ull * 1024ull)) - << " nhead_q=" << nhead << " nhead_k=" << nhead_k - << " gqa_ratio=" << gqa_ratio << " group_size=" << group_sz - << " groups=" << n_groups << std::endl; - } - fmha_fwd_traits fmha_traits; - init_traits(fmha_traits); - - fmha_fwd_args fmha_args; - init_args(fmha_args); - - fwd_ave_time = fmha_fwd_head_grouping::run_fwd_head_grouped( - stream_config, - fmha_traits, - fmha_args, - nhead, - nhead_k, - group_size_opt.value(), - qscale.type == quant_scale_enum::blockscale, - [&](const auto& traits, auto& args, const auto& sc) { - return fmha_fwd(traits, args, sc); - }); - } - else if(fmha_fwd_head_grouping::log_enabled()) - { - std::cout << "[LLC Head Grouping] skipped (group_size not set or >= nhead)" - << std::endl; - } - } - } - else if(fmha_fwd_head_grouping::log_enabled()) - { - std::cout << "[LLC Head Grouping] disabled by conditions/layout" << std::endl; - } -#endif - - if(fwd_ave_time < 0.0f) - fwd_ave_time = run_fwd(stream_config); + float fwd_ave_time = run_fwd(stream_config); if(fwd_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; diff --git a/projects/composablekernel/include/ck_tile/core.hpp b/projects/composablekernel/include/ck_tile/core.hpp index f42526ddf75c..c377d6b4b953 100644 --- a/projects/composablekernel/include/ck_tile/core.hpp +++ b/projects/composablekernel/include/ck_tile/core.hpp @@ -20,9 +20,17 @@ #include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp" #include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp" #include "ck_tile/core/arch/mma/mma.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/arch/mma/mma_transforms.hpp" +#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp" +#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" +#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp" +#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp" #include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" #include "ck_tile/core/arch/mma/wmma/wmma.hpp" #include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 959c3ee0c0cf..fb18f43f73f9 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -15,15 +15,6 @@ #include #define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0 - -#if !defined(CK_TILE_FMHA_FORCE_HEAD_MAJOR) -#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx11__) || defined(__gfx12__)) -#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 1 -#else -#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 0 -#endif -#endif - // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] @@ -121,9 +112,8 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_o; - // Optional global head count and head offset (for grouped launches & RNG correctness) - ck_tile::index_t num_head_q_total = 0; - ck_tile::index_t head_start = 0; + // Optional implicit head-group size for single-launch grouping. + ck_tile::index_t head_group_size_q = 0; }; struct FmhaFwdLogitsSoftCapKargs @@ -393,11 +383,10 @@ struct FmhaFwdKernel drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr, - ck_tile::index_t num_head_q_total = 0, - ck_tile::index_t head_start = 0) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t head_group_size_q = 0) { Kargs kargs{{q_ptr, k_ptr, @@ -433,8 +422,7 @@ struct FmhaFwdKernel batch_stride_k, batch_stride_v, batch_stride_o}; - kargs.num_head_q_total = num_head_q_total; - kargs.head_start = head_start; + kargs.head_group_size_q = head_group_size_q; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -571,11 +559,10 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr, - ck_tile::index_t num_head_q_total = 0, - ck_tile::index_t head_start = 0) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t head_group_size_q = 0) { return MakeKargsImpl( q_ptr, @@ -634,8 +621,7 @@ struct FmhaFwdKernel cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr, - num_head_q_total, - head_start); + head_group_size_q); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -694,11 +680,10 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr, - ck_tile::index_t num_head_q_total = 0, - ck_tile::index_t head_start = 0) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t head_group_size_q = 0) { return MakeKargsImpl( q_ptr, @@ -757,8 +742,7 @@ struct FmhaFwdKernel cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr, - num_head_q_total, - head_start); + head_group_size_q); } template @@ -812,11 +796,10 @@ struct FmhaFwdKernel drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr, - ck_tile::index_t num_head_q_total = 0, - ck_tile::index_t head_start = 0) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t head_group_size_q = 0) { Kargs kargs{{q_ptr, k_ptr, @@ -853,8 +836,7 @@ struct FmhaFwdKernel reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_q_ptr), reinterpret_cast(seqlen_k_ptr)}; - kargs.num_head_q_total = num_head_q_total; - kargs.head_start = head_start; + kargs.head_group_size_q = head_group_size_q; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -988,11 +970,10 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr, - ck_tile::index_t num_head_q_total = 0, - ck_tile::index_t head_start = 0) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t head_group_size_q = 0) { return MakeKargsImpl( q_ptr, @@ -1046,8 +1027,7 @@ struct FmhaFwdKernel cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr, - num_head_q_total, - head_start); + head_group_size_q); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -1101,11 +1081,10 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr, - ck_tile::index_t num_head_q_total = 0, - ck_tile::index_t head_start = 0) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t head_group_size_q = 0) { return MakeKargsImpl( q_ptr, @@ -1159,8 +1138,7 @@ struct FmhaFwdKernel cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr, - num_head_q_total, - head_start); + head_group_size_q); } CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, @@ -1195,11 +1173,86 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); -#if CK_TILE_FMHA_FORCE_HEAD_MAJOR +#if !defined(CK_TILE_FMHA_FORCE_HEAD_MAJOR) +#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx11__) || defined(__gfx12__)) +#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 1 +#else +#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 0 +#endif +#endif + // bhsd should satisfy stride_q == hdim_q and nhead_stride_q > hdim_q. // The extra nhead_stride_q guard prevents bshd false-positive when nhead == 1. const bool is_bhsd_layout = (kargs.stride_q == kargs.hdim_q) && (kargs.nhead_stride_q > kargs.hdim_q); + + // Single-launch head grouping for bshd only. + if((kargs.head_group_size_q > 0) && !is_bhsd_layout) + { + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + const index_t num_tile_total = has_padded_seqlen_k ? gridDim.z : gridDim.y; + const index_t num_head = gridDim.x; + const index_t batch_size = has_padded_seqlen_k ? gridDim.y : gridDim.z; + const index_t linear_id = + blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z); + + const index_t group_sz = ck_tile::min(kargs.head_group_size_q, num_head); + if(group_sz > 0 && group_sz < num_head) + { + const index_t full_groups = num_head / group_sz; + const index_t tail_heads = num_head - full_groups * group_sz; + const index_t spans_per_group = batch_size * num_tile_total; + const index_t blocks_per_group = spans_per_group * group_sz; + + index_t i_group = 0; + index_t i_batch = 0; + index_t i_block = 0; + index_t i_nhead_in_group = 0; + + if((tail_heads > 0) && (linear_id >= full_groups * blocks_per_group)) + { + // Tail group decode: [group=tail] -> [batch] -> [block] -> [head_in_group] + const index_t tail_linear = linear_id - full_groups * blocks_per_group; + const index_t tail_group_heads = tail_heads; + const index_t tail_blocks_per_batch = num_tile_total * tail_group_heads; + + i_group = full_groups; + i_batch = tail_linear / tail_blocks_per_batch; + + const index_t rem1 = tail_linear - i_batch * tail_blocks_per_batch; + i_block = rem1 / tail_group_heads; + i_nhead_in_group = rem1 - i_block * tail_group_heads; + } + else + { + // Full group decode: [group] -> [batch] -> [block] -> [head_in_group] + i_group = linear_id / blocks_per_group; + + const index_t rem0 = linear_id - i_group * blocks_per_group; + const index_t blocks_per_batch = num_tile_total * group_sz; + i_batch = rem0 / blocks_per_batch; + + const index_t rem1 = rem0 - i_batch * blocks_per_batch; + i_block = rem1 / group_sz; + i_nhead_in_group = rem1 - i_block * group_sz; + } + + const index_t i_nhead = i_group * group_sz + i_nhead_in_group; + + index_t i_tile_m = i_block / num_tile_n1; + index_t i_tile_n = i_block - i_tile_m * num_tile_n1; + + if constexpr(kHasMask) + { + const index_t num_tile_m = num_tile_total / num_tile_n1; + i_tile_m = num_tile_m - 1 - i_tile_m; + } + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + +#if CK_TILE_FMHA_FORCE_HEAD_MAJOR if(is_bhsd_layout) { const index_t num_tile_n1 = @@ -1646,12 +1699,9 @@ struct FmhaFwdKernel auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { if constexpr(kHasDropout) { - const auto num_head_q_total = - (kargs.num_head_q_total > 0 ? kargs.num_head_q_total : kargs.num_head_q); - const auto i_head_global = kargs.head_start + i_nhead_; return BlockDropout{i_batch_, - i_head_global, - num_head_q_total, + i_nhead_, + kargs.num_head_q, kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val : *kargs.drop_seed.ptr, kargs.is_drop_seed_offset_from_host diff --git a/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp b/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp index 3c7b00782f65..5bc4f0c6a042 100644 --- a/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp +++ b/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp @@ -2,10 +2,10 @@ // SPDX-License-Identifier: MIT #pragma once -#include "ck_tile/ops/grouped_convolution/pipeline/grouped_conv_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp" +#include "ck_tile/ops/grouped_convolution/pipeline/grouped_conv_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" #include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp" From cb1de3254a0116666c871363365a21bfbe90c971 Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Fri, 6 Mar 2026 15:36:34 +0000 Subject: [PATCH 05/12] formatting --- .../ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index fb18f43f73f9..0a1de838095c 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -15,6 +15,15 @@ #include #define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0 + +#if !defined(CK_TILE_FMHA_FORCE_HEAD_MAJOR) +#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx11__) || defined(__gfx12__)) +#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 1 +#else +#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 0 +#endif +#endif + // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] @@ -1173,14 +1182,6 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); -#if !defined(CK_TILE_FMHA_FORCE_HEAD_MAJOR) -#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx11__) || defined(__gfx12__)) -#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 1 -#else -#define CK_TILE_FMHA_FORCE_HEAD_MAJOR 0 -#endif -#endif - // bhsd should satisfy stride_q == hdim_q and nhead_stride_q > hdim_q. // The extra nhead_stride_q guard prevents bshd false-positive when nhead == 1. const bool is_bhsd_layout = From 5495f9732e22c10411e41e0df1ec158b233c894c Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Fri, 6 Mar 2026 17:11:22 +0000 Subject: [PATCH 06/12] remove unnecessary changes --- .../example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index f6ca276c32e3..fa6268ecbbb8 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -10,7 +10,6 @@ #include "ck_tile/utility/json_dump.hpp" #include -#include #include #include #include @@ -1368,7 +1367,7 @@ fwd_result fmha_fwd_run(mode_enum mode, return fmha_fwd(fmha_traits, fmha_args, sc); }; - float fwd_ave_time = run_fwd(stream_config); + const float fwd_ave_time = run_fwd(stream_config); if(fwd_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; From 0c97da41e1af56d36352123f2677bb8312245154 Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Fri, 6 Mar 2026 19:08:01 +0000 Subject: [PATCH 07/12] Remove head-group count caps --- .../example/ck_tile/01_fmha/fmha_fwd.hpp | 14 ++++---------- .../ck_tile/01_fmha/fmha_fwd_head_grouping.hpp | 14 +------------- 2 files changed, 5 insertions(+), 23 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp index ef91554477e7..01dea109661b 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -632,20 +632,17 @@ CK_TILE_HOST ck_tile::index_t fmha_fwd_resolve_head_group_size_q(const fmha_fwd_ } #if CK_TILE_FMHA_ENABLE_HEAD_GROUPING - if(fmha_fwd_head_grouping::disabled_by_env()) - return 0; - - if(args.nhead_k <= 0 || (args.nhead_q % args.nhead_k) != 0) - return 0; - // Apply implicit single-launch grouping only for bshd. const bool is_bshd_layout = (args.nhead_stride_q == args.hdim_q) && (args.stride_q > args.hdim_q); if(!is_bshd_layout) return 0; + if(args.batch <= 0) + return 0; + ck_tile::index_t seqlen_k_for_policy = args.seqlen_k; - if(args.batch > 0 && args.seqstart_k_ptr != nullptr && args.seqlen_k_ptr == nullptr) + if(args.seqstart_k_ptr != nullptr && args.seqlen_k_ptr == nullptr) { // group-mode without explicit per-batch seqlen: use per-batch average as policy input. seqlen_k_for_policy = ck_tile::integer_divide_ceil(args.seqlen_k, args.batch); @@ -654,7 +651,6 @@ CK_TILE_HOST ck_tile::index_t fmha_fwd_resolve_head_group_size_q(const fmha_fwd_ const auto group_size_opt = fmha_fwd_head_grouping::get_head_group_size(args.nhead_q, args.nhead_k, - args.batch, seqlen_k_for_policy, args.hdim_q, args.hdim_v, @@ -664,8 +660,6 @@ CK_TILE_HOST ck_tile::index_t fmha_fwd_resolve_head_group_size_q(const fmha_fwd_ return 0; const ck_tile::index_t group_size = group_size_opt.value(); - if(group_size <= 0 || group_size >= args.nhead_q) - return 0; if(fmha_fwd_head_grouping::log_enabled()) { diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp index bc4c0e50a1c6..3bfe2895ebe8 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp @@ -240,7 +240,6 @@ inline size_t get_llc_cache_bytes(const std::string& arch) inline std::optional get_head_group_size(ck_tile::index_t nhead_q, ck_tile::index_t nhead_k, - ck_tile::index_t batch, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, @@ -260,7 +259,7 @@ inline std::optional get_head_group_size(ck_tile::index_t nhea if(nhead_k <= 0 || nhead_q <= 0 || (nhead_q % nhead_k) != 0) return std::nullopt; - if(seqlen_k <= 0 || hdim_q <= 0 || hdim_v <= 0 || batch <= 0) + if(seqlen_k <= 0 || hdim_q <= 0 || hdim_v <= 0) return std::nullopt; const size_t kv_bytes_per_head = @@ -288,17 +287,6 @@ inline std::optional get_head_group_size(ck_tile::index_t nhea if(group < 1) group = 1; - const ck_tile::index_t min_group_size = std::max(1, nhead_q / 16); - if(group < min_group_size) - group = min_group_size; - - // Cap the number of groups to avoid excessive launch overhead. - constexpr ck_tile::index_t kMaxGroups = 8; - const ck_tile::index_t min_group_for_max_groups = - ck_tile::integer_divide_ceil(nhead_q, kMaxGroups); - if(group < min_group_for_max_groups) - group = min_group_for_max_groups; - const ck_tile::index_t gqa_ratio = nhead_q / nhead_k; if(gqa_ratio > 1) { From aee75887d43aaa58c510aa2717370faee2eb090d Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Sat, 7 Mar 2026 17:17:32 +0000 Subject: [PATCH 08/12] Revert single-launch head-grouping series --- .../example/ck_tile/01_fmha/fmha_fwd.hpp | 80 +-------- .../01_fmha/fmha_fwd_head_grouping.hpp | 116 ++++++++++++- .../ck_tile/01_fmha/fmha_fwd_runner.hpp | 88 +++++++++- .../composablekernel/include/ck_tile/core.hpp | 8 - .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 153 ++++++------------ .../ck_tile/ops/grouped_convolution.hpp | 2 +- 6 files changed, 261 insertions(+), 186 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp index 01dea109661b..9990bf088c8f 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -10,12 +10,10 @@ #include "ck_tile/ops/fmha.hpp" #include "bias.hpp" -#include "fmha_fwd_head_grouping.hpp" #include "mask.hpp" #include "quant.hpp" #include "rotary.hpp" -#include #include #include #include @@ -244,6 +242,8 @@ struct fmha_fwd_args ck_tile::index_t hdim_v; ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; + ck_tile::index_t num_head_q_total = 0; + ck_tile::index_t head_start = 0; float scale_s; float logits_soft_cap; @@ -289,10 +289,6 @@ struct fmha_fwd_args ck_tile::index_t block_scale_size_q; ck_tile::index_t block_scale_size_kv; - - // Optional override for implicit single-launch head grouping. - // 0 means "auto decide in CK using LLC-aware policy". - ck_tile::index_t head_group_size_q = 0; }; struct fmha_fwd_pagedkv_args @@ -617,75 +613,11 @@ struct fmha_batch_prefill_args ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension }; -template -CK_TILE_HOST ck_tile::index_t fmha_fwd_resolve_head_group_size_q(const fmha_fwd_args& args) -{ - if(args.nhead_q <= 1) - return 0; - - if(args.head_group_size_q > 0) - { - const ck_tile::index_t explicit_group = std::min(args.head_group_size_q, args.nhead_q); - if(explicit_group < args.nhead_q) - return explicit_group; - return 0; - } - -#if CK_TILE_FMHA_ENABLE_HEAD_GROUPING - // Apply implicit single-launch grouping only for bshd. - const bool is_bshd_layout = - (args.nhead_stride_q == args.hdim_q) && (args.stride_q > args.hdim_q); - if(!is_bshd_layout) - return 0; - - if(args.batch <= 0) - return 0; - - ck_tile::index_t seqlen_k_for_policy = args.seqlen_k; - if(args.seqstart_k_ptr != nullptr && args.seqlen_k_ptr == nullptr) - { - // group-mode without explicit per-batch seqlen: use per-batch average as policy input. - seqlen_k_for_policy = ck_tile::integer_divide_ceil(args.seqlen_k, args.batch); - } - - const auto group_size_opt = - fmha_fwd_head_grouping::get_head_group_size(args.nhead_q, - args.nhead_k, - seqlen_k_for_policy, - args.hdim_q, - args.hdim_v, - sizeof(typename FmhaKernel::KDataType), - sizeof(typename FmhaKernel::VDataType)); - if(!group_size_opt.has_value()) - return 0; - - const ck_tile::index_t group_size = group_size_opt.value(); - - if(fmha_fwd_head_grouping::log_enabled()) - { - const std::string arch = ck_tile::get_device_name(); - const size_t llc_bytes = fmha_fwd_head_grouping::get_llc_cache_bytes(arch); - const ck_tile::index_t gqa_ratio = (args.nhead_k > 0 ? (args.nhead_q / args.nhead_k) : 1); - const ck_tile::index_t n_groups = ck_tile::integer_divide_ceil(args.nhead_q, group_size); - std::cout << "[LLC Head Grouping] enabled (fmha_fwd auto)" - << " arch=" << (arch.empty() ? "unknown" : arch) - << " llc_mb=" << (llc_bytes / (1024ull * 1024ull)) << " nhead_q=" << args.nhead_q - << " nhead_k=" << args.nhead_k << " gqa_ratio=" << gqa_ratio - << " group_size=" << group_size << " groups=" << n_groups << std::endl; - } - - return group_size; -#else - return 0; -#endif -} - template auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) { assert(args.nhead_q % args.nhead_k == 0); - const ck_tile::index_t head_group_size_q = fmha_fwd_resolve_head_group_size_q(args); - auto kargs = [&] { + auto kargs = [&] { // create group mode kernel arguments if constexpr(FmhaKernel::kIsGroupMode) { @@ -740,7 +672,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.sink_ptr, - head_group_size_q); + args.num_head_q_total, + args.head_start); } else { // create batch mode kernel arguments @@ -800,7 +733,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr, args.sink_ptr, - head_group_size_q); + args.num_head_q_total, + args.head_start); } }(); diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp index 3bfe2895ebe8..9cd1fb9cdcb7 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp @@ -240,6 +240,7 @@ inline size_t get_llc_cache_bytes(const std::string& arch) inline std::optional get_head_group_size(ck_tile::index_t nhead_q, ck_tile::index_t nhead_k, + ck_tile::index_t batch, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, @@ -259,7 +260,7 @@ inline std::optional get_head_group_size(ck_tile::index_t nhea if(nhead_k <= 0 || nhead_q <= 0 || (nhead_q % nhead_k) != 0) return std::nullopt; - if(seqlen_k <= 0 || hdim_q <= 0 || hdim_v <= 0) + if(seqlen_k <= 0 || hdim_q <= 0 || hdim_v <= 0 || batch <= 0) return std::nullopt; const size_t kv_bytes_per_head = @@ -287,6 +288,17 @@ inline std::optional get_head_group_size(ck_tile::index_t nhea if(group < 1) group = 1; + const ck_tile::index_t min_group_size = std::max(1, nhead_q / 16); + if(group < min_group_size) + group = min_group_size; + + // Cap the number of groups to avoid excessive launch overhead. + constexpr ck_tile::index_t kMaxGroups = 8; + const ck_tile::index_t min_group_for_max_groups = + ck_tile::integer_divide_ceil(nhead_q, kMaxGroups); + if(group < min_group_for_max_groups) + group = min_group_for_max_groups; + const ck_tile::index_t gqa_ratio = nhead_q / nhead_k; if(gqa_ratio > 1) { @@ -300,5 +312,107 @@ inline std::optional get_head_group_size(ck_tile::index_t nhea return group; } +template +inline const void* ptr_offset(const void* base, ck_tile::index_t offset_elems) +{ + if(base == nullptr) + return nullptr; + return static_cast(reinterpret_cast(base) + offset_elems); +} + +template +inline void* ptr_offset(void* base, ck_tile::index_t offset_elems) +{ + if(base == nullptr) + return nullptr; + return static_cast(reinterpret_cast(base) + offset_elems); +} + +template +float run_fwd_head_grouped(const ck_tile::stream_config& sc, + const FmhaFwdTraits& fmha_traits, + const FmhaFwdArgs& base_args_in, + ck_tile::index_t nhead, + ck_tile::index_t nhead_k, + ck_tile::index_t group_size_q, + bool use_blockscale_qscale, + RunKernelFn&& run_kernel_fn) +{ + auto base_args = base_args_in; + base_args.num_head_q_total = nhead; + const ck_tile::index_t gqa_ratio = (nhead_k > 0 ? (nhead / nhead_k) : 1); + const ck_tile::index_t group_sz = std::min(group_size_q, nhead); + const ck_tile::index_t n_groups = ck_tile::integer_divide_ceil(nhead, group_sz); + + float total_time = 0.0f; + for(ck_tile::index_t head_start = 0; head_start < nhead; head_start += group_sz) + { + const ck_tile::index_t q_heads = std::min(group_sz, nhead - head_start); + const ck_tile::index_t k_head_start = + (gqa_ratio >= 1 ? head_start / gqa_ratio : head_start); + const ck_tile::index_t k_heads = (gqa_ratio >= 1 ? q_heads / gqa_ratio : q_heads); + + auto args = base_args; + args.nhead_q = q_heads; + args.nhead_k = k_heads; + args.head_start = head_start; + + args.q_ptr = ptr_offset(base_args.q_ptr, head_start * base_args.nhead_stride_q); + args.k_ptr = + ptr_offset(base_args.k_ptr, k_head_start * base_args.nhead_stride_k); + args.v_ptr = + ptr_offset(base_args.v_ptr, k_head_start * base_args.nhead_stride_v); + args.o_ptr = ptr_offset(base_args.o_ptr, head_start * base_args.nhead_stride_o); + + args.bias_ptr = + ptr_offset(base_args.bias_ptr, head_start * base_args.nhead_stride_bias); + args.lse_ptr = + ptr_offset(base_args.lse_ptr, head_start * base_args.nhead_stride_lse); + args.rand_val_ptr = ptr_offset( + base_args.rand_val_ptr, head_start * base_args.nhead_stride_randval); + + if(use_blockscale_qscale) + { + args.q_descale_ptr = ptr_offset(base_args.q_descale_ptr, + head_start * base_args.nhead_stride_q_descale); + args.k_descale_ptr = ptr_offset(base_args.k_descale_ptr, + k_head_start * base_args.nhead_stride_k_descale); + args.v_descale_ptr = ptr_offset(base_args.v_descale_ptr, + k_head_start * base_args.nhead_stride_v_descale); + } + else + { + args.q_descale_ptr = base_args.q_descale_ptr; + args.k_descale_ptr = base_args.k_descale_ptr; + args.v_descale_ptr = base_args.v_descale_ptr; + } + + args.sink_ptr = ptr_offset(base_args.sink_ptr, head_start); + + if(log_enabled()) + { + const ck_tile::index_t head_end = head_start + q_heads; + std::cout << "[LLC Head Grouping] group " << (head_start / group_sz) << "/" << n_groups + << " heads_q=[" << head_start << ", " << head_end << ") heads_k=[" + << k_head_start << ", " << (k_head_start + k_heads) << ")" << std::endl; + } + + const float t = run_kernel_fn(fmha_traits, args, sc); + if(t < 0.0f) + return t; + total_time += t; + } + return total_time; +} + } // namespace fmha_fwd_head_grouping #endif // CK_TILE_FMHA_ENABLE_HEAD_GROUPING diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index fa6268ecbbb8..a616da9d2b5e 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -6,10 +6,12 @@ #include "ck_tile/host.hpp" #include "ck_tile/ref/naive_attention.hpp" #include "fmha_fwd.hpp" +#include "fmha_fwd_head_grouping.hpp" #include "utils.hpp" #include "ck_tile/utility/json_dump.hpp" #include +#include #include #include #include @@ -1076,6 +1078,11 @@ fwd_result fmha_fwd_run(mode_enum mode, args.hdim_v = hdim_v; args.nhead_q = nhead; args.nhead_k = nhead_k; + if constexpr(std::is_same_v>) + { + args.num_head_q_total = nhead; + args.head_start = 0; + } args.stride_q = stride_q; args.stride_k = stride_k; @@ -1367,7 +1374,86 @@ fwd_result fmha_fwd_run(mode_enum mode, return fmha_fwd(fmha_traits, fmha_args, sc); }; - const float fwd_ave_time = run_fwd(stream_config); + float fwd_ave_time = -1.0f; +#if CK_TILE_FMHA_ENABLE_HEAD_GROUPING + const bool allow_head_grouping = !i_perm && !use_kvcache && (num_splits <= 1) && + !need_append_kvcache && + (mode == mode_enum::batch || mode == mode_enum::group); + + if(allow_head_grouping) + { + if(fmha_fwd_head_grouping::disabled_by_env()) + { + if(fmha_fwd_head_grouping::log_enabled()) + std::cout << "[LLC Head Grouping] disabled by env" << std::endl; + } + else + { + const auto group_size_opt = + fmha_fwd_head_grouping::get_head_group_size(nhead, + nhead_k, + batch, + max_seqlen_k, + hdim_q, + hdim_v, + sizeof(KDataType), + sizeof(VDataType)); + + if(group_size_opt.has_value() && group_size_opt.value() < nhead) + { + if(fmha_fwd_head_grouping::log_enabled()) + { + const std::string arch = ck_tile::get_device_name(); + const size_t llc_bytes = fmha_fwd_head_grouping::get_llc_cache_bytes(arch); + const ck_tile::index_t gqa_ratio = (nhead_k > 0 ? (nhead / nhead_k) : 1); + const ck_tile::index_t group_sz = group_size_opt.value(); + const ck_tile::index_t n_groups = ck_tile::integer_divide_ceil(nhead, group_sz); + std::cout << "[LLC Head Grouping] enabled" << std::endl; + std::cout << "[LLC Head Grouping] arch=" << (arch.empty() ? "unknown" : arch) + << " llc_mb=" << (llc_bytes / (1024ull * 1024ull)) + << " nhead_q=" << nhead << " nhead_k=" << nhead_k + << " gqa_ratio=" << gqa_ratio << " group_size=" << group_sz + << " groups=" << n_groups << std::endl; + } + fmha_fwd_traits fmha_traits; + init_traits(fmha_traits); + + fmha_fwd_args fmha_args; + init_args(fmha_args); + + fwd_ave_time = fmha_fwd_head_grouping::run_fwd_head_grouped( + stream_config, + fmha_traits, + fmha_args, + nhead, + nhead_k, + group_size_opt.value(), + qscale.type == quant_scale_enum::blockscale, + [&](const auto& traits, auto& args, const auto& sc) { + return fmha_fwd(traits, args, sc); + }); + } + else if(fmha_fwd_head_grouping::log_enabled()) + { + std::cout << "[LLC Head Grouping] skipped (group_size not set or >= nhead)" + << std::endl; + } + } + } + else if(fmha_fwd_head_grouping::log_enabled()) + { + std::cout << "[LLC Head Grouping] disabled by conditions/layout" << std::endl; + } +#endif + + if(fwd_ave_time < 0.0f) + fwd_ave_time = run_fwd(stream_config); if(fwd_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; diff --git a/projects/composablekernel/include/ck_tile/core.hpp b/projects/composablekernel/include/ck_tile/core.hpp index c377d6b4b953..f42526ddf75c 100644 --- a/projects/composablekernel/include/ck_tile/core.hpp +++ b/projects/composablekernel/include/ck_tile/core.hpp @@ -20,17 +20,9 @@ #include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp" #include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp" #include "ck_tile/core/arch/mma/mma.hpp" -#include "ck_tile/core/arch/mma/mma_op_family.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/arch/mma/mma_transforms.hpp" -#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp" -#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp" -#include "ck_tile/core/arch/mma/sparse/sparse.hpp" -#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp" -#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" -#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp" -#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp" #include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" #include "ck_tile/core/arch/mma/wmma/wmma.hpp" #include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 0a1de838095c..959c3ee0c0cf 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -121,8 +121,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_o; - // Optional implicit head-group size for single-launch grouping. - ck_tile::index_t head_group_size_q = 0; + // Optional global head count and head offset (for grouped launches & RNG correctness) + ck_tile::index_t num_head_q_total = 0; + ck_tile::index_t head_start = 0; }; struct FmhaFwdLogitsSoftCapKargs @@ -392,10 +393,11 @@ struct FmhaFwdKernel drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr, - ck_tile::index_t head_group_size_q = 0) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { Kargs kargs{{q_ptr, k_ptr, @@ -431,7 +433,8 @@ struct FmhaFwdKernel batch_stride_k, batch_stride_v, batch_stride_o}; - kargs.head_group_size_q = head_group_size_q; + kargs.num_head_q_total = num_head_q_total; + kargs.head_start = head_start; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -568,10 +571,11 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr, - ck_tile::index_t head_group_size_q = 0) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { return MakeKargsImpl( q_ptr, @@ -630,7 +634,8 @@ struct FmhaFwdKernel cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr, - head_group_size_q); + num_head_q_total, + head_start); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -689,10 +694,11 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr, - ck_tile::index_t head_group_size_q = 0) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { return MakeKargsImpl( q_ptr, @@ -751,7 +757,8 @@ struct FmhaFwdKernel cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr, - head_group_size_q); + num_head_q_total, + head_start); } template @@ -805,10 +812,11 @@ struct FmhaFwdKernel drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr, - ck_tile::index_t head_group_size_q = 0) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { Kargs kargs{{q_ptr, k_ptr, @@ -845,7 +853,8 @@ struct FmhaFwdKernel reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_q_ptr), reinterpret_cast(seqlen_k_ptr)}; - kargs.head_group_size_q = head_group_size_q; + kargs.num_head_q_total = num_head_q_total; + kargs.head_start = head_start; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -979,10 +988,11 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr, - ck_tile::index_t head_group_size_q = 0) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { return MakeKargsImpl( q_ptr, @@ -1036,7 +1046,8 @@ struct FmhaFwdKernel cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr, - head_group_size_q); + num_head_q_total, + head_start); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -1090,10 +1101,11 @@ struct FmhaFwdKernel const std::tuple& drop_seed_offset, ck_tile::index_t block_scale_size_q, ck_tile::index_t block_scale_size_kv, - const void* cu_seqlen_q_ptr = nullptr, - const void* cu_seqlen_k_ptr = nullptr, - const void* sink_ptr = nullptr, - ck_tile::index_t head_group_size_q = 0) + const void* cu_seqlen_q_ptr = nullptr, + const void* cu_seqlen_k_ptr = nullptr, + const void* sink_ptr = nullptr, + ck_tile::index_t num_head_q_total = 0, + ck_tile::index_t head_start = 0) { return MakeKargsImpl( q_ptr, @@ -1147,7 +1159,8 @@ struct FmhaFwdKernel cu_seqlen_q_ptr, cu_seqlen_k_ptr, sink_ptr, - head_group_size_q); + num_head_q_total, + head_start); } CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, @@ -1182,78 +1195,11 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); +#if CK_TILE_FMHA_FORCE_HEAD_MAJOR // bhsd should satisfy stride_q == hdim_q and nhead_stride_q > hdim_q. // The extra nhead_stride_q guard prevents bshd false-positive when nhead == 1. const bool is_bhsd_layout = (kargs.stride_q == kargs.hdim_q) && (kargs.nhead_stride_q > kargs.hdim_q); - - // Single-launch head grouping for bshd only. - if((kargs.head_group_size_q > 0) && !is_bhsd_layout) - { - const index_t num_tile_n1 = - ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); - const index_t num_tile_total = has_padded_seqlen_k ? gridDim.z : gridDim.y; - const index_t num_head = gridDim.x; - const index_t batch_size = has_padded_seqlen_k ? gridDim.y : gridDim.z; - const index_t linear_id = - blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z); - - const index_t group_sz = ck_tile::min(kargs.head_group_size_q, num_head); - if(group_sz > 0 && group_sz < num_head) - { - const index_t full_groups = num_head / group_sz; - const index_t tail_heads = num_head - full_groups * group_sz; - const index_t spans_per_group = batch_size * num_tile_total; - const index_t blocks_per_group = spans_per_group * group_sz; - - index_t i_group = 0; - index_t i_batch = 0; - index_t i_block = 0; - index_t i_nhead_in_group = 0; - - if((tail_heads > 0) && (linear_id >= full_groups * blocks_per_group)) - { - // Tail group decode: [group=tail] -> [batch] -> [block] -> [head_in_group] - const index_t tail_linear = linear_id - full_groups * blocks_per_group; - const index_t tail_group_heads = tail_heads; - const index_t tail_blocks_per_batch = num_tile_total * tail_group_heads; - - i_group = full_groups; - i_batch = tail_linear / tail_blocks_per_batch; - - const index_t rem1 = tail_linear - i_batch * tail_blocks_per_batch; - i_block = rem1 / tail_group_heads; - i_nhead_in_group = rem1 - i_block * tail_group_heads; - } - else - { - // Full group decode: [group] -> [batch] -> [block] -> [head_in_group] - i_group = linear_id / blocks_per_group; - - const index_t rem0 = linear_id - i_group * blocks_per_group; - const index_t blocks_per_batch = num_tile_total * group_sz; - i_batch = rem0 / blocks_per_batch; - - const index_t rem1 = rem0 - i_batch * blocks_per_batch; - i_block = rem1 / group_sz; - i_nhead_in_group = rem1 - i_block * group_sz; - } - - const index_t i_nhead = i_group * group_sz + i_nhead_in_group; - - index_t i_tile_m = i_block / num_tile_n1; - index_t i_tile_n = i_block - i_tile_m * num_tile_n1; - - if constexpr(kHasMask) - { - const index_t num_tile_m = num_tile_total / num_tile_n1; - i_tile_m = num_tile_m - 1 - i_tile_m; - } - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); - } - } - -#if CK_TILE_FMHA_FORCE_HEAD_MAJOR if(is_bhsd_layout) { const index_t num_tile_n1 = @@ -1700,9 +1646,12 @@ struct FmhaFwdKernel auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { if constexpr(kHasDropout) { + const auto num_head_q_total = + (kargs.num_head_q_total > 0 ? kargs.num_head_q_total : kargs.num_head_q); + const auto i_head_global = kargs.head_start + i_nhead_; return BlockDropout{i_batch_, - i_nhead_, - kargs.num_head_q, + i_head_global, + num_head_q_total, kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val : *kargs.drop_seed.ptr, kargs.is_drop_seed_offset_from_host diff --git a/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp b/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp index 5bc4f0c6a042..3c7b00782f65 100644 --- a/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp +++ b/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp @@ -2,10 +2,10 @@ // SPDX-License-Identifier: MIT #pragma once +#include "ck_tile/ops/grouped_convolution/pipeline/grouped_conv_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp" -#include "ck_tile/ops/grouped_convolution/pipeline/grouped_conv_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" #include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp" From 73ce5d9600a31cb65d23fb2957ddc46d1b646b45 Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Wed, 11 Mar 2026 18:05:02 +0000 Subject: [PATCH 09/12] fix missing return in check_hdim compatibility rule --- .../example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index e888cbd3832f..f2489be405bf 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -806,7 +806,7 @@ def check_hdim(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: kernel_ctx.pipeline.F_bias != "no" or kernel_ctx.pipeline.F_dropout == "t" ): - False + return False return True def check_feature( From 90eed51df6f33385fae5edcb1758e8bc4969650b Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Thu, 12 Mar 2026 22:04:12 +0000 Subject: [PATCH 10/12] gate head-major elemwise-bias path for gfx12 on ROCm 7.1 --- .../composablekernel/include/ck_tile/core.hpp | 8 +++ .../include/ck_tile/ops/fmha.hpp | 1 + .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 66 ++++++++++++------- .../ck_tile/ops/grouped_convolution.hpp | 2 +- 4 files changed, 51 insertions(+), 26 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/core.hpp b/projects/composablekernel/include/ck_tile/core.hpp index f42526ddf75c..c377d6b4b953 100644 --- a/projects/composablekernel/include/ck_tile/core.hpp +++ b/projects/composablekernel/include/ck_tile/core.hpp @@ -20,9 +20,17 @@ #include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp" #include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp" #include "ck_tile/core/arch/mma/mma.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/arch/mma/mma_transforms.hpp" +#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp" +#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" +#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp" +#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp" #include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" #include "ck_tile/core/arch/mma/wmma/wmma.hpp" #include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/fmha.hpp b/projects/composablekernel/include/ck_tile/ops/fmha.hpp index 0639fa1b36ef..aa96c61958c5 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha.hpp @@ -46,6 +46,7 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vr_gltr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 959c3ee0c0cf..20834e88ac11 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1196,34 +1196,50 @@ struct FmhaFwdKernel has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); #if CK_TILE_FMHA_FORCE_HEAD_MAJOR - // bhsd should satisfy stride_q == hdim_q and nhead_stride_q > hdim_q. - // The extra nhead_stride_q guard prevents bshd false-positive when nhead == 1. - const bool is_bhsd_layout = - (kargs.stride_q == kargs.hdim_q) && (kargs.nhead_stride_q > kargs.hdim_q); - if(is_bhsd_layout) + // compiler-workaround gate (ROCm 7.1 + gfx12). + // Keep head-major enabled for all unaffected kernels. +#if defined(__gfx12__) && (HIP_VERSION_MAJOR == 7) && (HIP_VERSION_MINOR == 1) + constexpr bool kSkipHeadMajor = kIsGroupMode && kHasMask && !kHasDropout && + (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) && + kPadHeadDimQ && kPadHeadDimV && + (FmhaPipeline::kN1 == 256) && + std::is_same_v && + std::is_same_v && + std::is_same_v; +#else + constexpr bool kSkipHeadMajor = false; +#endif + if constexpr(!kSkipHeadMajor) { - const index_t num_tile_n1 = - ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); - const index_t num_tile_total = has_padded_seqlen_k ? gridDim.z : gridDim.y; - const index_t num_head = gridDim.x; - const index_t blocks_per_batch = num_head * num_tile_total; - const index_t linear_id = - blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z); - - const index_t i_batch = linear_id / blocks_per_batch; - const index_t rem0 = linear_id - i_batch * blocks_per_batch; - const index_t i_nhead = rem0 / num_tile_total; - const index_t i_block = rem0 - i_nhead * num_tile_total; - - index_t i_tile_m = i_block / num_tile_n1; - index_t i_tile_n = i_block - i_tile_m * num_tile_n1; - - if constexpr(kHasMask) + // bhsd should satisfy stride_q == hdim_q and nhead_stride_q > hdim_q + // The extra nhead_stride_q guard prevents bshd false-positive when nhead == 1 + const bool is_bhsd_layout = + (kargs.stride_q == kargs.hdim_q) && (kargs.nhead_stride_q > kargs.hdim_q); + if(is_bhsd_layout) { - const index_t num_tile_m = num_tile_total / num_tile_n1; - i_tile_m = num_tile_m - 1 - i_tile_m; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + const index_t num_tile_total = has_padded_seqlen_k ? gridDim.z : gridDim.y; + const index_t num_head = gridDim.x; + const index_t blocks_per_batch = num_head * num_tile_total; + const index_t linear_id = + blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z); + + const index_t i_batch = linear_id / blocks_per_batch; + const index_t rem0 = linear_id - i_batch * blocks_per_batch; + const index_t i_nhead = rem0 / num_tile_total; + const index_t i_block = rem0 - i_nhead * num_tile_total; + + index_t i_tile_m = i_block / num_tile_n1; + index_t i_tile_n = i_block - i_tile_m * num_tile_n1; + + if constexpr(kHasMask) + { + const index_t num_tile_m = num_tile_total / num_tile_n1; + i_tile_m = num_tile_m - 1 - i_tile_m; + } + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); } - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); } #endif diff --git a/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp b/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp index 3c7b00782f65..5bc4f0c6a042 100644 --- a/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp +++ b/projects/composablekernel/include/ck_tile/ops/grouped_convolution.hpp @@ -2,10 +2,10 @@ // SPDX-License-Identifier: MIT #pragma once -#include "ck_tile/ops/grouped_convolution/pipeline/grouped_conv_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp" +#include "ck_tile/ops/grouped_convolution/pipeline/grouped_conv_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" #include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp" From ebd06161f463854c07c549600231514515483917 Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Fri, 13 Mar 2026 07:29:47 +0000 Subject: [PATCH 11/12] Restore executable bit on CMakeLists files --- projects/miopen/test/CMakeLists.txt | 0 projects/rocsolver/clients/CMakeLists.txt | 0 projects/rocsolver/clients/gtest/CMakeLists.txt | 0 projects/rocsolver/library/src/CMakeLists.txt | 0 4 files changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 projects/miopen/test/CMakeLists.txt mode change 100644 => 100755 projects/rocsolver/clients/CMakeLists.txt mode change 100644 => 100755 projects/rocsolver/clients/gtest/CMakeLists.txt mode change 100644 => 100755 projects/rocsolver/library/src/CMakeLists.txt diff --git a/projects/miopen/test/CMakeLists.txt b/projects/miopen/test/CMakeLists.txt old mode 100644 new mode 100755 diff --git a/projects/rocsolver/clients/CMakeLists.txt b/projects/rocsolver/clients/CMakeLists.txt old mode 100644 new mode 100755 diff --git a/projects/rocsolver/clients/gtest/CMakeLists.txt b/projects/rocsolver/clients/gtest/CMakeLists.txt old mode 100644 new mode 100755 diff --git a/projects/rocsolver/library/src/CMakeLists.txt b/projects/rocsolver/library/src/CMakeLists.txt old mode 100644 new mode 100755 From eceb7c9d4a39f4219b68510e57c0393ab72a683e Mon Sep 17 00:00:00 2001 From: Hosang Yoon Date: Fri, 13 Mar 2026 07:48:03 +0000 Subject: [PATCH 12/12] Drop unintended untracked header include --- projects/composablekernel/include/ck_tile/ops/fmha.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha.hpp b/projects/composablekernel/include/ck_tile/ops/fmha.hpp index f68db3366fbf..8a5d77bf462e 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha.hpp @@ -47,7 +47,6 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vr_gltr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"