From daf072c3efce753065ec42a3e874154554ca7d02 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk <21087696+oleksandr-pavlyk@users.noreply.github.com> Date: Thu, 29 May 2025 13:23:21 -0500 Subject: [PATCH] Close gh-4845 by ensuring CTK minor version compatibility nvJitLink.h header file provides unversioned inline functions which inject versionsed symbols. For example: ``` static inline nvJitLinkResult nvJitLinkCreate( nvJitLinkHandle *handle, uint32_t numOptions, const char **options) { return __nvJitLinkCreate_12_8 (handle, numOptions, options); } ``` c/parallel uses unversioned symbols, but due to inlining, the object files of each TU in c/parallel contains versioned symbols, and hence the final shared library depends on the specific CTK version it was built with. The nvJitLink.so.12 does provide unversioned symbols too, which map to versioned symbols at run-time. ``` (nvbench) opavlyk@ee09c48-lcedt:~/repos/cccl$ nm -D /usr/local/cuda/lib64/libnvJitLink.so.12 | grep nvJitLinkCreate 00000000004ba560 T nvJitLinkCreate@@libnvJitLink.so.12 00000000004ba660 T __nvJitLinkCreate_12_0@@libnvJitLink.so.12 00000000004ba670 T __nvJitLinkCreate_12_1@@libnvJitLink.so.12 00000000004ba680 T __nvJitLinkCreate_12_2@@libnvJitLink.so.12 00000000004ba690 T __nvJitLinkCreate_12_3@@libnvJitLink.so.12 00000000004ba6a0 T __nvJitLinkCreate_12_4@@libnvJitLink.so.12 00000000004ba6b0 T __nvJitLinkCreate_12_5@@libnvJitLink.so.12 00000000004ba6c0 T __nvJitLinkCreate_12_6@@libnvJitLink.so.12 00000000004ba6d0 T __nvJitLinkCreate_12_7@@libnvJitLink.so.12 00000000004ba6e0 T __nvJitLinkCreate_12_8@@libnvJitLink.so.12 ``` This change replaces direct uses of `#include ` with `#include ` which defines `NVJITLINK_NO_INLINE` before including and simply declares unversioned symbols. Linking with subsequently result in using the dynamic unversioned symbols provided by nvJitLink.so.12 library guaranteeign CTK minor version compatibility. I verified that gh-4845 is resolved with this change by installing cuda-parallel wheel from this PR after torch built with CTK 12.8 was installed. ``` (pathfinder-trouble) opavlyk@ee09c48-lcedt:~$ python Python 3.12.10 | packaged by conda-forge | (main, Apr 10 2025, 22:21:13) [GCC 13.3.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> >>> import torch >>> import cuda.parallel.experimental.algorithms as algorithms >>> >>> import ctypes >>> >>> lib = ctypes.cdll.LoadLibrary("libnvJitLink.so.12") >>> lib >>> fn = lib.nvJitLinkVersion >>> fn.restype = ctypes.c_int >>> fn.argtypes = [ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int)] >>> maj = ctypes.c_int(0) >>> min = ctypes.c_int(0) >>> >>> >>> fn(maj, min) 0 >>> maj, min (c_int(12), c_int(8)) >>> quit() ``` --- c/parallel/src/nvrtc/command_list.h | 2 +- c/parallel/src/nvrtc/nvjitlink_helper.h | 23 +++++++++++++++++++++++ c/parallel/src/util/context.h | 3 ++- c/parallel/src/util/errors.h | 3 ++- 4 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 c/parallel/src/nvrtc/nvjitlink_helper.h diff --git a/c/parallel/src/nvrtc/command_list.h b/c/parallel/src/nvrtc/command_list.h index a5427649a62..8ee7ec6ae9f 100644 --- a/c/parallel/src/nvrtc/command_list.h +++ b/c/parallel/src/nvrtc/command_list.h @@ -19,9 +19,9 @@ #include #include -#include #include +#include #include struct nvrtc_ptx diff --git a/c/parallel/src/nvrtc/nvjitlink_helper.h b/c/parallel/src/nvrtc/nvjitlink_helper.h new file mode 100644 index 00000000000..6c69c55df15 --- /dev/null +++ b/c/parallel/src/nvrtc/nvjitlink_helper.h @@ -0,0 +1,23 @@ +#pragma once + +#define NVJITLINK_NO_INLINE +#include +#undef NVJITLINK_NO_INLINE + +// declare unversioned functions + +extern "C" { +nvJitLinkResult nvJitLinkCreate(nvJitLinkHandle*, uint32_t, const char**); +nvJitLinkResult nvJitLinkDestroy(nvJitLinkHandle*); +nvJitLinkResult nvJitLinkAddData(nvJitLinkHandle, nvJitLinkInputType, const void*, size_t, const char*); +nvJitLinkResult nvJitLinkAddFile(nvJitLinkHandle, nvJitLinkInputType, const char*); +nvJitLinkResult nvJitLinkComplete(nvJitLinkHandle); +nvJitLinkResult nvJitLinkGetLinkedCubinSize(nvJitLinkHandle, size_t*); +nvJitLinkResult nvJitLinkGetLinkedCubin(nvJitLinkHandle, void*); +nvJitLinkResult nvJitLinkGetLinkedPtxSize(nvJitLinkHandle, size_t*); +nvJitLinkResult nvJitLinkGetLinkedPtx(nvJitLinkHandle, char*); +nvJitLinkResult nvJitLinkGetErrorLogSize(nvJitLinkHandle, size_t*); +nvJitLinkResult nvJitLinkGetErrorLog(nvJitLinkHandle, char*); +nvJitLinkResult nvJitLinkGetInfoLogSize(nvJitLinkHandle, size_t*); +nvJitLinkResult nvJitLinkGetInfoLog(nvJitLinkHandle, char*); +} diff --git a/c/parallel/src/util/context.h b/c/parallel/src/util/context.h index 9038527f5a8..596f273510a 100644 --- a/c/parallel/src/util/context.h +++ b/c/parallel/src/util/context.h @@ -11,7 +11,8 @@ #pragma once #include -#include #include +#include + bool try_push_context(); diff --git a/c/parallel/src/util/errors.h b/c/parallel/src/util/errors.h index 2968b87f6da..7b5cb07c5c6 100644 --- a/c/parallel/src/util/errors.h +++ b/c/parallel/src/util/errors.h @@ -11,9 +11,10 @@ #pragma once #include -#include #include +#include + void check(nvrtcResult result); void check(CUresult result); void check(nvJitLinkResult result);