From 118e2fac059305dbb3f301914b66ec7fe72c3f86 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Wed, 25 Feb 2026 13:45:30 +0000 Subject: [PATCH 1/4] [GSan] Implement shadow memory allocator This implements an allocator that hooks into PyTorch's memory allocation API to map tensors into a GSan-managed virtual address space. We also create a corresponding shadow memory region that is mapped into the lower half of the reserved address space. Usage is like: ```python from triton.experimental import gsan allocator = gsan.get_allocator() pool = torch.cuda.MemPool(allocator.allocator()) with torch.cuda.use_mem_pool(pool): t = torch.empty(4096, dtype=torch.uint8, device="cuda") ``` git-pr-chain: gsan_implement_shadow_memory_allocator_cc5d --- .../workflows/integration-tests-nvidia.yml | 2 + Makefile | 5 + python/test/gsan/test_allocator.py | 152 ++++ python/test/gsan/test_utils.py | 18 + python/triton/experimental/gsan/__init__.py | 3 + python/triton/experimental/gsan/_allocator.py | 66 ++ python/triton/experimental/gsan/_utils.py | 103 +++ python/triton/experimental/gsan/src/GSan.h | 123 +++ .../experimental/gsan/src/GSanAllocator.cc | 729 ++++++++++++++++++ python/triton/runtime/build.py | 143 +++- third_party/nvidia/backend/driver.py | 21 +- 11 files changed, 1338 insertions(+), 27 deletions(-) create mode 100644 python/test/gsan/test_allocator.py create mode 100644 python/test/gsan/test_utils.py create mode 100644 python/triton/experimental/gsan/__init__.py create mode 100644 python/triton/experimental/gsan/_allocator.py create mode 100644 python/triton/experimental/gsan/_utils.py create mode 100644 python/triton/experimental/gsan/src/GSan.h create mode 100644 python/triton/experimental/gsan/src/GSanAllocator.cc diff --git a/.github/workflows/integration-tests-nvidia.yml b/.github/workflows/integration-tests-nvidia.yml index fb522ea3d771..ceb57657376d 100644 --- a/.github/workflows/integration-tests-nvidia.yml +++ b/.github/workflows/integration-tests-nvidia.yml @@ -92,6 +92,8 @@ jobs: run: make NUM_PROCS=24 test-unit - name: Run gluon tests run: make NUM_PROCS=24 test-gluon + - name: Run gsan tests + run: make NUM_PROCS=24 test-gsan - name: Run interpreter tests if: ${{ matrix.config.runner_type == 'nvidia-h100' }} run: make test-interpret diff --git a/Makefile b/Makefile index 3813532c0018..53eb64509acd 100644 --- a/Makefile +++ b/Makefile @@ -49,6 +49,11 @@ test-gluon: all $(PYTEST) -n $(NUM_PROCS) python/test/gluon/ python/tutorials/gluon/ $(PYTEST) -n 2 python/examples/gluon/ +.PHONY: test-gsan +test-gsan: all + $(PYTEST) --tb=short python/test/gsan --ignore python/test/gsan/test_gsan_failures.py + $(PYTEST) --tb=short python/test/gsan/test_gsan_failures.py + .PHONY: test-regression test-regression: all $(PYTEST) -n $(NUM_PROCS) python/test/regression diff --git a/python/test/gsan/test_allocator.py b/python/test/gsan/test_allocator.py new file mode 100644 index 000000000000..207a45ee3ad0 --- /dev/null +++ b/python/test/gsan/test_allocator.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import pytest +import torch + +from triton._internal_testing import is_cuda +from triton.experimental.gsan import create_mem_pool +from triton.experimental.gsan._allocator import get_reserve_pointer, get_reserve_size, gsan_free, gsan_malloc +from triton.experimental.gsan._utils import shadow_region, uint8_cuda_tensor_from_ptr + + +def shadow_tensor_for(real: torch.Tensor) -> torch.Tensor: + reserve_ptr = get_reserve_pointer() + reserve_size = get_reserve_size() + shadow_ptr, shadow_size = shadow_region(real.data_ptr(), real.untyped_storage().nbytes(), reserve_ptr, reserve_size) + return uint8_cuda_tensor_from_ptr(shadow_ptr, shadow_size, torch.cuda.current_device()) + + +@pytest.fixture +def _direct_allocator(): + device = torch.cuda.current_device() + stream = 0 + reserve_ptr = get_reserve_pointer() + reserve_size = get_reserve_size() + allocated = set() + + def malloc(size: int) -> int: + ptr_int = gsan_malloc(size, device, stream) + if ptr_int != 0: + allocated.add(ptr_int) + return ptr_int + + def free(ptr: int, size: int = 0) -> None: + gsan_free(ptr, device, size, stream) + if ptr in allocated: + allocated.remove(ptr) + + try: + yield malloc, free, reserve_ptr, reserve_size + finally: + # Cleanup any allocated pointers + for ptr in list(allocated): + gsan_free(ptr, device, 0, stream) + + +@pytest.mark.skipif(not is_cuda(), reason="requires CUDA backend") +def test_malloc_edge_cases(_direct_allocator): + malloc, free, reserve_ptr, reserve_size = _direct_allocator + + # Invalid sizes are rejected. + assert malloc(0) == 0 + assert malloc(-1) == 0 + assert malloc(reserve_size) == 0 # larger than the full real region + + # Null free is a no-op. + free(0) + + +def test_malloc_free(_direct_allocator): + malloc, free, reserve_ptr, reserve_size = _direct_allocator + real_base = reserve_ptr + reserve_size // 2 + + # First valid allocation should come from the real base and be reusable. + p0 = malloc(1) + assert p0 == real_base + free(p0) + assert malloc(1) == p0 + + p1 = malloc(1) + _ = malloc(1) + + free(p1) + p3 = malloc(1) + assert p3 == p1 + + +@pytest.mark.skipif(not is_cuda(), reason="requires CUDA backend") +def test_malloc_fragmentation_reuse_and_coalesce(_direct_allocator): + malloc, free, _, _ = _direct_allocator + + p0 = malloc(1) + p1 = malloc(1) + assert p0 != 0 and p1 != 0 + assert p0 < p1 + + block = p1 - p0 + assert block > 0 + + # Reuse exact freed block under fragmentation. + free(p1) + p1_reuse = malloc(1) + assert p1_reuse == p1 + + # Free two siblings and request a slightly larger block; should coalesce. + free(p0) + free(p1_reuse) + parent = malloc(block + 1) + assert parent == p0 + + free(parent) + + +@pytest.mark.skipif(not is_cuda(), reason="requires CUDA backend") +def test_free_invalid_pointer_and_double_free(_direct_allocator): + malloc, free, _, _ = _direct_allocator + + p0 = malloc(1) + assert p0 != 0 + + # Invalid interior-pointer free should not free p0 and must not crash. + free(p0 + 1) + + free(p0) + free(p0) # double free must be a no-op + + # p0 should become reusable after the valid free above. + p0_reuse = malloc(1) + assert p0_reuse == p0 + + free(p0_reuse) + + +@pytest.mark.skipif(not is_cuda(), reason="requires CUDA backend") +def test_mem_pool(): + pool = create_mem_pool() + with torch.cuda.use_mem_pool(pool): + real = torch.empty(4096, dtype=torch.uint8, device="cuda") + + reserve_ptr = get_reserve_pointer() + reserve_size = get_reserve_size() + assert reserve_ptr != 0 + assert reserve_size > 0 + + # Check real allocation is in higher half of reserve + real_base = reserve_ptr + reserve_size // 2 + assert real_base <= real.data_ptr() < reserve_ptr + reserve_size + + shadow = shadow_tensor_for(real) + assert reserve_ptr <= shadow.data_ptr() < reserve_ptr + reserve_size // 2 + + # Test that real and shadow allocation can be used + real.zero_() + real.add_(7) + # Note: shadow memory is zero-initialized by the allocator + shadow.add_(3) + + assert torch.all(real == 7).item() + assert torch.all(shadow == 3).item() + del pool + del real + del shadow + torch.cuda.synchronize() diff --git a/python/test/gsan/test_utils.py b/python/test/gsan/test_utils.py new file mode 100644 index 000000000000..b2e550f7ce16 --- /dev/null +++ b/python/test/gsan/test_utils.py @@ -0,0 +1,18 @@ +import pytest +import torch + +from triton._internal_testing import is_cuda +from triton.experimental.gsan._utils import uint8_cuda_tensor_from_ptr + + +@pytest.mark.skipif(not is_cuda(), reason="requires CUDA backend") +def test_uint8_cuda_tensor_from_ptr_delete_tensor(): + if torch.cuda.device_count() < 1: + pytest.skip("requires at least 1 CUDA device") + + torch.cuda.set_device(0) + view = uint8_cuda_tensor_from_ptr(12345, 10, 1) + assert view.data_ptr() == 12345 + assert view.shape == (10, ) + assert view.dtype == torch.uint8 + del view diff --git a/python/triton/experimental/gsan/__init__.py b/python/triton/experimental/gsan/__init__.py new file mode 100644 index 000000000000..57ad109c3ce0 --- /dev/null +++ b/python/triton/experimental/gsan/__init__.py @@ -0,0 +1,3 @@ +from ._allocator import create_mem_pool, get_allocator + +__all__ = ["create_mem_pool", "get_allocator"] diff --git a/python/triton/experimental/gsan/_allocator.py b/python/triton/experimental/gsan/_allocator.py new file mode 100644 index 000000000000..1a99fa07ac12 --- /dev/null +++ b/python/triton/experimental/gsan/_allocator.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import functools +from pathlib import Path +from types import ModuleType + +from triton.runtime import driver as runtime_driver +from triton.runtime.build import compile_module_from_file + +_THIS_DIR = Path(__file__).resolve().parent +_GSAN_SOURCE_PATH = _THIS_DIR / "src" / "GSanAllocator.cc" + + +@functools.lru_cache() +def _load_gsan_module() -> ModuleType: + if runtime_driver.active.get_current_target().backend != "cuda": + raise RuntimeError("GSan allocator requires the CUDA backend.") + + from triton.backends.nvidia.driver import library_dirs, include_dirs + + return compile_module_from_file( + src_path=str(_GSAN_SOURCE_PATH), + name="gsan_allocator", + library_dirs=library_dirs(), + include_dirs=include_dirs, + libraries=["libcuda.so.1"], + ) + + +@functools.lru_cache() +def _compile_gsan_allocator() -> str: + return str(_load_gsan_module().__file__) + + +@functools.lru_cache() +def get_allocator(): + from torch.cuda.memory import CUDAPluggableAllocator + so_name = _compile_gsan_allocator() + return CUDAPluggableAllocator(so_name, "gsanMalloc", "gsanFree") + + +def create_mem_pool(): + from torch.cuda.memory import MemPool + return MemPool(get_allocator().allocator()) + + +def gsan_malloc(size: int, device: int, stream: int = 0) -> int: + module = _load_gsan_module() + return module.malloc(size, device, stream) + + +def gsan_free(ptr: int, device: int, size: int = 0, stream: int = 0) -> None: + module = _load_gsan_module() + module.free(ptr, device, size, stream) + + +def get_reserve_pointer() -> int: + return _load_gsan_module().get_reserve_pointer() + + +def get_reserve_size() -> int: + return _load_gsan_module().get_reserve_size() + + +def get_global_state_pointer() -> int: + return _load_gsan_module().get_global_state_pointer() diff --git a/python/triton/experimental/gsan/_utils.py b/python/triton/experimental/gsan/_utils.py new file mode 100644 index 000000000000..0903fa0f753f --- /dev/null +++ b/python/triton/experimental/gsan/_utils.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import ctypes +import torch + +_DLPACK_CAPSULE_NAME = b"dltensor" +_DL_UINT = 1 +_DL_BITS_UINT8 = 8 +_DL_LANES = 1 +_DL_CUDA = 2 + + +class _DLDevice(ctypes.Structure): + _fields_ = [("device_type", ctypes.c_int), ("device_id", ctypes.c_int)] + + +class _DLDataType(ctypes.Structure): + _fields_ = [("code", ctypes.c_uint8), ("bits", ctypes.c_uint8), ("lanes", ctypes.c_uint16)] + + +class _DLTensor(ctypes.Structure): + _fields_ = [ + ("data", ctypes.c_void_p), + ("device", _DLDevice), + ("ndim", ctypes.c_int), + ("dtype", _DLDataType), + ("shape", ctypes.POINTER(ctypes.c_int64)), + ("strides", ctypes.POINTER(ctypes.c_int64)), + ("byte_offset", ctypes.c_uint64), + ] + + +class _DLManagedTensor(ctypes.Structure): + pass + + +_DLManagedTensorHandle = ctypes.POINTER(_DLManagedTensor) +_DLManagedTensorDeleter = ctypes.CFUNCTYPE(None, _DLManagedTensorHandle) + +_DLManagedTensor._fields_ = [ + ("dl_tensor", _DLTensor), + ("manager_ctx", ctypes.c_void_p), + ("deleter", _DLManagedTensorDeleter), +] + +_PyCapsule_New = ctypes.pythonapi.PyCapsule_New +_PyCapsule_New.restype = ctypes.py_object +_PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] + +# Hold ctypes-backed DLPack payloads until the tensor deleter runs. +_DLPACK_STATE: dict[int, tuple[object, object, object]] = {} + + +@_DLManagedTensorDeleter +def _dl_managed_tensor_deleter(dl_managed_tensor: _DLManagedTensorHandle) -> None: + if not dl_managed_tensor: + return + _DLPACK_STATE.pop(ctypes.addressof(dl_managed_tensor.contents), None) + + +def uint8_cuda_tensor_from_ptr(data_ptr: int, numel: int, device_index: int) -> torch.Tensor: + numel = int(numel) + if numel < 0: + raise ValueError(f"numel must be >= 0, got {numel}") + + shape = (ctypes.c_int64 * 1)(numel) + strides = (ctypes.c_int64 * 1)(1) + dl_managed_tensor = _DLManagedTensor() + dl_managed_tensor.dl_tensor.data = ctypes.c_void_p(int(data_ptr)) + dl_managed_tensor.dl_tensor.device = _DLDevice(_DL_CUDA, device_index) + dl_managed_tensor.dl_tensor.ndim = 1 + dl_managed_tensor.dl_tensor.dtype = _DLDataType(_DL_UINT, _DL_BITS_UINT8, _DL_LANES) + dl_managed_tensor.dl_tensor.shape = ctypes.cast(shape, ctypes.POINTER(ctypes.c_int64)) + dl_managed_tensor.dl_tensor.strides = ctypes.cast(strides, ctypes.POINTER(ctypes.c_int64)) + dl_managed_tensor.dl_tensor.byte_offset = 0 + dl_managed_tensor.manager_ctx = None + dl_managed_tensor.deleter = _dl_managed_tensor_deleter + + dl_managed_tensor_ptr = ctypes.addressof(dl_managed_tensor) + _DLPACK_STATE[dl_managed_tensor_ptr] = (dl_managed_tensor, shape, strides) + + try: + dlpack_capsule = _PyCapsule_New( + ctypes.c_void_p(dl_managed_tensor_ptr), + _DLPACK_CAPSULE_NAME, + None, + ) + return torch.from_dlpack(dlpack_capsule) + except Exception: + _DLPACK_STATE.pop(dl_managed_tensor_ptr, None) + raise + + +SHADOW_SIZE_BYTES = 24 +SHADOW_GRANULARITY_BYTES = 4 + + +def shadow_region(real_ptr: int, real_size_bytes: int, reserve_ptr: int, reserve_size: int) -> tuple[int, int]: + real_base = reserve_ptr + reserve_size // 2 + word_offset = (real_ptr - real_base) // SHADOW_GRANULARITY_BYTES + shadow_ptr = reserve_ptr + word_offset * SHADOW_SIZE_BYTES + shadow_size = ((real_size_bytes + SHADOW_GRANULARITY_BYTES - 1) // SHADOW_GRANULARITY_BYTES) * SHADOW_SIZE_BYTES + return shadow_ptr, shadow_size diff --git a/python/triton/experimental/gsan/src/GSan.h b/python/triton/experimental/gsan/src/GSan.h new file mode 100644 index 000000000000..53000fc4ad92 --- /dev/null +++ b/python/triton/experimental/gsan/src/GSan.h @@ -0,0 +1,123 @@ +#include +#include + +#ifdef __CUDACC__ +#define GSAN_HOST_DEVICE __host__ __device__ +#else +#define GSAN_HOST_DEVICE +#endif + +namespace gsan { + +// Reserve 1 PiB, should be big enough for a while :) +static constexpr size_t kReserveSize = 1ull << 40; +static constexpr int kShadowMemGranularityBytes = 4; +static_assert((kReserveSize & (kReserveSize - 1)) == 0, + "kReserveSize must be a power of 2"); + +using thread_id_t = uint16_t; + +enum class AtomicScope : uint8_t { + NonAtomic, + CTA, + GPU, + System, + MAX_VALUE = System, +}; + +using epoch_t = uint16_t; + +struct alignas(4) ScalarClock { + epoch_t epoch; + thread_id_t threadId : 12; // Supports 4096 threads + AtomicScope scope : 2; +}; +static constexpr int kMaxThreads = 1 << 12; +static_assert(sizeof(ScalarClock) == 4); +static_assert(static_cast(AtomicScope::MAX_VALUE) == 3); + +// TODO: Change to struct-of-array for better coalescing? +struct alignas(4) ShadowCell { + static constexpr int kReadClockSize = 4; + ScalarClock readClocks[kReadClockSize]; + ScalarClock writeClock; + uint16_t numReads; + uint16_t lock; +}; +static_assert(sizeof(ShadowCell) == 24); +static_assert(alignof(ShadowCell) == 4); + +static constexpr int kShadowSizeBytes = sizeof(ShadowCell); + +struct GlobalState { + // Base address of gsan managed memory + uintptr_t reserveBase; + uintptr_t globalsBase; + + uint32_t rngSeed; + + thread_id_t numSms; + thread_id_t numDevices; + // numThreads = numSms * numDevices + thread_id_t numThreads; + + uint16_t clockBufferSize; +}; + +struct ThreadState { + GlobalState *globals; + uintptr_t reserveBase; + + // monotonic counter, used for stochastic read clock updates + uint32_t numReads; + + // Index to head of the circular clock buffer, plus a bit to mark if the + // vector clock has changed since the last clock buffer write (to allow reuse) + uint32_t clockBufferDirty : 1; + uint32_t clockBufferHead : 31; + + // Reader-writer lock controlling access to the vector clock and clock buffer + uint32_t lock; + + thread_id_t threadId; + + // Local vector clock, shape [numThreads] + // Followed by the clock buffer, shape [clockBufferSize, numThreads] + epoch_t vectorClock[]; +}; + +// Place the thread state for each device at a fixed stride for ease of +// address calculation. +static constexpr uintptr_t kPerDeviceStateStride = 1ull << 30; +static constexpr uintptr_t kMaxGPUs = 16; +static constexpr uintptr_t kGlobalsReserveSize = + kPerDeviceStateStride * kMaxGPUs; + +inline GSAN_HOST_DEVICE GlobalState *getGlobalState(ThreadState *threadState) { + auto threadAddr = (uintptr_t)threadState; + return (GlobalState *)(threadAddr & ~(kPerDeviceStateStride - 1)); +} + +inline GSAN_HOST_DEVICE uintptr_t getRealBaseAddress(uintptr_t reserveBase) { + return reserveBase + kReserveSize / 2; +} + +inline GSAN_HOST_DEVICE uintptr_t getReserveBaseFromAddress(uintptr_t addr) { + return addr & ~(kReserveSize - 1); +} + +// Assumes address is in gsan-managed memory +inline GSAN_HOST_DEVICE uintptr_t getShadowAddress(uintptr_t virtualAddress) { + auto reserveBase = getReserveBaseFromAddress(virtualAddress); + auto realBase = getRealBaseAddress(reserveBase); + auto byteOffset = virtualAddress - realBase; + auto wordOffset = byteOffset / kShadowMemGranularityBytes; + return reserveBase + kShadowSizeBytes * wordOffset; +} + +inline GSAN_HOST_DEVICE bool isGsanManaged(uintptr_t addr, + uintptr_t reserveBase) { + return getReserveBaseFromAddress(addr) == reserveBase; +} + +} // namespace gsan diff --git a/python/triton/experimental/gsan/src/GSanAllocator.cc b/python/triton/experimental/gsan/src/GSanAllocator.cc new file mode 100644 index 000000000000..09cae6d85a00 --- /dev/null +++ b/python/triton/experimental/gsan/src/GSanAllocator.cc @@ -0,0 +1,729 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "GSan.h" + +// #define GSAN_LOG_ALLOCATIONS +#ifdef GSAN_LOG_ALLOCATIONS +#define LOGF(...) printf(__VA_ARGS__); +#else +#define LOGF(...) +#endif + +extern "C" { +void *gsanMalloc(ssize_t size, int device, void *stream); +void gsanFree(void *ptr, ssize_t size, int device, void *stream); +} + +namespace { + +// We use a tree structure to manage virtual address allocations. +// +// This is a binary tree where each node represents a power of two-sized region +// of memory. Each node tracks the largest free node in its subtree. This +// allows us to allocate best-fit regions in O(log(AddressSpaceSize)), and same +// for deallocation. +// +// Note that we don't really care about being compact/defragmented in any way, +// since we can reserve millions of times more virtual memory than there is +// physical memory. +// We also are based under the PyTorch CUDACachingAllocator which manages most +// of the hard parts for us and only asks us to allocate large blocks that it +// will divide up as needed. +struct AllocNode { + CUdeviceptr virtualAddress = 0; + AllocNode *parent = nullptr; + std::unique_ptr leftChild; + std::unique_ptr rightChild; + size_t size = 0; + size_t maxFreeBlockSize = 0; + + // Allocation handles, used only by leaf nodes + CUmemGenericAllocationHandle realHandle = 0; + CUmemGenericAllocationHandle shadowHandle = 0; +}; + +struct GSanConfig { + int numGPUs = 4; + int numSMs = 152; + int numThreads = 4 * 152; + int clockBufferSize = 1024; + uint32_t rngSeed = 0x12345678u; +}; + +struct AllocatorState { + GSanConfig config; + + // User memory + shadow memory + CUdeviceptr reserveBaseAddress = 0; + AllocNode treeRoot; + + // GSan global state + CUdeviceptr globalStateAddress = 0; + CUmemGenericAllocationHandle perDeviceHandles[gsan::kMaxGPUs] = {0}; + size_t perDeviceStateSize = 0; +}; + +void printCUDAError(CUresult err) { + const char *errs = ""; + cuGetErrorString(err, &errs); + fprintf(stderr, "gsan allocator encountered an unexpected error: %s\n", errs); +} + +static AllocatorState *alloc = nullptr; +static std::mutex mut; + +size_t cdiv(size_t num, size_t den) { return (num + (den - 1)) / den; } + +size_t roundUp(size_t val, size_t alignment) { + return cdiv(val, alignment) * alignment; +} + +size_t roundUpToPowerOfTwo(size_t value) { + if (value <= 1) + return 1; + if ((value & (value - 1)) == 0) + return value; + size_t rounded = 1; + while (rounded < value) + rounded <<= 1; + return rounded; +} + +size_t getShadowSize(size_t realMemSize) { + auto wordSize = cdiv(realMemSize, gsan::kShadowMemGranularityBytes); + return wordSize * gsan::kShadowSizeBytes; +} + +bool isLeaf(const AllocNode *node) { + return node->leftChild == nullptr && node->rightChild == nullptr; +} + +void recomputeNodeState(AllocNode *node) { + assert((node->leftChild == nullptr) == (node->rightChild == nullptr) && + "allocator tree node should have both children or none"); + + if (isLeaf(node)) { + assert( + (node->maxFreeBlockSize == 0 || node->maxFreeBlockSize == node->size) && + "leaf nodes should be either fully free or fully allocated"); + return; + } + + node->maxFreeBlockSize = std::max(node->leftChild->maxFreeBlockSize, + node->rightChild->maxFreeBlockSize); +} + +void recomputeToRoot(AllocNode *node) { + for (AllocNode *curr = node; curr != nullptr; curr = curr->parent) + recomputeNodeState(curr); +} + +void splitNode(AllocNode *node) { + assert(isLeaf(node)); + assert(node->maxFreeBlockSize == node->size); + const size_t halfSize = node->size / 2; + auto left = std::make_unique(); + auto right = std::make_unique(); + + left->virtualAddress = node->virtualAddress; + left->size = halfSize; + left->maxFreeBlockSize = halfSize; + left->parent = node; + + right->virtualAddress = node->virtualAddress + halfSize; + right->size = halfSize; + right->maxFreeBlockSize = halfSize; + right->parent = node; + + node->leftChild = std::move(left); + node->rightChild = std::move(right); + node->maxFreeBlockSize = halfSize; +} + +AllocNode *allocateNode(AllocNode *root, size_t allocSize) { + AllocNode *node = root; + if (node == nullptr || node->maxFreeBlockSize < allocSize) + return nullptr; + + if (isLeaf(node)) { + assert(node->maxFreeBlockSize == node->size); + + while (node->size > 1 && (node->size / 2) >= allocSize) { + splitNode(node); + node = node->leftChild.get(); + } + node->maxFreeBlockSize = 0; + recomputeToRoot(node->parent); + return node; + } + + auto *left = node->leftChild.get(); + auto *right = node->rightChild.get(); + const bool leftFits = left->maxFreeBlockSize >= allocSize; + const bool rightFits = right->maxFreeBlockSize >= allocSize; + + AllocNode *next = nullptr; + // Prefer the tighter-fitting subtree to keep larger blocks available. + if (leftFits && rightFits) { + next = (left->maxFreeBlockSize <= right->maxFreeBlockSize) ? left : right; + } else if (leftFits) { + next = left; + } else { + next = right; + } + return allocateNode(next, allocSize); +} + +AllocNode *findNodeByAddress(AllocNode *root, CUdeviceptr address) { + AllocNode *node = root; + while (node != nullptr) { + if (address < node->virtualAddress || + address >= node->virtualAddress + node->size) + return nullptr; + + if (!node->leftChild && !node->rightChild) + return node; + + if (node->rightChild && address >= node->rightChild->virtualAddress) { + node = node->rightChild.get(); + } else { + node = node->leftChild.get(); + } + } + return nullptr; +} + +bool canCoalesce(const AllocNode *node) { + if (node == nullptr) + return false; + assert((node->leftChild == nullptr) == (node->rightChild == nullptr) && + "allocator tree node should have both children or none"); + if (!node->leftChild) + return false; + + const auto *left = node->leftChild.get(); + const auto *right = node->rightChild.get(); + const bool leftFree = left->maxFreeBlockSize == left->size; + const bool rightFree = right->maxFreeBlockSize == right->size; + return leftFree && rightFree; +} + +void coalesceUp(AllocNode *node) { + if (node == nullptr) + return; + while (node != nullptr && canCoalesce(node)) { + node->leftChild.reset(); + node->rightChild.reset(); + node->maxFreeBlockSize = node->size; + node = node->parent; + } + recomputeToRoot(node); +} + +void freeNode(AllocNode *leaf) { + assert(isLeaf(leaf)); + leaf->realHandle = 0; + leaf->shadowHandle = 0; + leaf->maxFreeBlockSize = leaf->size; + coalesceUp(leaf->parent); +} + +int gsanEnsureInit() { + if (alloc) + return 0; + + CUdeviceptr reserveBase; + CUresult err = cuMemAddressReserve(&reserveBase, /*size*/ gsan::kReserveSize, + /*alignment*/ gsan::kReserveSize, + /*addr*/ 0, /*flags*/ 0); + if (err != CUDA_SUCCESS) { + printCUDAError(err); + return -1; + } + + CUdeviceptr globalsBase; + err = cuMemAddressReserve(&globalsBase, /*size*/ gsan::kGlobalsReserveSize, + /*alignment*/ gsan::kGlobalsReserveSize, + /*addr*/ 0, /*flags*/ 0); + if (err != CUDA_SUCCESS) { + printCUDAError(err); + return -1; + } + alloc = new AllocatorState(); + alloc->reserveBaseAddress = reserveBase; + alloc->globalStateAddress = globalsBase; + + auto *root = &alloc->treeRoot; + root->virtualAddress = gsan::getRealBaseAddress(reserveBase); + root->size = gsan::kReserveSize / 2; + root->maxFreeBlockSize = root->size; + return 0; +} + +CUresult ensureContext(int device) { + CUcontext ctx = 0; + CUresult res = cuCtxGetCurrent(&ctx); + if (res != CUDA_SUCCESS) + return res; + if (ctx) + return res; + + res = cuDevicePrimaryCtxRetain(&ctx, device); + if (res != CUDA_SUCCESS) + return res; + return cuCtxSetCurrent(ctx); +} + +CUresult refreshConfigForDevice(int device) { + if (alloc == nullptr) + return CUDA_ERROR_NOT_INITIALIZED; + + int numGPUs = 0; + CUresult err = cuDeviceGetCount(&numGPUs); + if (err != CUDA_SUCCESS) + return err; + if (numGPUs <= 0) + return CUDA_ERROR_NO_DEVICE; + if (numGPUs > static_cast(gsan::kMaxGPUs)) + return CUDA_ERROR_NOT_SUPPORTED; + if (device < 0 || device >= numGPUs) + return CUDA_ERROR_INVALID_DEVICE; + + CUdevice cuDevice = 0; + err = cuDeviceGet(&cuDevice, device); + if (err != CUDA_SUCCESS) + return err; + + int numSMs = 0; + err = cuDeviceGetAttribute(&numSMs, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, + cuDevice); + if (err != CUDA_SUCCESS) + return err; + if (numSMs <= 0) + return CUDA_ERROR_INVALID_VALUE; + + auto &config = alloc->config; + // Triton may execute more than one instrumented launch on a kernel's first + // user-visible invocation (e.g. compile/warmup paths). Using 3 slots avoids + // aliasing back to the same logical thread ID when two launches occur. + constexpr int kGSanThreadSlotsPerDeviceThread = 3; + config.numGPUs = numGPUs; + config.numSMs = numSMs; + config.numThreads = + kGSanThreadSlotsPerDeviceThread * config.numGPUs * config.numSMs; + return CUDA_SUCCESS; +} + +CUresult ensureRuntimeStateMapped(int device) { + if (alloc == nullptr) + return CUDA_ERROR_NOT_INITIALIZED; + CUresult err = refreshConfigForDevice(device); + if (err != CUDA_SUCCESS) + return err; + auto &config = alloc->config; + if (alloc->perDeviceHandles[device] != 0) + return CUDA_SUCCESS; + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + + size_t granularity = 0; + err = cuMemGetAllocationGranularity(&granularity, &prop, + CU_MEM_ALLOC_GRANULARITY_MINIMUM); + if (err != CUDA_SUCCESS) + return err; + + static_assert(alignof(gsan::GlobalState) >= alignof(gsan::ThreadState)); + static_assert(alignof(gsan::ThreadState) >= alignof(gsan::epoch_t)); + + auto numSMs = config.numSMs; + auto numThreads = config.numThreads; + assert(numThreads <= gsan::kMaxThreads); + auto clockSizeBytes = sizeof(gsan::epoch_t) * config.numThreads; + // 1 local clock + the circular clock buffer + auto clocksPerThread = 1 + config.clockBufferSize; + auto perDeviceStateSize = ( + // Each device has a local copy of the constant global state + sizeof(gsan::GlobalState) + + // Plus per-thread state for each SM + config.numSMs * + (sizeof(gsan::ThreadState) + clockSizeBytes * clocksPerThread)); + assert(perDeviceStateSize <= gsan::kPerDeviceStateStride); + + size_t allocSize = roundUp(perDeviceStateSize, granularity); + + CUmemGenericAllocationHandle allocHandle = 0; + bool mapped = false; + CUmemAccessDesc accessDesc = {}; + gsan::GlobalState globals = {}; + CUdeviceptr deviceAddr = + alloc->globalStateAddress + device * gsan::kPerDeviceStateStride; + + err = cuMemCreate(&allocHandle, allocSize, &prop, 0); + if (err != CUDA_SUCCESS) + goto error; + + err = cuMemMap(deviceAddr, allocSize, /*offset*/ 0, allocHandle, /*flags*/ 0); + if (err != CUDA_SUCCESS) + goto error; + mapped = true; + + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + accessDesc.location.id = device; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + err = cuMemSetAccess(deviceAddr, allocSize, &accessDesc, 1); + if (err != CUDA_SUCCESS) + goto error; + + err = cuMemsetD8(deviceAddr, 0, allocSize); + if (err != CUDA_SUCCESS) + goto error; + + globals.reserveBase = static_cast(alloc->reserveBaseAddress); + globals.globalsBase = static_cast(alloc->globalStateAddress); + globals.rngSeed = config.rngSeed; + globals.numSms = static_cast(config.numSMs); + globals.numDevices = static_cast(config.numGPUs); + globals.numThreads = static_cast(config.numThreads); + globals.clockBufferSize = config.clockBufferSize; + err = cuMemcpyHtoD(deviceAddr, &globals, sizeof(globals)); + if (err != CUDA_SUCCESS) + goto error; + + alloc->perDeviceHandles[device] = allocHandle; + alloc->perDeviceStateSize = allocSize; + return CUDA_SUCCESS; + +error: + if (mapped) + cuMemUnmap(deviceAddr, allocSize); + if (allocHandle != 0) + cuMemRelease(allocHandle); + return err; +} + +CUresult mapNodeHandles(AllocNode *node, + CUmemGenericAllocationHandle realHandle, + CUmemGenericAllocationHandle shadowHandle, int device, + bool *realMapped, bool *shadowMapped) { + assert(node != nullptr); + assert(realMapped != nullptr); + assert(shadowMapped != nullptr); + + const auto shadowAddress = gsan::getShadowAddress(node->virtualAddress); + const auto shadowSize = getShadowSize(node->size); + + CUresult err = cuMemMap(node->virtualAddress, node->size, /*offset*/ 0, + realHandle, /*flags*/ 0); + if (err != CUDA_SUCCESS) + return err; + *realMapped = true; + + err = cuMemMap(shadowAddress, shadowSize, /*offset*/ 0, shadowHandle, + /*flags*/ 0); + if (err != CUDA_SUCCESS) + return err; + *shadowMapped = true; + + CUmemAccessDesc accessDesc = {}; + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + accessDesc.location.id = device; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + + err = cuMemSetAccess(node->virtualAddress, node->size, &accessDesc, 1); + if (err != CUDA_SUCCESS) + return err; + + return cuMemSetAccess(shadowAddress, shadowSize, &accessDesc, 1); +} + +void unmapNodeHandles(AllocNode *node, bool realMapped, bool shadowMapped) { + assert(node != nullptr); + const auto shadowAddress = gsan::getShadowAddress(node->virtualAddress); + const auto shadowSize = getShadowSize(node->size); + if (shadowMapped) + cuMemUnmap(shadowAddress, shadowSize); + if (realMapped) + cuMemUnmap(node->virtualAddress, node->size); +} + +} // namespace + +// TODO: Handle streams? +extern "C" void *gsanMalloc(ssize_t size, int device, + [[maybe_unused]] void *stream) { + if (size <= 0) + return nullptr; + + std::lock_guard lg(mut); + if (gsanEnsureInit() != 0) + return nullptr; + + CUresult err = ensureContext(device); + if (err != CUDA_SUCCESS) { + printCUDAError(err); + return nullptr; + } + err = ensureRuntimeStateMapped(device); + if (err != CUDA_SUCCESS) { + printCUDAError(err); + return nullptr; + } + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + + size_t granularity = 0; + err = cuMemGetAllocationGranularity(&granularity, &prop, + CU_MEM_ALLOC_GRANULARITY_MINIMUM); + if (err != CUDA_SUCCESS) { + printCUDAError(err); + return nullptr; + } + size_t allocSize = roundUp(static_cast(size), granularity); + AllocNode *node = allocateNode(&alloc->treeRoot, allocSize); + if (node == nullptr) + return nullptr; + + CUmemGenericAllocationHandle realHandle = 0; + CUmemGenericAllocationHandle shadowHandle = 0; + bool realMapped = false; + bool shadowMapped = false; + auto cuStream = reinterpret_cast(stream); + auto shadowAddress = gsan::getShadowAddress(node->virtualAddress); + auto shadowSize = getShadowSize(allocSize); + + err = cuMemCreate(&realHandle, node->size, &prop, 0); + if (err != CUDA_SUCCESS) + goto error; + + err = cuMemCreate(&shadowHandle, shadowSize, &prop, 0); + if (err != CUDA_SUCCESS) + goto error; + + err = mapNodeHandles(node, realHandle, shadowHandle, device, &realMapped, + &shadowMapped); + if (err != CUDA_SUCCESS) + goto error; + + // Zero-initialize shadow memory + err = cuMemsetD8Async(shadowAddress, 0, shadowSize, cuStream); + if (err != CUDA_SUCCESS) + goto error; + + node->realHandle = realHandle; + node->shadowHandle = shadowHandle; + LOGF("gsanMalloc: %p, 0x%zxu", reinterpret_cast(node->virtualAddress), + size); + return reinterpret_cast(node->virtualAddress); + +error: + printCUDAError(err); + unmapNodeHandles(node, realMapped, shadowMapped); + if (shadowHandle != 0) + cuMemRelease(shadowHandle); + if (realHandle != 0) + cuMemRelease(realHandle); + freeNode(node); + return nullptr; +} + +extern "C" void gsanFree(void *void_ptr, [[maybe_unused]] ssize_t size, + [[maybe_unused]] int device, + [[maybe_unused]] void *stream) { + LOGF("gsanFree: %p, 0x%zx", void_ptr, size); + auto ptr = reinterpret_cast(void_ptr); + if (!ptr) + return; + + std::lock_guard lg(mut); + if (alloc == nullptr) + return; + + AllocNode *node = findNodeByAddress(&alloc->treeRoot, ptr); + if (node == nullptr || node->maxFreeBlockSize != 0 || + node->virtualAddress != ptr) { + fprintf(stderr, "gsanFree called with an invalid pointer\n"); + return; + } + + const auto shadowAddress = gsan::getShadowAddress(node->virtualAddress); + const auto shadowSize = getShadowSize(node->size); + + CUresult err = cuMemUnmap(node->virtualAddress, node->size); + if (err != CUDA_SUCCESS) + printCUDAError(err); + + err = cuMemUnmap(shadowAddress, shadowSize); + if (err != CUDA_SUCCESS) + printCUDAError(err); + + err = cuMemRelease(node->realHandle); + if (err != CUDA_SUCCESS) + printCUDAError(err); + + err = cuMemRelease(node->shadowHandle); + if (err != CUDA_SUCCESS) + printCUDAError(err); + + freeNode(node); +} + +void *gsanGetReservePointer() { + std::lock_guard lg(mut); + if (gsanEnsureInit() != 0) + return nullptr; + return reinterpret_cast(alloc->reserveBaseAddress); +} + +namespace { + +constexpr const char *kModuleName = "gsan_allocator"; + +bool parseIntArg(PyObject *obj, const char *name, int *out) { + long value = PyLong_AsLong(obj); + if (value == -1 && PyErr_Occurred()) + return false; + if (value < std::numeric_limits::min() || + value > std::numeric_limits::max()) { + PyErr_Format(PyExc_OverflowError, "%s is out of range for int", name); + return false; + } + *out = static_cast(value); + return true; +} + +bool parseVoidPtrArg(PyObject *obj, void **out) { + *out = PyLong_AsVoidPtr(obj); + return !(*out == nullptr && PyErr_Occurred()); +} + +PyObject *pyMalloc(PyObject *self, PyObject *const *args, Py_ssize_t nargs) { + (void)self; + if (nargs != 2 && nargs != 3) { + PyErr_Format(PyExc_TypeError, + "%s.malloc expected 2 or 3 positional arguments, got %zd", + kModuleName, nargs); + return nullptr; + } + + Py_ssize_t size = PyLong_AsSsize_t(args[0]); + if (size == -1 && PyErr_Occurred()) + return nullptr; + + int device = 0; + if (!parseIntArg(args[1], "device", &device)) + return nullptr; + + void *stream = nullptr; + if (nargs == 3 && !parseVoidPtrArg(args[2], &stream)) + return nullptr; + + return PyLong_FromVoidPtr(gsanMalloc(size, device, stream)); +} + +PyObject *pyFree(PyObject *self, PyObject *const *args, Py_ssize_t nargs) { + (void)self; + if (nargs < 2 || nargs > 4) { + PyErr_Format( + PyExc_TypeError, + "%s.free expected between 2 and 4 positional arguments, got %zd", + kModuleName, nargs); + return nullptr; + } + + void *ptr = nullptr; + if (!parseVoidPtrArg(args[0], &ptr)) + return nullptr; + + int device = 0; + if (!parseIntArg(args[1], "device", &device)) + return nullptr; + + Py_ssize_t size = 0; + if (nargs >= 3) { + size = PyLong_AsSsize_t(args[2]); + if (size == -1 && PyErr_Occurred()) + return nullptr; + } + + void *stream = nullptr; + if (nargs == 4 && !parseVoidPtrArg(args[3], &stream)) + return nullptr; + + gsanFree(ptr, size, device, stream); + Py_RETURN_NONE; +} + +PyObject *pyGetReservePointer(PyObject *self, PyObject *const *args, + Py_ssize_t nargs) { + (void)self; + if (nargs != 0) { + PyErr_Format( + PyExc_TypeError, + "%s.get_reserve_pointer expected 0 positional arguments, got %zd", + kModuleName, nargs); + return nullptr; + } + return PyLong_FromVoidPtr(gsanGetReservePointer()); +} + +PyObject *pyGetReserveSize(PyObject *self, PyObject *args) { + return PyLong_FromUnsignedLongLong(gsan::kReserveSize); +} + +PyObject *pyGetShadowSizeBytes(PyObject *self, PyObject *args) { + return PyLong_FromLong(sizeof(gsan::ShadowCell)); +} + +PyObject *pyGetGlobalStatePointer(PyObject *self, PyObject *args) { + (void)self; + std::lock_guard lg(mut); + if (gsanEnsureInit() != 0) { + PyErr_SetString(PyExc_RuntimeError, "failed to initialize gsan allocator"); + return nullptr; + } + return PyLong_FromUnsignedLongLong(alloc->globalStateAddress); +} + +PyMethodDef kGSanAllocatorMethods[] = { + {"malloc", reinterpret_cast(pyMalloc), METH_FASTCALL, + "Allocate GSan memory. Returns a CUDA pointer as an integer."}, + {"free", reinterpret_cast(pyFree), METH_FASTCALL, + "Free GSan memory by pointer."}, + {"get_reserve_pointer", reinterpret_cast(pyGetReservePointer), + METH_FASTCALL, "Return the reserve base pointer as an integer."}, + {"get_reserve_size", reinterpret_cast(pyGetReserveSize), + METH_NOARGS, "Return the reserve size in bytes."}, + {"get_shadow_size_bytes", + reinterpret_cast(pyGetShadowSizeBytes), METH_NOARGS, + "Return the shadow cell size in bytes."}, + {"get_global_state_pointer", + reinterpret_cast(pyGetGlobalStatePointer), METH_NOARGS, + "Return the pointer to the GSan global state region."}, + {nullptr, nullptr, 0, nullptr}, +}; + +PyModuleDef kGSanAllocatorModuleDef = { + PyModuleDef_HEAD_INIT, "gsan_allocator", nullptr, -1, kGSanAllocatorMethods, +}; + +} // namespace + +PyMODINIT_FUNC PyInit_gsan_allocator(void) { + return PyModule_Create(&kGSanAllocatorModuleDef); +} diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 786f51e54db7..11b3e7ca3574 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -3,13 +3,14 @@ import functools import hashlib import importlib.util -import logging import os +import re import shutil import subprocess import sysconfig import tempfile -import re +import logging +from pathlib import Path from types import ModuleType @@ -17,20 +18,52 @@ from .. import knobs +@functools.lru_cache() +def _find_compiler(language: str) -> str: + if language == "c": + cc = os.environ.get("CC") + if cc is not None: + return cc + clang = shutil.which("clang") + gcc = shutil.which("gcc") + cc = gcc if gcc is not None else clang + if cc is not None: + return cc + raise RuntimeError( + "Failed to find C compiler. Please specify via CC environment variable or set triton.knobs.build.impl.") + + assert language == "c++" + cxx = os.environ.get("CXX") + if cxx is not None: + return cxx + + clangxx = shutil.which("clang++") + gxx = shutil.which("g++") + cxx = gxx if gxx is not None else clangxx + if cxx is not None: + return cxx + + raise RuntimeError( + "Failed to find C++ compiler. Please specify via CXX environment variable or set triton.knobs.build.impl.") + + +def _language_from_filename(source_name: str) -> str: + ext = Path(source_name).suffix + if ext == ".c": + return "c" + if ext in {".cc", ".cpp", ".cxx"}: + return "c++" + print(source_name) + raise ValueError(f"Unrecognized file extension: {source_name}") + + def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str], libraries: list[str], - ccflags: list[str]) -> str: + ccflags: list[str], language: str = "c") -> str: if impl := knobs.build.impl: return impl(name, src, srcdir, library_dirs, include_dirs, libraries) suffix = sysconfig.get_config_var('EXT_SUFFIX') - so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) - cc = os.environ.get("CC") - if cc is None: - clang = shutil.which("clang") - gcc = shutil.which("gcc") - cc = gcc if gcc is not None else clang - if cc is None: - raise RuntimeError( - "Failed to find C compiler. Please specify via CC environment variable or set triton.knobs.build.impl.") + so = os.path.join(srcdir, f'{name}{suffix}') + cc = _find_compiler(language) scheme = sysconfig.get_default_scheme() # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install # path changes to include 'local'. This change is required to use triton with system-wide python. @@ -41,6 +74,8 @@ def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_di include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] + if language == "c++": + cc_cmd.insert(3, "-std=c++17") cc_cmd += [_library_flag(lib) for lib in libraries] cc_cmd += [f"-L{dir}" for dir in library_dirs] cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] @@ -62,6 +97,14 @@ def platform_key() -> str: return ",".join([machine(), system(), *architecture()]) +def _get_file_extension(language): + if language == "c": + return ".c" + if language == "c++": + return ".cpp" + raise ValueError(f"Unexpected languange: {language}") + + def _load_module_from_path(name: str, path: str) -> ModuleType: spec = importlib.util.spec_from_file_location(name, path) if not spec or not spec.loader: @@ -71,15 +114,23 @@ def _load_module_from_path(name: str, path: str) -> ModuleType: return mod -def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None, - include_dirs: list[str] | None = None, libraries: list[str] | None = None, - ccflags: list[str] | None = None) -> ModuleType: - key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest() - cache = get_cache_manager(key) +def _get_cache_manager(src: bytes, language: str): + digest = hashlib.sha256() + digest.update(src) + digest.update(platform_key().encode("utf-8")) + digest.update(language.encode("utf-8")) + key = digest.hexdigest() + return get_cache_manager(key) + + +def _compile_so(src: bytes, src_path: str, name: str, library_dirs: list[str] | None, include_dirs: list[str] | None, + libraries: list[str] | None, ccflags: list[str] | None, load_module: bool, language: str): + cache = _get_cache_manager(src, language) suffix = sysconfig.get_config_var("EXT_SUFFIX") cache_path = cache.get_file(f"{name}{suffix}") - if cache_path is not None: + if not load_module: + return cache_path try: return _load_module_from_path(name, cache_path) except (RuntimeError, ImportError): @@ -87,11 +138,57 @@ def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None log.warning(f"Triton cache error: compiled module {name}.so could not be loaded") with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, name + ".c") - with open(src_path, "w") as f: - f.write(src) - so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [], ccflags or []) + so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [], ccflags or [], + language=language) with open(so, "rb") as f: cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True) - return _load_module_from_path(name, cache_path) + return _load_module_from_path(name, cache_path) if load_module else cache_path + + +def _compile_so_from_file(src_path: str, name: str, library_dirs: list[str] | None, include_dirs: list[str] | None, + libraries: list[str] | None, ccflags: list[str] | None, load_module: bool): + src_path = os.path.abspath(src_path) + src_name = os.path.basename(src_path) + with open(src_path, "rb") as f: + src = f.read() + + language = _language_from_filename(src_name) + return _compile_so(src=src, src_path=src_path, name=name, library_dirs=library_dirs, include_dirs=include_dirs, + libraries=libraries, ccflags=ccflags, language=language, load_module=load_module) + + +def _compile_so_from_src(src: str, name: str, library_dirs: list[str] | None, include_dirs: list[str] | None, + libraries: list[str] | None, ccflags: list[str] | None, language, load_module: bool): + src_bytes = src.encode("utf-8") + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, f"{name}{_get_file_extension(language)}") + with open(src_path, "wb") as f: + f.write(src_bytes) + return _compile_so(src=src_bytes, src_path=src_path, name=name, library_dirs=library_dirs, + include_dirs=include_dirs, libraries=libraries, ccflags=ccflags, language=language, + load_module=load_module) + + +def compile_so_from_file(src_path: str, name: str, library_dirs: list[str] | None = None, + include_dirs: list[str] | None = None, libraries: list[str] | None = None, + ccflags: list[str] | None = None) -> str: + return _compile_so_from_file(src_path, name, library_dirs, include_dirs, libraries, ccflags, load_module=False) + + +def compile_so_from_src(src: str, name: str, library_dirs: list[str] | None = None, + include_dirs: list[str] | None = None, libraries: list[str] | None = None, + ccflags: list[str] | None = None, language="c") -> str: + return _compile_so_from_src(src, name, library_dirs, include_dirs, libraries, ccflags, language, load_module=False) + + +def compile_module_from_file(src_path: str, name: str, library_dirs: list[str] | None = None, + include_dirs: list[str] | None = None, libraries: list[str] | None = None, + ccflags: list[str] | None = None) -> ModuleType: + return _compile_so_from_file(src_path, name, library_dirs, include_dirs, libraries, ccflags, load_module=True) + + +def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None, + include_dirs: list[str] | None = None, libraries: list[str] | None = None, + ccflags: list[str] | None = None, language="c") -> ModuleType: + return _compile_so_from_src(src, name, library_dirs, include_dirs, libraries, ccflags, language, load_module=True) diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index b88d0015743a..5853c00a83f7 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -5,8 +5,9 @@ import ctypes import sys from pathlib import Path +import re from triton import knobs -from triton.runtime.build import compile_module_from_src +from triton.runtime.build import compile_module_from_file from triton.runtime import _allocation from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver, decompose_descriptor, expand_signature, wrap_handle_tensordesc_impl @@ -20,6 +21,7 @@ ARG_CONSTEXPR = None ARG_KERNEL = None ARG_TUPLE = None +GSAN_PER_DEVICE_STATE_STRIDE = 1 << 30 @functools.lru_cache() @@ -98,8 +100,8 @@ def __new__(cls): return cls.instance def __init__(self): - mod = compile_module_from_src( - src=Path(os.path.join(dirname, "driver.c")).read_text(), + mod = compile_module_from_file( + src_path=os.path.join(dirname, "driver.c"), name="cuda_utils", library_dirs=library_dirs(), include_dirs=include_dirs, @@ -279,6 +281,9 @@ def __init__(self, src, metadata): launcher = triton.runtime.driver.active.utils.launch expanded_signature = expand_signature(signature.values(), tensordesc_meta, "nvTmaDesc") + self.gsan_enabled = "gsan" in getattr(metadata, "instrumentation_mode", "") + if self.gsan_enabled: + expanded_signature.append("*i8") self.arg_annotations = annotate_arguments(expanded_signature) self.kernel_signature = make_kernel_signature(expanded_signature) self.num_ctas = getattr(metadata, "num_ctas", 1) @@ -315,10 +320,18 @@ def allocate_default_profile_scratch(size, align): _allocation._profile_allocator) else: profile_scratch = allocate_default_profile_scratch(self.profile_scratch_size, self.profile_scratch_align) + _allocation._profile_allocator) + kernel_args = args + if self.gsan_enabled: + from triton.experimental.gsan import _allocator as gsan_allocator + + device = triton.runtime.driver.active.get_current_device() + gsan_state_ptr = gsan_allocator.get_global_state_pointer() + device * GSAN_PER_DEVICE_STATE_STRIDE + kernel_args = (*args, gsan_state_ptr) self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl, kernel_metadata, launch_metadata, launch_enter_hook, launch_exit_hook, global_scratch, - profile_scratch, self.arg_annotations, self.kernel_signature, args) + profile_scratch, self.arg_annotations, self.kernel_signature, kernel_args) class CudaDriver(GPUDriver): From e56b30b33d611ee3067ddaecb5b0096c787ab590 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Fri, 27 Feb 2026 15:18:22 +0000 Subject: [PATCH 2/4] Misc cleanup/fixes --- Makefile | 3 +- python/test/gsan/test_allocator.py | 9 +- python/triton/experimental/gsan/_allocator.py | 3 +- python/triton/experimental/gsan/_utils.py | 16 ++- python/triton/experimental/gsan/src/GSan.h | 4 +- .../experimental/gsan/src/GSanAllocator.cc | 126 ++++++++++++------ python/triton/runtime/build.py | 14 +- third_party/nvidia/backend/driver.py | 14 +- 8 files changed, 112 insertions(+), 77 deletions(-) diff --git a/Makefile b/Makefile index 53eb64509acd..0e7573ff4820 100644 --- a/Makefile +++ b/Makefile @@ -51,8 +51,7 @@ test-gluon: all .PHONY: test-gsan test-gsan: all - $(PYTEST) --tb=short python/test/gsan --ignore python/test/gsan/test_gsan_failures.py - $(PYTEST) --tb=short python/test/gsan/test_gsan_failures.py + $(PYTEST) --tb=short python/test/gsan .PHONY: test-regression test-regression: all diff --git a/python/test/gsan/test_allocator.py b/python/test/gsan/test_allocator.py index 207a45ee3ad0..ea343d37e0f7 100644 --- a/python/test/gsan/test_allocator.py +++ b/python/test/gsan/test_allocator.py @@ -6,14 +6,7 @@ from triton._internal_testing import is_cuda from triton.experimental.gsan import create_mem_pool from triton.experimental.gsan._allocator import get_reserve_pointer, get_reserve_size, gsan_free, gsan_malloc -from triton.experimental.gsan._utils import shadow_region, uint8_cuda_tensor_from_ptr - - -def shadow_tensor_for(real: torch.Tensor) -> torch.Tensor: - reserve_ptr = get_reserve_pointer() - reserve_size = get_reserve_size() - shadow_ptr, shadow_size = shadow_region(real.data_ptr(), real.untyped_storage().nbytes(), reserve_ptr, reserve_size) - return uint8_cuda_tensor_from_ptr(shadow_ptr, shadow_size, torch.cuda.current_device()) +from triton.experimental.gsan._utils import shadow_tensor_for @pytest.fixture diff --git a/python/triton/experimental/gsan/_allocator.py b/python/triton/experimental/gsan/_allocator.py index 1a99fa07ac12..f85ce79d4bea 100644 --- a/python/triton/experimental/gsan/_allocator.py +++ b/python/triton/experimental/gsan/_allocator.py @@ -29,7 +29,8 @@ def _load_gsan_module() -> ModuleType: @functools.lru_cache() def _compile_gsan_allocator() -> str: - return str(_load_gsan_module().__file__) + # __file__ for a compiled module is the so file + return _load_gsan_module().__file__ @functools.lru_cache() diff --git a/python/triton/experimental/gsan/_utils.py b/python/triton/experimental/gsan/_utils.py index 0903fa0f753f..ec1a25683826 100644 --- a/python/triton/experimental/gsan/_utils.py +++ b/python/triton/experimental/gsan/_utils.py @@ -1,5 +1,7 @@ from __future__ import annotations +from triton.experimental.gsan._allocator import get_reserve_pointer, get_reserve_size + import ctypes import torch @@ -43,9 +45,8 @@ class _DLManagedTensor(ctypes.Structure): ("deleter", _DLManagedTensorDeleter), ] -_PyCapsule_New = ctypes.pythonapi.PyCapsule_New -_PyCapsule_New.restype = ctypes.py_object -_PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] +PyCapsule_NewType = ctypes.CFUNCTYPE(ctypes.py_object, ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p) +PyCapsule_New = PyCapsule_NewType(ctypes.pythonapi.PyCapsule_New) # Hold ctypes-backed DLPack payloads until the tensor deleter runs. _DLPACK_STATE: dict[int, tuple[object, object, object]] = {} @@ -80,7 +81,7 @@ def uint8_cuda_tensor_from_ptr(data_ptr: int, numel: int, device_index: int) -> _DLPACK_STATE[dl_managed_tensor_ptr] = (dl_managed_tensor, shape, strides) try: - dlpack_capsule = _PyCapsule_New( + dlpack_capsule = PyCapsule_New( ctypes.c_void_p(dl_managed_tensor_ptr), _DLPACK_CAPSULE_NAME, None, @@ -101,3 +102,10 @@ def shadow_region(real_ptr: int, real_size_bytes: int, reserve_ptr: int, reserve shadow_ptr = reserve_ptr + word_offset * SHADOW_SIZE_BYTES shadow_size = ((real_size_bytes + SHADOW_GRANULARITY_BYTES - 1) // SHADOW_GRANULARITY_BYTES) * SHADOW_SIZE_BYTES return shadow_ptr, shadow_size + + +def shadow_tensor_for(real: torch.Tensor) -> torch.Tensor: + reserve_ptr = get_reserve_pointer() + reserve_size = get_reserve_size() + shadow_ptr, shadow_size = shadow_region(real.data_ptr(), real.untyped_storage().nbytes(), reserve_ptr, reserve_size) + return uint8_cuda_tensor_from_ptr(shadow_ptr, shadow_size, torch.cuda.current_device()) diff --git a/python/triton/experimental/gsan/src/GSan.h b/python/triton/experimental/gsan/src/GSan.h index 53000fc4ad92..d8a6b3104757 100644 --- a/python/triton/experimental/gsan/src/GSan.h +++ b/python/triton/experimental/gsan/src/GSan.h @@ -47,8 +47,6 @@ struct alignas(4) ShadowCell { static_assert(sizeof(ShadowCell) == 24); static_assert(alignof(ShadowCell) == 4); -static constexpr int kShadowSizeBytes = sizeof(ShadowCell); - struct GlobalState { // Base address of gsan managed memory uintptr_t reserveBase; @@ -112,7 +110,7 @@ inline GSAN_HOST_DEVICE uintptr_t getShadowAddress(uintptr_t virtualAddress) { auto realBase = getRealBaseAddress(reserveBase); auto byteOffset = virtualAddress - realBase; auto wordOffset = byteOffset / kShadowMemGranularityBytes; - return reserveBase + kShadowSizeBytes * wordOffset; + return reserveBase + sizeof(ShadowCell) * wordOffset; } inline GSAN_HOST_DEVICE bool isGsanManaged(uintptr_t addr, diff --git a/python/triton/experimental/gsan/src/GSanAllocator.cc b/python/triton/experimental/gsan/src/GSanAllocator.cc index 09cae6d85a00..a92006090008 100644 --- a/python/triton/experimental/gsan/src/GSanAllocator.cc +++ b/python/triton/experimental/gsan/src/GSanAllocator.cc @@ -3,12 +3,14 @@ #include #include +#include #include #include #include #include #include #include +#include #include "GSan.h" @@ -50,14 +52,15 @@ struct AllocNode { // Allocation handles, used only by leaf nodes CUmemGenericAllocationHandle realHandle = 0; CUmemGenericAllocationHandle shadowHandle = 0; + size_t allocSize = 0; }; struct GSanConfig { - int numGPUs = 4; - int numSMs = 152; - int numThreads = 4 * 152; - int clockBufferSize = 1024; - uint32_t rngSeed = 0x12345678u; + int numGPUs; + int numSMs; + int numThreads; + int clockBufferSize; + uint32_t rngSeed; }; struct AllocatorState { @@ -88,20 +91,22 @@ size_t roundUp(size_t val, size_t alignment) { return cdiv(val, alignment) * alignment; } -size_t roundUpToPowerOfTwo(size_t value) { - if (value <= 1) - return 1; - if ((value & (value - 1)) == 0) - return value; - size_t rounded = 1; - while (rounded < value) - rounded <<= 1; - return rounded; +uint32_t roundDownToPowerOfTwo(uint32_t x) { + if (x == 0) + return 0; + + x |= x >> 1; + x |= x >> 2; + x |= x >> 4; + x |= x >> 8; + x |= x >> 16; + + return x - (x >> 1); } size_t getShadowSize(size_t realMemSize) { auto wordSize = cdiv(realMemSize, gsan::kShadowMemGranularityBytes); - return wordSize * gsan::kShadowSizeBytes; + return wordSize * sizeof(gsan::ShadowCell); } bool isLeaf(const AllocNode *node) { @@ -232,6 +237,7 @@ void coalesceUp(AllocNode *node) { void freeNode(AllocNode *leaf) { assert(isLeaf(leaf)); + leaf->allocSize = 0; leaf->realHandle = 0; leaf->shadowHandle = 0; leaf->maxFreeBlockSize = leaf->size; @@ -265,8 +271,16 @@ int gsanEnsureInit() { auto *root = &alloc->treeRoot; root->virtualAddress = gsan::getRealBaseAddress(reserveBase); - root->size = gsan::kReserveSize / 2; - root->maxFreeBlockSize = root->size; + + // Choose size so that both shadow memory and real memory definitely fit in + // the address reservation + auto shadowSize = gsan::kReserveSize / 2; + auto realSize = gsan::kShadowMemGranularityBytes * + (shadowSize / sizeof(gsan::ShadowCell)); + realSize = std::min(gsan::kReserveSize / 2, realSize); + realSize = roundDownToPowerOfTwo(realSize); + root->size = realSize; + root->maxFreeBlockSize = realSize; return 0; } @@ -313,14 +327,41 @@ CUresult refreshConfigForDevice(int device) { return CUDA_ERROR_INVALID_VALUE; auto &config = alloc->config; - // Triton may execute more than one instrumented launch on a kernel's first - // user-visible invocation (e.g. compile/warmup paths). Using 3 slots avoids - // aliasing back to the same logical thread ID when two launches occur. - constexpr int kGSanThreadSlotsPerDeviceThread = 3; config.numGPUs = numGPUs; config.numSMs = numSMs; - config.numThreads = - kGSanThreadSlotsPerDeviceThread * config.numGPUs * config.numSMs; + config.numThreads = config.numGPUs * config.numSMs; + + // Seed rng for stochastic read clocks + auto userSeed = getenv("TRITON_GSAN_SEED"); + if (userSeed) { + auto res = + std::from_chars(userSeed, userSeed + strlen(userSeed), config.rngSeed); + if (res.ec != std::errc()) { + auto errc = make_error_code(res.ec); + auto msg = errc.message(); + fprintf(stderr, "Invalid TRITON_GSAN_SEED value: %s", msg.c_str()); + return CUDA_ERROR_INVALID_VALUE; + } + } else { + std::uniform_int_distribution dist; + std::random_device rd{}; + config.rngSeed = dist(rd); + } + + auto userClockSize = getenv("TRITON_GSAN_CLOCK_BUFFER_SIZE"); + if (userClockSize) { + auto res = + std::from_chars(userSeed, userSeed + strlen(userSeed), config.rngSeed); + if (res.ec != std::errc()) { + auto errc = make_error_code(res.ec); + auto msg = errc.message(); + fprintf(stderr, "Invalid TRITON_CLOCK_BUFFER_SIZE value: %s", + msg.c_str()); + return CUDA_ERROR_INVALID_VALUE; + } + } else { + config.clockBufferSize = 1024; + } return CUDA_SUCCESS; } @@ -354,15 +395,14 @@ CUresult ensureRuntimeStateMapped(int device) { auto clockSizeBytes = sizeof(gsan::epoch_t) * config.numThreads; // 1 local clock + the circular clock buffer auto clocksPerThread = 1 + config.clockBufferSize; - auto perDeviceStateSize = ( - // Each device has a local copy of the constant global state - sizeof(gsan::GlobalState) + - // Plus per-thread state for each SM - config.numSMs * - (sizeof(gsan::ThreadState) + clockSizeBytes * clocksPerThread)); - assert(perDeviceStateSize <= gsan::kPerDeviceStateStride); - + auto perSMStateSize = + sizeof(gsan::ThreadState) + clockSizeBytes * clocksPerThread; + perSMStateSize = roundUp(perSMStateSize, alignof(gsan::ThreadState)); + // Each device has a local copy of the constant global state + auto perDeviceStateSize = + (sizeof(gsan::GlobalState) + config.numSMs * perSMStateSize); size_t allocSize = roundUp(perDeviceStateSize, granularity); + assert(allocSize <= gsan::kPerDeviceStateStride); CUmemGenericAllocationHandle allocHandle = 0; bool mapped = false; @@ -507,7 +547,7 @@ extern "C" void *gsanMalloc(ssize_t size, int device, auto shadowAddress = gsan::getShadowAddress(node->virtualAddress); auto shadowSize = getShadowSize(allocSize); - err = cuMemCreate(&realHandle, node->size, &prop, 0); + err = cuMemCreate(&realHandle, allocSize, &prop, 0); if (err != CUDA_SUCCESS) goto error; @@ -525,6 +565,7 @@ extern "C" void *gsanMalloc(ssize_t size, int device, if (err != CUDA_SUCCESS) goto error; + node->allocSize = allocSize; node->realHandle = realHandle; node->shadowHandle = shadowHandle; LOGF("gsanMalloc: %p, 0x%zxu", reinterpret_cast(node->virtualAddress), @@ -562,9 +603,9 @@ extern "C" void gsanFree(void *void_ptr, [[maybe_unused]] ssize_t size, } const auto shadowAddress = gsan::getShadowAddress(node->virtualAddress); - const auto shadowSize = getShadowSize(node->size); + const auto shadowSize = getShadowSize(node->allocSize); - CUresult err = cuMemUnmap(node->virtualAddress, node->size); + CUresult err = cuMemUnmap(node->virtualAddress, node->allocSize); if (err != CUDA_SUCCESS) printCUDAError(err); @@ -612,8 +653,8 @@ bool parseVoidPtrArg(PyObject *obj, void **out) { return !(*out == nullptr && PyErr_Occurred()); } -PyObject *pyMalloc(PyObject *self, PyObject *const *args, Py_ssize_t nargs) { - (void)self; +PyObject *pyMalloc([[maybe_unused]] PyObject *self, PyObject *const *args, + Py_ssize_t nargs) { if (nargs != 2 && nargs != 3) { PyErr_Format(PyExc_TypeError, "%s.malloc expected 2 or 3 positional arguments, got %zd", @@ -636,8 +677,8 @@ PyObject *pyMalloc(PyObject *self, PyObject *const *args, Py_ssize_t nargs) { return PyLong_FromVoidPtr(gsanMalloc(size, device, stream)); } -PyObject *pyFree(PyObject *self, PyObject *const *args, Py_ssize_t nargs) { - (void)self; +PyObject *pyFree([[maybe_unused]] PyObject *self, PyObject *const *args, + Py_ssize_t nargs) { if (nargs < 2 || nargs > 4) { PyErr_Format( PyExc_TypeError, @@ -669,9 +710,8 @@ PyObject *pyFree(PyObject *self, PyObject *const *args, Py_ssize_t nargs) { Py_RETURN_NONE; } -PyObject *pyGetReservePointer(PyObject *self, PyObject *const *args, - Py_ssize_t nargs) { - (void)self; +PyObject *pyGetReservePointer([[maybe_unused]] PyObject *self, + PyObject *const *args, Py_ssize_t nargs) { if (nargs != 0) { PyErr_Format( PyExc_TypeError, @@ -690,8 +730,8 @@ PyObject *pyGetShadowSizeBytes(PyObject *self, PyObject *args) { return PyLong_FromLong(sizeof(gsan::ShadowCell)); } -PyObject *pyGetGlobalStatePointer(PyObject *self, PyObject *args) { - (void)self; +PyObject *pyGetGlobalStatePointer([[maybe_unused]] PyObject *self, + PyObject *args) { std::lock_guard lg(mut); if (gsanEnsureInit() != 0) { PyErr_SetString(PyExc_RuntimeError, "failed to initialize gsan allocator"); diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 11b3e7ca3574..3db9346c7524 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -114,18 +114,26 @@ def _load_module_from_path(name: str, path: str) -> ModuleType: return mod -def _get_cache_manager(src: bytes, language: str): +def _get_cache_manager(src: bytes, config: dict[str, list[str] | None]): digest = hashlib.sha256() digest.update(src) digest.update(platform_key().encode("utf-8")) - digest.update(language.encode("utf-8")) + for k, vs in config.items(): + if vs is None: + continue + digest.update(k.encode("utf-8")) + for v in vs: + digest.update(v.encode("utf-8")) + digest.update(b":") key = digest.hexdigest() return get_cache_manager(key) def _compile_so(src: bytes, src_path: str, name: str, library_dirs: list[str] | None, include_dirs: list[str] | None, libraries: list[str] | None, ccflags: list[str] | None, load_module: bool, language: str): - cache = _get_cache_manager(src, language) + config = dict(language=[language], library_dirs=library_dirs, include_dirs=include_dirs, libraries=libraries, + ccflags=ccflags) + cache = _get_cache_manager(src, config=config) suffix = sysconfig.get_config_var("EXT_SUFFIX") cache_path = cache.get_file(f"{name}{suffix}") if cache_path is not None: diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 5853c00a83f7..2ce247642c5f 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -21,7 +21,6 @@ ARG_CONSTEXPR = None ARG_KERNEL = None ARG_TUPLE = None -GSAN_PER_DEVICE_STATE_STRIDE = 1 << 30 @functools.lru_cache() @@ -281,9 +280,6 @@ def __init__(self, src, metadata): launcher = triton.runtime.driver.active.utils.launch expanded_signature = expand_signature(signature.values(), tensordesc_meta, "nvTmaDesc") - self.gsan_enabled = "gsan" in getattr(metadata, "instrumentation_mode", "") - if self.gsan_enabled: - expanded_signature.append("*i8") self.arg_annotations = annotate_arguments(expanded_signature) self.kernel_signature = make_kernel_signature(expanded_signature) self.num_ctas = getattr(metadata, "num_ctas", 1) @@ -320,18 +316,10 @@ def allocate_default_profile_scratch(size, align): _allocation._profile_allocator) else: profile_scratch = allocate_default_profile_scratch(self.profile_scratch_size, self.profile_scratch_align) - _allocation._profile_allocator) - kernel_args = args - if self.gsan_enabled: - from triton.experimental.gsan import _allocator as gsan_allocator - - device = triton.runtime.driver.active.get_current_device() - gsan_state_ptr = gsan_allocator.get_global_state_pointer() + device * GSAN_PER_DEVICE_STATE_STRIDE - kernel_args = (*args, gsan_state_ptr) self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, self.launch_pdl, kernel_metadata, launch_metadata, launch_enter_hook, launch_exit_hook, global_scratch, - profile_scratch, self.arg_annotations, self.kernel_signature, kernel_args) + profile_scratch, self.arg_annotations, self.kernel_signature, args) class CudaDriver(GPUDriver): From 4f3ecb5ec359cea8f4887f1ac8c41294074ffb45 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Wed, 4 Mar 2026 14:32:22 +0000 Subject: [PATCH 3/4] Sync stream before dealloc --- python/test/gsan/test_allocator.py | 2 ++ python/triton/experimental/gsan/src/GSanAllocator.cc | 12 +++++++++--- third_party/nvidia/backend/driver.py | 2 -- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/python/test/gsan/test_allocator.py b/python/test/gsan/test_allocator.py index ea343d37e0f7..191af8a79911 100644 --- a/python/test/gsan/test_allocator.py +++ b/python/test/gsan/test_allocator.py @@ -91,6 +91,7 @@ def test_malloc_fragmentation_reuse_and_coalesce(_direct_allocator): assert parent == p0 free(parent) + torch.cuda.synchronize() @pytest.mark.skipif(not is_cuda(), reason="requires CUDA backend") @@ -111,6 +112,7 @@ def test_free_invalid_pointer_and_double_free(_direct_allocator): assert p0_reuse == p0 free(p0_reuse) + torch.cuda.synchronize() @pytest.mark.skipif(not is_cuda(), reason="requires CUDA backend") diff --git a/python/triton/experimental/gsan/src/GSanAllocator.cc b/python/triton/experimental/gsan/src/GSanAllocator.cc index a92006090008..90b2f44fc02d 100644 --- a/python/triton/experimental/gsan/src/GSanAllocator.cc +++ b/python/triton/experimental/gsan/src/GSanAllocator.cc @@ -584,8 +584,7 @@ extern "C" void *gsanMalloc(ssize_t size, int device, } extern "C" void gsanFree(void *void_ptr, [[maybe_unused]] ssize_t size, - [[maybe_unused]] int device, - [[maybe_unused]] void *stream) { + [[maybe_unused]] int device, void *stream) { LOGF("gsanFree: %p, 0x%zx", void_ptr, size); auto ptr = reinterpret_cast(void_ptr); if (!ptr) @@ -602,10 +601,17 @@ extern "C" void gsanFree(void *void_ptr, [[maybe_unused]] ssize_t size, return; } + // Wait for outstanding work on the deallocation stream, including the + // allocator's own async shadow memset from gsanMalloc, before unmapping. + auto cuStream = reinterpret_cast(stream); + CUresult err = cuStreamSynchronize(cuStream); + if (err != CUDA_SUCCESS) + printCUDAError(err); + const auto shadowAddress = gsan::getShadowAddress(node->virtualAddress); const auto shadowSize = getShadowSize(node->allocSize); - CUresult err = cuMemUnmap(node->virtualAddress, node->allocSize); + err = cuMemUnmap(node->virtualAddress, node->allocSize); if (err != CUDA_SUCCESS) printCUDAError(err); diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 2ce247642c5f..b6eff3fa7990 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -4,8 +4,6 @@ import triton import ctypes import sys -from pathlib import Path -import re from triton import knobs from triton.runtime.build import compile_module_from_file from triton.runtime import _allocation From 661198b1254d9db1c65879367bc3bc9c6ff6eacb Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Thu, 12 Mar 2026 21:24:24 +0000 Subject: [PATCH 4/4] More misc changes --- Makefile | 2 +- python/test/gsan/test_allocator.py | 3 +-- python/test/gsan/test_utils.py | 5 +---- python/triton/experimental/gsan/_utils.py | 2 +- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index 0e7573ff4820..8cbe51d24af4 100644 --- a/Makefile +++ b/Makefile @@ -51,7 +51,7 @@ test-gluon: all .PHONY: test-gsan test-gsan: all - $(PYTEST) --tb=short python/test/gsan + $(PYTEST) -n $(NUM_PROCS) python/test/gsan .PHONY: test-regression test-regression: all diff --git a/python/test/gsan/test_allocator.py b/python/test/gsan/test_allocator.py index 191af8a79911..a2440eca7ea5 100644 --- a/python/test/gsan/test_allocator.py +++ b/python/test/gsan/test_allocator.py @@ -101,8 +101,7 @@ def test_free_invalid_pointer_and_double_free(_direct_allocator): p0 = malloc(1) assert p0 != 0 - # Invalid interior-pointer free should not free p0 and must not crash. - free(p0 + 1) + free(p0 + 1) # freeing an invalid pointer should not crash. free(p0) free(p0) # double free must be a no-op diff --git a/python/test/gsan/test_utils.py b/python/test/gsan/test_utils.py index b2e550f7ce16..f00f0c9ced98 100644 --- a/python/test/gsan/test_utils.py +++ b/python/test/gsan/test_utils.py @@ -7,12 +7,9 @@ @pytest.mark.skipif(not is_cuda(), reason="requires CUDA backend") def test_uint8_cuda_tensor_from_ptr_delete_tensor(): - if torch.cuda.device_count() < 1: - pytest.skip("requires at least 1 CUDA device") - - torch.cuda.set_device(0) view = uint8_cuda_tensor_from_ptr(12345, 10, 1) assert view.data_ptr() == 12345 assert view.shape == (10, ) assert view.dtype == torch.uint8 + assert view.device == torch.device("cuda:1") del view diff --git a/python/triton/experimental/gsan/_utils.py b/python/triton/experimental/gsan/_utils.py index ec1a25683826..a4e153043cc1 100644 --- a/python/triton/experimental/gsan/_utils.py +++ b/python/triton/experimental/gsan/_utils.py @@ -108,4 +108,4 @@ def shadow_tensor_for(real: torch.Tensor) -> torch.Tensor: reserve_ptr = get_reserve_pointer() reserve_size = get_reserve_size() shadow_ptr, shadow_size = shadow_region(real.data_ptr(), real.untyped_storage().nbytes(), reserve_ptr, reserve_size) - return uint8_cuda_tensor_from_ptr(shadow_ptr, shadow_size, torch.cuda.current_device()) + return uint8_cuda_tensor_from_ptr(shadow_ptr, shadow_size, real.device.index)