[Dist] Add lazy-loading stubs for CUDART + NVRTC (CUDA 11/12/13 compatible wheels)#1821
Conversation
Build and ship libcudart_stub.so and libnvrtc_stub.so, then force TVM to link against them so the wheel does not hard-depend on libcudart.so.<major> / libnvrtc.so.<major>. This allows a single wheel to run across CUDA major versions where only libcudart/libnvrtc 12 or 13 is present.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds optional CUDA stub shared libraries (cuda_stub, cudart_stub, nvrtc_stub) and POSIX runtime-loading implementations for CUDA Runtime and NVRTC; updates CMake wiring, output targets, and install-time patchelf removal to account for the new stubs and conditional compilation flags. Changes
Sequence Diagram(s)sequenceDiagram
actor App as Application
participant Stub as "CUDA / NVRTC Stub"
participant Loader as "dlopen / dlsym"
participant Lib as "System CUDA Library\n(libcudart / libnvrtc)"
App->>Stub: call CUDA/NVRTC API (e.g., cudaMalloc / nvrtcCompileProgram)
alt first call (no handle)
Stub->>Loader: dlopen("libX.so.13") / dlopen("libX.so.12") / ...
Loader-->>Stub: handle or NULL
alt handle obtained
Stub->>Loader: dlsym(handle, "symbol")
Loader-->>Stub: function pointer
Stub->>Stub: store pointer in API struct
end
end
alt library loaded
Stub->>Lib: invoke resolved symbol(...)
Lib-->>Stub: result
Stub-->>App: return result
else library unavailable
Stub-->>App: return fallback error/result
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
No actionable comments were generated in the recent review. 🎉 🧹 Recent nitpick comments
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@src/target/stubs/cudart.cc`:
- Around line 528-539: Add a compile-time guard to prevent building this
CUDA-12-only wrapper against CUDA 11 headers by adding a static_assert that
checks CUDA_MAJOR_VERSION >= 12 with a clear error message; place this near the
top of the cudart stub file (or immediately before the cudaGraphInstantiate
wrapper) so the assertion triggers at compile time if the build headers are CUDA
11.x, and ensure it references the same CUDA header macros used elsewhere so
wrappers like TILELANG_CUDART_STUB_API cudaGraphInstantiate, GetCUDARuntimeAPI,
and MissingLibraryError are only compiled when CUDA_MAJOR_VERSION >= 12.
In `@src/target/stubs/nvrtc.cc`:
- Around line 80-90: The NVRTC API stub is missing CUBIN retrieval functions;
add members nvrtcGetCUBINSize_ and nvrtcGetCUBIN_ to the NVRTCAPI struct
(matching the decltype pattern used for nvrtcGetPTXSize_ and nvrtcGetPTX_) and
initialize them to nullptr so the stub exposes nvrtcGetCUBINSize and
nvrtcGetCUBIN symbols for CUBIN compilation compatibility with
tilelang/contrib/nvrtc.py.
🧹 Nitpick comments (2)
src/target/stubs/cudart.cc (1)
149-212:LOOKUP_REQUIREDblanks the entire API struct on any single symbol failure — includingcudaGetErrorString_.If
GetLibCudartHandle()succeeds (i.e., the.sowas loaded) but any single required symbol is missing,CreateCUDARuntimeAPIreturns a zeroedCUDARuntimeAPI{}, discarding the previously resolvedcudaGetErrorString_pointer (line 164-165). TheFallbackCudaErrorStringfallback covers this, so it won't crash, but the error message will say "libcudart not found" when in reality libcudart was found but is missing symbols — potentially confusing during debugging.Consider adjusting the fallback message or preserving
cudaGetErrorString_even when other lookups fail.src/target/stubs/nvrtc.cc (1)
156-162: Inconsistent defensive zeroing of output parameters vs. cudart.cc.In
cudart.cc, wrappers likecudaGetDevice,cudaGetDeviceCount,cudaMalloc, etc., zero their output pointers in the missing-library path before returning the error. Here,nvrtcVersiondoesn't zero*major/*minor,nvrtcCreateProgramdoesn't null*prog,nvrtcGetPTXSizedoesn't zero*ptxSizeRet, andnvrtcGetProgramLogSizedoesn't zero*logSizeRet.For consistency and defensive robustness:
Example fix for nvrtcVersion
TILELANG_NVRTC_STUB_API nvrtcResult nvrtcVersion(int *major, int *minor) { auto *api = GetNVRTCAPI(); if (api->nvrtcVersion_ == nullptr) { + if (major != nullptr) *major = 0; + if (minor != nullptr) *minor = 0; return MissingLibraryError(); } return api->nvrtcVersion_(major, minor); }Apply similarly to
nvrtcCreateProgram(*prog),nvrtcGetPTXSize(*ptxSizeRet), andnvrtcGetProgramLogSize(*logSizeRet).Also applies to: 164-174, 194-201, 211-218
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/target/stubs/cudart.cc`:
- Around line 25-33: The file defines Windows visibility macros
(TILELANG_CUDART_STUB_API) but unconditionally includes <dlfcn.h> and uses
dlopen/dlsym/dlerror, which prevents compilation on Windows; either remove the
Windows branch or add a Windows dynamic-load path: wrap the `#include` <dlfcn.h>
and all uses of dlopen/dlsym/dlerror in an `#ifndef` _WIN32 guard and implement
the corresponding Windows equivalents
(LoadLibrary/GetProcAddress/FormatMessage/GetLastError) in the _WIN32 branch, or
if Windows is not supported, delete the _WIN32/#else visibility block so only
the POSIX visibility macro and dlopen-based logic remain; update any references
in the file that call dlopen/dlsym/dlerror accordingly and keep
TILELANG_CUDART_STUB_API unchanged for exported symbols.
🧹 Nitpick comments (3)
src/target/stubs/nvrtc.cc (1)
23-35: Windows export macros defined but implementation is POSIX-only.
dlfcn.h,dlopen,dlsym, anddlerrorare used unconditionally, so this file won't compile on Windows despite the__declspec(dllexport/dllimport)branch. If Windows support isn't planned, consider simplifying the export macro to the GCC visibility attribute only, or add a compile guard /#errorfor unsupported platforms to make the intent explicit.src/target/stubs/cudart.cc (2)
156-209: All-or-nothingLOOKUP_REQUIREDnullifies the entire API if any single symbol is missing.If
dlopenresolves to an older or stripped libcudart that lacks even one "required" symbol (e.g.,cudaGraphInstantiate,cudaIpcGetMemHandle),CreateCUDARuntimeAPIreturns a defaultCUDARuntimeAPI{}with all pointers null — including basic operations likecudaMallocandcudaMemcpy.Consider splitting symbols into tiers: core (memory, streams, events) vs. optional (graph capture, IPC,
cudaFuncSetAttribute). Core symbols would remain fatal; others would gracefully degrade toMissingLibraryErrorindividually, similar to howcudaFuncSetAttribute_is already handled on line 206.
132-135: No diagnostic whenTryLoadLibCudartfails.When all
dlopenattempts fail,GetLibCudartHandlesilently returnsnullptrand every subsequent API call returnscudaErrorUnknown. A one-timefprintf(stderr, ...)or similar diagnostic on load failure would significantly ease debugging in environments where libcudart is unexpectedly absent.
Add compile-time checks for CUDA 12+ runtime headers to avoid signature mismatches (e.g. cudaGraphInstantiate) when the stub is built with older toolkits.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/target/stubs/cudart.cc`:
- Around line 552-562: The three-arg cudaGraphInstantiate wrapper risks calling
a CUDA-11 five-arg symbol from an unversioned libcudart.so; update the runtime
loader (GetCUDARuntimeAPI or wherever graph API pointers are populated) to
resolve cudaRuntimeGetVersion and require version >= 12000 before assigning the
CUDA-12 three-arg function pointers (e.g., api->cudaGraphInstantiate_); if the
version check fails, leave those pointers null so cudaGraphInstantiate returns
MissingLibraryError, or alternatively drop resolving the unversioned
libcudart.so fallback entirely—implement one of these two fixes and ensure
cudaGraphInstantiate uses the null-check path already present.
🧹 Nitpick comments (1)
src/target/stubs/cudart.cc (1)
167-230: All-or-nothing symbol resolution makes the stub fragile for optional/uncommon APIs.If any
LOOKUP_REQUIREDsymbol fails to resolve (e.g.,cudaIpcGetMemHandleorcudaGraphInstantiateon a minimal/stripped CUDA runtime),CreateCUDARuntimeAPIreturns a fully zeroed struct — causing every wrapper (including fundamental ones likecudaMalloc,cudaMemcpy) to fail withcudaErrorUnknown.Consider splitting into tiers: resolve core APIs as required, and treat graph/IPC/capture APIs as optional (like
cudaFuncSetAttributealready is). Each wrapper already null-checks its own pointer, so optional symbols degrade gracefully per-function.♻️ Suggested approach
- LOOKUP_REQUIRED(cudaStreamBeginCapture) - LOOKUP_REQUIRED(cudaStreamEndCapture) - LOOKUP_REQUIRED(cudaGraphInstantiate) - LOOKUP_REQUIRED(cudaGraphLaunch) - LOOKUP_REQUIRED(cudaGraphDestroy) - LOOKUP_REQUIRED(cudaGraphExecDestroy) - LOOKUP_REQUIRED(cudaIpcGetMemHandle) - LOOKUP_REQUIRED(cudaIpcOpenMemHandle) - LOOKUP_REQUIRED(cudaIpcCloseMemHandle) + // Optional: CUDA Graph APIs (may be absent in minimal runtimes) + api.cudaStreamBeginCapture_ = GetSymbol<decltype(api.cudaStreamBeginCapture_)>(handle, "cudaStreamBeginCapture"); + api.cudaStreamEndCapture_ = GetSymbol<decltype(api.cudaStreamEndCapture_)>(handle, "cudaStreamEndCapture"); + api.cudaGraphInstantiate_ = GetSymbol<decltype(api.cudaGraphInstantiate_)>(handle, "cudaGraphInstantiate"); + api.cudaGraphLaunch_ = GetSymbol<decltype(api.cudaGraphLaunch_)>(handle, "cudaGraphLaunch"); + api.cudaGraphDestroy_ = GetSymbol<decltype(api.cudaGraphDestroy_)>(handle, "cudaGraphDestroy"); + api.cudaGraphExecDestroy_ = GetSymbol<decltype(api.cudaGraphExecDestroy_)>(handle, "cudaGraphExecDestroy"); + + // Optional: IPC APIs + api.cudaIpcGetMemHandle_ = GetSymbol<decltype(api.cudaIpcGetMemHandle_)>(handle, "cudaIpcGetMemHandle"); + api.cudaIpcOpenMemHandle_ = GetSymbol<decltype(api.cudaIpcOpenMemHandle_)>(handle, "cudaIpcOpenMemHandle"); + api.cudaIpcCloseMemHandle_ = GetSymbol<decltype(api.cudaIpcCloseMemHandle_)>(handle, "cudaIpcCloseMemHandle");
- Adjusted comments in CMakeLists.txt to reflect the correct major versions for libcudart and NVRTC. - Modified cudart.cc to support CUDA 11.x, including changes to the function pointer typedefs and the GraphInstantiate function to handle both legacy and new signatures. - Updated nvrtc.cc to include support for NVRTC versions 11.0 to 13.x, ensuring compatibility across different CUDA environments.
- Introduced an option to use POSIX dlopen-based CUDA stub libraries for better compatibility across different CUDA Toolkit versions and CPU-only machines. - Updated CMakeLists.txt to conditionally enable CUDA stubs based on the platform. - Added compile-time checks in CUDA, CUDART, and NVRTC stub implementations to ensure they are only built on POSIX systems, providing clear error messages for Windows users. - Enhanced documentation within the CMake configuration for clarity on the use of CUDA stubs.
Wheels that link directly against versioned CUDA libraries (libcudart.so., libnvrtc.so.) often break when run in an environment with a different CUDA
Toolkit major (e.g. build with 12.x, run with 13.x).
versioned SONAMEs.
What’s in this PR
Add src/target/stubs/cudart.cc: a cudart_stub shared library that exports the CUDA Runtime API subset used by TVM/TileLang and lazily loads the real libcudart via
Add src/target/stubs/nvrtc.cc: an nvrtc_stub shared library that exports the NVRTC C API subset used by TVM/TileLang and lazily loads the real libnvrtc at runtime.
Update CMakeLists.txt to build cudart_stub and nvrtc_stub when USE_CUDA is enabled, and force TVM’s cached CUDA_CUDART_LIBRARY / CUDA_NVRTC_LIBRARY variables to
point to these stub targets before add_subdirectory(tvm).
Add cudart_stub / nvrtc_stub to TILELANG_OUTPUT_TARGETS so existing install/RPATH/patchelf handling is applied consistently.
CUDA 11/12 API compatibility
Testing
Summary by CodeRabbit
New Features
Chores