Skip to content

Commit

Permalink
* hook cuGetProcAddress_v2
Browse files Browse the repository at this point in the history
* cuGetProcAddress and _v2 serve as lib entrypoint
  • Loading branch information
pokerfaceSad committed Dec 4, 2023
1 parent 67bed3f commit 70df1b4
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 2 deletions.
17 changes: 17 additions & 0 deletions src/cuda_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ typedef enum CUmemAttach_flags_enum {
CU_MEM_ATTACH_GLOBAL = 0x1
} CUmemAttach_flags;

/**
* Flags to indicate search status. For more details see ::cuGetProcAddress
*/
typedef enum CUdriverProcAddressQueryResult_enum
{
CU_GET_PROC_ADDRESS_SUCCESS = 0, /**< Symbol was succesfully found */
CU_GET_PROC_ADDRESS_SYMBOL_NOT_FOUND = 1, /**< Symbol was not found in search */
CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT = 2 /**< Symbol was found but version supplied was not sufficient */
} CUdriverProcAddressQueryResult;

typedef enum nvmlReturn_t_enum {
NVML_SUCCESS = 0,
NVML_ERROR_UNKNOWN = 999
Expand All @@ -93,6 +103,9 @@ typedef struct nvmlUtilization_st {
/* typedefs for CUDA functions, to make hooking code cleaner */
typedef CUresult (*cuGetProcAddress_func)(const char *symbol, void **pfn,
int cudaVersion, cuuint64_t flags);
typedef CUresult (*cuGetProcAddress_v2_func)(const char *symbol, void **pfn,
int cudaVersion, cuuint64_t flags,
CUdriverProcAddressQueryResult *symbolStatus);
typedef CUresult (*cuMemAllocManaged_func)(CUdeviceptr *dptr, size_t bytesize,
unsigned int flags);
typedef CUresult (*cuMemFree_func)(CUdeviceptr dptr);
Expand Down Expand Up @@ -135,6 +148,9 @@ typedef nvmlReturn_t (*nvmlDeviceGetHandleByIndex_func)(unsigned int index,
/* Hooked CUDA functions */
extern CUresult cuGetProcAddress(const char *symbol, void **pfn,
int cudaVersion, cuuint64_t flags);
extern CUresult cuGetProcAddress_v2(const char *symbol, void **pfn,
int cudaVersion, cuuint64_t flags,
CUdriverProcAddressQueryResult *symbolStatus);
extern CUresult cuMemGetInfo(size_t *free, size_t *total);
extern CUresult cuMemAlloc(CUdeviceptr *dptr, size_t bytesize);
extern CUresult cuMemFree(CUdeviceptr dptr);
Expand Down Expand Up @@ -163,6 +179,7 @@ extern CUresult cuMemcpyDtoDAsync(CUdeviceptr dstDevice,

/* Real CUDA functions */
extern cuGetProcAddress_func real_cuGetProcAddress;
extern cuGetProcAddress_v2_func real_cuGetProcAddress_v2;
extern cuMemAllocManaged_func real_cuMemAllocManaged;
extern cuMemFree_func real_cuMemFree;
extern cuMemGetInfo_func real_cuMemGetInfo;
Expand Down
66 changes: 64 additions & 2 deletions src/hook.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ cuMemcpyHtoDAsync_func real_cuMemcpyHtoDAsync = NULL;
cuMemcpyDtoD_func real_cuMemcpyDtoD = NULL;
cuMemcpyDtoDAsync_func real_cuMemcpyDtoDAsync = NULL;
cuGetProcAddress_func real_cuGetProcAddress = NULL;
cuGetProcAddress_v2_func real_cuGetProcAddress_v2 = NULL;
cuMemAllocManaged_func real_cuMemAllocManaged = NULL;
cuMemFree_func real_cuMemFree = NULL;
cuMemGetInfo_func real_cuMemGetInfo = NULL;
Expand Down Expand Up @@ -93,6 +94,10 @@ struct cuda_mem_allocation {
/* Linked list that holds all memory allocations of current application. */
struct cuda_mem_allocation *cuda_allocation_list = NULL;

/* Establishes init step that will be executed only once in a cuda process */
static pthread_once_t init_libnvshare_done = PTHREAD_ONCE_INIT;
static pthread_once_t init_done = PTHREAD_ONCE_INIT;

/* Load real CUDA {Driver API, NVML} functions and bootstrap auxiliary stuff. */
static void bootstrap_cuda(void)
{
Expand Down Expand Up @@ -171,6 +176,11 @@ static void bootstrap_cuda(void)
* Runtime <11.3.
*/
log_debug("%s", error);
error = dlerror();
real_cuGetProcAddress_v2 = (cuGetProcAddress_v2_func)
real_dlsym_225(cuda_handle, CUDA_SYMBOL_STRING(cuGetProcAddress_v2));
if (error != NULL)
log_debug("%s", error);
real_cuMemGetInfo = (cuMemGetInfo_func)
real_dlsym_225(cuda_handle,CUDA_SYMBOL_STRING(cuMemGetInfo));
error = dlerror();
Expand Down Expand Up @@ -426,6 +436,8 @@ void *dlsym_225(void *handle, const char *symbol)
return (void *)(&cuMemGetInfo);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuGetProcAddress)) == 0) {
return (void *)(&cuGetProcAddress);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuGetProcAddress_v2)) == 0) {
return (void *)(&cuGetProcAddress_v2);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuInit)) == 0) {
return (void *)(&cuInit);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuLaunchKernel)) == 0) {
Expand Down Expand Up @@ -463,6 +475,8 @@ void *dlsym_234(void *handle, const char *symbol)
return (void *)(&cuMemGetInfo);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuGetProcAddress)) == 0) {
return (void *)(&cuGetProcAddress);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuGetProcAddress_v2)) == 0) {
return (void *)(&cuGetProcAddress_v2);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuInit)) == 0) {
return (void *)(&cuInit);
} else if (strcmp(symbol, CUDA_SYMBOL_STRING(cuLaunchKernel)) == 0) {
Expand Down Expand Up @@ -510,6 +524,8 @@ void *dlsym_234(void *handle, const char *symbol)
CUresult cuGetProcAddress(const char *symbol, void **pfn, int cudaVersion,
cuuint64_t flags)
{
true_or_exit(pthread_once(&init_libnvshare_done, initialize_libnvshare) == 0);
true_or_exit(pthread_once(&init_done, initialize_client) == 0);
CUresult result = CUDA_SUCCESS;

if (real_cuGetProcAddress == NULL) return CUDA_ERROR_NOT_INITIALIZED;
Expand All @@ -522,6 +538,8 @@ CUresult cuGetProcAddress(const char *symbol, void **pfn, int cudaVersion,
*pfn = (void *)(&cuMemGetInfo);
} else if (strcmp(symbol, "cuGetProcAddress") == 0) {
*pfn = (void *)(&cuGetProcAddress);
} else if (strcmp(symbol, "cuGetProcAddress_v2") == 0) {
*pfn = (void *)(&cuGetProcAddress_v2);
} else if (strcmp(symbol, "cuInit") == 0) {
*pfn = (void *)(&cuInit);
} else if (strcmp(symbol, "cuLaunchKernel") == 0) {
Expand Down Expand Up @@ -549,6 +567,52 @@ CUresult cuGetProcAddress(const char *symbol, void **pfn, int cudaVersion,
return result;
}

CUresult cuGetProcAddress_v2(const char *symbol, void **pfn, int cudaVersion,
cuuint64_t flags, CUdriverProcAddressQueryResult *symbolStatus)
{
true_or_exit(pthread_once(&init_libnvshare_done, initialize_libnvshare) == 0);
true_or_exit(pthread_once(&init_done, initialize_client) == 0);
CUresult result = CUDA_SUCCESS;

if (real_cuGetProcAddress == NULL) return CUDA_ERROR_NOT_INITIALIZED;

if (strcmp(symbol, "cuMemAlloc") == 0) {
*pfn = (void *)(&cuMemAlloc);
} else if (strcmp(symbol, "cuMemFree") == 0) {
*pfn = (void *)(&cuMemFree);
} else if (strcmp(symbol, "cuMemGetInfo") == 0) {
*pfn = (void *)(&cuMemGetInfo);
} else if (strcmp(symbol, "cuGetProcAddress") == 0) {
*pfn = (void *)(&cuGetProcAddress);
} else if (strcmp(symbol, "cuGetProcAddress_v2") == 0) {
*pfn = (void *)(&cuGetProcAddress_v2);
} else if (strcmp(symbol, "cuInit") == 0) {
*pfn = (void *)(&cuInit);
} else if (strcmp(symbol, "cuLaunchKernel") == 0) {
*pfn = (void *)(&cuLaunchKernel);
} else if (strcmp(symbol, "cuMemcpy") == 0) {
*pfn = (void *)(&cuMemcpy);
} else if (strcmp(symbol, "cuMemcpyAsync") == 0) {
*pfn = (void *)(&cuMemcpyAsync);
} else if (strcmp(symbol, "cuMemcpyDtoH") == 0) {
*pfn = (void *)(&cuMemcpyDtoH);
} else if (strcmp(symbol, "cuMemcpyDtoHAsync") == 0) {
*pfn = (void *)(&cuMemcpyDtoHAsync);
} else if (strcmp(symbol, "cuMemcpyHtoD") == 0) {
*pfn = (void *)(&cuMemcpyHtoD);
} else if (strcmp(symbol, "cuMemcpyHtoDAsync") == 0) {
*pfn = (void *)(&cuMemcpyHtoDAsync);
} else if (strcmp(symbol, "cuMemcpyDtoD") == 0) {
*pfn = (void *)(&cuMemcpyDtoD);
} else if (strcmp(symbol, "cuMemcpyDtoDAsync") == 0) {
*pfn = (void *)(&cuMemcpyDtoDAsync);
} else {
result = real_cuGetProcAddress_v2(symbol, pfn, cudaVersion, flags, symbolStatus);
}

return result;
}


CUresult cuMemAlloc(CUdeviceptr *dptr, size_t bytesize)
{
Expand Down Expand Up @@ -659,8 +723,6 @@ CUresult cuMemGetInfo(size_t *free, size_t *total)
CUresult cuInit(unsigned int flags)
{
CUresult result = CUDA_SUCCESS;
static pthread_once_t init_libnvshare_done = PTHREAD_ONCE_INIT;
static pthread_once_t init_done = PTHREAD_ONCE_INIT;

true_or_exit(pthread_once(&init_libnvshare_done, initialize_libnvshare) == 0);
true_or_exit(pthread_once(&init_done, initialize_client) == 0);
Expand Down

0 comments on commit 70df1b4

Please sign in to comment.