Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 32 additions & 30 deletions .github/workflows/dist.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,22 +115,14 @@ jobs:
strategy:
matrix:
target:
- { runner: ubuntu-latest, toolkit: "CUDA-12.8" }
- { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8" }
- { runner: ubuntu-latest, toolkit: "Nightly-CUDA-13.0" }
- { runner: ubuntu-24.04-arm, toolkit: "Nightly-CUDA-13.0" }
- { runner: ubuntu-latest, toolkit: "CUDA-12.8", test_backends: "cu118 cu130" }
- { runner: ubuntu-24.04-arm, toolkit: "CUDA-12.8", test_backends: "cu126 cu130" }
- { runner: macos-latest, toolkit: "Metal" }
python-version:
# Build wheels for different Python ABIs
- "3.9"
# - "3.14t" # let user to build from source for now
# TODO: Add cp315-abi3.abi3t after PEP 803
include:
# map build version to test version
# Python 3.9 implicitly restrict torch version (e.g. on arm64).
# This test rely on torch to provide libnvrtc, so we need a rather new torch.
- { python-version: "3.9", test-python-version: "3.12" }
# - { python-version: "3.14t", test-python-version: "3.14t" }
fail-fast: false
timeout-minutes: 120
runs-on: ${{ matrix.target.runner }}
Expand Down Expand Up @@ -171,23 +163,11 @@ jobs:
CUDA_VERSION_MAJMIN="$(echo ${CUDA_VERSION} | cut -d '.' -f-2)"
CUDA_VERSION_MAJMIN_NODOT="${CUDA_VERSION_MAJMIN//./}"
echo "CUDA_VERSION=${CUDA_VERSION}" | tee -a "${GITHUB_ENV}"
if [[ "${{ matrix.target.toolkit }}" == "Nightly-"* ]]; then
# Use torch nightly builds
export UV_INDEX="https://download.pytorch.org/whl/nightly/cu${CUDA_VERSION_MAJMIN_NODOT}"
else
export UV_INDEX="https://download.pytorch.org/whl/cu${CUDA_VERSION_MAJMIN_NODOT}"
echo "UV_TORCH_BACKEND=cu${CUDA_VERSION_MAJMIN_NODOT}" | tee -a "${GITHUB_ENV}"
fi
echo "UV_INDEX=${UV_INDEX}" | tee -a "${GITHUB_ENV}"
echo "UV_TORCH_BACKEND=cu${CUDA_VERSION_MAJMIN_NODOT}" | tee -a "${GITHUB_ENV}"
fi

if [[ "${{ env.IS_RELEASE }}" == "true" ]]; then
if [[ "${{ matrix.target.toolkit }}" == "Nightly-"* ]]; then
# Avoid using same file name for different toolkit.
echo "NO_GIT_VERSION=ON" | tee -a "${GITHUB_ENV}"
else
echo "NO_VERSION_LABEL=ON" | tee -a "${GITHUB_ENV}"
fi
echo "NO_VERSION_LABEL=ON" | tee -a "${GITHUB_ENV}"
fi

if [[ "${{ runner.os }}" == "Linux" ]]; then
Expand All @@ -206,7 +186,7 @@ jobs:
id: setup-uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.test-python-version }}
python-version: ${{ matrix.python-version }}
activate-environment: true

- name: Test built wheels
Expand All @@ -221,11 +201,6 @@ jobs:
uv venv test-venv
source test-venv/bin/activate

uv pip install --upgrade pip setuptools wheel
if [[ "${UV_INDEX}" == *"/nightly/"* ]]; then
uv pip install --prerelease=allow -v torch
fi

uv pip install -v "${WHEEL}"
(
set -e
Expand All @@ -237,6 +212,33 @@ jobs:
)
done

- name: Test built wheels with different CUDA
if: contains(matrix.target.toolkit, 'CUDA')
run: |
export UV_PYTHON=3.12
VERSION_LOG=$(mktemp)
for UV_TORCH_BACKEND in ${{ matrix.target.test_backends }}; do
export UV_TORCH_BACKEND
for WHEEL in wheelhouse/*.whl; do
echo "Testing wheel: ${WHEEL}"
(
set -e
uv venv test-venv
source test-venv/bin/activate

uv pip install -v "${WHEEL}"
(
set -e
cd /
uv run --no-project -- python -c "import tilelang, torch; print(tilelang.__version__, torch.__version__)" >> "$VERSION_LOG"
)
deactivate
rm -rf test-venv
)
done
done
cat "$VERSION_LOG"

- name: Upload wheels
# Not PR to save artifact storage, as wheels are only needed for releases.
if: github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]')
Expand Down
20 changes: 20 additions & 0 deletions docs/runtime_internals/stubs.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# CUDA/CUDART/NVRTC Stubs

This document describes the stub mechanism in TileLang for CUDA-related libraries.

## Purpose

1. **CUDA Driver (`cuda`)**: Allows TileLang to be imported on systems without a GPU (e.g., CI/compilation nodes) by lazy-loading `libcuda.so` only when needed.
2. **Runtime & Compiler (`cudart`, `nvrtc`)**: Resolves SONAME versioning mismatches, enabling a single build to work across different CUDA versions. This is achieved by reusing the existing CUDA runtime loaded by frameworks like PyTorch.

## Implementation

The stubs in `src/target/stubs/` implement a lazy-loading mechanism:

- **Lazy Loading**: Libraries are loaded via `dlopen` only upon the first API call.
- **Global Symbol Reuse**: For `cudart` and `nvrtc`, the stubs first check the global namespace (`RTLD_DEFAULT`) to use any already loaded symbols (e.g., from PyTorch).
- **Versioning Support**: Handles ABI differences between CUDA versions (e.g., `cudaGraphInstantiate` changes in CUDA 12).

## Build Option

`TILELANG_USE_CUDA_STUBS` (Default: `ON`) controls this behavior. When enabled, TileLang links against these stubs instead of the system CUDA toolkit.
9 changes: 6 additions & 3 deletions src/target/stubs/cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
* This file implements lazy loading of libcuda.so and provides global wrapper
* functions that serve as drop-in replacements for the CUDA driver API.
*
* The library is loaded on first API call using dlopen(). If loading fails
* (e.g., on a CPU-only machine), an exception is thrown at call time rather
* than at import time, allowing tilelang to be imported without CUDA.
* Motivation
* ----------
* The primary purpose is to allow TileLang to be imported on systems without
* a GPU (e.g., CI/compilation nodes). The library is loaded on first API call
* using dlopen(). If loading fails, an exception is thrown at call time rather
* than at import time.
*/

#include "cuda.h"
Expand Down
11 changes: 11 additions & 0 deletions src/target/stubs/cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@
* \file cuda.h
* \brief Stub library for lazy loading libcuda.so at runtime.
*
* Motivation
* ----------
* libcuda.so is the CUDA Driver API library. Linking directly against it
* creates a strong dependency on the presence of the NVIDIA driver at build
* time and runtime.
*
* This stub library allows TileLang to:
* 1. Be imported on CPU-only machines (no libcuda.so present).
* 2. Avoid versioning conflicts by loading the available libcuda.so
* dynamically.
*
* This library provides drop-in replacements for CUDA driver API functions.
* It allows tilelang to be imported on CPU-only machines without CUDA
* installed. The actual libcuda.so is loaded lazily on first API call.
Expand Down
66 changes: 26 additions & 40 deletions src/target/stubs/cudart.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
*
* Motivation
* ----------
* libcudart's SONAME encodes its major version (e.g. libcudart.so.11.0,
* libcudart.so.12, libcudart.so.13). If we link libtvm.so / libtvm_runtime.so
* directly against a specific SONAME, a wheel built against one CUDA toolkit
* becomes unusable in another environment that only provides a different
* libcudart major version.
* The primary purpose is to resolve SONAME mismatches (e.g., libcudart.so.11.0
* vs libcudart.so.12), allowing a single build to work across different CUDA
* versions. This is achieved by reusing the CUDA runtime already loaded by
* frameworks like PyTorch.
*
* This stub exports the subset of CUDA Runtime API entrypoints used by TVM in
* this repository. The real libcudart is loaded lazily via dlopen() on first
* API call, and symbols are resolved via dlsym().
*/

#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif

#include <cuda_runtime_api.h>

#if defined(_WIN32) && !defined(__CYGWIN__)
Expand All @@ -26,6 +29,8 @@

#include <dlfcn.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

// This stub supports CUDA 11+.
Expand All @@ -49,19 +54,6 @@ static_assert(CUDART_VERSION >= 11000,

namespace {

// Try multiple major versions for cross-toolkit compatibility.
constexpr const char *kLibCudartPaths[] = {
"libcudart.so.13",
"libcudart.so.12",
// CUDA 11 typically uses `libcudart.so.11.0` (and may also provide a
// `libcudart.so.11` symlink depending on the packaging).
"libcudart.so.11.0",
"libcudart.so.11",
// Unversioned name typically only exists with development packages, but try
// it as a last resort.
"libcudart.so",
};

using CudaGraphInstantiateLegacy = cudaError_t (*)(cudaGraphExec_t *pGraphExec,
cudaGraph_t graph,
cudaGraphNode_t *pErrorNode,
Expand All @@ -71,26 +63,24 @@ using CudaGraphInstantiateWithFlags = cudaError_t (*)(
cudaGraphExec_t *pGraphExec, cudaGraph_t graph, unsigned long long flags);

void *TryLoadLibCudart() {
// If libcudart is already loaded in the current process (e.g. via PyTorch or
// another CUDA-enabled library), prefer reusing that instance to avoid
// loading multiple libcudart versions in one process.
#ifdef RTLD_NOLOAD
for (const char *path : kLibCudartPaths) {
void *existing = dlopen(path, RTLD_LAZY | RTLD_LOCAL | RTLD_NOLOAD);
if (existing != nullptr) {
return existing;
}
// First, check if the symbols are already available globally.
// This handles cases where PyTorch or another library has already loaded
// libcudart, making its symbols available in the global namespace.
// We use a representative symbol like cudaGetErrorString.
// dlsym with RTLD_DEFAULT searches the global scope.
void *sym = dlsym(RTLD_DEFAULT, "cudaGetErrorString");
if (sym != nullptr && sym != reinterpret_cast<void *>(&cudaGetErrorString)) {
return RTLD_DEFAULT;
}
#endif

void *handle = nullptr;
for (const char *path : kLibCudartPaths) {
handle = dlopen(path, RTLD_LAZY | RTLD_LOCAL);
if (handle != nullptr) {
break;
}
sym = dlsym(RTLD_NEXT, "cudaGetErrorString");
if (sym != nullptr) {
return RTLD_NEXT;
}
return handle;

fprintf(stderr,
"TileLang Error: libcudart symbols not found globally. "
"Make sure PyTorch with CUDA is installed before using TileLang.\n");
abort();
}

template <typename T> T GetSymbol(void *handle, const char *name) {
Expand Down Expand Up @@ -179,10 +169,6 @@ const char *FallbackCudaErrorString(cudaError_t error) {
CUDARuntimeAPI CreateCUDARuntimeAPI() {
CUDARuntimeAPI api{};
void *handle = GetLibCudartHandle();
if (handle == nullptr) {
return api;
}

#define LOOKUP_REQUIRED(name) \
api.name##_ = GetSymbol<decltype(api.name##_)>(handle, #name); \
if (api.name##_ == nullptr) { \
Expand Down
68 changes: 25 additions & 43 deletions src/target/stubs/nvrtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
*
* Motivation
* ----------
* NVRTC's SONAME encodes its major version (e.g. libnvrtc.so.12,
* libnvrtc.so.13). If we link libtvm.so directly against a specific SONAME, a
* wheel built in one CUDA toolkit environment becomes unusable in another
* environment that only provides a different NVRTC major version.
* Similar to cudart, the primary purpose is to resolve SONAME mismatches,
* allowing a single build to work across different CUDA versions. This is
* achieved by reusing the NVRTC library already loaded by frameworks like
* PyTorch.
*
* This stub exports a minimal set of NVRTC C API entrypoints used by
* TVM/TileLang. The actual libnvrtc is loaded lazily via dlopen() on first API
* call, and symbols are resolved via dlsym().
*
* As a result, the final wheel can run in environments that have NVRTC from
* CUDA 11/12/13 available (as long as the required symbols exist).
*/

#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif

#include <nvrtc.h>

#if defined(_WIN32) && !defined(__CYGWIN__)
Expand All @@ -27,47 +28,32 @@

#include <dlfcn.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>

// Export symbols with default visibility for the shared stub library.
#define TILELANG_NVRTC_STUB_API __attribute__((visibility("default")))

namespace {

// Try multiple major versions for cross-toolkit compatibility.
constexpr const char *kLibNvrtcPaths[] = {
"libnvrtc.so.13",
"libnvrtc.so.12",
// CUDA 11 typically uses `libnvrtc.so.11.2` (and may also provide a
// `libnvrtc.so.11` symlink depending on the packaging).
"libnvrtc.so.11.2",
"libnvrtc.so.11.1",
"libnvrtc.so.11.0",
"libnvrtc.so.11",
// Unversioned name typically only exists with development packages, but try
// it as a last resort.
"libnvrtc.so",
};

void *TryLoadLibNvrtc() {
// If libnvrtc is already loaded in the current process, prefer reusing that
// instance to avoid loading multiple NVRTC versions in one process.
#ifdef RTLD_NOLOAD
for (const char *path : kLibNvrtcPaths) {
void *existing = dlopen(path, RTLD_LAZY | RTLD_LOCAL | RTLD_NOLOAD);
if (existing != nullptr) {
return existing;
}
// First, check if the symbols are already available globally.
// This handles cases where PyTorch or another library has already loaded
// libnvrtc.
// We use a representative symbol like nvrtcVersion.
void *sym = dlsym(RTLD_DEFAULT, "nvrtcVersion");
if (sym != nullptr && sym != reinterpret_cast<void *>(&nvrtcVersion)) {
return RTLD_DEFAULT;
}
#endif

void *handle = nullptr;
for (const char *path : kLibNvrtcPaths) {
handle = dlopen(path, RTLD_LAZY | RTLD_LOCAL);
if (handle != nullptr) {
break;
}
sym = dlsym(RTLD_NEXT, "nvrtcVersion");
if (sym != nullptr) {
return RTLD_NEXT;
}
return handle;

fprintf(stderr,
"TileLang Error: libnvrtc symbols not found globally. "
"Make sure PyTorch with CUDA is installed before using TileLang.\n");
abort();
}

template <typename T> T GetSymbol(void *handle, const char *name) {
Expand Down Expand Up @@ -100,10 +86,6 @@ void *GetLibNvrtcHandle() {
NVRTCAPI CreateNVRTCAPI() {
NVRTCAPI api{};
void *handle = GetLibNvrtcHandle();
if (handle == nullptr) {
return api;
}

#define LOOKUP_REQUIRED(name) \
api.name##_ = GetSymbol<decltype(api.name##_)>(handle, #name); \
if (api.name##_ == nullptr) { \
Expand Down
Loading