diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index a99bf8d1e4bee..20d8d9901d818 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -29,6 +29,7 @@ include(CheckLanguage) include(CMakeDependentOption) include(FetchContent) include(CheckFunctionExists) +include(CheckSymbolExists) include(GNUInstallDirs) # onnxruntime_providers_* require CMAKE_INSTALL_* variables # TODO: update this once all system adapt c++20 diff --git a/cmake/external/composable_kernel.cmake b/cmake/external/composable_kernel.cmake index 826bb7c468a02..dff6ed187616d 100644 --- a/cmake/external/composable_kernel.cmake +++ b/cmake/external/composable_kernel.cmake @@ -1,13 +1,14 @@ set(PATCH_CLANG ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Fix_Clang_Build.patch) set(PATCH_GFX12X ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Add_gfx12x_support.patch) +set(PATCH_GFX950 ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Add_gfx950.patch) include(FetchContent) onnxruntime_fetchcontent_declare(composable_kernel URL ${DEP_URL_composable_kernel} URL_HASH SHA1=${DEP_SHA1_composable_kernel} PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_CLANG} && - ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_GFX12X} - EXCLUDE_FROM_ALL + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_GFX12X} && + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_GFX950} ) FetchContent_GetProperties(composable_kernel) diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake index 495ff093326ad..90c0f447800c6 100644 --- a/cmake/onnxruntime_providers_migraphx.cmake +++ b/cmake/onnxruntime_providers_migraphx.cmake @@ -2,21 +2,11 @@ # Licensed under the MIT License. add_definitions(-DUSE_MIGRAPHX=1) - set(BUILD_LIBRARY_ONLY 1) - add_definitions("-DONNX_ML=1") - add_definitions("-DONNX_NAMESPACE=onnx") - include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR}) - set(MIGRAPHX_ROOT ${onnxruntime_MIGRAPHX_HOME}) - include_directories(${onnx_SOURCE_DIR}) + include_directories(${protobuf_SOURCE_DIR} ${eigen_SOURCE_DIR} ${onnx_SOURCE_DIR}) set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) - if ( CMAKE_COMPILER_IS_GNUCC ) + if (CMAKE_COMPILER_IS_GNUCC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers") endif() - set(CXX_VERSION_DEFINED TRUE) - set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) - if ( CMAKE_COMPILER_IS_GNUCC ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") - endif() # Add search paths for default rocm installation list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm $ENV{HIP_PATH}) @@ -33,8 +23,6 @@ find_package(hip REQUIRED) find_package(migraphx REQUIRED PATHS ${AMD_MIGRAPHX_HOME}) - set(migraphx_libs migraphx::c hip::host) - file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h" "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc" @@ -42,17 +30,17 @@ "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs}) - onnxruntime_add_shared_library_module(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) + onnxruntime_add_shared_library(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - add_dependencies(onnxruntime_providers_migraphx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime) + add_dependencies(onnxruntime_providers_migraphx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE migraphx::c hip::host ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) + target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/migraphx/onnxruntime) set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime") - target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1) - if(MSVC) + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1 ONNX_ML=1 ONNX_NAMESPACE=onnx) + if(CMAKE_SYSTEM_NAME STREQUAL "Windows") set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS /DEF:${ONNXRUNTIME_ROOT}/core/providers/migraphx/symbols.def) - target_link_libraries(onnxruntime_providers_migraphx PRIVATE ws2_32) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE ws2_32 shlwapi) else() target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare) set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") @@ -62,6 +50,15 @@ target_link_libraries(onnxruntime_providers_migraphx PRIVATE stdc++fs) endif() + set(CMAKE_REQUIRED_LIBRARIES migraphx::c) + + check_symbol_exists(migraphx_onnx_options_set_external_data_path + "migraphx/migraphx.h" HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH) + + if(HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH) + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH=1) + endif() + if (onnxruntime_ENABLE_TRAINING_OPS) onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_training) target_link_libraries(onnxruntime_providers_migraphx PRIVATE onnxruntime_training) @@ -70,16 +67,10 @@ endif() endif() - if(CMAKE_SYSTEM_NAME STREQUAL "Windows") - install(TARGETS onnxruntime_providers_migraphx - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - ) - else() - install(TARGETS onnxruntime_providers_migraphx - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - ) - endif() + install(TARGETS onnxruntime_providers_migraphx + EXPORT onnxruntime_providers_migraphxTargets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR} + ) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index d1fb06a95f4c9..b7b38e497340b 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -610,7 +610,6 @@ endif() if(onnxruntime_USE_MIGRAPHX) list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared) endif() if(onnxruntime_USE_COREML) @@ -691,9 +690,6 @@ endif() if(onnxruntime_USE_MIGRAPHX) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/migraphx/*) - list(APPEND onnxruntime_test_framework_src_patterns "${ONNXRUNTIME_ROOT}/core/providers/migraphx/migraphx_execution_provider_utils.h") - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared) endif() if(onnxruntime_USE_NNAPI_BUILTIN) diff --git a/cmake/patches/composable_kernel/Add_gfx12x_support.patch b/cmake/patches/composable_kernel/Add_gfx12x_support.patch index ef529184d2ed8..072aefab45a0a 100644 --- a/cmake/patches/composable_kernel/Add_gfx12x_support.patch +++ b/cmake/patches/composable_kernel/Add_gfx12x_support.patch @@ -14,7 +14,8 @@ index bc326c8b5..db5ad5052 100644 @@ -127,8 +127,10 @@ else() rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") elseif(GPU_ARCH MATCHES "gfx11") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") +- rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") ++ rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102;gfx1151") + elseif(GPU_ARCH MATCHES "gfx12") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201") else() @@ -259,7 +260,8 @@ index 55f562061..69a7abf62 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) - #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) +- #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) ++ #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || defined(__gfx1151__) #define __gfx11__ #endif +#if defined(__gfx1200__) || defined(__gfx1201__) diff --git a/cmake/patches/composable_kernel/Add_gfx950.patch b/cmake/patches/composable_kernel/Add_gfx950.patch new file mode 100644 index 0000000000000..f16524bda5be2 --- /dev/null +++ b/cmake/patches/composable_kernel/Add_gfx950.patch @@ -0,0 +1,14 @@ +diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp +index 55f562061..ee340eba1 100644 +--- a/include/ck/ck.hpp ++++ b/include/ck/ck.hpp +@@ -53,7 +53,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) + + // define general macros for various architectures + #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ +- defined(__gfx942__) ++ defined(__gfx942__) || defined(__gfx950__) + #define __gfx9__ + #endif + #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 86c0b60db2bc4..193f706549ec6 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -711,15 +711,13 @@ typedef struct OrtTensorRTProviderOptions { typedef struct OrtMIGraphXProviderOptions { int device_id; // hip device id. int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true + int migraphx_bf16_enable; // MIGraphX BF16 precision. Default 0 = false, nonzero = true int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true - int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true + int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, nonzero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name - int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true - const char* migraphx_save_model_path; // migraphx model path name - int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true - const char* migraphx_load_model_path; // migraphx model path name - bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false + const char* migraphx_cache_dir; // MIGraphX model cache directory + int migraphx_exhaustive_tune; // MIGraphX tuned compile. Default = false, nonzero = true /** \brief MIGraphX memory limit (To use all possible memory pass in maximum size_t) * Defaults to SIZE_MAX. diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index c0d8a4f02bbc3..776fd5fec367f 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -562,7 +562,10 @@ static D3D12_COMMAND_LIST_TYPE CalculateCommandListType(ID3D12Device* d3d12_devi sizeof(feature_levels) )); - auto use_compute_command_list = (feature_levels.MaxSupportedFeatureLevel <= D3D_FEATURE_LEVEL_1_0_CORE); + // Use compute queue whenever possible on supported hardware to avoid TDR and maintain UI QoS + // Core and generic devices only have compute queues, DX11 has "immediate" submission, DX12 has both + auto use_compute_command_list = (feature_levels.MaxSupportedFeatureLevel <= D3D_FEATURE_LEVEL_1_0_CORE) || + (feature_levels.MaxSupportedFeatureLevel >= D3D_FEATURE_LEVEL_12_0); if (use_compute_command_list) { diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc index cf9f44f4cd8f0..17dfdf4519b16 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc @@ -23,11 +23,11 @@ void MIGraphXAllocator::CheckDevice() const { #endif } -void* MIGraphXAllocator::Alloc(size_t size) { +void* MIGraphXAllocator::Alloc(const size_t size) { CheckDevice(); void* p = nullptr; if (size > 0) { - HIP_CALL_THROW(hipMalloc((void**)&p, size)); + HIP_CALL_THROW(hipMalloc(&p, size)); } return p; } @@ -37,7 +37,7 @@ void MIGraphXAllocator::Free(void* p) { (void)hipFree(p); // do not throw error since it's OK for hipFree to fail during shutdown } -void* MIGraphXExternalAllocator::Alloc(size_t size) { +void* MIGraphXExternalAllocator::Alloc(const size_t size) { void* p = nullptr; if (size > 0) { p = alloc_(size); @@ -51,27 +51,27 @@ void* MIGraphXExternalAllocator::Alloc(size_t size) { void MIGraphXExternalAllocator::Free(void* p) { free_(p); - std::lock_guard lock(lock_); - auto it = reserved_.find(p); - if (it != reserved_.end()) { + std::lock_guard lock(lock_); + if (const auto it = reserved_.find(p); it != reserved_.end()) { reserved_.erase(it); if (empty_cache_) empty_cache_(); } } -void* MIGraphXExternalAllocator::Reserve(size_t size) { +void* MIGraphXExternalAllocator::Reserve(const size_t size) { void* p = Alloc(size); - if (!p) return nullptr; - std::lock_guard lock(lock_); - ORT_ENFORCE(reserved_.find(p) == reserved_.end()); - reserved_.insert(p); + if (p != nullptr) { + std::lock_guard lock(lock_); + ORT_ENFORCE(reserved_.find(p) == reserved_.end()); + reserved_.insert(p); + } return p; } -void* MIGraphXPinnedAllocator::Alloc(size_t size) { +void* MIGraphXPinnedAllocator::Alloc(const size_t size) { void* p = nullptr; if (size > 0) { - HIP_CALL_THROW(hipHostMalloc((void**)&p, size)); + HIP_CALL_THROW(hipHostMalloc(&p, size)); } return p; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.h b/onnxruntime/core/providers/migraphx/migraphx_allocator.h index f6b7788e0604c..c06b650e67dfd 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.h @@ -11,27 +11,27 @@ namespace onnxruntime { class MIGraphXAllocator : public IAllocator { public: - MIGraphXAllocator(int device_id, const char* name) + MIGraphXAllocator(const OrtDevice::DeviceId device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, - static_cast(device_id)), + device_id), OrtMemTypeDefault)) {} - virtual void* Alloc(size_t size) override; - virtual void Free(void* p) override; + void* Alloc(size_t size) override; + void Free(void* p) override; private: void CheckDevice() const; }; -class MIGraphXExternalAllocator : public MIGraphXAllocator { +class MIGraphXExternalAllocator final : public MIGraphXAllocator { typedef void* (*ExternalAlloc)(size_t size); typedef void (*ExternalFree)(void* p); typedef void (*ExternalEmptyCache)(); public: - MIGraphXExternalAllocator(OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free, void* empty_cache) + MIGraphXExternalAllocator(const OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free, void* empty_cache) : MIGraphXAllocator(device_id, name) { alloc_ = reinterpret_cast(alloc); free_ = reinterpret_cast(free); @@ -52,11 +52,11 @@ class MIGraphXExternalAllocator : public MIGraphXAllocator { class MIGraphXPinnedAllocator final : public IAllocator { public: - MIGraphXPinnedAllocator(const int device_id, const char* name) + MIGraphXPinnedAllocator(const OrtDevice::DeviceId device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, - static_cast(device_id)), + device_id), OrtMemTypeCPUOutput)) {} void* Alloc(size_t size) override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index aa8b21ea3fe52..ea24b597714b9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1,11 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License -#include #include +#include +#include #include -#include +#include #include -#include +#include +#include #include "core/providers/shared_library/provider_api.h" #define ORT_API_MANUAL_INIT @@ -40,7 +42,7 @@ namespace onnxruntime { class Memcpy final : public OpKernel { public: - Memcpy(const OpKernelInfo& info) : OpKernel(info) {} + explicit Memcpy(const OpKernelInfo& info) : OpKernel(info) {} Status Compute(OpKernelContext* ctx) const override { const auto* X = ctx->Input(0); @@ -56,16 +58,13 @@ class Memcpy final : public OpKernel { } }; -template -KernelCreateInfo BuildKernelCreateInfo(); - ONNX_OPERATOR_KERNEL_EX( MemcpyFromHost, kOnnxDomain, 1, kMIGraphXExecutionProvider, - (*KernelDefBuilder::Create()) - .InputMemoryType(OrtMemTypeCPUInput, 0) + KernelDefBuilder::Create() + ->InputMemoryType(OrtMemTypeCPUInput, 0) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), Memcpy); @@ -74,14 +73,11 @@ ONNX_OPERATOR_KERNEL_EX( kOnnxDomain, 1, kMIGraphXExecutionProvider, - (*KernelDefBuilder::Create()) - .OutputMemoryType(OrtMemTypeCPUOutput, 0) + KernelDefBuilder::Create() + ->OutputMemoryType(OrtMemTypeCPUOutput, 0) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), Memcpy); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMIGraphXExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMIGraphXExecutionProvider, kOnnxDomain, 1, MemcpyToHost); - static std::shared_ptr s_kernel_registry; void InitializeRegistry() { @@ -106,9 +102,8 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c } MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, - info.device_id)}, + : IExecutionProvider{kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, + OrtDevice::VendorIds::AMD, info.device_id)}, info_(info) { InitProviderOrtApi(); get_flags_from_session_info(info); @@ -116,17 +111,25 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv get_flags_from_env(); } -MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { -} - void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info) { // Set GPU device to be used HIP_CALL_THROW(hipSetDevice(info_.device_id)); + HIP_CALL_THROW(hipGetDeviceProperties(&device_prop_, info.device_id)); t_ = migraphx::target(info.target_device.c_str()); // Quantization fp16_enable_ = info.fp16_enable; +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4 && HIP_VERSION_PATCH >= 2) + bf16_enable_ = info.bf16_enable; +#endif + + if (bf16_enable_ and fp16_enable_) { + bf16_enable_ = false; + fp16_enable_ = false; + LOGS_DEFAULT(FATAL) << "MIGraphX: BF16 and FP16 Quantization Mutually exclusive. Ignoring both Quantization flags"; + } + #if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) fp8_enable_ = info.fp8_enable; #else @@ -136,6 +139,8 @@ void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecut int8_enable_ = info.int8_enable; if (int8_enable_ and fp8_enable_) { + int8_enable_ = false; + fp8_enable_ = false; LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; } @@ -158,10 +163,7 @@ void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecut } // Save/load migraphx compiled models - save_compiled_model_ = info.save_compiled_model; - save_compiled_path_ = info.save_model_file; - load_compiled_model_ = info.load_compiled_model; - load_compiled_path_ = info.load_model_file; + model_cache_path_ = info.model_cache_dir; exhaustive_tune_ = info.exhaustive_tune; @@ -170,19 +172,34 @@ void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecut } void MIGraphXExecutionProvider::get_flags_from_env() { - LOGS_DEFAULT(WARNING) << "\n[MIGraphX EP] MIGraphX ENV Override Variables Set:"; - // whether fp16 is enable - const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP16Enable); + LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX ENV Override Variables Set:"; + // whether fp16 is enabled + const std::string fp16_enable_env = GetEnvironmentVar(migraphx_env_vars::kFP16Enable); if (!fp16_enable_env.empty()) { - fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); + fp16_enable_ = std::stoi(fp16_enable_env) != 0; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP16_ENABLE: " << fp16_enable_; } + const std::string bf16_enable_env = GetEnvironmentVar(migraphx_env_vars::kBF16Enable); + if (!bf16_enable_env.empty()) { +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4 && HIP_VERSION_PATCH >= 2) + bf16_enable_ = std::stoi(bf16_enable_env) != 0; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_BF16_ENABLE: " << fp16_enable_; +#else + LOGS_DEFAULT(WARNING) << "MIGraphX: BF16 Quantization requires ROCm 6.4.2 or greater"; + bf16_enable_ = false; +#endif + } + + if (bf16_enable_ and fp16_enable_) { + LOGS_DEFAULT(FATAL) << "\nMIGraphX: FP16 and BF16 Quantization Mutually exclusive. Ignoring both flags"; + } + // whether fp8 quantization is enabled - const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable); + const std::string fp8_enable_env = GetEnvironmentVar(migraphx_env_vars::kFP8Enable); if (!fp8_enable_env.empty()) { #if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) - fp8_enable_ = (std::stoi(fp8_enable_env) == 0 ? false : true); + fp8_enable_ = std::stoi(fp8_enable_env) != 0; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP8_ENABLE: " << fp8_enable_; #else LOGS_DEFAULT(WARNING) << "MIGraphX: FP8 Quantization requires ROCm 6.4 or greater"; @@ -191,9 +208,9 @@ void MIGraphXExecutionProvider::get_flags_from_env() { } // whether int8 is enabled - const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); + const std::string int8_enable_env = GetEnvironmentVar(migraphx_env_vars::kINT8Enable); if (!int8_enable_env.empty()) { - int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); + int8_enable_ = std::stoi(int8_enable_env) != 0; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_ENABLE: " << int8_enable_; } @@ -203,23 +220,22 @@ void MIGraphXExecutionProvider::get_flags_from_env() { if (int8_enable_ || fp8_enable_) { const std::string int8_calibration_cache_name_env = - onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName); + GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName); if (!int8_calibration_cache_name_env.empty()) { int8_calibration_cache_name_ = int8_calibration_cache_name_env; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CALIBRATION_TABLE_NAME: " << int8_calibration_cache_name_; } - const std::string cache_path = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kCachePath); + const std::string cache_path = GetEnvironmentVar(migraphx_env_vars::kCachePath); if (!cache_path.empty()) { calibration_cache_path_ = cache_path; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CACHE_PATH: " << calibration_cache_path_; } const std::string int8_use_native_migraphx_calibration_table_env = - onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable); + GetEnvironmentVar(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable); if (!int8_use_native_migraphx_calibration_table_env.empty()) { - int8_use_native_migraphx_calibration_table_ = - (std::stoi(int8_use_native_migraphx_calibration_table_env) == 0 ? false : true); + int8_use_native_migraphx_calibration_table_ = std::stoi(int8_use_native_migraphx_calibration_table_env) != 0; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE: " << int8_use_native_migraphx_calibration_table_; } @@ -239,69 +255,50 @@ void MIGraphXExecutionProvider::get_flags_from_env() { } // Save/load migraphx compiled models - const std::string save_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSaveCompiledModel); - if (!save_comp_model_env.empty()) { - save_compiled_model_ = (std::stoi(save_comp_model_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_MODEL: " << save_compiled_model_; - } - - const std::string save_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSavedModelPath); - if (save_compiled_model_ && !save_model_path_env.empty()) { - save_compiled_path_ = save_model_path_env; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_PATH: " << save_compiled_path_; - } - - const std::string load_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadCompiledModel); - if (!load_comp_model_env.empty()) { - load_compiled_model_ = (std::stoi(load_comp_model_env) == 0 ? false : true); - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_MODEL: " << load_compiled_model_; - } - - const std::string load_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadModelPath); - if (load_compiled_model_ && !load_model_path_env.empty()) { - load_compiled_path_ = load_model_path_env; - LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_PATH: " << load_compiled_path_; + const auto model_cache_path_env = GetEnvironmentVar(migraphx_env_vars::kModelCachePath); + if (!model_cache_path_env.empty()) { + model_cache_path_ = GetEnvironmentVar(migraphx_env_vars::kModelCachePath); + LOGS_DEFAULT(INFO) << "\n" + << migraphx_env_vars::kModelCachePath << ": " << model_cache_path_; } // dump unsupported ops - const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps); + const std::string dump_model_ops_env = GetEnvironmentVar(migraphx_env_vars::kDumpModelOps); if (!dump_model_ops_env.empty()) { - dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); + dump_model_ops_ = std::stoi(dump_model_ops_env) != 0; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_DUMP_MODEL_OPS: " << dump_model_ops_; } // Allow for exhaustive tune during compile - const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune); + const std::string exhaustive_tune_env = GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune); if (!exhaustive_tune_env.empty()) { - exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true); + exhaustive_tune_ = std::stoi(exhaustive_tune_env) != 0; LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_EXHAUSTIVE_TUNE_OPS: " << exhaustive_tune_; } } -void MIGraphXExecutionProvider::print_migraphx_ep_flags() { - LOGS_DEFAULT(WARNING) << "\n device_id: " << info_.device_id - << "\n migraphx_fp16_enable: " << fp16_enable_ - << "\n migraphx_fp8_enable: " << fp8_enable_ - << "\n migraphx_int8_enable: " << int8_enable_ +void MIGraphXExecutionProvider::print_migraphx_ep_flags() const { + LOGS_DEFAULT(VERBOSE) << "\n " << migraphx_provider_option::kDeviceId << ": " << info_.device_id + << "\n " << migraphx_provider_option::kFp16Enable << ": " << fp16_enable_ + << "\n " << migraphx_provider_option::kBf16Enable << ": " << bf16_enable_ + << "\n " << migraphx_provider_option::kFp8Enable << ": " << fp8_enable_ + << "\n " << migraphx_provider_option::kInt8Enable << ": " << int8_enable_ << "\n dump_model_ops: " << dump_model_ops_ - << "\n exhaustive_tune: " << exhaustive_tune_ - << "\n migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ + << "\n " << migraphx_provider_option::kExhaustiveTune << ": " << exhaustive_tune_ + << "\n " << migraphx_provider_option::kInt8CalibTable << ": " << int8_calibration_cache_name_ << "\n int8_calibration_cache_available: " << int8_calibration_cache_available_ - << "\n use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ - << "\n migraphx_save_compiled_model: " << save_compiled_model_ - << "\n migraphx_save_compiled_model_path: " << save_compiled_path_ - << "\n migraphx_load_compiled_model: " << load_compiled_model_ - << "\n migraphx_load_compiled_model_path: " << load_compiled_path_; + << "\n " << migraphx_provider_option::kInt8UseNativeCalibTable << ": " << int8_use_native_migraphx_calibration_table_ + << "\n " << migraphx_provider_option::kModelCacheDir << ": " << model_cache_path_; } -AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, - size_t migx_mem_limit, +AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(const OrtDevice::DeviceId device_id, + const size_t migx_mem_limit, ArenaExtendStrategy arena_extend_strategy, MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) { if (external_allocator_info.UseExternalAllocator()) { - AllocatorCreationInfo default_memory_info( + const AllocatorCreationInfo default_memory_info( [external_allocator_info](OrtDevice::DeviceId id) { return std::make_unique(id, HIP, external_allocator_info.alloc, @@ -312,40 +309,38 @@ AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::Devic false); return CreateAllocator(default_memory_info); - } else { - AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP); - }, - device_id, - true, - {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(migx_mem_limit, static_cast(arena_extend_strategy), - -1, -1, -1, -1L)}, - // make it stream aware - true, - // enable cross stream sharing? - false); - - // ROCM malloc/free is expensive so always use an arena - return CreateAllocator(default_memory_info); } + const AllocatorCreationInfo default_memory_info( + [](OrtDevice::DeviceId id) { + return std::make_unique(id, HIP); + }, + device_id, + true, + {default_memory_arena_cfg ? *default_memory_arena_cfg + : OrtArenaCfg(migx_mem_limit, static_cast(arena_extend_strategy), + -1, -1, -1, -1L)}, + // make it stream aware + true, + // enable cross stream sharing? + false); + + // ROCM malloc/free is expensive so always use an arena + return CreateAllocator(default_memory_info); } std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { - AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { return std::make_unique(device_id, onnxruntime::CUDA); }, - info_.device_id); - AllocatorCreationInfo pinned_allocator_info( - [](OrtDevice::DeviceId device_id) { - return std::make_unique(device_id, onnxruntime::CUDA_PINNED); + const AllocatorCreationInfo default_memory_info( + [](const OrtDevice::DeviceId device_id) { return std::make_unique(device_id, CUDA); }, info_.device_id); + const AllocatorCreationInfo pinned_allocator_info( + [](const OrtDevice::DeviceId device_id) { + return std::make_unique(device_id, CUDA_PINNED); }, - info_.device_id); - return std::vector{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)}; + 0); + return {CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)}; } -std::unique_ptr MIGraphXExecutionProvider::GetDataTransfer() const { - return std::make_unique(); +std::unique_ptr MIGraphXExecutionProvider::GetDataTransfer() const { + return std::make_unique(); } static bool IsTypeSupported(const NodeArg* node_arg) { @@ -356,6 +351,7 @@ static bool IsTypeSupported(const NodeArg* node_arg) { switch (type_proto->tensor_type().elem_type()) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ: @@ -379,13 +375,16 @@ static bool IsTypeSupported(const NodeArg* node_arg) { } } -static bool getMIGraphXType(ONNXTensorElementDataType type, +static bool getMIGraphXType(const ONNXTensorElementDataType type, migraphx_shape_datatype_t& mgx_type) { mgx_type = migraphx_shape_float_type; switch (type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: mgx_type = migraphx_shape_half_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + mgx_type = migraphx_shape_bf16_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: mgx_type = migraphx_shape_float_type; break; @@ -401,12 +400,13 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2: mgx_type = migraphx_shape_fp8e5m2_type; break; +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4) case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: mgx_type = migraphx_shape_fp8e5m2fnuz_type; break; +#endif case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: - mgx_type = migraphx_shape_int8_type; - break; + // No `break` intentional case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: mgx_type = migraphx_shape_int8_type; break; @@ -420,8 +420,7 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, mgx_type = migraphx_shape_int64_type; break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: - mgx_type = migraphx_shape_uint8_type; - break; + // No `break` intentional case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: mgx_type = migraphx_shape_uint8_type; break; @@ -438,8 +437,8 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, mgx_type = migraphx_shape_bool_type; break; default: - LOGS_DEFAULT(WARNING) << "MiGraphx: unsupported data type " << type << ", fallback to CPU"; - LOGS_DEFAULT(WARNING) << "implementation"; + LOGS_DEFAULT(WARNING) << "MIGraphX: unsupported data type " << type + << ", fallback to CPU" << "implementation"; return false; } @@ -456,7 +455,7 @@ std::vector toVector(const ONNX_NAMESPACE::int64s& nums) { return result; } -static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node) { +static bool IsUnsupportedOpMode(const GraphViewer& graph_viewer, const Node* node) { std::vector input_nodes; const auto& optype = node->OpType(); if (optype == "ArgMax" or optype == "ArgMin") { @@ -630,7 +629,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } } } else if (optype == "Split") { - // cannot process input dim of 0 size + // cannot process input dim of size 0 const auto arg_s = node->InputDefs()[0]->Shape(); if (arg_s != nullptr) { const auto& tensor_dims = arg_s->dim(); @@ -672,24 +671,21 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return false; } -void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::vector>& clusters, +void SubgraphPostProcessing(const GraphViewer& graph_viewer, std::vector>& clusters, [[maybe_unused]] const logging::Logger& logger) { // Then check whether a subgraph should fall back to CPU - // 1. Check whether a subgraph contains a RNN operator + // 1. Check whether a subgraph contains an RNN operator std::unordered_set rnn_names = {"RNN", "GRU", "LSTM"}; std::unordered_set op_names = {"AveragePool", "Conv", "Gemm", "LRN", "MatMul", "MaxPool"}; - auto it = std::remove_if(clusters.begin(), clusters.end(), [&](auto git) { + const auto it = std::remove_if(clusters.begin(), clusters.end(), [&](auto git) { for (auto index : git) { - auto node = graph_viewer.GetNode(index); - if (node->OpType() == "Reshape") { - const auto& args = node->InputDefs(); - if (args.size() == 2) { - std::vector node_inputs; - if (canEvalNodeArgument(graph_viewer, node, {1}, node_inputs)) { - return (not std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto index) { - return std::find(git.begin(), git.end(), index) != git.end(); - })); + if (auto node = graph_viewer.GetNode(index); node->OpType() == "Reshape") { + if (const auto& args = node->InputDefs(); args.size() == 2) { + if (std::vector node_inputs; canEvalNodeArgument(graph_viewer, node, {1}, node_inputs)) { + return !std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto i) { + return std::find(git.begin(), git.end(), i) != git.end(); + }); } else { return true; } @@ -711,7 +707,7 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v const auto& node = graph_viewer.GetNode(nid); const auto& op_type = node->OpType(); if (op_names.count(op_type) > 0) { - // check number of elements in input + // check the number of elements in input auto inputs = node->InputDefs(); if (std::any_of(inputs.begin(), inputs.end(), [&](auto& arg) { const auto& arg_s = arg->Shape(); @@ -741,7 +737,7 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v } static bool IsNodeSupported(const std::set& op_set, - const onnxruntime::GraphViewer& graph_viewer, + const GraphViewer& graph_viewer, const NodeIndex node_idx, [[maybe_unused]] const logging::Logger& logger) { const auto& node = graph_viewer.GetNode(node_idx); @@ -757,7 +753,7 @@ static bool IsNodeSupported(const std::set& op_set, // check data type bool are_types_supported = true; - node->ForEachDef([&are_types_supported](const onnxruntime::NodeArg& node_arg, bool /*is_input*/) { + node->ForEachDef([&are_types_supported](const NodeArg& node_arg, bool /*is_input*/) { are_types_supported &= IsTypeSupported(&node_arg); }); @@ -794,7 +790,7 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st } // Find inputs and outputs of the subgraph - std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); + std::unique_ptr sub_graph = IndexedSubGraph::Create(); std::unordered_map fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; std::unordered_set erased; int input_order = 0; @@ -875,15 +871,15 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st // Sort inputs and outputs by the order they were added std::multimap inputs, outputs; - for (auto it = fused_inputs.begin(), end = fused_inputs.end(); it != end; ++it) { - inputs.insert(std::pair(it->second, it->first)); + for (auto& [fst, snd] : fused_inputs) { + inputs.insert({snd, fst}); } - for (auto it = fused_outputs.begin(), end = fused_outputs.end(); it != end; ++it) { - outputs.insert(std::pair(it->second, it->first)); + for (auto& [fst, snd] : fused_outputs) { + outputs.insert({snd, fst}); } - // It is possible that an output of an node is put bebind the output of an later + // It is possible that an output of a node is put behind the output of a later // node in the graph output list. So we should sort the output name according // to the graph output names std::vector output_names; @@ -900,7 +896,7 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st } for (auto& name : graph_output_names) { - if (std::find(graph_out_names.begin(), graph_out_names.end(), name) != graph_out_names.end()) + if (graph_out_names.find(name) != graph_out_names.end()) output_names.push_back(name); } @@ -946,6 +942,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "Atan", "Atanh", "ATen", + "Attention", "AveragePool", "BatchNormalization", "BiasGelu", @@ -988,6 +985,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "Greater", "GreaterOrEqual", "GroupNormalization", + "GroupNorm", "GroupQueryAttention", "HardSigmoid", "HardSwish", @@ -1019,6 +1017,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "MultiHeadAttention", "Neg", "NegativeLogLikelihoodLoss", + "NhwcConv", "NonMaxSuppression", "NonZero", "Not", @@ -1095,7 +1094,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, if (IsNodeSupported(mgx_supported_ops, graph_viewer, node_idx, logger)) { // Collect inputs that are initializers graph_viewer.GetNode(node_idx)->ForEachDef([&mgx_required_initializers, - &graph_viewer](const onnxruntime::NodeArg& node_arg, bool is_input) { + &graph_viewer](const NodeArg& node_arg, bool is_input) { if(is_input && graph_viewer.GetAllInitializedTensors().count(node_arg.Name())) { mgx_required_initializers.insert(node_arg.Name()); } }, @@ -1109,7 +1108,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, } // Returns a vector clusters(or node_idx). For each unsupported node, the graph -// is split into 3 parts. supported_cluster + (UNsupported_node + rest_of_the_graph). +// is split into 3 parts. supported_cluster + (Unsupported_node + rest_of_the_graph). // This functions returns vector of all supported_subgraphx by amdmigraphx static std::vector> GetPartitionedSubgraphs(const std::vector& topological_order, @@ -1126,7 +1125,7 @@ GetPartitionedSubgraphs(const std::vector& topological_order, if (!this_subgraph.empty()) { mgx_subgraphx.push_back(std::move(this_subgraph)); } - // Point prev to node idx past this unsuported node. + // Point prev to node idx past this unsupported node. prev = ++it; } @@ -1140,7 +1139,7 @@ GetPartitionedSubgraphs(const std::vector& topological_order, } std::vector> -MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, +MIGraphXExecutionProvider::GetCapability(const GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const { @@ -1168,7 +1167,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v // If all ops are supported, no partitioning is required. Short-circuit and avoid splitting. if (unsupported_nodes.empty()) { - auto node_indices = graph_viewer.GetNodesInTopologicalOrder(); + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); auto sub_graph = GetSubGraph(node_indices, graph_viewer); result.push_back(ComputeCapability::Create(std::move(sub_graph))); } else { // unsupported_nodes_idx.empty() @@ -1245,43 +1244,40 @@ bool get_input_output_names(const GraphViewer& graph, // Attempt to load a model and catch any exceptions on load fail. // Useful to default to EP to trigger the compile if file doesn't exist or loading fails. -bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::string path) { - try { - if (load_enable) { - LOGS_DEFAULT(WARNING) << "Attempting to load model at:" << path; - prog = migraphx::load(path.c_str()); - LOGS_DEFAULT(WARNING) << "load model : Success"; - return true; - } else { - return false; - } - } catch (...) { - return false; +bool load_precompiled_model(migraphx::program& prog, const std::filesystem::path& path) try { + if (!path.empty() && exists(path)) { + LOGS_DEFAULT(VERBOSE) << "Attempting to load model at:" << path.string(); + prog = migraphx::load(path.string().c_str()); + LOGS_DEFAULT(VERBOSE) << "load model : Success"; + return true; } return false; +} catch (...) { + return false; } -void save_compiled_model(migraphx::program& prog, bool save_enable, std::string out_path) { - if (save_enable) { - LOGS_DEFAULT(WARNING) << "Model Save at " << out_path << ": Begin"; +void save_compiled_model(const migraphx::program& prog, const std::filesystem::path& path) { + if (!path.empty()) { + LOGS_DEFAULT(VERBOSE) << "Model Save at " << path << ": Begin"; migraphx::file_options fo; fo.set_file_format("msgpack"); - migraphx::save(prog, out_path.c_str(), fo); - LOGS_DEFAULT(WARNING) << "Model Save: Complete"; + save(prog, path.string().c_str(), fo); + LOGS_DEFAULT(VERBOSE) << "Model Save: Complete"; } } -// Order matters here especially if the program uses mixed quantization +// Order matters here, especially if the program uses mixed quantization // Calibrate on full precision for int8/fp8 and then quantize down to fp16 -void calibrate_and_quantize(migraphx::program& prog, +void calibrate_and_quantize(const migraphx::program& prog, const migraphx::target& t, - const migraphx::program_parameters quant_params, - bool fp16_enable, - bool int8_enable, - bool fp8_enable, - bool int8_calibration_cache_available, + const migraphx::program_parameters& quant_params, + const bool fp16_enable, + const bool bf16_enable, + const bool int8_enable, + const bool fp8_enable, + const bool int8_calibration_cache_available, std::unordered_map& dynamic_range_map) { - // Read in the calibration data and map it to an migraphx paramater map for the calibration ops + // Read in the calibration data and map it to a migraphx parameter map for the calibration ops if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { LOGS_DEFAULT(WARNING) << "Quantizing input program"; @@ -1289,8 +1285,8 @@ void calibrate_and_quantize(migraphx::program& prog, // Add all calibration data read in from int8 table for (auto& [cal_key, cal_val] : dynamic_range_map) { - auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); - quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); + const auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); + quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, &cal_val)); } // perform static quantization on the programs @@ -1319,11 +1315,19 @@ void calibrate_and_quantize(migraphx::program& prog, migraphx::quantize_fp16(prog); LOGS_DEFAULT(WARNING) << "Quantizing fp16: Complete"; } + +#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 4 && HIP_VERSION_PATCH >= 2) + if (bf16_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to bf16"; + migraphx::quantize_bf16(prog); + LOGS_DEFAULT(WARNING) << "Quantizing bf16: Complete"; + } +#endif } -void compile_program(migraphx::program& prog, +void compile_program(const migraphx::program& prog, const migraphx::target& t, - bool exhaustive_tune) { + const bool exhaustive_tune) { LOGS_DEFAULT(WARNING) << "Model Compile: Begin"; migraphx::compile_options co; co.set_fast_math(false); @@ -1332,6 +1336,27 @@ void compile_program(migraphx::program& prog, LOGS_DEFAULT(WARNING) << "Model Compile: Complete"; } +std::string to_hex(const uint64_t v) { + std::array s{}; + auto [ptr, _] = std::to_chars(s.data(), s.data() + s.size(), v, 16); + return std::string{s.data(), ptr}; +} + +template +std::string make_hash(T v) { + std::array temp{}; + MurmurHash3::x86_128(v.data(), gsl::narrow_cast(v.size()), temp[0], temp.data()); + return to_hex(temp[0] | static_cast(temp[1]) << 32); +} + +template <> +std::string make_hash(const char* v) { + return make_hash(std::string_view{v}); +} + +constexpr std::uint64_t MIGraphX_Version = + ((MIGRAPHX_VERSION_MAJOR << 16) | (MIGRAPHX_VERSION_MINOR << 8) | MIGRAPHX_VERSION_PATCH); + Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { migraphx::onnx_options options; @@ -1339,6 +1364,33 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& for (const auto& fused_node_graph : fused_nodes) { const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; const Node& fused_node = fused_node_graph.fused_node; + + std::filesystem::path model_cache_file; + auto mxr_filename_prefix = to_hex(MIGraphX_Version) + "-" + GenerateGraphId(graph_body_viewer) + "-" + make_hash(std::string_view{device_prop_.gcnArchName}) + "-"; + + // Get model input names (only first layer) + const Graph* cur_graph = &graph_body_viewer.GetGraph(); + while (cur_graph->IsSubgraph()) { + cur_graph = cur_graph->ParentGraph(); + } + const Graph& main_graph = *cur_graph; + const auto& input_tensor = main_graph.GetInputs(); + for (auto i : input_tensor) { + session_input_names.insert(i->Name()); + } + + // empty cache path means the MXR caching is disabled - always compile + if (!model_cache_path_.empty()) { + std::vector input_shapes; + for (std::size_t i = 0; i < session_input_names.size(); ++i) { + auto tensor_shape = input_tensor[i]->Shape(); + for (int j = 1; j < tensor_shape->dim_size(); ++j) { + input_shapes.push_back(tensor_shape->dim(j).dim_value()); + } + } + model_cache_file = model_cache_path_ / (mxr_filename_prefix + make_hash(input_shapes) + ".mxr"); + } + // map parameter input name to index std::unordered_map input_name_index; const auto& input_defs = fused_node.InputDefs(); @@ -1369,15 +1421,20 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::program prog; if (!no_input_shape) { - if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(INFO) << "No input shapes detected quantizing model"; + if (!load_precompiled_model(prog, model_cache_file)) { + LOGS_DEFAULT(VERBOSE) << "No input shapes detected quantizing model"; +#ifndef ENABLE_TRAINING_CORE +#ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH + options.set_external_data_path(model_path_.parent_path().string()); +#endif +#endif prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); migraphx::program_parameters quant_params; - calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, int8_enable_, + calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, bf16_enable_, int8_enable_, fp8_enable_, int8_calibration_cache_available_, dynamic_range_map_); compile_program(prog, t_, exhaustive_tune_); - save_compiled_model(prog, save_compiled_model_, save_compiled_path_); + save_compiled_model(prog, model_cache_file); } auto prog_output_shapes = prog.get_output_shapes(); @@ -1394,14 +1451,12 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& map_input_index_[fused_node.Name()] = input_name_index; map_no_input_shape_[fused_node.Name()] = no_input_shape; NodeComputeInfo compute_info; - compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { - std::unique_ptr p = std::make_unique(); + compute_info.create_state_func = [=](const ComputeContext* context, FunctionState* state) { + auto p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_, fp8_enable_, int8_enable_, - int8_calibration_cache_available_, dynamic_range_map_, - save_compiled_model_, save_compiled_path_, - load_compiled_model_, load_compiled_path_, dump_model_ops_}; + map_no_input_shape_[context->node_name], fp16_enable_, bf16_enable_, fp8_enable_, int8_enable_, + int8_calibration_cache_available_, dynamic_range_map_, model_cache_path_.string(), dump_model_ops_}; *state = p.release(); return 0; }; @@ -1411,9 +1466,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& delete static_cast(state); }; - compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + compute_info.compute_func = [this, mxr_filename_prefix](FunctionState state, const OrtApi* api, OrtKernelContext* context) { Ort::KernelContext ctx(context); - MIGraphXFuncState* mgx_state = reinterpret_cast(state); + auto mgx_state = static_cast(state); std::unordered_map& map_input_name_index = mgx_state->input_name_indexes; std::unordered_map& map_dynamic_range = mgx_state->dynamic_range_map; @@ -1423,6 +1478,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::onnx_options& cmp_options = mgx_state->options; bool& no_input_shape = mgx_state->no_input_shape; bool fp16_enable = mgx_state->fp16_enable; + bool bf16_enable = mgx_state->bf16_enable; bool fp8_enable = mgx_state->fp8_enable; bool int8_enable = mgx_state->int8_enable; bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available; @@ -1431,8 +1487,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // from input data bool input_shape_match = true; migraphx::program_parameter_shapes param_shapes; + std::vector input_shapes; + if (no_input_shape) { - LOGS_DEFAULT(INFO) << "Missing input shape setting input parameters again"; + LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again"; for (auto& it : map_input_name_index) { auto& name = it.first; auto& index = it.second; @@ -1444,7 +1502,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& input_shape_match = false; } } else { - LOGS_DEFAULT(INFO) << "Assigning inputs, and parameters from compiled model"; + LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model"; param_shapes = prog.get_parameter_shapes(); auto prog_output_shapes = prog.get_output_shapes(); @@ -1470,19 +1528,24 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& cmp_options.set_input_parameter_shape(name, ort_lens); input_shape_match = false; } + input_shapes.insert(input_shapes.end(), tensor_shape.begin(), tensor_shape.end()); } } } } - // input shapes are different, needs to re-parse onnx and - // re-compile the program + // input shapes are different, needs to reparse onnx and recompile the program if (!input_shape_match) { - if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { - LOGS_DEFAULT(VERBOSE) << "Input shape mismatch detected. Recompiling" << std::endl; + std::filesystem::path model_cache_file; + // empty cache path means the MXR caching is disabled - always compile + if (!model_cache_path_.empty()) { + model_cache_file = mgx_state->model_cache_dir / (mxr_filename_prefix + make_hash(input_shapes) + ".mxr"); + } + if (!load_precompiled_model(prog, model_cache_file)) { + LOGS_DEFAULT(VERBOSE) << "Input shape mismatch detected. Recompiling"; #ifndef ENABLE_TRAINING_CORE -#if HIP_VERSION_MAJOR > 6 || (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR >= 2) - cmp_options.set_external_data_path(model_path_.has_parent_path() ? model_path_.parent_path().string() : std::filesystem::current_path().string()); +#ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH + cmp_options.set_external_data_path(model_path_.parent_path().string()); #endif #endif prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); @@ -1509,10 +1572,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& } } } - calibrate_and_quantize(prog, t, quant_params, fp16_enable, int8_enable, + calibrate_and_quantize(prog, t, quant_params, fp16_enable, bf16_enable, int8_enable, fp8_enable, int8_calibration_cache_available, map_dynamic_range); compile_program(prog, t, exhaustive_tune_); - save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path); + save_compiled_model(prog, model_cache_file); } mgx_state->prog = prog; @@ -1526,7 +1589,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (param_shapes.size() > 0) { for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { - LOGS_DEFAULT(INFO) << "Setting parameters for:" << name; + LOGS_DEFAULT(VERBOSE) << "Setting parameters for:" << name; auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); @@ -1540,20 +1603,20 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - LOGS_DEFAULT(INFO) << "Writing Raw tensor data "; + LOGS_DEFAULT(VERBOSE) << "Writing Raw tensor data "; m.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } - // It is a output argument + // It is an output argument else { auto compute_output_index = [](const std::string& name) -> int { - std::string out_name_prefix = "#output_"; - auto pos = name.find(out_name_prefix); + const std::string out_name_prefix = "#output_"; + const auto pos = name.find(out_name_prefix); if (pos == std::string::npos) { return -1; } - std::string index_str = name.substr(pos + out_name_prefix.length()); + const std::string index_str = name.substr(pos + out_name_prefix.length()); return std::stoi(index_str); }; @@ -1576,13 +1639,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& { // lock to avoid race condition - std::lock_guard lock(*(mgx_state->mgx_mu_ptr)); + std::lock_guard lock(*mgx_state->mgx_mu_ptr); void* rocm_stream; Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &rocm_stream)); auto prog_outputs = prog.run_async(m, static_cast(rocm_stream)); - // In case of input parameters are reused as output parameter call hipMemcpy + // In the case of input parameters are reused as output parameter calls hipMemcpy auto output_num = prog_outputs.size(); if (prog_output_indices.size() < output_num) { for (std::size_t i = 0; i < output_num; ++i) { @@ -1601,7 +1664,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& static_cast(rocm_stream))); } } - }; + } return Status::OK(); }; @@ -1627,23 +1690,19 @@ OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) } Status MIGraphXExecutionProvider::Sync() const { - HIP_CALL_THROW(hipStreamSynchronize(static_cast(nullptr))); - - auto status = hipStreamQuery(stream_); - if (status != hipSuccess) { - return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::EP_FAIL); + HIP_CALL_THROW(hipStreamSynchronize(nullptr)); + if (hipStreamQuery(stream_) != hipSuccess) { + return {common::ONNXRUNTIME, common::EP_FAIL}; } return Status::OK(); } -Status MIGraphXExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { +Status MIGraphXExecutionProvider::OnRunStart(const RunOptions& /*run_options*/) { return Status::OK(); } -Status MIGraphXExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) { - auto status = hipStreamQuery(stream_); - - if (status != hipSuccess) { +Status MIGraphXExecutionProvider::OnRunEnd(bool /*sync_stream*/, const RunOptions& /*run_options*/) { + if (hipStreamQuery(stream_) != hipSuccess) { HIP_CALL_THROW(hipStreamSynchronize(stream_)); } return Status::OK(); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index aecccdd54d697..c5c1f0f2f1650 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -3,6 +3,7 @@ #pragma once +#include #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" #include @@ -16,20 +17,17 @@ namespace onnxruntime { namespace migraphx_env_vars { -static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE"; -static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE"; -static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE"; -static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; -static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; -static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH"; -static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; -static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL"; -static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILED_PATH"; -static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"; -static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILED_PATH"; -static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"; - -}; // namespace migraphx_env_vars +constexpr auto kFP16Enable = "ORT_MIGRAPHX_FP16_ENABLE"; +constexpr auto kBF16Enable = "ORT_MIGRAPHX_BF16_ENABLE"; +constexpr auto kFP8Enable = "ORT_MIGRAPHX_FP8_ENABLE"; +constexpr auto kINT8Enable = "ORT_MIGRAPHX_INT8_ENABLE"; +constexpr auto kDumpModelOps = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; +constexpr auto kINT8CalibrationTableName = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; +constexpr auto kCachePath = "ORT_MIGRAPHX_CACHE_PATH"; +constexpr auto kINT8UseNativeMIGraphXCalibrationTable = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; +constexpr auto kModelCachePath = "ORT_MIGRAPHX_MODEL_CACHE_PATH"; +constexpr auto kExhaustiveTune = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"; +} // namespace migraphx_env_vars // Information to construct kernel function state. struct MIGraphXFuncState { @@ -44,45 +42,43 @@ struct MIGraphXFuncState { std::mutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool bf16_enable = false; bool fp8_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; - bool save_compiled_mode = false; - std::string save_compiled_path; - bool load_compiled_mode = false; - std::string load_compiled_path; + std::filesystem::path model_cache_dir; bool dump_model_ops = false; bool exhaustive_tune = false; }; // Logical device representation. -class MIGraphXExecutionProvider : public IExecutionProvider { +class MIGraphXExecutionProvider final : public IExecutionProvider { public: explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); - ~MIGraphXExecutionProvider(); + ~MIGraphXExecutionProvider() override = default; void get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info); void get_flags_from_env(); - void print_migraphx_ep_flags(); + void print_migraphx_ep_flags() const; Status Sync() const override; - Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + Status OnRunStart(const RunOptions& run_options) override; - Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; + Status OnRunEnd(bool sync_stream, const RunOptions& run_options) override; std::vector> - GetCapability(const onnxruntime::GraphViewer& graph_viewer, + GetCapability(const GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; - common::Status Compile(const std::vector& fused_nodes, - std::vector& node_compute_funcs) override; + Status Compile(const std::vector& fused_nodes, + std::vector& node_compute_funcs) override; - virtual std::shared_ptr GetKernelRegistry() const override; - std::unique_ptr GetDataTransfer() const override; + std::shared_ptr GetKernelRegistry() const override; + std::unique_ptr GetDataTransfer() const override; static AllocatorPtr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t migx_mem_limit, ArenaExtendStrategy arena_extend_strategy, MIGraphXExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg); @@ -100,6 +96,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { private: MIGraphXExecutionProviderInfo info_; bool fp16_enable_ = false; + bool bf16_enable_ = false; bool fp8_enable_ = false; bool int8_enable_ = false; std::string int8_calibration_cache_name_; @@ -107,14 +104,13 @@ class MIGraphXExecutionProvider : public IExecutionProvider { bool int8_use_native_migraphx_calibration_table_ = false; std::string calibration_cache_path_; std::unordered_map dynamic_range_map_; - bool save_compiled_model_ = false; - std::string save_compiled_path_; - bool load_compiled_model_ = false; - std::string load_compiled_path_; + std::filesystem::path model_cache_path_{}; + std::set session_input_names; bool dump_model_ops_ = false; migraphx::target t_; std::mutex mgx_mu_; hipStream_t stream_ = nullptr; + hipDeviceProp_t device_prop_{}; bool exhaustive_tune_ = false; mutable std::filesystem::path model_path_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index cf21d791cfe6b..b2fda0885ff5c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -17,28 +17,6 @@ const EnumNameMapping arena_extend_strategy_mapping{ {ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"}, }; -namespace migraphx { -namespace provider_option_names { -constexpr const char* kDeviceId = "device_id"; -constexpr const char* kFp16Enable = "trt_fp16_enable"; -constexpr const char* kFp8Enable = "migx_fp8_enable"; -constexpr const char* kInt8Enable = "migx_int8_enable"; -constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; -constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; -constexpr const char* kSaveCompiledModel = "migx_save_compiled_model"; -constexpr const char* kSaveModelPath = "migx_save_model_name"; -constexpr const char* kLoadCompiledModel = "migx_load_compiled_model"; -constexpr const char* kLoadModelPath = "migx_load_model_name"; -constexpr const char* kExhaustiveTune = "migx_exhaustive_tune"; -constexpr const char* kMemLimit = "migx_mem_limit"; -constexpr const char* kArenaExtendStrategy = "migx_arena_extend_strategy"; -constexpr const char* kGpuExternalAlloc = "migx_external_alloc"; -constexpr const char* kGpuExternalFree = "migx_external_free"; -constexpr const char* kGpuExternalEmptyCache = "migx_external_empty_cache"; - -} // namespace provider_option_names -} // namespace migraphx - MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { MIGraphXExecutionProviderInfo info{}; void* alloc = nullptr; @@ -47,7 +25,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( - migraphx::provider_option_names::kDeviceId, + migraphx_provider_option::kDeviceId, [&info](const std::string& value_str) -> Status { ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); int num_devices{}; @@ -59,7 +37,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddValueParser( - migraphx::provider_option_names::kGpuExternalAlloc, + migraphx_provider_option::kGpuExternalAlloc, [&alloc](const std::string& value_str) -> Status { size_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); @@ -67,7 +45,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddValueParser( - migraphx::provider_option_names::kGpuExternalFree, + migraphx_provider_option::kGpuExternalFree, [&free](const std::string& value_str) -> Status { size_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); @@ -75,21 +53,21 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddValueParser( - migraphx::provider_option_names::kGpuExternalEmptyCache, + migraphx_provider_option::kGpuExternalEmptyCache, [&empty_cache](const std::string& value_str) -> Status { size_t address; ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); empty_cache = reinterpret_cast(address); return Status::OK(); }) - .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) - .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) - .AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model) - .AddAssignmentToReference(migraphx::provider_option_names::kExhaustiveTune, info.exhaustive_tune) - .AddAssignmentToReference(migraphx::provider_option_names::kMemLimit, info.mem_limit) - .AddAssignmentToEnumReference(migraphx::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) + .AddAssignmentToReference(migraphx_provider_option::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(migraphx_provider_option::kBf16Enable, info.bf16_enable) + .AddAssignmentToReference(migraphx_provider_option::kFp8Enable, info.fp8_enable) + .AddAssignmentToReference(migraphx_provider_option::kInt8Enable, info.int8_enable) + .AddAssignmentToReference(migraphx_provider_option::kModelCacheDir, info.model_cache_dir) + .AddAssignmentToReference(migraphx_provider_option::kExhaustiveTune, info.exhaustive_tune) + .AddAssignmentToReference(migraphx_provider_option::kMemLimit, info.mem_limit) + .AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) .Parse(options)); MIGraphXExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; @@ -100,34 +78,33 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXExecutionProviderInfo& info) { const ProviderOptions options{ - {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, - {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, - {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)}, - {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, - {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, - {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)}, - {migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.mem_limit)}, - {migraphx::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, - {migraphx::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, - {migraphx::provider_option_names::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.empty_cache))}, - {migraphx::provider_option_names::kArenaExtendStrategy, - EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, - {migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)}, + {migraphx_provider_option::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {migraphx_provider_option::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, + {migraphx_provider_option::kBf16Enable, MakeStringWithClassicLocale(info.bf16_enable)}, + {migraphx_provider_option::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)}, + {migraphx_provider_option::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, + {migraphx_provider_option::kModelCacheDir, MakeStringWithClassicLocale(info.model_cache_dir)}, + {migraphx_provider_option::kMemLimit, MakeStringWithClassicLocale(info.mem_limit)}, + {migraphx_provider_option::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, + {migraphx_provider_option::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, + {migraphx_provider_option::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.empty_cache))}, + {migraphx_provider_option::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, + {migraphx_provider_option::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)}, }; return options; } ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGraphXProviderOptions& info) { const ProviderOptions options{ - {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, - {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, - {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, - {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, - {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, - {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, - {migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.migraphx_mem_limit)}, - {migraphx::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast(info.migraphx_arena_extend_strategy))}, - {migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)}, + {migraphx_provider_option::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {migraphx_provider_option::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, + {migraphx_provider_option::kBf16Enable, MakeStringWithClassicLocale(info.migraphx_bf16_enable)}, + {migraphx_provider_option::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, + {migraphx_provider_option::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, + {migraphx_provider_option::kModelCacheDir, MakeStringWithClassicLocale(info.migraphx_cache_dir)}, + {migraphx_provider_option::kMemLimit, MakeStringWithClassicLocale(info.migraphx_mem_limit)}, + {migraphx_provider_option::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast(info.migraphx_arena_extend_strategy))}, + {migraphx_provider_option::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)}, }; return options; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index a598052c5f025..76745ed831f5e 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -14,6 +14,23 @@ namespace onnxruntime { +namespace migraphx_provider_option { +constexpr auto kDeviceId = "device_id"; +constexpr auto kFp16Enable = "migraphx_fp16_enable"; +constexpr auto kBf16Enable = "migraphx_bf16_enable"; +constexpr auto kFp8Enable = "migraphx_fp8_enable"; +constexpr auto kInt8Enable = "migraphx_int8_enable"; +constexpr auto kInt8CalibTable = "migraphx_int8_calibration_table_name"; +constexpr auto kInt8UseNativeCalibTable = "migraphx_int8_use_native_calibration_table"; +constexpr auto kModelCacheDir = "migraphx_model_cache_dir"; +constexpr auto kExhaustiveTune = "migraphx_exhaustive_tune"; +constexpr auto kMemLimit = "migraphx_mem_limit"; +constexpr auto kArenaExtendStrategy = "migraphx_arena_extend_strategy"; +constexpr auto kGpuExternalAlloc = "migraphx_external_alloc"; +constexpr auto kGpuExternalFree = "migraphx_external_free"; +constexpr auto kGpuExternalEmptyCache = "migraphx_external_empty_cache"; +} // namespace migraphx_provider_option + // Information needed to construct MIGraphX execution providers. struct MIGraphXExecutionProviderExternalAllocatorInfo { void* alloc{nullptr}; @@ -42,14 +59,12 @@ struct MIGraphXExecutionProviderInfo { std::string target_device; OrtDevice::DeviceId device_id{0}; bool fp16_enable{false}; + bool bf16_enable{false}; bool fp8_enable{false}; bool int8_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; - bool save_compiled_model{true}; - std::string save_model_file{"./compiled_model.mxr"}; - bool load_compiled_model{true}; - std::string load_model_file{"./compiled_model.mxr"}; + std::filesystem::path model_cache_dir{}; bool exhaustive_tune{false}; size_t mem_limit{std::numeric_limits::max()}; // Will be over-ridden by contents of `default_memory_arena_cfg` (if specified) @@ -75,11 +90,14 @@ struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { (static_cast(info.fp16_enable) << 18) ^ (static_cast(info.int8_enable) << 19) ^ (static_cast(info.int8_use_native_calibration_table) << 20) ^ - (static_cast(info.save_compiled_model) << 21) ^ - (static_cast(info.load_compiled_model) << 22) ^ - (static_cast(info.exhaustive_tune) << 23); + (static_cast(info.exhaustive_tune) << 21) ^ + (static_cast(info.bf16_enable) << 22); onnxruntime::HashCombine(data, value); + onnxruntime::HashCombine(info.target_device, value); + onnxruntime::HashCombine(info.default_memory_arena_cfg, value); + onnxruntime::HashCombine(info.int8_calibration_table_name, value); + onnxruntime::HashCombine(info.model_cache_dir, value); onnxruntime::HashCombine(info.mem_limit, value); // Memory pointers diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index 9274b5696185c..cb25db032ebf2 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -3,9 +3,11 @@ #pragma once +#include #include #include #include +#include #include #include #include @@ -14,12 +16,13 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/execution_provider.h" #include "core/common/path_string.h" +#include "core/framework/murmurhash3.h" namespace fs = std::filesystem; namespace onnxruntime { -bool IsGraphInput(const GraphViewer& graph, const std::string& name) { +inline bool IsGraphInput(const GraphViewer& graph, const std::string& name) { const auto& graph_inputs = graph.GetInputs(); std::vector input_names(graph_inputs.size()); std::transform(graph_inputs.begin(), graph_inputs.end(), input_names.begin(), [](auto in) { @@ -28,12 +31,12 @@ bool IsGraphInput(const GraphViewer& graph, const std::string& name) { return (std::find(input_names.begin(), input_names.end(), name) != input_names.end()); } -bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, [[maybe_unused]] bool check_outer_scope = true) { +inline bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, [[maybe_unused]] bool check_outer_scope = true) { const ONNX_NAMESPACE::TensorProto* initializer = nullptr; return graph.GetInitializedTensor(name, initializer); } -const Node* GetInputNode(const Node& node, int arg_index) { +inline const Node* GetInputNode(const Node& node, int arg_index) { int index = 0; for (auto nit = node.InputNodesBegin(); nit != node.InputNodesEnd(); ++nit, ++index) { if (index == arg_index) { @@ -44,7 +47,7 @@ const Node* GetInputNode(const Node& node, int arg_index) { return nullptr; } -std::size_t getNodeInputNum(const Node& node) { +inline std::size_t getNodeInputNum(const Node& node) { std::size_t node_num = 0; for (auto it = node.InputNodesBegin(); it != node.InputNodesEnd(); ++it) { node_num++; @@ -53,14 +56,14 @@ std::size_t getNodeInputNum(const Node& node) { return node_num; } -bool isInputNode(const Node* node, const std::string& name) { +inline bool isInputNode(const Node* node, const std::string& name) { auto outputs = node->OutputDefs(); return std::any_of(outputs.begin(), outputs.end(), [&](auto out) { return (out->Name() == name); }); } -bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector& input_nodes) { +inline bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector& input_nodes) { if (node == nullptr) { return false; } @@ -113,10 +116,10 @@ bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector return true; } -bool canEvalNodeArgument(const GraphViewer& graph, - const Node* node, - std::vector indices, - std::vector& input_nodes) { +inline bool canEvalNodeArgument(const GraphViewer& graph, + const Node* node, + std::vector indices, + std::vector& input_nodes) { input_nodes.clear(); std::vector in_nodes; for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit) { @@ -152,7 +155,7 @@ bool canEvalNodeArgument(const GraphViewer& graph, return true; } -float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { +inline float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { int s = (input >> 31) & 0x01; int e = ((input & 0x7f800000) >> 23) - 127; int p = -1; @@ -184,12 +187,12 @@ float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { * Taken from the tensorRT EP to allow MIGraphX EP to reuse calibration tables for existing models * */ -bool ReadDynamicRange(const std::string file_name, - const bool is_calibration_table, - std::unordered_map& dynamic_range_map) { +inline bool ReadDynamicRange(const std::string file_name, + const bool is_calibration_table, + std::unordered_map& dynamic_range_map) { std::ifstream infile(file_name, std::ios::binary | std::ios::in); - if (!infile) { + if (!infile.good()) { return false; } @@ -240,7 +243,7 @@ bool ReadDynamicRange(const std::string file_name, * Get cache by name * */ -std::string GetCachePath(const std::string& root, const std::string& name) { +inline std::string GetCachePath(const std::string& root, const std::string& name) { if (root.empty()) { return name; } else { @@ -250,4 +253,83 @@ std::string GetCachePath(const std::string& root, const std::string& name) { } } +inline std::string GenerateGraphId(const GraphViewer& graph_viewer) { + HashValue model_hash; + + // find the top level graph + const Graph* cur_graph = &graph_viewer.GetGraph(); + while (cur_graph->IsSubgraph()) { + cur_graph = cur_graph->ParentGraph(); + } + + const Graph& main_graph = *cur_graph; + uint32_t hash[4] = {0, 0, 0, 0}; + + auto hash_str = [&hash](const std::string& str) { + MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); + }; + + // Use the model's file name instead of the entire path to avoid cache regeneration if a path changes + const fs::path path{main_graph.ModelPath()}; + + if (path.has_filename()) { + const auto model_name = path.filename().string(); + + LOGS_DEFAULT(INFO) << "Model name is '" << model_name << "'"; + // Ensure enough characters are hashed in case model names are too short + const size_t model_name_length = model_name.length(); + constexpr size_t hash_string_length = 500; + std::string repeat_model_name = model_name; + for (size_t i = model_name_length; i > 0 && i < hash_string_length; i += model_name_length) { + repeat_model_name += model_name; + } + hash_str(repeat_model_name); + } else { + LOGS_DEFAULT(INFO) << "Model path is empty"; + } + + // fingerprint current graph by hashing graph inputs + for (const auto* node_arg : graph_viewer.GetInputsIncludingInitializers()) { + hash_str(node_arg->Name()); + } + + // hashing outputs, inputs and inputs shapes of each node + const int number_of_ort_nodes = graph_viewer.NumberOfNodes(); + std::vector nodes_vector(number_of_ort_nodes); + std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); + const std::vector& node_index = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto& index : nodes_vector) { + const auto& node = graph_viewer.GetNode(node_index[index]); + for (const auto* node_arg : node->OutputDefs()) { + if (node_arg != nullptr && node_arg->Exists()) { + hash_str(node_arg->Name()); + } + } + for (const auto* node_arg : node->InputDefs()) { + if (node_arg != nullptr && node_arg->Exists()) { + hash_str(node_arg->Name()); + if (node_arg->Shape() == nullptr) { + continue; + } + int dim_size = node_arg->Shape()->dim_size(); + for (int i = 0; i < dim_size; i++) { + hash_str(std::to_string(node_arg->Shape()->dim(i).dim_value())); + } + } + } + } + +#ifdef __linux__ + hash_str("LINUX"); +#elif defined(_WIN32) + hash_str("WINDOWS"); +#endif + + model_hash = hash[0] | static_cast(hash[1]) << 32; + + std::array s; + auto [ptr, ec] = std::to_chars(s.data(), s.data() + s.size(), model_hash, 16); + return std::string{s.data(), ptr}; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_inc.h b/onnxruntime/core/providers/migraphx/migraphx_inc.h index 2b035b20f619f..954a09faede9b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_inc.h +++ b/onnxruntime/core/providers/migraphx/migraphx_inc.h @@ -5,4 +5,5 @@ #include #include -#include +#include +#include \ No newline at end of file diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 4a3945ac680d0..30650005bbc21 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -1,28 +1,35 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License #include +#include +#include +#include + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#include +#include +#endif #include "core/providers/shared_library/provider_api.h" #include "core/providers/migraphx/migraphx_provider_factory.h" -#include "migraphx_execution_provider.h" -#include "migraphx_execution_provider_info.h" -#include "migraphx_provider_factory_creator.h" -#include "migraphx_allocator.h" -#include "gpu_data_transfer.h" +#include "core/providers/migraphx/migraphx_execution_provider.h" +#include "core/providers/migraphx/migraphx_execution_provider_info.h" +#include "core/providers/migraphx/migraphx_provider_factory_creator.h" +#include "core/providers/migraphx/migraphx_allocator.h" +#include "core/providers/migraphx/gpu_data_transfer.h" #include "core/framework/provider_options.h" #include "core/session/onnxruntime_c_api.h" -using namespace onnxruntime; - namespace onnxruntime { void InitializeRegistry(); void DeleteRegistry(); -struct MIGraphXProviderFactory : IExecutionProviderFactory { - MIGraphXProviderFactory(const MIGraphXExecutionProviderInfo& info) : info_{info} {} - ~MIGraphXProviderFactory() override {} +struct MIGraphXProviderFactory final : IExecutionProviderFactory { + explicit MIGraphXProviderFactory(MIGraphXExecutionProviderInfo info) : info_{std::move(info)} {} + ~MIGraphXProviderFactory() override = default; std::unique_ptr CreateProvider() override; @@ -35,15 +42,15 @@ std::unique_ptr MIGraphXProviderFactory::CreateProvider() { } struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { - std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) override { + std::unique_ptr CreateMIGraphXAllocator(const OrtDevice::DeviceId device_id, const char* name) override { return std::make_unique(device_id, name); } - std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) override { + std::unique_ptr CreateMIGraphXPinnedAllocator(const OrtDevice::DeviceId device_id, const char* name) override { return std::make_unique(device_id, name); } - void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) override { + void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, const size_t count) override { // hipMemcpy() operates on the default stream HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyHostToDevice)); @@ -52,24 +59,26 @@ struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { // The function will return once the pageable buffer has been copied to the staging memory for DMA transfer // to device memory, but the DMA to final destination may not have completed. - HIP_CALL_THROW(hipStreamSynchronize(0)); + HIP_CALL_THROW(hipStreamSynchronize(nullptr)); } // Used by onnxruntime_pybind_state.cc - void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override { + void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, const size_t count) override { // For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost)); } - std::shared_ptr CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { - return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, migx_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); + std::shared_ptr CreateMIGraphXAllocator(const OrtDevice::DeviceId device_id, const size_t mem_limit, const ArenaExtendStrategy arena_extend_strategy, const MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { + return MIGraphXExecutionProvider::CreateMIGraphXAllocator(device_id, mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); } } g_info; -struct MIGraphX_Provider : Provider { +struct MIGraphX_Provider final : Provider { void* GetInfo() override { return &g_info; } - std::shared_ptr CreateExecutionProviderFactory(int device_id) override { + virtual ~MIGraphX_Provider() = default; + + std::shared_ptr CreateExecutionProviderFactory(const int device_id) override { MIGraphXExecutionProviderInfo info; info.device_id = static_cast(device_id); info.target_device = "gpu"; @@ -82,6 +91,7 @@ struct MIGraphX_Provider : Provider { info.device_id = static_cast(options.device_id); info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; + info.bf16_enable = options.migraphx_bf16_enable; info.fp8_enable = options.migraphx_fp8_enable; info.exhaustive_tune = options.migraphx_exhaustive_tune; info.int8_enable = options.migraphx_int8_enable; @@ -90,26 +100,21 @@ struct MIGraphX_Provider : Provider { info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name; } info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0; - info.save_compiled_model = options.migraphx_save_compiled_model; - info.save_model_file = ""; - if (options.migraphx_save_model_path != nullptr) { - info.save_model_file = options.migraphx_save_model_path; - } - info.load_compiled_model = options.migraphx_load_compiled_model; - info.load_model_file = ""; - if (options.migraphx_load_model_path != nullptr) { - info.load_model_file = options.migraphx_load_model_path; + info.model_cache_dir = ""; + if (options.migraphx_cache_dir != nullptr) { + info.model_cache_dir = options.migraphx_cache_dir; } - info.arena_extend_strategy = static_cast(options.migraphx_arena_extend_strategy); + info.arena_extend_strategy = static_cast(options.migraphx_arena_extend_strategy); info.mem_limit = options.migraphx_mem_limit; return std::make_shared(info); } void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { - auto internal_options = onnxruntime::MIGraphXExecutionProviderInfo::FromProviderOptions(options); - auto& migx_options = *reinterpret_cast(provider_options); + auto internal_options = MIGraphXExecutionProviderInfo::FromProviderOptions(options); + auto& migx_options = *static_cast(provider_options); migx_options.device_id = internal_options.device_id; migx_options.migraphx_fp16_enable = internal_options.fp16_enable; + migx_options.migraphx_bf16_enable = internal_options.bf16_enable; migx_options.migraphx_fp8_enable = internal_options.fp8_enable; migx_options.migraphx_int8_enable = internal_options.int8_enable; migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune; @@ -126,25 +131,35 @@ struct MIGraphX_Provider : Provider { strncpy(dest, internal_options.int8_calibration_table_name.c_str(), str_size); #endif dest[str_size] = '\0'; - migx_options.migraphx_int8_calibration_table_name = (const char*)dest; + migx_options.migraphx_int8_calibration_table_name = static_cast(dest); } migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; - migx_options.migraphx_save_compiled_model = internal_options.save_compiled_model; - migx_options.migraphx_save_model_path = internal_options.save_model_file.c_str(); - migx_options.migraphx_load_compiled_model = internal_options.load_compiled_model; - migx_options.migraphx_load_model_path = internal_options.load_model_file.c_str(); + migx_options.migraphx_cache_dir = internal_options.model_cache_dir.string().c_str(); migx_options.migraphx_arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); migx_options.migraphx_mem_limit = internal_options.mem_limit; } ProviderOptions GetProviderOptions(const void* provider_options) override { auto& options = *reinterpret_cast(provider_options); - return onnxruntime::MIGraphXExecutionProviderInfo::ToProviderOptions(options); + return MIGraphXExecutionProviderInfo::ToProviderOptions(options); } void Initialize() override { +#ifdef _WIN32 + HMODULE module = nullptr; + if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + static_cast(static_cast(InitializeRegistry)), + &module) != 0) { + char buffer[MAX_PATH]; + if (GetModuleFileName(module, buffer, sizeof(buffer)) != 0) { + PathRemoveFileSpec(buffer); + SetDllDirectory(buffer); + } + } +#endif InitializeRegistry(); } @@ -157,7 +172,6 @@ struct MIGraphX_Provider : Provider { } // namespace onnxruntime extern "C" { - ORT_API(onnxruntime::Provider*, GetProvider) { return &onnxruntime::g_provider; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h index d1c9457bafa0f..313603a4ecbf0 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h @@ -1,7 +1,11 @@ -// Copyright 2019 AMD AMDMIGraphX +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License -#include "core/framework/provider_options.h" -#include "onnxruntime_c_api.h" +#pragma once + +#include +#include "core/framework/ortdevice.h" +#include "core/session/onnxruntime_c_api.h" namespace onnxruntime { class IAllocator; @@ -12,11 +16,11 @@ enum class ArenaExtendStrategy : int32_t; struct MIGraphXExecutionProviderExternalAllocatorInfo; struct ProviderInfo_MIGraphX { - virtual std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0; - virtual std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXPinnedAllocator(OrtDevice::DeviceId device_id, const char* name) = 0; virtual void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) = 0; virtual void MIGraphXMemcpy_DeviceToHost(void* dst, const void* src, size_t count) = 0; - virtual std::shared_ptr CreateMIGraphXAllocator(int16_t device_id, size_t migx_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::MIGraphXExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0; + virtual std::shared_ptr CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, size_t mem_limit, ArenaExtendStrategy arena_extend_strategy, MIGraphXExecutionProviderExternalAllocatorInfo external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0; protected: ~ProviderInfo_MIGraphX() = default; // Can only be destroyed through a subclass instance diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 3db35ae8769e0..26775e478a3da 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2002,7 +2002,7 @@ std::shared_ptr DnnlProviderFactoryCreator::Create(in return s_library_dnnl.Get().CreateExecutionProviderFactory(use_arena); } -std::shared_ptr MIGraphXProviderFactoryCreator::Create(int device_id) { +std::shared_ptr MIGraphXProviderFactoryCreator::Create(const int device_id) { return s_library_migraphx.Get().CreateExecutionProviderFactory(device_id); } diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 958c9fc46bcd8..431fb0f422b81 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -207,6 +207,7 @@ std::unique_ptr GetGPUDataTransfer() { #endif #ifdef USE_MIGRAPHX + void CpuToMIGraphXMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_MIGraphX().MIGraphXMemcpy_HostToDevice(dst, src, num_bytes); } diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index d1d4d6f3cdad5..352fe7755dc80 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -390,10 +390,10 @@ void addOrtValueMethods(pybind11::module& m) { py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetRocmToHostMemCpyFunction()); #elif USE_CANN py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction()); -#elif USE_DML - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction()); #elif USE_MIGRAPHX py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction()); +#elif USE_DML + py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction()); #else py::object obj = GetPyObjFromTensor(*ml_value, nullptr, nullptr); #endif diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index bdc4f65e590d9..12e3883e35e52 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -953,9 +953,7 @@ static std::shared_ptr CreateExecutionProviderFactory #endif } else if (type == kMIGraphXExecutionProvider) { #ifdef USE_MIGRAPHX - std::string calibration_table; - std::string save_model_path; - std::string load_model_path; + std::string model_cache_path, cal_table_name; auto it = provider_options_map.find(type); if (it != provider_options_map.end()) { OrtMIGraphXProviderOptions params{ @@ -964,12 +962,10 @@ static std::shared_ptr CreateExecutionProviderFactory 0, 0, 0, + 0, + nullptr, nullptr, - 1, - "./compiled_model.mxr", - 1, - "./compiled_model.mxr", - 1, + false, SIZE_MAX, 0}; for (auto option : it->second) { @@ -979,7 +975,7 @@ static std::shared_ptr CreateExecutionProviderFactory } else { ORT_THROW("[ERROR] [MIGraphX] The value for the key 'device_id' should be a number i.e. '0'.\n"); } - } else if (option.first == "migraphx_fp16_enable") { + } else if (option.first == migraphx_provider_option::kFp16Enable) { if (option.second == "True" || option.second == "true") { params.migraphx_fp16_enable = true; } else if (option.second == "False" || option.second == "false") { @@ -989,7 +985,17 @@ static std::shared_ptr CreateExecutionProviderFactory "[ERROR] [MIGraphX] The value for the key 'migraphx_fp16_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } - } else if (option.first == "migraphx_fp8_enable") { + } else if (option.first == migraphx_provider_option::kBf16Enable) { + if (option.second == "True" || option.second == "true") { + params.migraphx_bf16_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_bf16_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migraphx_bf16_enable' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } + } else if (option.first == migraphx_provider_option::kFp8Enable) { if (option.second == "True" || option.second == "true") { params.migraphx_fp8_enable = true; } else if (option.second == "False" || option.second == "false") { @@ -999,7 +1005,7 @@ static std::shared_ptr CreateExecutionProviderFactory "[ERROR] [MIGraphX] The value for the key 'migraphx_fp8_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } - } else if (option.first == "migraphx_int8_enable") { + } else if (option.first == migraphx_provider_option::kInt8Enable) { if (option.second == "True" || option.second == "true") { params.migraphx_int8_enable = true; } else if (option.second == "False" || option.second == "false") { @@ -1009,16 +1015,16 @@ static std::shared_ptr CreateExecutionProviderFactory "[ERROR] [MIGraphX] The value for the key 'migraphx_int8_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } - } else if (option.first == "migraphx_int8_calibration_table_name") { + } else if (option.first == migraphx_provider_option::kInt8CalibTable) { if (!option.second.empty()) { - calibration_table = option.second; - params.migraphx_int8_calibration_table_name = calibration_table.c_str(); + cal_table_name = option.second; + params.migraphx_int8_calibration_table_name = cal_table_name.c_str(); } else { ORT_THROW( "[ERROR] [MIGraphX] The value for the key 'migraphx_int8_calibration_table_name' should be a " "file name i.e. 'cal_table'.\n"); } - } else if (option.first == "migraphx_use_native_calibration_table") { + } else if (option.first == migraphx_provider_option::kInt8UseNativeCalibTable) { if (option.second == "True" || option.second == "true") { params.migraphx_use_native_calibration_table = true; } else if (option.second == "False" || option.second == "false") { @@ -1028,45 +1034,16 @@ static std::shared_ptr CreateExecutionProviderFactory "[ERROR] [MIGraphX] The value for the key 'migraphx_use_native_calibration_table' should be" " 'True' or 'False'. Default value is 'False'.\n"); } - } else if (option.first == "migraphx_save_compiled_model") { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp16_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp16_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_save_compiled_model' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_save_model_path") { - if (!option.second.empty()) { - save_model_path = option.second; - params.migraphx_save_model_path = save_model_path.c_str(); - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_save_model_name' should be a " - "file name i.e. 'compiled_model.mxr'.\n"); - } - } else if (option.first == "migraphx_load_compiled_model") { - if (option.second == "True" || option.second == "true") { - params.migraphx_fp16_enable = true; - } else if (option.second == "False" || option.second == "false") { - params.migraphx_fp16_enable = false; - } else { - ORT_THROW( - "[ERROR] [MIGraphX] The value for the key 'migraphx_load_compiled_model' should be" - " 'True' or 'False'. Default value is 'False'.\n"); - } - } else if (option.first == "migraphx_load_model_path") { + } else if (option.first == migraphx_provider_option::kModelCacheDir) { if (!option.second.empty()) { - load_model_path = option.second; - params.migraphx_load_model_path = load_model_path.c_str(); + model_cache_path = option.second; + params.migraphx_cache_dir = model_cache_path.c_str(); } else { ORT_THROW( "[ERROR] [MIGraphX] The value for the key 'migraphx_load_model_name' should be a " "file name i.e. 'compiled_model.mxr'.\n"); } - } else if (option.first == "migraphx_exhaustive_tune") { + } else if (option.first == migraphx_provider_option::kExhaustiveTune) { if (option.second == "True" || option.second == "true") { params.migraphx_exhaustive_tune = true; } else if (option.second == "False" || option.second == "false") { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index b3251abbc427e..0aa8c4cf81129 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -40,13 +40,13 @@ struct OrtStatus { #define BACKEND_PROC "CPU" #endif -#if USE_DNNL +#ifdef USE_DNNL #define BACKEND_DNNL "-DNNL" #else #define BACKEND_DNNL "" #endif -#if USE_MIGRAPHX +#ifdef USE_MIGRAPHX #define BACKEND_MIGRAPHX "-MIGRAPHX" #else #define BACKEND_MIGRAPHX "" diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 2e4aa3923b649..75094002933ec 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -86,11 +86,9 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { 0, 0, 0, + 0, + nullptr, nullptr, - 1, - "./compiled_model.mxr", - 1, - "./compiled_model.mxr", 1, SIZE_MAX, 0}; diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py index 6b028d8f05e11..596f3d5711743 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py @@ -10,12 +10,15 @@ from setuptools import setup from torch.utils import cpp_extension +# Resolve symlink in path to fix discrepancy between platforms +fused_ops_dir = os.path.realpath(os.path.dirname(__file__)) + filenames = [ - os.path.join(os.path.dirname(__file__), "fused_ops_frontend.cpp"), - os.path.join(os.path.dirname(__file__), "multi_tensor_adam.cu"), - os.path.join(os.path.dirname(__file__), "multi_tensor_scale_kernel.cu"), - os.path.join(os.path.dirname(__file__), "multi_tensor_axpby_kernel.cu"), - os.path.join(os.path.dirname(__file__), "multi_tensor_l2norm_kernel.cu"), + os.path.join(fused_ops_dir, "fused_ops_frontend.cpp"), + os.path.join(fused_ops_dir, "multi_tensor_adam.cu"), + os.path.join(fused_ops_dir, "multi_tensor_scale_kernel.cu"), + os.path.join(fused_ops_dir, "multi_tensor_axpby_kernel.cu"), + os.path.join(fused_ops_dir, "multi_tensor_l2norm_kernel.cu"), ] use_rocm = bool(os.environ["ONNXRUNTIME_ROCM_VERSION"]) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 29cb17257ecf1..da3b185b43a22 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -722,6 +722,8 @@ def generate_build_tree( cmake_args.append("-Donnxruntime_MIGRAPHX_HOME=" + migraphx_home) cmake_args.append("-Donnxruntime_ROCM_HOME=" + rocm_home) cmake_args.append("-Donnxruntime_ROCM_VERSION=" + args.rocm_version) + if args.rocm_gfx_arch: + cmake_args.append("-DCMAKE_HIP_ARCHITECTURES=" + args.rocm_gfx_arch) if args.use_tensorrt: cmake_args.append("-Donnxruntime_TENSORRT_HOME=" + tensorrt_home) diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index 7448ebe931d1e..d9f32bd3eca57 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -722,6 +722,9 @@ def add_execution_provider_args(parser: argparse.ArgumentParser) -> None: migx_group.add_argument("--use_rocm", action="store_true", help="Enable ROCm EP.") migx_group.add_argument("--rocm_version", help="ROCm stack version.") migx_group.add_argument("--rocm_home", help="Path to ROCm installation directory.") + migx_group.add_argument( + "--rocm_gfx_arch", help='Provide gfx arch. Example --rocm_gfx_arch gfx942 or --rocm_gfx_arch "gfx90a;gfx942" ' + ) # --- WebNN --- webnn_group = parser.add_argument_group("WebNN Execution Provider")