diff --git a/csrc/cpp_itfs/mha_bwd.cu b/csrc/cpp_itfs/mha_bwd.cu index 105fb35401..3ba559edfb 100644 --- a/csrc/cpp_itfs/mha_bwd.cu +++ b/csrc/cpp_itfs/mha_bwd.cu @@ -371,7 +371,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) AiterAsmKernel* impl_ptr_pre = nullptr; AiterAsmKernel* impl_ptr_dqdkdv = nullptr; AiterAsmKernel* impl_ptr_post = nullptr; - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; auto it_pre = pre_cfgs->find(pre_kernel); if(it_pre != pre_cfgs->end()) @@ -381,13 +381,8 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) const char* co_name = cfg.co_name.c_str(); ts_odo = cfg.ts; - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - - impl_ptr_pre = result.first->second.get(); + impl_ptr_pre = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else { @@ -402,13 +397,8 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) const char* co_name = cfg.co_name.c_str(); ts_kv = cfg.ts; - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - - impl_ptr_dqdkdv = result.first->second.get(); + impl_ptr_dqdkdv = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else { @@ -425,13 +415,8 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) const char* co_name = cfg.co_name.c_str(); ts_dq = cfg.ts; - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - - impl_ptr_post = result.first->second.get(); + impl_ptr_post = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else { diff --git a/csrc/cpp_itfs/mha_fwd.cu b/csrc/cpp_itfs/mha_fwd.cu index de0df023b5..31984c89c2 100644 --- a/csrc/cpp_itfs/mha_fwd.cu +++ b/csrc/cpp_itfs/mha_fwd.cu @@ -239,19 +239,14 @@ float fmha_fwd_v3(mha_fwd_args a, const ck_tile::stream_config& s) }; AiterAsmKernel* impl_ptr = nullptr; - static thread_local std::unordered_map> - impl_ptr_map; + static SynchronizedCache impl_ptr_map; const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); std::string co_name = get_kernel_co_name(cfg.co_name, arch_id); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name.c_str()); - } - impl_ptr = result.first->second.get(); + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name.c_str()); }); fmha_fwd_v3_args args; size_t arg_size = sizeof(args); diff --git a/csrc/cpp_itfs/moe/asm_moe.cpp.jinja b/csrc/cpp_itfs/moe/asm_moe.cpp.jinja index 05e0b5d33d..faa83cb44a 100644 --- a/csrc/cpp_itfs/moe/asm_moe.cpp.jinja +++ b/csrc/cpp_itfs/moe/asm_moe.cpp.jinja @@ -63,19 +63,18 @@ struct __attribute__((packed)) KernelArgs }; -unsigned char hsaco[{{bin_size}}] = { {{bin_data}} }; +static const unsigned char hsaco[{{bin_size}}] = { {{bin_data}} }; class FMoeKernel { private: - std::unique_ptr asm_kernel=nullptr; + AiterAsmKernelFast asm_kernel; uint32_t sub_GU = 512; bool is_int4 = false; public: - FMoeKernel() + FMoeKernel() : asm_kernel("{{kernel_name}}", hsaco) { - asm_kernel=std::make_unique("{{kernel_name}}", hsaco); this->sub_GU = {{selected_tile}}; }; @@ -181,11 +180,11 @@ public: if constexpr (switchGxy) { - asm_kernel->launch_kernel({&args, &arg_size, gdy, gdx, gdz, bdx, 1, 1, stream}); + asm_kernel.launch_kernel({&args, &arg_size, gdy, gdx, gdz, bdx, 1, 1, stream}); } else { - asm_kernel->launch_kernel({&args, &arg_size, gdx, gdy, gdz, bdx, 1, 1, stream}); + asm_kernel.launch_kernel({&args, &arg_size, gdx, gdy, gdz, bdx, 1, 1, stream}); } }; }; diff --git a/csrc/include/aiter_hip_common.h b/csrc/include/aiter_hip_common.h index 7eca12994e..2a324d8d66 100644 --- a/csrc/include/aiter_hip_common.h +++ b/csrc/include/aiter_hip_common.h @@ -18,6 +18,9 @@ #include #include #include +#include +#include +#include #ifdef AITER_EMBEDDED_HSA_HEADER #include AITER_EMBEDDED_HSA_HEADER #endif @@ -27,7 +30,7 @@ namespace aiter_detail { inline thread_local bool g_aiter_can_throw = false; template -[[noreturn, noinline]] inline void aiter_check_fatal(const char* file, size_t line, Args&&... args) +[[noreturn, gnu::noinline]] inline void aiter_check_fatal(const char* file, size_t line, Args&&... args) { std::cerr << "[AITER] " << file << ":" << line << " "; (std::cerr << ... << std::forward(args)) << std::endl; @@ -133,55 +136,83 @@ struct AiterAsmKernelArgs static const std::string get_gpu_arch(); -inline void load_asm_kernel(const char* name, - const char* hsaco, - hipModule_t& module, - hipFunction_t& kernel_func) +namespace aiter_detail { +// Taken from +// https://github.com/llvm/llvm-project/blob/b0230f59969b9e8e7e0aff44cd34718987098462/llvm/lib/Frontend/Offloading/OffloadWrapper.cpp#L226 +struct FatBinaryWrapper { - const char* AITER_ASM_DIR = std::getenv("AITER_ASM_DIR"); - std::string arch_name = get_gpu_arch(); - if(AITER_ASM_DIR != nullptr) - { - std::string hsa_path = std::string(AITER_ASM_DIR) + "/" + arch_name + "/" + hsaco; - AITER_LOG_INFO("hipModuleLoad: " << hsa_path << " GetFunction: " << name); - HIP_CALL(hipModuleLoad(&module, hsa_path.c_str())); - } - else - { -#if defined(AITER_EMBEDDED_HSA_HEADER) && defined(AITER_EMBEDDED_HSA_MAP) - std::string fname = "hsa/" + arch_name + "/" + hsaco; - auto hasco_obj = AITER_EMBEDDED_HSA_MAP.find(fname); - AITER_CHECK(hasco_obj != AITER_EMBEDDED_HSA_MAP.end(), "hasco_obj not found"); - AITER_CHECK(hasco_obj->second.data() != nullptr, "hasco_obj is nullptr"); - AITER_LOG_INFO("hipModuleLoad: " << fname << " GetFunction: " << name); - HIP_CALL(hipModuleLoadData(&module, hasco_obj->second.data())); -#endif - } - HIP_CALL(hipModuleGetFunction(&kernel_func, module, name)); - AITER_LOG_INFO("hipModuleGetFunction: " << name << " Success"); -} + uint32_t magic = 0x48495046; // "HIPF"; + uint32_t version = 1; + const void* binary = nullptr; + intptr_t __pad = 0; +}; + +extern "C" void* __hipRegisterFatBinary(const FatBinaryWrapper* data) noexcept; +extern "C" void __hipUnregisterFatBinary(void* module) noexcept; +extern "C" void __hipRegisterFunction(void* module, + const void* hostFunction, + const char* deviceFunction, + const char* deviceName, + int threadLimit, + void* tid, + void* bid, + void* blockDim, + void* gridDim, + void* wSize) noexcept; +} // namespace aiter_detail -class AiterAsmKernel +namespace { + +class AiterAsmKernelFast { private: - hipModule_t module; - hipFunction_t kernel_func; + void* module = nullptr; + + protected: + AiterAsmKernelFast() = default; + void init(const char* kernel_name, const void* hsaco) + { + aiter_detail::FatBinaryWrapper fat_bin{}; + fat_bin.binary = hsaco; + module = aiter_detail::__hipRegisterFatBinary(&fat_bin); + AITER_CHECK(module != nullptr, "failed to load module for ", kernel_name); + aiter_detail::__hipRegisterFunction(module, + static_cast(this), + kernel_name, + kernel_name, + -1, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr); + } public: - AiterAsmKernel(const char* name, const char* hsaco) + AiterAsmKernelFast(const char* kernel_name, const void* hsaco) { - load_asm_kernel(name, hsaco, module, kernel_func); + init(kernel_name, hsaco); }; - ~AiterAsmKernel() { HIP_CALL(hipModuleUnload(module)); } + ~AiterAsmKernelFast() { aiter_detail::__hipUnregisterFatBinary(module); } + + AiterAsmKernelFast(AiterAsmKernelFast&) = delete; + AiterAsmKernelFast(AiterAsmKernelFast&&) = delete; + AiterAsmKernelFast& operator=(AiterAsmKernelFast&) = delete; + AiterAsmKernelFast& operator=(AiterAsmKernelFast&&) = delete; void launch_kernel(const AiterAsmKernelArgs& kargs) { - void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, - kargs.args_ptr, - HIP_LAUNCH_PARAM_BUFFER_SIZE, - kargs.arg_size_ptr, - HIP_LAUNCH_PARAM_END}; + void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, + kargs.args_ptr, + HIP_LAUNCH_PARAM_BUFFER_SIZE, + kargs.arg_size_ptr, + HIP_LAUNCH_PARAM_END}; + hipFunction_t kernel_func = nullptr; + // TODO Ask runtime folks to provide an API for hipLaunchKernel with extra arg + // Don't error check here. + // Failure to load the func would cause hipModuleLaunchKernel to fail anyways. + (void)hipGetFuncBySymbol(&kernel_func, reinterpret_cast(this)); HIP_CALL_LAUNCH(hipModuleLaunchKernel(kernel_func, kargs.gdx, @@ -197,44 +228,59 @@ class AiterAsmKernel }; }; -class AiterAsmKernelFast + +class AiterAsmKernel: private AiterAsmKernelFast { private: - hipModule_t module; - hipFunction_t kernel_func; + std::unique_ptr hsaco_data; - public: - AiterAsmKernelFast(const char* name, void* hsaco) + const void* load_hsaco_file(const char* hsaco_path) { - HIP_CALL(hipModuleLoadData(&module, hsaco)); - HIP_CALL(hipModuleGetFunction(&kernel_func, module, name)); - AITER_LOG_INFO("hipModuleGetFunction: " << name << " Success"); - }; + const char* AITER_ASM_DIR = std::getenv("AITER_ASM_DIR"); + std::string arch_name = get_gpu_arch(); + if(AITER_ASM_DIR != nullptr) + { + std::string full_path = std::string(AITER_ASM_DIR) + "/" + arch_name + "/" + hsaco_path; - ~AiterAsmKernelFast() { HIP_CALL(hipModuleUnload(module)); } + std::ifstream file(full_path, std::ios::binary | std::ios::ate); - void launch_kernel(const AiterAsmKernelArgs& kargs) - { - void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, - kargs.args_ptr, - HIP_LAUNCH_PARAM_BUFFER_SIZE, - kargs.arg_size_ptr, - HIP_LAUNCH_PARAM_END}; + AITER_CHECK(file.is_open(), "failed to open ", full_path.c_str()); - HIP_CALL_LAUNCH(hipModuleLaunchKernel(kernel_func, - kargs.gdx, - kargs.gdy, - kargs.gdz, - kargs.bdx, - kargs.bdy, - kargs.bdz, - 0, - kargs.stream, - nullptr, - (void**)&config)); + size_t file_size = file.tellg(); + hsaco_data.reset(new char[file_size]); + + file.seekg(0, std::ios::beg); + AITER_CHECK( + file.read(hsaco_data.get(), file_size), "failed to read ", full_path.c_str()); + return hsaco_data.get(); + } + else + { +#if defined(AITER_EMBEDDED_HSA_HEADER) && defined(AITER_EMBEDDED_HSA_MAP) + std::string fname = "hsa/" + arch_name + "/" + hsaco; + auto hasco_obj = AITER_EMBEDDED_HSA_MAP.find(fname); + AITER_CHECK(hasco_obj != AITER_EMBEDDED_HSA_MAP.end(), "hasco_obj not found"); + AITER_CHECK(hasco_obj->second.data() != nullptr, "hasco_obj is nullptr"); + return hasco_obj->second.data(); +#else + AITER_CHECK(AITER_ASM_DIR != nullptr, "AITER_ASM_DIR not set"); + return nullptr; +#endif + } + } + + public: + AiterAsmKernel(const char* kernel_name, const char* hsaco_path) + { + init(kernel_name, load_hsaco_file(hsaco_path)); }; + + using AiterAsmKernelFast::launch_kernel; }; + +} // namespace + static const std::string get_gpu_arch() { int device_count; @@ -345,3 +391,27 @@ class HipDeviceGuard private: int prev_device_{}; }; + +template , class KeyEqual = std::equal_to> +struct SynchronizedCache +{ + template + inline T& get_or_create(K&& k, F&& factory) + { + std::lock_guard map_mu_guard(map_mu); + + struct Wrapper + { + F& f; + // Makes sure we only invoke lambda on insert + operator T() && { return f(); } + }; + + auto [it, _] = map.try_emplace(std::forward(k), Wrapper{factory}); + return it->second; + } + + private: + std::mutex map_mu; + std::unordered_map map; +}; diff --git a/csrc/py_itfs_cu/asm_a8w8_blockscale_bpreshuffle.cu b/csrc/py_itfs_cu/asm_a8w8_blockscale_bpreshuffle.cu index e483369ed5..a530ce9d7f 100644 --- a/csrc/py_itfs_cu/asm_a8w8_blockscale_bpreshuffle.cu +++ b/csrc/py_itfs_cu/asm_a8w8_blockscale_bpreshuffle.cu @@ -152,45 +152,39 @@ struct KernelSelector { } }; - static std::unordered_map, SimpleHash> heuristic_cache; - static std::unordered_map> kernel_cache; - - static std::tuple select_kernel( - int M, - int N, - int K, - const std::string& arch_id, - std::optional splitK, - std::optional bpreshuffle, - const char* kernelName, - CFG* config_map) { + static SynchronizedCache, SimpleHash> heuristic_cache; + static SynchronizedCache kernel_cache; + + static std::tuple select_kernel(int M, + int N, + int K, + const std::string& arch_id, + std::optional splitK, + std::optional bpreshuffle, + const char* kernelName, + CFG* config_map) + { if (kernelName && kernelName[0] != 0) { return std::make_tuple(arch_id + kernelName, splitK.value_or(0) ?: 1); } DictKey key(M, N, K, splitK, bpreshuffle); - auto it = heuristic_cache.find(key); - if (it != heuristic_cache.end()) { - return it->second; // find it and return - } - auto result = get_heuristic_fp8_kernel(M, N, K, arch_id, splitK, bpreshuffle, config_map); - heuristic_cache[key] = result; - return result; + + return heuristic_cache.get_or_create(key, [&]() { + return get_heuristic_fp8_kernel(M, N, K, arch_id, splitK, bpreshuffle, config_map); + }); } static AiterAsmKernel* get_kernel(const std::string& kernel_name, const std::string& co_name) { - auto result = kernel_cache.emplace(kernel_name, nullptr); - if (result.second) { - result.first->second = std::make_unique(kernel_name.c_str(), co_name.c_str()); - } - return result.first->second.get(); + return &kernel_cache.get_or_create( + kernel_name, [&]() { return AiterAsmKernel(kernel_name.c_str(), co_name.c_str()); }); } }; -std::unordered_map, KernelSelector::SimpleHash> +SynchronizedCache, KernelSelector::SimpleHash> KernelSelector::heuristic_cache; -std::unordered_map> KernelSelector::kernel_cache; +SynchronizedCache KernelSelector::kernel_cache; static KernelArgs setup_kernel_args( aiter_tensor_t* A, diff --git a/csrc/py_itfs_cu/asm_fmoe.cu b/csrc/py_itfs_cu/asm_fmoe.cu index ec80e12438..2f27abbda4 100755 --- a/csrc/py_itfs_cu/asm_fmoe.cu +++ b/csrc/py_itfs_cu/asm_fmoe.cu @@ -71,8 +71,7 @@ struct __attribute__((packed)) KernelArgs class FMoeKernel { private: - hipModule_t module; - hipFunction_t kernel_func; + AiterAsmKernel kernel; uint32_t sub_GU = 512; bool is_int4 = false; uint32_t num_persistent_tgs = 0; @@ -82,9 +81,8 @@ class FMoeKernel FMoeKernel(const char* name, const char* hsaco, uint32_t sub_GU = 512, - uint32_t num_persistent_tgs = 0) + uint32_t num_persistent_tgs = 0) : kernel(name, hsaco) { - load_asm_kernel(name, hsaco, module, kernel_func); this->sub_GU = sub_GU; this->num_persistent_tgs = num_persistent_tgs; this->name = name; @@ -180,11 +178,6 @@ class FMoeKernel args.ps_deno = ((inter_dim + sub_GU - 1) / sub_GU); args.total_tgs = this->num_persistent_tgs / args.ps_deno * args.ps_deno; - void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, - &args, - HIP_LAUNCH_PARAM_BUFFER_SIZE, - &arg_size, - HIP_LAUNCH_PARAM_END}; int bdx; int gdx; int gdy; @@ -208,13 +201,27 @@ class FMoeKernel if constexpr(switchGxy) { - HIP_CALL_LAUNCH(hipModuleLaunchKernel( - kernel_func, gdy, gdx, gdz, bdx, 1, 1, 0, stream, nullptr, (void**)&config)); + kernel.launch_kernel({&args, + &arg_size, + gdy, // gdx + gdx, // gdy + gdz, // gdz + bdx, // bdx + 1, // bdy + 1, // bdz + stream}); } else { - HIP_CALL_LAUNCH(hipModuleLaunchKernel( - kernel_func, gdx, gdy, gdz, bdx, 1, 1, 0, stream, nullptr, (void**)&config)); + kernel.launch_kernel({&args, + &arg_size, + gdx, // gdx + gdy, // gdy + gdz, // gdz + bdx, // bdx + 1, // bdy + 1, // bdz + stream}); } }; }; @@ -231,7 +238,7 @@ FMoeKernel* get_heuristic_kernel( std::string arch_id = get_gpu_arch(); std::string selectedKl = kernel_name.empty() ? "" : arch_id + kernel_name; int vskip = 1; - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; const char* vs_env_value = std::getenv("AITER_ENABLE_VSKIP"); if(vs_env_value != nullptr && std::string(vs_env_value) == "0") @@ -285,15 +292,13 @@ FMoeKernel* get_heuristic_kernel( const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); if(cfg.ps == 1) num_persistent_tgs = cfg.tg_num_perCU * num_cu; else num_persistent_tgs = 0; - if(result.second) - result.first->second = - std::make_unique(name, co_name, cfg.subGU_n, num_persistent_tgs); - impl_ptr = result.first->second.get(); + + impl_ptr = &impl_ptr_map.get_or_create( + name, [&]() { return FMoeKernel(name, co_name, cfg.subGU_n, num_persistent_tgs); }); } else AITER_CHECK(false, __func__, " not find kernel " + selectedKl); @@ -406,7 +411,7 @@ AITER_CTYPES_DEFINE_ENTRYPOINT_VOID( ActivationType act = static_cast(activation); FMoeKernel* impl_ptr = nullptr; int inter_dim = down->size(2); - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; struct FMoeKernelConfig { @@ -482,13 +487,8 @@ AITER_CTYPES_DEFINE_ENTRYPOINT_VOID( const char* name = config.name.c_str(); const char* co_name = config.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = - std::make_unique(name, co_name, config.tile_size); - } - impl_ptr = result.first->second.get(); + impl_ptr = &impl_ptr_map.get_or_create( + name, [&]() { return FMoeKernel(name, co_name, config.tile_size); }); } } impl_ptr->launch_kernel<1, 2>(out, @@ -545,7 +545,7 @@ AITER_CTYPES_DEFINE_ENTRYPOINT_VOID( int inter_dim = down->size(2); inter_dim *= model_dim / gate->size(2); int sub_X_cnt = sorted_expert_ids->size(0); - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; std::string kernel_name_str = kernel_name ? kernel_name : ""; if(gate->dtype() == AITER_DTYPE_u32 || gate->dtype() == AITER_DTYPE_i32) // int4 diff --git a/csrc/py_itfs_cu/asm_gemm_a16w16.cu b/csrc/py_itfs_cu/asm_gemm_a16w16.cu index f78f15f187..a9a2d79db0 100644 --- a/csrc/py_itfs_cu/asm_gemm_a16w16.cu +++ b/csrc/py_itfs_cu/asm_gemm_a16w16.cu @@ -175,7 +175,7 @@ AiterAsmKernel* get_or_load_kernel(const std::string& selectedKernelName, unsigned int& SUBM, unsigned int& SUBN) { - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; auto it_kl = config_map->find(selectedKernelName); AITER_CHECK(it_kl != config_map->end(), __func__, " not find kernel~ " + selectedKernelName); @@ -186,11 +186,7 @@ AiterAsmKernel* get_or_load_kernel(const std::string& selectedKernelName, SUBM = cfg.tileM; SUBN = cfg.tileN; - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - result.first->second = std::make_unique(name, co_name); - - return result.first->second.get(); + return &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } AITER_CTYPES_ERROR_DEF diff --git a/csrc/py_itfs_cu/asm_gemm_a4w4.cu b/csrc/py_itfs_cu/asm_gemm_a4w4.cu index 9b1dc2e4ce..33e788378e 100644 --- a/csrc/py_itfs_cu/asm_gemm_a4w4.cu +++ b/csrc/py_itfs_cu/asm_gemm_a4w4.cu @@ -211,12 +211,12 @@ AITER_CTYPES_DEFINE_ENTRYPOINT_VOID( std::hash()(log2) ^ std::hash()(shuffle); } }; - static std::unordered_map, SimpleHash> + static SynchronizedCache, SimpleHash> heuristic_kernel_dict; AITER_CHECK(!config_map->empty(), __func__, " no kernel support a4w4 for this gpu arch"); - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; std::string arch_id = get_gpu_arch(); std::string kname = (kernelName && kernelName[0] != 0) ? (arch_id + kernelName) : ""; @@ -224,23 +224,11 @@ AITER_CTYPES_DEFINE_ENTRYPOINT_VOID( int selectedksplit = (log2_k_split >= 0) ? log2_k_split : 0; if(kname.empty()) { - auto it = heuristic_kernel_dict.find(DictKey(Mdim, Ndim, Kdim, log2_k_split, bpreshuffle)); - if(it != heuristic_kernel_dict.end()) - { - auto res = it->second; - kname = std::get<0>(res); - selectedksplit = std::get<1>(res); - } - else - { - auto it = get_heuristic_kernel( - Mdim, Ndim, Kdim, arch_id, log2_k_split, bpreshuffle, config_map); - - kname = std::get<0>(it); - selectedksplit = std::get<1>(it); - heuristic_kernel_dict[{Mdim, Ndim, Kdim, log2_k_split, bpreshuffle}] = - std::make_tuple(kname, selectedksplit); - } + std::tie(kname, selectedksplit) = heuristic_kernel_dict.get_or_create( + DictKey(Mdim, Ndim, Kdim, log2_k_split, bpreshuffle), [&]() { + return get_heuristic_kernel( + Mdim, Ndim, Kdim, arch_id, log2_k_split, bpreshuffle, config_map); + }); } AiterAsmKernel* impl_ptr = nullptr; @@ -269,12 +257,8 @@ AITER_CTYPES_DEFINE_ENTRYPOINT_VOID( gdz = (Kdim + k_per_tg - 1) / k_per_tg; } - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else AITER_CHECK(false, __func__, " not find kernel " + kname); diff --git a/csrc/py_itfs_cu/asm_gemm_a8w8.cu b/csrc/py_itfs_cu/asm_gemm_a8w8.cu index fb759450d6..985f830c86 100644 --- a/csrc/py_itfs_cu/asm_gemm_a8w8.cu +++ b/csrc/py_itfs_cu/asm_gemm_a8w8.cu @@ -184,14 +184,14 @@ AITER_CTYPES_DEFINE_ENTRYPOINT_VOID( std::hash()(splitk_key) ^ std::hash()(shuffle_key); } }; - static std::unordered_map, SimpleHash> + static SynchronizedCache, SimpleHash> heuristic_kernel_dict; if(config_map->empty()) { AITER_CHECK(false, __func__, " no kernel support a8w8 for this gpu arch"); } - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; std::string arch_id = get_gpu_arch(); std::string selectedName = (kernelName && kernelName[0] != '\0') ? arch_id + kernelName @@ -199,22 +199,11 @@ AITER_CTYPES_DEFINE_ENTRYPOINT_VOID( int selectedksplit = opt_splitK.value_or(0) ?: 1; if(selectedName.empty()) { - auto it = heuristic_kernel_dict.find(DictKey(Mdim, Ndim, Kdim, opt_splitK, opt_bpreshuffle)); - if(it != heuristic_kernel_dict.end()) - { - auto res = it->second; - selectedName = std::get<0>(res); - selectedksplit = std::get<1>(res); - } - else - { - auto it = get_heuristic_kernel(Mdim, Ndim, Kdim, arch_id, opt_splitK, opt_bpreshuffle, config_map); - - selectedName = std::get<0>(it); - selectedksplit = std::get<1>(it); - heuristic_kernel_dict[{Mdim, Ndim, Kdim, opt_splitK, opt_bpreshuffle}] = - std::make_tuple(selectedName, selectedksplit); - } + std::tie(selectedName, selectedksplit) = heuristic_kernel_dict.get_or_create( + DictKey(Mdim, Ndim, Kdim, splitK, bpreshuffle), [&]() { + return get_heuristic_kernel( + Mdim, Ndim, Kdim, arch_id, splitK, bpreshuffle, config_map); + }); } AiterAsmKernel* impl_ptr = nullptr; @@ -279,12 +268,9 @@ AITER_CTYPES_DEFINE_ENTRYPOINT_VOID( } } gdx = gdx * selectedksplit; - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else AITER_CHECK(false, __func__, " not find kernel ", selectedName); diff --git a/csrc/py_itfs_cu/asm_mla.cu b/csrc/py_itfs_cu/asm_mla.cu index cdc415fc6c..9aade7fdb9 100644 --- a/csrc/py_itfs_cu/asm_mla.cu +++ b/csrc/py_itfs_cu/asm_mla.cu @@ -254,7 +254,7 @@ void mla_decode_stage1_asm_fwd( // Get kernel using config dispatch std::string arch_id = get_gpu_arch(); CFG* config_map = &cfg_mla_asm; - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; int ps = persistent ? 1 : 0; int prefill = 0; // decode stage @@ -351,13 +351,9 @@ void mla_decode_stage1_asm_fwd( const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); - + + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else AITER_CHECK(false, __func__, " not find kernel ", kernelName); @@ -506,7 +502,7 @@ void mla_prefill_ps_asm_fwd( AITER_CHECK(false, __func__, ": fp8 mla persistent prefill is not supported on gfx942"); } CFG* config_map = &cfg_mla_asm; - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; int ps = 1; // ps_prefill always uses persistent scheduling int prefill = 1; // prefill stage @@ -526,12 +522,9 @@ void mla_prefill_ps_asm_fwd( const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else AITER_CHECK(false, __func__, " not find kernel ", kernelName); @@ -620,7 +613,7 @@ void mla_prefill_asm_fwd( std::string arch_id = get_gpu_arch(); CFG* config_map = &cfg_mla_asm; - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; int ps = 0; // prefill without persistent scheduling int prefill = 1; // prefill stage @@ -638,12 +631,9 @@ void mla_prefill_asm_fwd( const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else AITER_CHECK(false, __func__, " not find kernel ", kernelName); diff --git a/csrc/py_itfs_cu/asm_moe_2stage.cu b/csrc/py_itfs_cu/asm_moe_2stage.cu index 4c1255eea2..d9004fb31f 100644 --- a/csrc/py_itfs_cu/asm_moe_2stage.cu +++ b/csrc/py_itfs_cu/asm_moe_2stage.cu @@ -174,7 +174,7 @@ AITER_CTYPES_DEFINE_ENTRYPOINT_VOID( QuantType qt = static_cast(quant_type); CFG *config_map = get_cfg(input, out, w1, qt, sorted_weights != nullptr); - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; int model_dim = input->size(1); int hidden_dim = inter_dim; int sub_X_cnt = sorted_expert_ids->size(0); @@ -197,12 +197,8 @@ AITER_CTYPES_DEFINE_ENTRYPOINT_VOID( "ASM kernel ", name, " is not supported for inter_dim=", inter_dim, " (tile_n=", cfg.tile_n, ", block_m=", block_m, ")"); - auto result = impl_ptr_map.emplace(name, nullptr); - if (result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else { diff --git a/csrc/py_itfs_cu/asm_pa.cu b/csrc/py_itfs_cu/asm_pa.cu index 0c0062c168..fb96d2b3d3 100644 --- a/csrc/py_itfs_cu/asm_pa.cu +++ b/csrc/py_itfs_cu/asm_pa.cu @@ -272,7 +272,7 @@ void pa_fwd(aiter_tensor_t* Q, // [num_seqs, num_heads, head_size }; int qTile = 0; CFG* config_map = &cfg_pa_asm; // only one config csv in hsa//pa, now - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; std::string kernelName = (kernelName_ != nullptr) ? arch_id + std::string(kernelName_) : ""; int ps = 0; if (kernelName.empty()) @@ -290,12 +290,9 @@ void pa_fwd(aiter_tensor_t* Q, // [num_seqs, num_heads, head_size const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else AITER_CHECK(false, __func__, " not find kernel ", kernelName); @@ -451,7 +448,7 @@ void pa_ps_fwd(aiter_tensor_t* Q, // [num_seqs, num_heads, head_siz ") exceeds maximum available qTile. Please reduce gqa_ratio or max_qlen."); CFG* config_map = &cfg_pa_asm; // only one config csv in hsa//pa, now - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; std::string arch_id = get_gpu_arch(); std::string kernelName = (kernelName_ != nullptr) ? std::string(kernelName_) : get_heuristic_kernel(q_type, kv_type, gqa, mtp, msk, hp, block_size, arch_id, ps, qTile, quant_type, config_map); @@ -469,12 +466,9 @@ void pa_ps_fwd(aiter_tensor_t* Q, // [num_seqs, num_heads, head_siz const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); if(cfg.ps) { gdx = get_num_cu_func(); diff --git a/csrc/py_itfs_cu/asm_topksoftmax.cu b/csrc/py_itfs_cu/asm_topksoftmax.cu index 9e2f82050e..2c892a236d 100644 --- a/csrc/py_itfs_cu/asm_topksoftmax.cu +++ b/csrc/py_itfs_cu/asm_topksoftmax.cu @@ -93,7 +93,7 @@ void topk_softmax_asm(aiter_tensor_t* topk_weights, // [num_tokens, topk args.out_stride = out_stride * 4; CFG* config_map = &cfg_topksoftmax; - static std::unordered_map> impl_ptr_map; + static SynchronizedCache impl_ptr_map; AiterAsmKernel* impl_ptr = nullptr; auto [kernelName, subm] = get_heuristic_kernel_topksoftmax(arch_id, dtype, MAX_SUBM, num_experts, topk, config_map); @@ -103,12 +103,9 @@ void topk_softmax_asm(aiter_tensor_t* topk_weights, // [num_tokens, topk const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); - auto result = impl_ptr_map.emplace(name, nullptr); - if(result.second) - { - result.first->second = std::make_unique(name, co_name); - } - impl_ptr = result.first->second.get(); + + impl_ptr = + &impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); }); } else AITER_CHECK(false, __func__, " not find kernel " + kernelName);