Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 7 additions & 22 deletions csrc/cpp_itfs/mha_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
AiterAsmKernel* impl_ptr_pre = nullptr;
AiterAsmKernel* impl_ptr_dqdkdv = nullptr;
AiterAsmKernel* impl_ptr_post = nullptr;
static std::unordered_map<std::string, std::unique_ptr<AiterAsmKernel>> impl_ptr_map;
static SynchronizedCache<std::string_view, AiterAsmKernel> impl_ptr_map;

auto it_pre = pre_cfgs->find(pre_kernel);
if(it_pre != pre_cfgs->end())
Expand All @@ -381,13 +381,8 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
const char* co_name = cfg.co_name.c_str();
ts_odo = cfg.ts;

auto result = impl_ptr_map.emplace(name, nullptr);
if(result.second)
{
result.first->second = std::make_unique<AiterAsmKernel>(name, co_name);
}

impl_ptr_pre = result.first->second.get();
impl_ptr_pre =
&impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); });
}
else
{
Expand All @@ -402,13 +397,8 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
const char* co_name = cfg.co_name.c_str();
ts_kv = cfg.ts;

auto result = impl_ptr_map.emplace(name, nullptr);
if(result.second)
{
result.first->second = std::make_unique<AiterAsmKernel>(name, co_name);
}

impl_ptr_dqdkdv = result.first->second.get();
impl_ptr_dqdkdv =
&impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); });
}
else
{
Expand All @@ -425,13 +415,8 @@ float fmha_v3_bwd(mha_bwd_args a, const ck_tile::stream_config& s)
const char* co_name = cfg.co_name.c_str();
ts_dq = cfg.ts;

auto result = impl_ptr_map.emplace(name, nullptr);
if(result.second)
{
result.first->second = std::make_unique<AiterAsmKernel>(name, co_name);
}

impl_ptr_post = result.first->second.get();
impl_ptr_post =
&impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name); });
}
else
{
Expand Down
11 changes: 3 additions & 8 deletions csrc/cpp_itfs/mha_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -239,19 +239,14 @@ float fmha_fwd_v3(mha_fwd_args a, const ck_tile::stream_config& s)
};

AiterAsmKernel* impl_ptr = nullptr;
static thread_local std::unordered_map<std::string, std::unique_ptr<AiterAsmKernel>>
impl_ptr_map;
static SynchronizedCache<std::string_view, AiterAsmKernel> impl_ptr_map;

const auto& cfg = it->second;
const char* name = cfg.knl_name.c_str();
std::string co_name = get_kernel_co_name(cfg.co_name, arch_id);

auto result = impl_ptr_map.emplace(name, nullptr);
if(result.second)
{
result.first->second = std::make_unique<AiterAsmKernel>(name, co_name.c_str());
}
impl_ptr = result.first->second.get();
impl_ptr =
&impl_ptr_map.get_or_create(name, [&]() { return AiterAsmKernel(name, co_name.c_str()); });

fmha_fwd_v3_args args;
size_t arg_size = sizeof(args);
Expand Down
11 changes: 5 additions & 6 deletions csrc/cpp_itfs/moe/asm_moe.cpp.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,18 @@ struct __attribute__((packed)) KernelArgs
};


unsigned char hsaco[{{bin_size}}] = { {{bin_data}} };
static const unsigned char hsaco[{{bin_size}}] = { {{bin_data}} };

class FMoeKernel
{
private:
std::unique_ptr<AiterAsmKernelFast> asm_kernel=nullptr;
AiterAsmKernelFast asm_kernel;
uint32_t sub_GU = 512;
bool is_int4 = false;

public:
FMoeKernel()
FMoeKernel() : asm_kernel("{{kernel_name}}", hsaco)
{
asm_kernel=std::make_unique<AiterAsmKernelFast>("{{kernel_name}}", hsaco);
this->sub_GU = {{selected_tile}};
};

Expand Down Expand Up @@ -181,11 +180,11 @@ public:

if constexpr (switchGxy)
{
asm_kernel->launch_kernel({&args, &arg_size, gdy, gdx, gdz, bdx, 1, 1, stream});
asm_kernel.launch_kernel({&args, &arg_size, gdy, gdx, gdz, bdx, 1, 1, stream});
}
else
{
asm_kernel->launch_kernel({&args, &arg_size, gdx, gdy, gdz, bdx, 1, 1, stream});
asm_kernel.launch_kernel({&args, &arg_size, gdx, gdy, gdz, bdx, 1, 1, stream});
}
};
};
Expand Down
202 changes: 136 additions & 66 deletions csrc/include/aiter_hip_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
#include <sstream>
#include <stdexcept>
#include <utility>
#include <fstream>
#include <mutex>
#include <memory>
#ifdef AITER_EMBEDDED_HSA_HEADER
#include AITER_EMBEDDED_HSA_HEADER
#endif
Expand All @@ -27,7 +30,7 @@ namespace aiter_detail {
inline thread_local bool g_aiter_can_throw = false;

template <typename... Args>
[[noreturn, noinline]] inline void aiter_check_fatal(const char* file, size_t line, Args&&... args)
[[noreturn, gnu::noinline]] inline void aiter_check_fatal(const char* file, size_t line, Args&&... args)
{
std::cerr << "[AITER] " << file << ":" << line << " ";
(std::cerr << ... << std::forward<Args>(args)) << std::endl;
Expand Down Expand Up @@ -133,55 +136,83 @@ struct AiterAsmKernelArgs

static const std::string get_gpu_arch();

inline void load_asm_kernel(const char* name,
const char* hsaco,
hipModule_t& module,
hipFunction_t& kernel_func)
namespace aiter_detail {
// Taken from
// https://github.com/llvm/llvm-project/blob/b0230f59969b9e8e7e0aff44cd34718987098462/llvm/lib/Frontend/Offloading/OffloadWrapper.cpp#L226
struct FatBinaryWrapper
{
const char* AITER_ASM_DIR = std::getenv("AITER_ASM_DIR");
std::string arch_name = get_gpu_arch();
if(AITER_ASM_DIR != nullptr)
{
std::string hsa_path = std::string(AITER_ASM_DIR) + "/" + arch_name + "/" + hsaco;
AITER_LOG_INFO("hipModuleLoad: " << hsa_path << " GetFunction: " << name);
HIP_CALL(hipModuleLoad(&module, hsa_path.c_str()));
}
else
{
#if defined(AITER_EMBEDDED_HSA_HEADER) && defined(AITER_EMBEDDED_HSA_MAP)
std::string fname = "hsa/" + arch_name + "/" + hsaco;
auto hasco_obj = AITER_EMBEDDED_HSA_MAP.find(fname);
AITER_CHECK(hasco_obj != AITER_EMBEDDED_HSA_MAP.end(), "hasco_obj not found");
AITER_CHECK(hasco_obj->second.data() != nullptr, "hasco_obj is nullptr");
AITER_LOG_INFO("hipModuleLoad: " << fname << " GetFunction: " << name);
HIP_CALL(hipModuleLoadData(&module, hasco_obj->second.data()));
#endif
}
HIP_CALL(hipModuleGetFunction(&kernel_func, module, name));
AITER_LOG_INFO("hipModuleGetFunction: " << name << " Success");
}
uint32_t magic = 0x48495046; // "HIPF";
uint32_t version = 1;
const void* binary = nullptr;
intptr_t __pad = 0;
};

extern "C" void* __hipRegisterFatBinary(const FatBinaryWrapper* data) noexcept;
extern "C" void __hipUnregisterFatBinary(void* module) noexcept;
extern "C" void __hipRegisterFunction(void* module,
const void* hostFunction,
const char* deviceFunction,
const char* deviceName,
int threadLimit,
void* tid,
void* bid,
void* blockDim,
void* gridDim,
void* wSize) noexcept;
} // namespace aiter_detail

class AiterAsmKernel
namespace {

class AiterAsmKernelFast
{
private:
hipModule_t module;
hipFunction_t kernel_func;
void* module = nullptr;

protected:
AiterAsmKernelFast() = default;
void init(const char* kernel_name, const void* hsaco)
{
aiter_detail::FatBinaryWrapper fat_bin{};
fat_bin.binary = hsaco;
module = aiter_detail::__hipRegisterFatBinary(&fat_bin);
AITER_CHECK(module != nullptr, "failed to load module for ", kernel_name);
aiter_detail::__hipRegisterFunction(module,
static_cast<void*>(this),
kernel_name,
kernel_name,
-1,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr);
}

public:
AiterAsmKernel(const char* name, const char* hsaco)
AiterAsmKernelFast(const char* kernel_name, const void* hsaco)
{
load_asm_kernel(name, hsaco, module, kernel_func);
init(kernel_name, hsaco);
};

~AiterAsmKernel() { HIP_CALL(hipModuleUnload(module)); }
~AiterAsmKernelFast() { aiter_detail::__hipUnregisterFatBinary(module); }

AiterAsmKernelFast(AiterAsmKernelFast&) = delete;
AiterAsmKernelFast(AiterAsmKernelFast&&) = delete;
AiterAsmKernelFast& operator=(AiterAsmKernelFast&) = delete;
AiterAsmKernelFast& operator=(AiterAsmKernelFast&&) = delete;

void launch_kernel(const AiterAsmKernelArgs& kargs)
{
void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER,
kargs.args_ptr,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
kargs.arg_size_ptr,
HIP_LAUNCH_PARAM_END};
void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER,
kargs.args_ptr,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
kargs.arg_size_ptr,
HIP_LAUNCH_PARAM_END};
hipFunction_t kernel_func = nullptr;
// TODO Ask runtime folks to provide an API for hipLaunchKernel with extra arg
// Don't error check here.
// Failure to load the func would cause hipModuleLaunchKernel to fail anyways.
(void)hipGetFuncBySymbol(&kernel_func, reinterpret_cast<void*>(this));

HIP_CALL_LAUNCH(hipModuleLaunchKernel(kernel_func,
kargs.gdx,
Expand All @@ -197,44 +228,59 @@ class AiterAsmKernel
};
};

class AiterAsmKernelFast

class AiterAsmKernel: private AiterAsmKernelFast
{
private:
hipModule_t module;
hipFunction_t kernel_func;
std::unique_ptr<char[]> hsaco_data;

public:
AiterAsmKernelFast(const char* name, void* hsaco)
const void* load_hsaco_file(const char* hsaco_path)
{
HIP_CALL(hipModuleLoadData(&module, hsaco));
HIP_CALL(hipModuleGetFunction(&kernel_func, module, name));
AITER_LOG_INFO("hipModuleGetFunction: " << name << " Success");
};
const char* AITER_ASM_DIR = std::getenv("AITER_ASM_DIR");
std::string arch_name = get_gpu_arch();
if(AITER_ASM_DIR != nullptr)
{
std::string full_path = std::string(AITER_ASM_DIR) + "/" + arch_name + "/" + hsaco_path;

~AiterAsmKernelFast() { HIP_CALL(hipModuleUnload(module)); }
std::ifstream file(full_path, std::ios::binary | std::ios::ate);

void launch_kernel(const AiterAsmKernelArgs& kargs)
{
void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER,
kargs.args_ptr,
HIP_LAUNCH_PARAM_BUFFER_SIZE,
kargs.arg_size_ptr,
HIP_LAUNCH_PARAM_END};
AITER_CHECK(file.is_open(), "failed to open ", full_path.c_str());

HIP_CALL_LAUNCH(hipModuleLaunchKernel(kernel_func,
kargs.gdx,
kargs.gdy,
kargs.gdz,
kargs.bdx,
kargs.bdy,
kargs.bdz,
0,
kargs.stream,
nullptr,
(void**)&config));
size_t file_size = file.tellg();
hsaco_data.reset(new char[file_size]);

file.seekg(0, std::ios::beg);
AITER_CHECK(
file.read(hsaco_data.get(), file_size), "failed to read ", full_path.c_str());
return hsaco_data.get();
}
else
{
#if defined(AITER_EMBEDDED_HSA_HEADER) && defined(AITER_EMBEDDED_HSA_MAP)
std::string fname = "hsa/" + arch_name + "/" + hsaco;
auto hasco_obj = AITER_EMBEDDED_HSA_MAP.find(fname);
AITER_CHECK(hasco_obj != AITER_EMBEDDED_HSA_MAP.end(), "hasco_obj not found");
AITER_CHECK(hasco_obj->second.data() != nullptr, "hasco_obj is nullptr");
return hasco_obj->second.data();
#else
AITER_CHECK(AITER_ASM_DIR != nullptr, "AITER_ASM_DIR not set");
return nullptr;
#endif
}
}

public:
AiterAsmKernel(const char* kernel_name, const char* hsaco_path)
{
init(kernel_name, load_hsaco_file(hsaco_path));
};

using AiterAsmKernelFast::launch_kernel;
};


} // namespace

static const std::string get_gpu_arch()
{
int device_count;
Expand Down Expand Up @@ -345,3 +391,27 @@ class HipDeviceGuard
private:
int prev_device_{};
};

template <class Key, class T, class Hash = std::hash<Key>, class KeyEqual = std::equal_to<Key>>
struct SynchronizedCache
{
template <typename K, typename F>
inline T& get_or_create(K&& k, F&& factory)
{
std::lock_guard<std::mutex> map_mu_guard(map_mu);

struct Wrapper
{
F& f;
// Makes sure we only invoke lambda on insert
operator T() && { return f(); }
};

auto [it, _] = map.try_emplace(std::forward<K>(k), Wrapper{factory});
return it->second;
}

private:
std::mutex map_mu;
std::unordered_map<Key, T, Hash, KeyEqual> map;
};
Loading
Loading