Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions transformer_engine/common/ck_fused_attn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,3 @@ set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN")

install(FILES ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
# copy v3 kernels to destination
foreach(ARCH IN LISTS V3_ASM_ARCHS)
install(DIRECTORY
${__AITER_SOURCE_DIR}/hsa/${ARCH}/fmha_v3_fwd
DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib/aiter/${ARCH}/
PATTERN "codegen.py" EXCLUDE)
install(DIRECTORY
${__AITER_SOURCE_DIR}/hsa/${ARCH}/fmha_v3_bwd
DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib/aiter/${ARCH}/
PATTERN "codegen.py" EXCLUDE)
endforeach()
Original file line number Diff line number Diff line change
Expand Up @@ -634,10 +634,6 @@ hipError_t ck_attn_bwd(

// print ck traits and args when needed
log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_type, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args);
if (uses_bwd_v3)
{
set_aiter_asm_dir();
}

float average_runtime = aiter::mha_bwd(fmha_args,
stream_config,
Expand Down Expand Up @@ -979,10 +975,6 @@ hipError_t ck_attn_varlen_bwd(

// print ck traits and args when needed
log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args);
if (uses_bwd_v3)
{
set_aiter_asm_dir();
}

float average_runtime = aiter::mha_bwd(fmha_args,
stream_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,6 @@ hipError_t ck_attn_fwd(

// print ck traits and args when needed
log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, fmha_args);
if (uses_fwd_v3)
{
set_aiter_asm_dir();
}

float average_runtime = aiter::mha_fwd(fmha_args,
stream_config,
Expand Down Expand Up @@ -434,10 +430,6 @@ hipError_t ck_attn_varlen_fwd(

// print ck traits and args when needed
log_fwd_config(__FUNCTION__, data_type_str, is_group_mode, has_logits_soft_cap, mask_type, bias_type, has_lse, has_dropout, is_v_rowmajor, do_fp8_static_quant, uses_fwd_v3, fmha_args);
if (uses_fwd_v3)
{
set_aiter_asm_dir();
}

float average_runtime = aiter::mha_fwd(fmha_args,
stream_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
************************************************************************/

#include <utility>
#include <dlfcn.h>
#include <filesystem>
#include <mutex> //once_flag
#include "ck_fused_attn_utils.hpp"
#include "ck_fused_attn/ck_fused_attn.hpp"
#include "mask.hpp"
Expand All @@ -16,40 +13,6 @@

namespace ck_fused_attn{

void set_aiter_asm_dir() {
static std::once_flag aiter_asm_dir_once;
std::call_once(aiter_asm_dir_once, []() {
hipDeviceProp_t prop;
hipError_t res= hipGetDeviceProperties(&prop, 0);
if (res != hipSuccess) {
throw std::runtime_error(std::string(
"hipGetDeviceProperties failed with error: ") + hipGetErrorString(res));
}
const char *arh_str = nullptr;
switch (prop.major*10 + prop.minor) {
case 94: // Gfx942
arh_str = "gfx942/"; // trailing slash is mandatory
break;
case 95: // Gfx950
arh_str = "gfx950/"; // trailing slash is mandatory
break;
default:
// Unsupported V3 architecture
return;
}
Dl_info info;
dladdr((void*)set_aiter_asm_dir, &info);
setenv("AITER_ASM_DIR",
(std::filesystem::path(info.dli_fname).parent_path() / "aiter" / arh_str).c_str(), 1);
if (const char* env_p = std::getenv("NVTE_LOG_CK_CONFIG") ) {
if (std::string(env_p) == "1"){
// Print the set environment variable for debugging purposes
std::cout << "AITER_ASM_DIR set to: " << getenv("AITER_ASM_DIR") << std::endl;
}
}
});
}

std::string get_data_type_str(DType dtype){
std::string data_type_str;
if(dtype==DType::kFloat16){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ enum class BiasType;
std::string get_data_type_str(DType dtype);
BiasShape get_bias_shape(uint64_t b, uint64_t h, uint64_t bias_b, uint64_t bias_h);
std::pair<bias_enum, BiasShape> get_ck_bias_type_shape(BiasType attn_bias_type, uint64_t b, uint64_t h, uint64_t bias_b, uint64_t bias_h);
void set_aiter_asm_dir();

}//namespace ck_fused_attn
#endif // CK_FUSED_ATTN_UTILS_H