Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 111 additions & 30 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,21 @@ foreach(BACKEND IN LISTS TILELANG_BACKENDS)
endforeach()

set(PREBUILD_CYTHON ON)

# CUDA stub libraries (cuda/cudart/nvrtc) are used to build wheels that can run
# across different CUDA Toolkit major versions and/or on CPU-only machines by
# avoiding hard DT_NEEDED dependencies on versioned CUDA SONAMEs.
#
# These stubs are currently POSIX-only (dlopen/dlsym via <dlfcn.h>).
if(WIN32 AND NOT CYGWIN)
set(_TILELANG_USE_CUDA_STUBS_DEFAULT OFF)
else()
set(_TILELANG_USE_CUDA_STUBS_DEFAULT ON)
endif()
option(TILELANG_USE_CUDA_STUBS
"Use POSIX dlopen-based CUDA stub libraries (cuda/cudart/nvrtc) for portable wheels"
${_TILELANG_USE_CUDA_STUBS_DEFAULT})
unset(_TILELANG_USE_CUDA_STUBS_DEFAULT)
# Configs end

include(cmake/load_tvm.cmake)
Expand Down Expand Up @@ -211,31 +226,94 @@ elseif(USE_CUDA)
# Set `USE_CUDA=/usr/local/cuda-x.y`
cmake_path(GET CUDAToolkit_BIN_DIR PARENT_PATH USE_CUDA)

# ============================================================================
# CUDA Driver Stub Library (libcuda_stub.so)
# ============================================================================
# This library provides drop-in replacements for CUDA driver API functions.
# Instead of linking directly against libcuda.so (which would fail on
# CPU-only machines), we link against this stub which loads libcuda.so
# lazily at runtime on first API call.
#
# The stub exports global C functions matching the CUDA driver API:
# - cuModuleLoadData, cuLaunchKernel, cuMemsetD32_v2, etc.
# These can be called directly without any wrapper macros.
# ============================================================================
add_library(cuda_stub SHARED src/target/stubs/cuda.cc)
target_include_directories(cuda_stub PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
# Export symbols with visibility="default" when building
target_compile_definitions(cuda_stub PRIVATE TILELANG_CUDA_STUB_EXPORTS)
# Use dlopen/dlsym for runtime library loading
target_link_libraries(cuda_stub PRIVATE ${CMAKE_DL_LIBS})
set_target_properties(cuda_stub PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
# Use consistent naming
OUTPUT_NAME "cuda_stub"
)
if(TILELANG_USE_CUDA_STUBS)
if(WIN32 AND NOT CYGWIN)
message(FATAL_ERROR "TILELANG_USE_CUDA_STUBS=ON is not supported on Windows. "
"Please configure with -DTILELANG_USE_CUDA_STUBS=OFF.")
endif()

# ============================================================================
# CUDA Driver Stub Library (libcuda_stub.so)
# ============================================================================
# This library provides drop-in replacements for CUDA driver API functions.
# Instead of linking directly against libcuda.so (which would fail on
# CPU-only machines), we link against this stub which loads libcuda.so
# lazily at runtime on first API call.
#
# The stub exports global C functions matching the CUDA driver API:
# - cuModuleLoadData, cuLaunchKernel, cuMemsetD32_v2, etc.
# These can be called directly without any wrapper macros.
# ============================================================================
add_library(cuda_stub SHARED src/target/stubs/cuda.cc)
target_include_directories(cuda_stub PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
# Export symbols with visibility="default" when building
target_compile_definitions(cuda_stub PRIVATE TILELANG_CUDA_STUB_EXPORTS)
# Use dlopen/dlsym for runtime library loading
target_link_libraries(cuda_stub PRIVATE ${CMAKE_DL_LIBS})
set_target_properties(cuda_stub PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
# Use consistent naming
OUTPUT_NAME "cuda_stub"
)

# ============================================================================
# CUDA Runtime Stub Library (libcudart_stub.so)
# ============================================================================
# libcudart's SONAME includes its major version (e.g. libcudart.so.11.0 / .12 / .13).
# Link against this stub instead of the real libcudart so a single wheel can
# run in environments that provide different libcudart major versions.
#
# The stub exports a minimal set of CUDA Runtime API entrypoints used by TVM
# and lazily loads libcudart at runtime on first API call.
# ============================================================================
add_library(cudart_stub SHARED src/target/stubs/cudart.cc)
target_include_directories(cudart_stub PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
target_compile_definitions(cudart_stub PRIVATE TILELANG_CUDART_STUB_EXPORTS)
target_link_libraries(cudart_stub PRIVATE ${CMAKE_DL_LIBS})
set_target_properties(cudart_stub PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
OUTPUT_NAME "cudart_stub"
)

# Make TVM link against our CUDA Runtime stub instead of the real libcudart.
#
# NOTE: TVM's `find_cuda()` calls `find_library(CUDA_CUDART_LIBRARY cudart ...)`.
# `find_library()` will not override an already-cached variable, so setting it
# here ensures TVM doesn't record a DT_NEEDED on `libcudart.so.<major>`.
set(CUDA_CUDART_LIBRARY cudart_stub CACHE STRING "CUDART library to link against" FORCE)

# ============================================================================
# NVRTC Stub Library (libnvrtc_stub.so)
# ============================================================================
# NVRTC's SONAME includes its major version (e.g. libnvrtc.so.11.2 / .12 / .13).
# Link against this stub instead of the real NVRTC library so a single wheel
# can run in environments that provide different NVRTC major versions.
#
# The stub exports a minimal set of NVRTC C API entrypoints used by TVM and
# lazily loads libnvrtc at runtime on first API call.
# ============================================================================
add_library(nvrtc_stub SHARED src/target/stubs/nvrtc.cc)
target_include_directories(nvrtc_stub PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
target_compile_definitions(nvrtc_stub PRIVATE TILELANG_NVRTC_STUB_EXPORTS)
target_link_libraries(nvrtc_stub PRIVATE ${CMAKE_DL_LIBS})
set_target_properties(nvrtc_stub PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
OUTPUT_NAME "nvrtc_stub"
)

# Make TVM link against our NVRTC stub instead of the real libnvrtc.
#
# NOTE: TVM's `find_cuda()` calls `find_library(CUDA_NVRTC_LIBRARY nvrtc ...)`.
# `find_library()` will not override an already-cached variable, so setting it
# here ensures TVM doesn't record a DT_NEEDED on `libnvrtc.so.<major>`.
set(CUDA_NVRTC_LIBRARY nvrtc_stub CACHE STRING "NVRTC library to link against" FORCE)
endif()

file(GLOB TILE_LANG_CUDA_SRCS
src/runtime/runtime.cc
Expand Down Expand Up @@ -344,15 +422,15 @@ set(TILELANG_OUTPUT_TARGETS
tvm_runtime
)

if(USE_CUDA)
if(USE_CUDA AND TILELANG_USE_CUDA_STUBS)
# Link against CUDA stub library instead of libcuda.so
# This enables lazy loading of libcuda.so at runtime, allowing
# `import tilelang` to succeed on CPU-only machines.
foreach(target IN LISTS TILELANG_OUTPUT_TARGETS)
target_link_libraries(${target} PUBLIC cuda_stub)
endforeach()
# Include CUDA stub in output targets for RPATH configuration
list(APPEND TILELANG_OUTPUT_TARGETS cuda_stub)
# Include CUDA stubs in output targets for RPATH configuration
list(APPEND TILELANG_OUTPUT_TARGETS cuda_stub cudart_stub nvrtc_stub)
endif()

unset(PATCHELF_EXECUTABLE CACHE)
Expand Down Expand Up @@ -388,14 +466,17 @@ foreach(target IN LISTS TILELANG_OUTPUT_TARGETS)
endforeach()

# Exclude libcuda.so to allow importing on a CPU-only machine
if(USE_CUDA AND PATCHELF_EXECUTABLE)
if(USE_CUDA AND TILELANG_USE_CUDA_STUBS AND PATCHELF_EXECUTABLE)
# Run `patchelf` on built libraries to remove libcuda.so dependency.
# Use `install(CODE ...)` instead of `add_custom_command(... POST_BUILD ...)`
# to avoid race conditions during linking.
foreach(target IN LISTS TILELANG_OUTPUT_TARGETS)
install(CODE "
execute_process(
COMMAND ${PATCHELF_EXECUTABLE} --remove-needed libcuda.so.1 --remove-needed libcuda.so \"$<TARGET_FILE:${target}>\"
COMMAND ${PATCHELF_EXECUTABLE}
--remove-needed libcuda.so.1
--remove-needed libcuda.so
\"$<TARGET_FILE:${target}>\"
WORKING_DIRECTORY \"${CMAKE_INSTALL_PREFIX}\"
RESULT_VARIABLE patchelf_result
)
Expand Down
6 changes: 6 additions & 0 deletions src/target/stubs/cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@

#include "cuda.h"

#if defined(_WIN32) && !defined(__CYGWIN__)
#error "cuda_stub is currently POSIX-only (requires <dlfcn.h> / dlopen). " \
"On Windows, build TileLang from source with -DTILELANG_USE_CUDA_STUBS=OFF " \
"to link against the real CUDA libraries."
#endif

#include <dlfcn.h>
#include <stdexcept>
#include <string>
Expand Down
Loading
Loading