diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/CMakeLists.txt index e9e9fa4cde7e..940dbf10fc0e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/CMakeLists.txt @@ -51,6 +51,7 @@ iree_cc_library( # Query for list of device libs to build. get_property(AMD_DEVICE_LIBS GLOBAL PROPERTY AMD_DEVICE_LIBS) +set(AMD_REQUIRED_LIBS "ocml" "ockl" "opencl") set(_platform_lib_reldir "iree_platform_libs/rocm") file(MAKE_DIRECTORY "${IREE_COMPILER_DYLIB_DIR}/${_platform_lib_reldir}") @@ -59,7 +60,7 @@ file(MAKE_DIRECTORY "${IREE_COMPILER_DYLIB_DIR}/${_platform_lib_reldir}") set(_all_device_bc_deps) set(_all_device_bc_copy_commands) set(_all_device_bc_files) -foreach (_device_lib_target ${AMD_DEVICE_LIBS}) +foreach (_device_lib_target ${AMD_REQUIRED_LIBS}) get_target_property(_device_basename ${_device_lib_target} ARCHIVE_OUTPUT_NAME) get_target_property(_device_output_path ${_device_lib_target} OUTPUT_NAME) set(_device_bc_relpath "${_platform_lib_reldir}/${_device_basename}.bc") diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTargetUtils.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTargetUtils.cpp index eaca11302800..b4373f849937 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTargetUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Target/ROCM/ROCMTargetUtils.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Dialect/HAL/Target/ROCM/ROCMTarget.h" #include "iree/compiler/Utils/ToolUtils.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Module.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Linker/Linker.h" @@ -96,13 +97,8 @@ linkWithBitcodeVector(llvm::Module *module, static std::vector getROCDLPaths(std::string targetChip, std::string bitCodeDir) { // AMDGPU bitcodes. - int lenOfChipPrefix = 3; - std::string chipId = targetChip.substr(lenOfChipPrefix); - std::string chipISABC = "oclc_isa_version_" + chipId + ".bc"; static const std::vector rocdlFilenames( - {"opencl.bc", "ocml.bc", "ockl.bc", "oclc_finite_only_off.bc", - "oclc_daz_opt_off.bc", "oclc_correctly_rounded_sqrt_on.bc", - "oclc_unsafe_math_off.bc", "oclc_wavefrontsize64_on.bc", chipISABC}); + {"opencl.bc", "ocml.bc", "ockl.bc"}); // Construct full path to ROCDL bitcode libraries. std::vector result; @@ -113,6 +109,64 @@ static std::vector getROCDLPaths(std::string targetChip, return result; } +static void overridePlatformGlobal(llvm::Module *module, StringRef globalName, + uint32_t newValue, llvm::Type *globalTy) { + // NOTE: the global will not be defined if it is not used in the module. + auto *globalValue = module->getNamedGlobal(globalName); + if (!globalValue) + return; + globalValue->setDSOLocal(true); + globalValue->setConstant(true); + globalValue->setInitializer(llvm::ConstantInt::get( + globalValue->getValueType(), + APInt(globalValue->getValueType()->getIntegerBitWidth(), newValue))); +} + +static LogicalResult linkModuleWithGlobal(llvm::Module *module, + std::string &targetChip) { + // Link target chip ISA version as global. + const int kLenOfChipPrefix = 3; + std::string chipId = targetChip.substr(kLenOfChipPrefix); + // i.e gfx90a -> 9000 series. + int chipArch = stoi(chipId.substr(0, chipId.length() - 1)) * 100; + // Oldest GFX arch supported is gfx60x. + if (chipArch < 6000) + return failure(); + // Latest GFX arch supported is gfx115x. + if (chipArch > 11000) + return failure(); + // Get chip code from suffix. i.e gfx1103 -> `3`. + // gfx90a -> `a` == `10`. + // gfx90c -> `c` == `12`. + std::string chipSuffix = chipId.substr(chipId.length() - 1); + uint32_t chipCode; + if (chipSuffix == "a") { + chipCode = chipArch + 10; + } else if (chipSuffix == "c") { + chipCode = chipArch + 12; + } else { + if (!std::isdigit(chipSuffix[0])) + return failure(); + chipCode = chipArch + stoi(chipSuffix); + } + auto *int32Type = llvm::Type::getInt32Ty(module->getContext()); + overridePlatformGlobal(module, "__oclc_ISA_version", chipCode, int32Type); + + // Link oclc configurations as globals. + auto *boolType = llvm::Type::getInt8Ty(module->getContext()); + static const std::vector> rocdlGlobalParams( + {{"__oclc_finite_only_opt", false}, + {"__oclc_daz_opt", false}, + {"__oclc_correctly_rounded_sqrt32", true}, + {"__oclc_unsafe_math_opt", false}, + {"__oclc_wavefrontsize64", true}}); + for (auto &globalParam : rocdlGlobalParams) { + overridePlatformGlobal(module, globalParam.first, globalParam.second, + boolType); + } + return success(); +} + // Links ROCm-Device-Libs into the given module if the module needs it. void linkROCDLIfNecessary(llvm::Module *module, std::string targetChip, std::string bitCodeDir) { @@ -122,6 +176,9 @@ void linkROCDLIfNecessary(llvm::Module *module, std::string targetChip, if (!succeeded(HAL::linkWithBitcodeVector( module, getROCDLPaths(targetChip, bitCodeDir)))) { llvm::WithColor::error(llvm::errs()) << "Fail to Link ROCDL.\n"; + } + if (!succeeded(HAL::linkModuleWithGlobal(module, targetChip))) { + llvm::WithColor::error(llvm::errs()) << "Fail to Link with Globals.\n"; }; }