From 28a6443c35fc5ff6534fcfd5f02df54dd2e93e8f Mon Sep 17 00:00:00 2001 From: xinyuanlv Date: Tue, 12 Dec 2023 17:15:34 +0800 Subject: [PATCH] hook: Add hook for cuGetProcAddress_v2() to support CUDA >=12 CUDA 12.0 introduced a new function, cuGetProcAddress_v2(), which Runtime API applications call instead of cuGetProcAddress() in order to obtain Driver API symbols. To maintain compatibility with CUDA >=12.0 applications, add a hook for cuGetProcAddress_v2(). Closes #11 Signed-off-by: Xinyuan Lyu --- src/cuda_defs.h | 17 ++++++++++++++ src/hook.c | 59 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/src/cuda_defs.h b/src/cuda_defs.h index 1f083bc..589f110 100644 --- a/src/cuda_defs.h +++ b/src/cuda_defs.h @@ -67,6 +67,16 @@ typedef enum CUmemAttach_flags_enum { CU_MEM_ATTACH_GLOBAL = 0x1 } CUmemAttach_flags; +/** + * Flags to indicate search status. For more details see ::cuGetProcAddress + */ +typedef enum CUdriverProcAddressQueryResult_enum +{ + CU_GET_PROC_ADDRESS_SUCCESS = 0, /**< Symbol was succesfully found */ + CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND = 1, /**< Symbol was not found in search */ + CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT = 2 /**< Symbol was found but version supplied was not sufficient */ +} CUdriverProcAddressQueryResult; + typedef enum nvmlReturn_t_enum { NVML_SUCCESS = 0, NVML_ERROR_UNKNOWN = 999 @@ -93,6 +103,9 @@ typedef struct nvmlUtilization_st { /* typedefs for CUDA functions, to make hooking code cleaner */ typedef CUresult (*cuGetProcAddress_func)(const char *symbol, void **pfn, int cudaVersion, cuuint64_t flags); +typedef CUresult (*cuGetProcAddress_v2_func)(const char *symbol, void **pfn, + int cudaVersion, cuuint64_t flags, + CUdriverProcAddressQueryResult *symbolStatus); typedef CUresult (*cuMemAllocManaged_func)(CUdeviceptr *dptr, size_t bytesize, unsigned int flags); typedef CUresult (*cuMemFree_func)(CUdeviceptr dptr); @@ -135,6 +148,9 @@ typedef nvmlReturn_t (*nvmlDeviceGetHandleByIndex_func)(unsigned int index, /* Hooked CUDA functions */ extern CUresult cuGetProcAddress(const char *symbol, void **pfn, int cudaVersion, cuuint64_t flags); +extern CUresult cuGetProcAddress_v2(const char *symbol, void **pfn, + int cudaVersion, cuuint64_t flags, + CUdriverProcAddressQueryResult *symbolStatus); extern CUresult cuMemGetInfo(size_t *free, size_t *total); extern CUresult cuMemAlloc(CUdeviceptr *dptr, size_t bytesize); extern CUresult cuMemFree(CUdeviceptr dptr); @@ -163,6 +179,7 @@ extern CUresult cuMemcpyDtoDAsync(CUdeviceptr dstDevice, /* Real CUDA functions */ extern cuGetProcAddress_func real_cuGetProcAddress; +extern cuGetProcAddress_v2_func real_cuGetProcAddress_v2; extern cuMemAllocManaged_func real_cuMemAllocManaged; extern cuMemFree_func real_cuMemFree; extern cuMemGetInfo_func real_cuMemGetInfo; diff --git a/src/hook.c b/src/hook.c index 410504e..bccad05 100644 --- a/src/hook.c +++ b/src/hook.c @@ -60,6 +60,7 @@ cuMemcpyHtoDAsync_func real_cuMemcpyHtoDAsync = NULL; cuMemcpyDtoD_func real_cuMemcpyDtoD = NULL; cuMemcpyDtoDAsync_func real_cuMemcpyDtoDAsync = NULL; cuGetProcAddress_func real_cuGetProcAddress = NULL; +cuGetProcAddress_v2_func real_cuGetProcAddress_v2 = NULL; cuMemAllocManaged_func real_cuMemAllocManaged = NULL; cuMemFree_func real_cuMemFree = NULL; cuMemGetInfo_func real_cuMemGetInfo = NULL; @@ -175,6 +176,11 @@ static void bootstrap_cuda(void) * Runtime <11.3. */ log_debug("%s", error); + real_cuGetProcAddress_v2 = (cuGetProcAddress_v2_func) + real_dlsym_225(cuda_handle, CUDA_SYMBOL_STRING(cuGetProcAddress_v2)); + error = dlerror(); + if (error != NULL) + log_debug("%s", error); real_cuMemGetInfo = (cuMemGetInfo_func) real_dlsym_225(cuda_handle,CUDA_SYMBOL_STRING(cuMemGetInfo)); error = dlerror(); @@ -430,6 +436,8 @@ void *dlsym_225(void *handle, const char *symbol) return (void *)(&cuMemGetInfo); } else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuGetProcAddress)) == 0) { return (void *)(&cuGetProcAddress); + } else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuGetProcAddress_v2)) == 0) { + return (void *)(&cuGetProcAddress_v2); } else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuInit)) == 0) { return (void *)(&cuInit); } else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuLaunchKernel)) == 0) { @@ -467,6 +475,8 @@ void *dlsym_234(void *handle, const char *symbol) return (void *)(&cuMemGetInfo); } else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuGetProcAddress)) == 0) { return (void *)(&cuGetProcAddress); + } else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuGetProcAddress_v2)) == 0) { + return (void *)(&cuGetProcAddress_v2); } else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuInit)) == 0) { return (void *)(&cuInit); } else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuLaunchKernel)) == 0) { @@ -534,6 +544,8 @@ CUresult cuGetProcAddress(const char *symbol, void **pfn, int cudaVersion, *pfn = (void *)(&cuMemGetInfo); } else if (strcmp(symbol, "cuGetProcAddress") == 0) { *pfn = (void *)(&cuGetProcAddress); + } else if (strcmp(symbol, "cuGetProcAddress_v2") == 0) { + *pfn = (void *)(&cuGetProcAddress_v2); } else if (strcmp(symbol, "cuInit") == 0) { *pfn = (void *)(&cuInit); } else if (strcmp(symbol, "cuLaunchKernel") == 0) { @@ -562,6 +574,53 @@ CUresult cuGetProcAddress(const char *symbol, void **pfn, int cudaVersion, } +CUresult cuGetProcAddress_v2(const char *symbol, void **pfn, int cudaVersion, + cuuint64_t flags, CUdriverProcAddressQueryResult *symbolStatus) +{ + true_or_exit(pthread_once(&init_libnvshare_done, initialize_libnvshare) == 0); + true_or_exit(pthread_once(&init_done, initialize_client) == 0); + CUresult result = CUDA_SUCCESS; + + if (real_cuGetProcAddress == NULL) return CUDA_ERROR_NOT_INITIALIZED; + + if (strcmp(symbol, "cuMemAlloc") == 0) { + *pfn = (void *)(&cuMemAlloc); + } else if (strcmp(symbol, "cuMemFree") == 0) { + *pfn = (void *)(&cuMemFree); + } else if (strcmp(symbol, "cuMemGetInfo") == 0) { + *pfn = (void *)(&cuMemGetInfo); + } else if (strcmp(symbol, "cuGetProcAddress") == 0) { + *pfn = (void *)(&cuGetProcAddress); + } else if (strcmp(symbol, "cuGetProcAddress_v2") == 0) { + *pfn = (void *)(&cuGetProcAddress_v2); + } else if (strcmp(symbol, "cuInit") == 0) { + *pfn = (void *)(&cuInit); + } else if (strcmp(symbol, "cuLaunchKernel") == 0) { + *pfn = (void *)(&cuLaunchKernel); + } else if (strcmp(symbol, "cuMemcpy") == 0) { + *pfn = (void *)(&cuMemcpy); + } else if (strcmp(symbol, "cuMemcpyAsync") == 0) { + *pfn = (void *)(&cuMemcpyAsync); + } else if (strcmp(symbol, "cuMemcpyDtoH") == 0) { + *pfn = (void *)(&cuMemcpyDtoH); + } else if (strcmp(symbol, "cuMemcpyDtoHAsync") == 0) { + *pfn = (void *)(&cuMemcpyDtoHAsync); + } else if (strcmp(symbol, "cuMemcpyHtoD") == 0) { + *pfn = (void *)(&cuMemcpyHtoD); + } else if (strcmp(symbol, "cuMemcpyHtoDAsync") == 0) { + *pfn = (void *)(&cuMemcpyHtoDAsync); + } else if (strcmp(symbol, "cuMemcpyDtoD") == 0) { + *pfn = (void *)(&cuMemcpyDtoD); + } else if (strcmp(symbol, "cuMemcpyDtoDAsync") == 0) { + *pfn = (void *)(&cuMemcpyDtoDAsync); + } else { + result = real_cuGetProcAddress_v2(symbol, pfn, cudaVersion, flags, symbolStatus); + } + + return result; +} + + CUresult cuMemAlloc(CUdeviceptr *dptr, size_t bytesize) { static int got_max_mem_size = 0;