Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] fix: obtain AMD GPU memory info through rocm_smi library #21190

Merged
merged 4 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cmake/onnxruntime_providers_rocm.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@

find_library(RCCL_LIB rccl REQUIRED)
find_library(ROCTRACER_LIB roctracer64 REQUIRED)
set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${RCCL_LIB} ${ROCTRACER_LIB})
find_package(rocm_smi REQUIRED)
set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${ROCM_SMI_LIBRARY} ${RCCL_LIB} ${ROCTRACER_LIB})
include_directories(${ROCM_SMI_INCLUDE_DIR})
link_directories(${ROCM_SMI_LIB_DIR})

file(GLOB_RECURSE onnxruntime_providers_rocm_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/rocm/*.h"
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/rocm/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
}

size_t GetMaxWorkspaceSize(miopenHandle_t handle, const MiopenConvState<miopenConvAlgoPerf_t>& s,
const miopenConvFwdAlgorithm_t* algo, int n_algo) {
const miopenConvFwdAlgorithm_t* algo, int n_algo, int device_id = 0) {
// TODO: get maximum available size from memory arena
size_t free, total;
HIP_CALL_THROW(hipMemGetInfo(&free, &total));
onnxruntime::rocm::hipMemGetInfoAlt(device_id, &free, &total);
// Assuming 10% of fragmentation
free = static_cast<size_t>(static_cast<double>(free) * 0.9);
size_t max_ws_size = 0;
Expand Down Expand Up @@ -283,7 +283,7 @@
int algo_count = 1;
const ROCMExecutionProvider* rocm_ep = static_cast<const ROCMExecutionProvider*>(this->Info().GetExecutionProvider());
static constexpr int num_algos = MIOPEN_CONVOLUTION_FWD_ALGO_COUNT;
size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos)
size_t max_ws_size = rocm_ep->GetMiopenConvUseMaxWorkspace() ? GetMaxWorkspaceSize(GetMiopenHandle(context), s_, kAllAlgos, num_algos, rocm_ep->GetDeviceId())

Check warning on line 286 in onnxruntime/core/providers/rocm/nn/conv.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/rocm/nn/conv.cc:286: Lines should be <= 120 characters long [whitespace/line_length] [2]
: AlgoSearchWorkspaceSize;
IAllocatorUniquePtr<void> algo_search_workspace = GetTransientScratchBuffer<void>(max_ws_size);
MIOPEN_RETURN_IF_ERROR(miopenFindConvolutionForwardAlgorithm(
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/rocm/rocm_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@
template void RocmCall<hiprandStatus_t, true>(hiprandStatus_t retCode, const char* exprString, const char* libName, hiprandStatus_t successCode, const char* msg, const char* file, const int line);
template Status RocmCall<hipfftResult, false>(hipfftResult retCode, const char* exprString, const char* libName, hipfftResult successCode, const char* msg, const char* file, const int line);
template void RocmCall<hipfftResult, true>(hipfftResult retCode, const char* exprString, const char* libName, hipfftResult successCode, const char* msg, const char* file, const int line);
template Status RocmCall<rsmi_status_t, false>(rsmi_status_t retCode, const char* exprString, const char* libName, rsmi_status_t successCode, const char* msg, const char* file, const int line);

Check warning on line 146 in onnxruntime/core/providers/rocm/rocm_call.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/rocm/rocm_call.cc:146: Lines should be <= 120 characters long [whitespace/line_length] [2]
template void RocmCall<rsmi_status_t, true>(rsmi_status_t retCode, const char* exprString, const char* libName, rsmi_status_t successCode, const char* msg, const char* file, const int line);

Check warning on line 147 in onnxruntime/core/providers/rocm/rocm_call.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/rocm/rocm_call.cc:147: Lines should be <= 120 characters long [whitespace/line_length] [2]

#ifdef ORT_USE_NCCL
template Status RocmCall<ncclResult_t, false>(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line);
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/rocm/rocm_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,17 @@ inline int warpSizeDynamic() {
return deviceProp.warpSize;
}

inline void hipMemGetInfoAlt(uint32_t deviceId, size_t* pFree, size_t* pTotal) {
const auto status = hipMemGetInfo(pFree, pTotal);
if (status != hipSuccess) {
size_t usedMemory = 0;
ROCMSMI_CALL_THROW(rsmi_init(0));
ROCMSMI_CALL_THROW(rsmi_dev_memory_total_get(deviceId, RSMI_MEM_TYPE_VIS_VRAM, pTotal));
ROCMSMI_CALL_THROW(rsmi_dev_memory_usage_get(deviceId, RSMI_MEM_TYPE_VIS_VRAM, &usedMemory));
*pFree = *pTotal - usedMemory;
ROCMSMI_CALL_THROW(rsmi_shut_down());
}
}

} // namespace rocm
} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/rocm/rocm_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& in

size_t free = 0;
size_t total = 0;
HIP_CALL_THROW(hipMemGetInfo(&free, &total));
onnxruntime::rocm::hipMemGetInfoAlt(info_.device_id, &free, &total);

OverrideTunableOpInfoByEnv(info_);

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/rocm/rocm_pch.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <hipsparse/hipsparse.h>
#include <miopen/miopen.h>
#include <rocblas/rocblas.h>
#include <rocm_smi/rocm_smi.h>

#ifdef ORT_USE_NCCL
#include <rccl/rccl.h>
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/rocm/shared_inc/rocm_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#define HIP_CALL(expr) (RocmCall<hipError_t, false>((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__))
#define ROCBLAS_CALL(expr) (RocmCall<rocblas_status, false>((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__))
#define ROCMSMI_CALL(expr) (RocmCall<rsmi_status_t, false>((expr), #expr, "ROCMSMI", RSMI_STATUS_SUCCESS, "", __FILE__, __LINE__))

Check warning on line 20 in onnxruntime/core/providers/rocm/shared_inc/rocm_call.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/rocm/shared_inc/rocm_call.h:20: Lines should be <= 120 characters long [whitespace/line_length] [2]

#define HIPSPARSE_CALL(expr) (RocmCall<hipsparseStatus_t, false>((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__))
#define HIPRAND_CALL(expr) (RocmCall<hiprandStatus_t, false>((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__))
Expand All @@ -27,6 +28,7 @@

#define HIP_CALL_THROW(expr) (RocmCall<hipError_t, true>((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__))
#define ROCBLAS_CALL_THROW(expr) (RocmCall<rocblas_status, true>((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__))
#define ROCMSMI_CALL_THROW(expr) (RocmCall<rsmi_status_t, true>((expr), #expr, "ROCMSMI", RSMI_STATUS_SUCCESS, "", __FILE__, __LINE__))

Check warning on line 31 in onnxruntime/core/providers/rocm/shared_inc/rocm_call.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/rocm/shared_inc/rocm_call.h:31: Lines should be <= 120 characters long [whitespace/line_length] [2]

#define HIPSPARSE_CALL_THROW(expr) (RocmCall<hipsparseStatus_t, true>((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__))
#define HIPRAND_CALL_THROW(expr) (RocmCall<hiprandStatus_t, true>((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__))
Expand Down
Loading