diff --git a/csrc/cpp_itfs/mha_bwd.cu b/csrc/cpp_itfs/mha_bwd.cu index da7e957b5b..fc7b2f8be7 100644 --- a/csrc/cpp_itfs/mha_bwd.cu +++ b/csrc/cpp_itfs/mha_bwd.cu @@ -371,6 +371,11 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) AiterAsmKernel* impl_ptr_post = nullptr; static std::unordered_map> impl_ptr_map; + // Include device ID in cache key so each GPU gets its own loaded module + int current_device; + HIP_CALL(hipGetDevice(¤t_device)); + std::string dev_prefix = std::to_string(current_device) + ":"; + auto it_pre = pre_cfgs->find(pre_kernel); if(it_pre != pre_cfgs->end()) { @@ -378,8 +383,9 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); ts_odo = cfg.ts; + std::string key = dev_prefix + name; - auto result = impl_ptr_map.emplace(name, nullptr); + auto result = impl_ptr_map.emplace(key, nullptr); if(result.second) { result.first->second = std::make_unique(name, co_name); @@ -399,8 +405,9 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); ts_kv = cfg.ts; + std::string key = dev_prefix + name; - auto result = impl_ptr_map.emplace(name, nullptr); + auto result = impl_ptr_map.emplace(key, nullptr); if(result.second) { result.first->second = std::make_unique(name, co_name); @@ -422,8 +429,9 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s) const char* name = cfg.knl_name.c_str(); const char* co_name = cfg.co_name.c_str(); ts_dq = cfg.ts; + std::string key = dev_prefix + name; - auto result = impl_ptr_map.emplace(name, nullptr); + auto result = impl_ptr_map.emplace(key, nullptr); if(result.second) { result.first->second = std::make_unique(name, co_name); diff --git a/csrc/cpp_itfs/mha_fwd.cu b/csrc/cpp_itfs/mha_fwd.cu index 5cf5f41389..e50c396784 100644 --- a/csrc/cpp_itfs/mha_fwd.cu +++ b/csrc/cpp_itfs/mha_fwd.cu @@ -242,11 +242,17 @@ float fmha_fwd_v3(mha_fwd_args a, const ck_tile::stream_config& s) static thread_local std::unordered_map> impl_ptr_map; + // Include device ID in cache key so each GPU gets its own loaded module + int current_device; + HIP_CALL(hipGetDevice(¤t_device)); + std::string dev_prefix = std::to_string(current_device) + ":"; + const auto& cfg = it->second; const char* name = cfg.knl_name.c_str(); std::string co_name = get_kernel_co_name(cfg.co_name, arch_id); + std::string key = dev_prefix + name; - auto result = impl_ptr_map.emplace(name, nullptr); + auto result = impl_ptr_map.emplace(key, nullptr); if(result.second) { result.first->second = std::make_unique(name, co_name.c_str());