Skip to content

Commit

Permalink
hook: Add hook for cuGetProcAddress_v2() to support CUDA >=12
Browse files Browse the repository at this point in the history
CUDA 12.0 introduced a new function, cuGetProcAddress_v2(), which Runtime
API applications call instead of cuGetProcAddress() in order to obtain
Driver API symbols.

To maintain compatibility with CUDA >=12.0 applications,
add a hook for cuGetProcAddress_v2().

Closes #11

Signed-off-by: Xinyuan Lyu <[email protected]>
Reviewed-by: George Alexopoulos <[email protected]>
  • Loading branch information
pokerfaceSad authored and grgalex committed Jan 29, 2024
1 parent d56dc01 commit 000bfa0
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
20 changes: 20 additions & 0 deletions src/cuda_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ typedef enum CUmemAttach_flags_enum {
CU_MEM_ATTACH_GLOBAL = 0x1
} CUmemAttach_flags;

/*
* Flags to indicate CUDA symbol query status.
* For more details see https://docs.nvidia.com/cuda/archive/12.0.0/cuda-driver-api/group__CUDA__DRIVER__ENTRY__POINT.html
*/
typedef enum CUdriverProcAddressQueryResult_enum {
/* Symbol was succesfully found */
CU_GET_PROC_ADDRESS_SUCCESS = 0,
/* Symbol was not found in search */
CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND = 1,
/* Symbol was found but version supplied was not sufficient */
CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT = 2
} CUdriverProcAddressQueryResult;

typedef enum nvmlReturn_t_enum {
NVML_SUCCESS = 0,
NVML_ERROR_UNKNOWN = 999
Expand All @@ -93,6 +106,9 @@ typedef struct nvmlUtilization_st {
/* typedefs for CUDA functions, to make hooking code cleaner */
typedef CUresult (*cuGetProcAddress_func)(const char *symbol, void **pfn,
int cudaVersion, cuuint64_t flags);
typedef CUresult (*cuGetProcAddress_v2_func)(const char *symbol, void **pfn,
int cudaVersion, cuuint64_t flags,
CUdriverProcAddressQueryResult *symbolStatus);
typedef CUresult (*cuMemAllocManaged_func)(CUdeviceptr *dptr, size_t bytesize,
unsigned int flags);
typedef CUresult (*cuMemFree_func)(CUdeviceptr dptr);
Expand Down Expand Up @@ -135,6 +151,9 @@ typedef nvmlReturn_t (*nvmlDeviceGetHandleByIndex_func)(unsigned int index,
/* Hooked CUDA functions */
extern CUresult cuGetProcAddress(const char *symbol, void **pfn,
int cudaVersion, cuuint64_t flags);
extern CUresult cuGetProcAddress_v2(const char *symbol, void **pfn,
int cudaVersion, cuuint64_t flags,
CUdriverProcAddressQueryResult *symbolStatus);
extern CUresult cuMemGetInfo(size_t *free, size_t *total);
extern CUresult cuMemAlloc(CUdeviceptr *dptr, size_t bytesize);
extern CUresult cuMemFree(CUdeviceptr dptr);
Expand Down Expand Up @@ -163,6 +182,7 @@ extern CUresult cuMemcpyDtoDAsync(CUdeviceptr dstDevice,

/* Real CUDA functions */
extern cuGetProcAddress_func real_cuGetProcAddress;
extern cuGetProcAddress_v2_func real_cuGetProcAddress_v2;
extern cuMemAllocManaged_func real_cuMemAllocManaged;
extern cuMemFree_func real_cuMemFree;
extern cuMemGetInfo_func real_cuMemGetInfo;
Expand Down
80 changes: 80 additions & 0 deletions src/hook.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ cuMemcpyHtoDAsync_func real_cuMemcpyHtoDAsync = NULL;
cuMemcpyDtoD_func real_cuMemcpyDtoD = NULL;
cuMemcpyDtoDAsync_func real_cuMemcpyDtoDAsync = NULL;
cuGetProcAddress_func real_cuGetProcAddress = NULL;
cuGetProcAddress_v2_func real_cuGetProcAddress_v2 = NULL;
cuMemAllocManaged_func real_cuMemAllocManaged = NULL;
cuMemFree_func real_cuMemFree = NULL;
cuMemGetInfo_func real_cuMemGetInfo = NULL;
Expand Down Expand Up @@ -175,6 +176,16 @@ static void bootstrap_cuda(void)
* Runtime <11.3.
*/
log_debug("%s", error);
real_cuGetProcAddress_v2 = (cuGetProcAddress_v2_func)
real_dlsym_225(cuda_handle, CUDA_SYMBOL_STRING(cuGetProcAddress_v2));
error = dlerror();
if (error != NULL)
/*
* Print a debug message instead of failing immediately, since
* this symbol may not be used. This may be the case for CUDA
* Runtime <12.0.
*/
log_debug("%s", error);
real_cuMemGetInfo = (cuMemGetInfo_func)
real_dlsym_225(cuda_handle,CUDA_SYMBOL_STRING(cuMemGetInfo));
error = dlerror();
Expand Down Expand Up @@ -430,6 +441,8 @@ void *dlsym_225(void *handle, const char *symbol)
return (void *)(&cuMemGetInfo);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuGetProcAddress)) == 0) {
return (void *)(&cuGetProcAddress);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuGetProcAddress_v2)) == 0) {
return (void *)(&cuGetProcAddress_v2);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuInit)) == 0) {
return (void *)(&cuInit);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuLaunchKernel)) == 0) {
Expand Down Expand Up @@ -467,6 +480,8 @@ void *dlsym_234(void *handle, const char *symbol)
return (void *)(&cuMemGetInfo);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuGetProcAddress)) == 0) {
return (void *)(&cuGetProcAddress);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuGetProcAddress_v2)) == 0) {
return (void *)(&cuGetProcAddress_v2);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuInit)) == 0) {
return (void *)(&cuInit);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuLaunchKernel)) == 0) {
Expand Down Expand Up @@ -535,6 +550,8 @@ CUresult cuGetProcAddress(const char *symbol, void **pfn, int cudaVersion,
*pfn = (void *)(&cuMemGetInfo);
} else if (strcmp(symbol, "cuGetProcAddress") == 0) {
*pfn = (void *)(&cuGetProcAddress);
} else if (strcmp(symbol, "cuGetProcAddress_v2") == 0) {
*pfn = (void *)(&cuGetProcAddress_v2);
} else if (strcmp(symbol, "cuInit") == 0) {
*pfn = (void *)(&cuInit);
} else if (strcmp(symbol, "cuLaunchKernel") == 0) {
Expand Down Expand Up @@ -563,6 +580,69 @@ CUresult cuGetProcAddress(const char *symbol, void **pfn, int cudaVersion,
}


CUresult cuGetProcAddress_v2(const char *symbol, void **pfn, int cudaVersion,
cuuint64_t flags, CUdriverProcAddressQueryResult *symbolStatus)
{
/*
* cuGetProcAddress_v2() will be called before cuInit() in CUDA
* Runtime API (version >=12.0), so cuGetProcAddress_v2()
* should also serve as an entrypoint.
*
* Otherwise, real_cuGetProcAddress_v2 may be a
* NULL pointer when it is called.
*/
true_or_exit(pthread_once(&init_libnvshare_done, initialize_libnvshare) == 0);
true_or_exit(pthread_once(&init_done, initialize_client) == 0);
CUresult result = CUDA_SUCCESS;

if (real_cuGetProcAddress_v2 == NULL) return CUDA_ERROR_NOT_INITIALIZED;

/* This covers our custom "if" conditions.
* If we end up calling the real cuGetProcAddress_v2,
* it will overwrite this value.
*/
if (symbolStatus != NULL)
*symbolStatus = CU_GET_PROC_ADDRESS_SUCCESS;

if (strcmp(symbol, "cuMemAlloc") == 0) {
*pfn = (void *)(&cuMemAlloc);
} else if (strcmp(symbol, "cuMemFree") == 0) {
*pfn = (void *)(&cuMemFree);
} else if (strcmp(symbol, "cuMemGetInfo") == 0) {
*pfn = (void *)(&cuMemGetInfo);
} else if (strcmp(symbol, "cuGetProcAddress") == 0) {
*pfn = (void *)(&cuGetProcAddress);
} else if (strcmp(symbol, "cuGetProcAddress_v2") == 0) {
*pfn = (void *)(&cuGetProcAddress_v2);
} else if (strcmp(symbol, "cuInit") == 0) {
*pfn = (void *)(&cuInit);
} else if (strcmp(symbol, "cuLaunchKernel") == 0) {
*pfn = (void *)(&cuLaunchKernel);
} else if (strcmp(symbol, "cuMemcpy") == 0) {
*pfn = (void *)(&cuMemcpy);
} else if (strcmp(symbol, "cuMemcpyAsync") == 0) {
*pfn = (void *)(&cuMemcpyAsync);
} else if (strcmp(symbol, "cuMemcpyDtoH") == 0) {
*pfn = (void *)(&cuMemcpyDtoH);
} else if (strcmp(symbol, "cuMemcpyDtoHAsync") == 0) {
*pfn = (void *)(&cuMemcpyDtoHAsync);
} else if (strcmp(symbol, "cuMemcpyHtoD") == 0) {
*pfn = (void *)(&cuMemcpyHtoD);
} else if (strcmp(symbol, "cuMemcpyHtoDAsync") == 0) {
*pfn = (void *)(&cuMemcpyHtoDAsync);
} else if (strcmp(symbol, "cuMemcpyDtoD") == 0) {
*pfn = (void *)(&cuMemcpyDtoD);
} else if (strcmp(symbol, "cuMemcpyDtoDAsync") == 0) {
*pfn = (void *)(&cuMemcpyDtoDAsync);
} else {
result = real_cuGetProcAddress_v2(symbol, pfn, cudaVersion,
flags, symbolStatus);
}

return result;
}


CUresult cuMemAlloc(CUdeviceptr *dptr, size_t bytesize)
{
static int got_max_mem_size = 0;
Expand Down

0 comments on commit 000bfa0

Please sign in to comment.