diff --git a/python/triton/experimental/gsan/src/GSan.h b/python/triton/experimental/gsan/src/GSan.h index 1c4d1fc0e39c..853a3cb0a328 100644 --- a/python/triton/experimental/gsan/src/GSan.h +++ b/python/triton/experimental/gsan/src/GSan.h @@ -1,14 +1,24 @@ -#include -#include +#pragma once -#ifdef __CUDACC__ +#if defined(__CUDA__) && defined(__clang__) +#define GSAN_DEVICE __attribute__((device)) +#define GSAN_HOST_DEVICE __attribute__((host)) __attribute__((device)) +#elif defined(__CUDACC__) +#define GSAN_DEVICE __device__ #define GSAN_HOST_DEVICE __host__ __device__ #else +#define GSAN_DEVICE #define GSAN_HOST_DEVICE #endif namespace gsan { +using size_t = __SIZE_TYPE__; +using uint8_t = __UINT8_TYPE__; +using uint16_t = __UINT16_TYPE__; +using uint32_t = __UINT32_TYPE__; +using uintptr_t = __UINTPTR_TYPE__; + // Reserve 1 PiB, should be big enough for a while :) static constexpr size_t kReserveSize = 1ull << 40; static constexpr int kShadowMemGranularityBytes = 4; diff --git a/python/triton/experimental/gsan/src/GSanLibrary.cu b/python/triton/experimental/gsan/src/GSanLibrary.cu index f90f47c63783..52e2f422cd82 100644 --- a/python/triton/experimental/gsan/src/GSanLibrary.cu +++ b/python/triton/experimental/gsan/src/GSanLibrary.cu @@ -1,11 +1,16 @@ #include "GSan.h" #include "Hash.cuh" -#include -#include -#include -#include -#include +extern "C" GSAN_DEVICE void __assertfail(const char *assertion, + const char *file, unsigned line, + const char *function, + __SIZE_TYPE__ charSize); + +static GSAN_DEVICE inline void __assert_fail(const char *assertion, + const char *file, unsigned line, + const char *function) { + __assertfail(assertion, file, line, function, sizeof(char)); +} namespace gsan { @@ -14,7 +19,7 @@ struct Location { unsigned line; }; -__device__ const char *getSourceFile(Location loc) { +GSAN_DEVICE const char *getSourceFile(Location loc) { return loc.file == nullptr ? "" : loc.file; } @@ -29,7 +34,9 @@ __device__ const char *getSourceFile(Location loc) { namespace gsan { namespace { -static constexpr uint32_t writerFlag = 1u << 31; +static constexpr uint32_t kWriterFlag = 1u << 31; +static constexpr epoch_t kMaxEpoch = static_cast(~0u); +static constexpr uint16_t kMaxUint16 = static_cast(~0u); enum class AtomicSem : uint8_t { Relaxed = 1, @@ -38,65 +45,67 @@ enum class AtomicSem : uint8_t { AcquireRelease = 4, }; -__device__ void rwLockAcquireRead(uint32_t &lock) { +GSAN_DEVICE void rwLockAcquireRead(uint32_t &lock) { uint32_t old = __scoped_atomic_fetch_add(&lock, 1, __ATOMIC_ACQUIRE, __MEMORY_SCOPE_WRKGRP); - if ((old & writerFlag) == 0) + if ((old & kWriterFlag) == 0) return; do { old = __scoped_atomic_load_n(&lock, __ATOMIC_ACQUIRE, __MEMORY_SCOPE_WRKGRP); - } while ((old & writerFlag) != 0); + } while ((old & kWriterFlag) != 0); } -__device__ void rwLockAcquireWrite(uint32_t &lock) { +GSAN_DEVICE void rwLockAcquireWrite(uint32_t &lock) { uint32_t actual = 0; - while (!__scoped_atomic_compare_exchange_n(&lock, &actual, writerFlag, true, + while (!__scoped_atomic_compare_exchange_n(&lock, &actual, kWriterFlag, true, __ATOMIC_ACQUIRE, __ATOMIC_RELAXED, __MEMORY_SCOPE_WRKGRP)) { actual = 0; } } -__device__ void rwLockReleaseRead(uint32_t &lock) { +GSAN_DEVICE void rwLockReleaseRead(uint32_t &lock) { __scoped_atomic_fetch_sub(&lock, 1, __ATOMIC_RELAXED, __MEMORY_SCOPE_WRKGRP); } -__device__ void rwLockReleaseWrite(uint32_t &lock) { +GSAN_DEVICE void rwLockReleaseWrite(uint32_t &lock) { // Note we don't set 0 as there may be readers who've already // incremented optimistically - __scoped_atomic_fetch_and(&lock, ~writerFlag, __ATOMIC_RELEASE, + __scoped_atomic_fetch_and(&lock, ~kWriterFlag, __ATOMIC_RELEASE, __MEMORY_SCOPE_WRKGRP); } -__device__ inline uintptr_t roundUp(uintptr_t ptr, uintptr_t align) { +GSAN_DEVICE inline uintptr_t roundUp(uintptr_t ptr, uintptr_t align) { return ptr % align == 0 ? ptr : ptr + align - (ptr % align); } -__device__ uint32_t getSmId() { return __nvvm_read_ptx_sreg_smid(); } +GSAN_DEVICE uint32_t getSmId() { return __nvvm_read_ptx_sreg_smid(); } -__device__ uintptr_t getThreadStateStrideBytes(GlobalState *globals) { +GSAN_DEVICE uint32_t getThreadIdxX() { return __nvvm_read_ptx_sreg_tid_x(); } + +GSAN_DEVICE uintptr_t getThreadStateStrideBytes(GlobalState *globals) { auto clocksPerThread = 1u + globals->clockBufferSize; return sizeof(ThreadState) + sizeof(epoch_t) * globals->numThreads * clocksPerThread; } -__device__ thread_id_t getDeviceThreadId(GlobalState *globals, uint32_t smid) { +GSAN_DEVICE thread_id_t getDeviceThreadId(GlobalState *globals, uint32_t smid) { auto globalsBase = static_cast(globals->globalsBase); auto deviceBase = reinterpret_cast(globals); auto deviceIdx = (deviceBase - globalsBase) / kPerDeviceStateStride; return static_cast(deviceIdx * globals->numSms + smid); } -__device__ uintptr_t getThreadStateBaseAddress(uintptr_t globalsAddr) { +GSAN_DEVICE uintptr_t getThreadStateBaseAddress(uintptr_t globalsAddr) { uintptr_t stateBase = globalsAddr; stateBase = roundUp(stateBase + sizeof(GlobalState), alignof(ThreadState)); return stateBase; } -__device__ ThreadState *getThreadStateById(GlobalState *globals, - thread_id_t tid) { +GSAN_DEVICE ThreadState *getThreadStateById(GlobalState *globals, + thread_id_t tid) { uint32_t deviceIdx = tid / globals->numSms; uint32_t smid = tid % globals->numSms; uintptr_t stateBase = static_cast(globals->globalsBase) + @@ -106,7 +115,7 @@ __device__ ThreadState *getThreadStateById(GlobalState *globals, return reinterpret_cast(stateBase + stateStride * smid); } -__device__ ThreadState *getThreadState(GlobalState *globals) { +GSAN_DEVICE ThreadState *getThreadState(GlobalState *globals) { uint32_t smid = getSmId(); uintptr_t stateBase = getThreadStateBaseAddress(reinterpret_cast(globals)); @@ -128,13 +137,13 @@ __device__ ThreadState *getThreadState(GlobalState *globals) { return state; } -__device__ epoch_t *getClockBufferBase(ThreadState *state) { +GSAN_DEVICE epoch_t *getClockBufferBase(ThreadState *state) { auto *globals = getGlobalState(state); return state->vectorClock + globals->numThreads; } -__device__ epoch_t *getClockBufferSlot(ThreadState *state, epoch_t token, - Location loc) { +GSAN_DEVICE epoch_t *getClockBufferSlot(ThreadState *state, epoch_t token, + Location loc) { assert_msg(loc, token != 0, "Invalid GSan clock token"); assert_msg(loc, token <= state->clockBufferHead, "Future GSan clock token"); auto *globals = getGlobalState(state); @@ -144,11 +153,10 @@ __device__ epoch_t *getClockBufferSlot(ThreadState *state, epoch_t token, return getClockBufferBase(state) + slot * globals->numThreads; } -__device__ epoch_t publishClockBuffer(ThreadState *state, Location loc) { +GSAN_DEVICE epoch_t publishClockBuffer(ThreadState *state, Location loc) { auto *globals = getGlobalState(state); uint32_t nextHead = state->clockBufferHead + 1; - assert_msg(loc, nextHead <= std::numeric_limits::max(), - "GSan clock buffer token overflowed"); + assert_msg(loc, nextHead <= kMaxEpoch, "GSan clock buffer token overflowed"); epoch_t *slot = getClockBufferBase(state) + ((nextHead - 1) % globals->clockBufferSize) * globals->numThreads; @@ -159,7 +167,7 @@ __device__ epoch_t publishClockBuffer(ThreadState *state, Location loc) { return static_cast(nextHead); } -__device__ AtomicSem decodeAtomicSem(uint32_t sem) { +GSAN_DEVICE AtomicSem decodeAtomicSem(uint32_t sem) { switch (sem) { case 1: return AtomicSem::Relaxed; @@ -170,11 +178,12 @@ __device__ AtomicSem decodeAtomicSem(uint32_t sem) { case 4: return AtomicSem::AcquireRelease; default: - assert(false || !"Unexpected atomic semantic type"); + __builtin_trap(); + return AtomicSem::Relaxed; } } -__device__ AtomicScope decodeAtomicScope(uint32_t scope) { +GSAN_DEVICE AtomicScope decodeAtomicScope(uint32_t scope) { switch (scope) { case 1: return AtomicScope::GPU; @@ -183,20 +192,21 @@ __device__ AtomicScope decodeAtomicScope(uint32_t scope) { case 3: return AtomicScope::System; default: - assert(false || !"Unexpected atomic scope"); + __builtin_trap(); + return AtomicScope::NonAtomic; } } -__device__ bool hasAcquire(AtomicSem sem) { +GSAN_DEVICE bool hasAcquire(AtomicSem sem) { return sem == AtomicSem::Acquire || sem == AtomicSem::AcquireRelease; } -__device__ bool hasRelease(AtomicSem sem) { +GSAN_DEVICE bool hasRelease(AtomicSem sem) { return sem == AtomicSem::Release || sem == AtomicSem::AcquireRelease; } -__device__ bool scopeCoversPair(AtomicScope scope, thread_id_t lhs, - thread_id_t rhs, GlobalState *globals) { +GSAN_DEVICE bool scopeCoversPair(AtomicScope scope, thread_id_t lhs, + thread_id_t rhs, GlobalState *globals) { switch (scope) { case AtomicScope::CTA: return lhs == rhs; @@ -210,27 +220,26 @@ __device__ bool scopeCoversPair(AtomicScope scope, thread_id_t lhs, return false; } -__device__ bool areAtomicScopesCompatible(AtomicScope lhs, thread_id_t lhsTid, - AtomicScope rhs, thread_id_t rhsTid, - GlobalState *globals) { +GSAN_DEVICE bool areAtomicScopesCompatible(AtomicScope lhs, thread_id_t lhsTid, + AtomicScope rhs, thread_id_t rhsTid, + GlobalState *globals) { if (!isAtomicScope(lhs) || !isAtomicScope(rhs)) return false; return scopeCoversPair(lhs, lhsTid, rhsTid, globals) && scopeCoversPair(rhs, lhsTid, rhsTid, globals); } -__device__ void initThread(GlobalState *globals, Location loc) { +GSAN_DEVICE void initThread(GlobalState *globals, Location loc) { auto *state = getThreadState(globals); - if (threadIdx.x == 0) { + if (getThreadIdxX() == 0) { auto smid = getSmId(); auto tid = getDeviceThreadId(globals, smid); // Preserve the synchronized vector clock from prior launches on this // stream and advance the local epoch for the new kernel entry. auto *clock = state->vectorClock; - assert_msg(loc, clock[tid] != std::numeric_limits::max(), - "Vector clock overflowed"); + assert_msg(loc, clock[tid] != kMaxEpoch, "Vector clock overflowed"); clock[tid] += 1; state->clockBufferDirty = 1; } @@ -241,7 +250,7 @@ struct Range { uintptr_t end; }; -__device__ Range roundRange(Range x) { +GSAN_DEVICE Range roundRange(Range x) { // Round start down to shadow granularity x.start = x.start - (x.start % kShadowMemGranularityBytes); // Round end up to shadow granularity @@ -250,7 +259,7 @@ __device__ Range roundRange(Range x) { return x; } -__device__ ShadowCell *acquireShadow(uintptr_t shadowAddr) { +GSAN_DEVICE ShadowCell *acquireShadow(uintptr_t shadowAddr) { auto cell = reinterpret_cast(shadowAddr); uint16_t actual = 0; @@ -262,21 +271,20 @@ __device__ ShadowCell *acquireShadow(uintptr_t shadowAddr) { return cell; } -__device__ void releaseShadow(ShadowCell *cell) { +GSAN_DEVICE void releaseShadow(ShadowCell *cell) { __scoped_atomic_store_n(&cell->lock, 0, __ATOMIC_RELEASE, __MEMORY_SCOPE_SYSTEM); } -__device__ epoch_t appendClockBufferSnapshot(ThreadState *state, - const epoch_t *snapshot, - Location loc) { +GSAN_DEVICE epoch_t appendClockBufferSnapshot(ThreadState *state, + const epoch_t *snapshot, + Location loc) { auto *globals = getGlobalState(state); assert_msg(loc, globals->clockBufferSize != 0, "GSan clock buffer size must be non-zero"); uint32_t curHead = state->clockBufferHead; uint32_t nextHead = curHead + 1; - assert_msg(loc, nextHead <= std::numeric_limits::max(), - "GSan clock buffer token overflowed"); + assert_msg(loc, nextHead <= kMaxEpoch, "GSan clock buffer token overflowed"); epoch_t *slot = getClockBufferBase(state) + (nextHead % globals->clockBufferSize) * globals->numThreads; for (int i = 0; i < globals->numThreads; ++i) @@ -285,7 +293,8 @@ __device__ epoch_t appendClockBufferSnapshot(ThreadState *state, return static_cast(nextHead); } -__device__ epoch_t publishCurrentVectorClock(ThreadState *state, Location loc) { +GSAN_DEVICE epoch_t publishCurrentVectorClock(ThreadState *state, + Location loc) { if (state->clockBufferDirty) { auto token = appendClockBufferSnapshot(state, state->vectorClock, loc); state->clockBufferDirty = 0; @@ -294,18 +303,18 @@ __device__ epoch_t publishCurrentVectorClock(ThreadState *state, Location loc) { return state->clockBufferHead; } -__device__ const epoch_t *getSnapshotForWrite(ThreadState *state, - const ScalarClock &write, - Location loc) { +GSAN_DEVICE const epoch_t *getSnapshotForWrite(ThreadState *state, + const ScalarClock &write, + Location loc) { if (!write.isRelease) return nullptr; auto *writerState = getThreadStateById(getGlobalState(state), write.threadId); return getClockBufferSlot(writerState, write.epoch, loc); } -__device__ epoch_t propagateClockBufferSnapshot(ThreadState *state, - const ScalarClock &write, - Location loc) { +GSAN_DEVICE epoch_t propagateClockBufferSnapshot(ThreadState *state, + const ScalarClock &write, + Location loc) { auto *snapshot = getSnapshotForWrite(state, write, loc); assert_msg(loc, snapshot != nullptr, "Invalid GSan propagated clock token"); auto token = appendClockBufferSnapshot(state, snapshot, loc); @@ -313,16 +322,16 @@ __device__ epoch_t propagateClockBufferSnapshot(ThreadState *state, return token; } -__device__ void incrementThreadEpoch(ThreadState *state, Location loc) { +GSAN_DEVICE void incrementThreadEpoch(ThreadState *state, Location loc) { auto tid = state->threadId; auto *clock = state->vectorClock; - assert_msg(loc, clock[tid] != std::numeric_limits::max(), - "Vector clock overflowed"); + assert_msg(loc, clock[tid] != kMaxEpoch, "Vector clock overflowed"); clock[tid] += 1; state->clockBufferDirty = 1; } -__device__ bool dominatesSnapshot(ThreadState *state, const epoch_t *snapshot) { +GSAN_DEVICE bool dominatesSnapshot(ThreadState *state, + const epoch_t *snapshot) { auto *globals = getGlobalState(state); for (int i = 0; i < globals->numThreads; ++i) { if (state->vectorClock[i] < snapshot[i]) @@ -331,8 +340,8 @@ __device__ bool dominatesSnapshot(ThreadState *state, const epoch_t *snapshot) { return true; } -__device__ bool clockHappensBefore(ThreadState *state, const ScalarClock &clock, - Location loc) { +GSAN_DEVICE bool clockHappensBefore(ThreadState *state, + const ScalarClock &clock, Location loc) { if (clock.epoch == 0) return true; if (const epoch_t *snapshot = getSnapshotForWrite(state, clock, loc)) @@ -340,10 +349,10 @@ __device__ bool clockHappensBefore(ThreadState *state, const ScalarClock &clock, return state->vectorClock[clock.threadId] >= clock.epoch; } -__device__ void assertOrderedOrCompatible(ThreadState *state, - AtomicScope currentScope, - const ScalarClock &prior, - Location loc, const char *message) { +GSAN_DEVICE void assertOrderedOrCompatible(ThreadState *state, + AtomicScope currentScope, + const ScalarClock &prior, + Location loc, const char *message) { if (prior.epoch == 0) return; if (isAtomicScope(currentScope) && @@ -354,8 +363,8 @@ __device__ void assertOrderedOrCompatible(ThreadState *state, assert_msg(loc, clockHappensBefore(state, prior, loc), message); } -__device__ void maybeMergeAcquire(ThreadState *state, AtomicScope currentScope, - const ScalarClock &prior, Location loc) { +GSAN_DEVICE void maybeMergeAcquire(ThreadState *state, AtomicScope currentScope, + const ScalarClock &prior, Location loc) { if (!prior.isRelease) return; if (!areAtomicScopesCompatible(currentScope, state->threadId, prior.scope, @@ -375,20 +384,20 @@ __device__ void maybeMergeAcquire(ThreadState *state, AtomicScope currentScope, state->clockBufferDirty = 1; } -__device__ ScalarClock makeScalarClock(ThreadState *state, AtomicScope scope) { +GSAN_DEVICE ScalarClock makeScalarClock(ThreadState *state, AtomicScope scope) { auto tid = state->threadId; return ScalarClock{state->vectorClock[tid], tid, scope, false}; } -__device__ ScalarClock makePublishedClock(ThreadState *state, AtomicScope scope, - epoch_t token) { +GSAN_DEVICE ScalarClock makePublishedClock(ThreadState *state, + AtomicScope scope, epoch_t token) { return ScalarClock{token, state->threadId, scope, true}; } -__device__ void recordRead(ThreadState *state, ShadowCell *cell, - AtomicScope scope) { +GSAN_DEVICE void recordRead(ThreadState *state, ShadowCell *cell, + AtomicScope scope) { auto numReads = cell->numReads; - if (numReads < std::numeric_limitsnumReads)>::max()) + if (numReads < kMaxUint16) ++cell->numReads; auto scalarClock = makeScalarClock(state, scope); @@ -410,7 +419,7 @@ __device__ void recordRead(ThreadState *state, ShadowCell *cell, } } -__device__ void doWrite(ThreadState *state, ShadowCell *cell, Location loc) { +GSAN_DEVICE void doWrite(ThreadState *state, ShadowCell *cell, Location loc) { // Check WAR for (int iRead = 0; iRead < ShadowCell::kReadClockSize; ++iRead) { assertOrderedOrCompatible(state, AtomicScope::NonAtomic, @@ -424,8 +433,8 @@ __device__ void doWrite(ThreadState *state, ShadowCell *cell, Location loc) { cell->writeClock = makeScalarClock(state, AtomicScope::NonAtomic); } -__device__ void writeRange(ThreadState *state, uintptr_t write_addr, int nBytes, - Location loc) { +GSAN_DEVICE void writeRange(ThreadState *state, uintptr_t write_addr, + int nBytes, Location loc) { auto range = roundRange(Range{write_addr, write_addr + nBytes}); auto reserveBase = state->reserveBase; @@ -445,8 +454,8 @@ __device__ void writeRange(ThreadState *state, uintptr_t write_addr, int nBytes, } // Handles tl.store(ptrs, values, mask) -__device__ void tensorStore(ThreadState *state, const char *stackPtr, - int nElems, int bytesPerElem, Location loc) { +GSAN_DEVICE void tensorStore(ThreadState *state, const char *stackPtr, + int nElems, int bytesPerElem, Location loc) { const uintptr_t *ptrsPtr = reinterpret_cast(stackPtr); const char *maskPtr = stackPtr + nElems * sizeof(uintptr_t); for (int i = 0; i < nElems; ++i) { @@ -457,14 +466,14 @@ __device__ void tensorStore(ThreadState *state, const char *stackPtr, } } -__device__ void doRead(ThreadState *state, ShadowCell *cell, Location loc) { +GSAN_DEVICE void doRead(ThreadState *state, ShadowCell *cell, Location loc) { assertOrderedOrCompatible(state, AtomicScope::NonAtomic, cell->writeClock, loc, "Read after write race detected"); recordRead(state, cell, AtomicScope::NonAtomic); } -__device__ void readRange(ThreadState *state, uintptr_t read_addr, int nBytes, - Location loc) { +GSAN_DEVICE void readRange(ThreadState *state, uintptr_t read_addr, int nBytes, + Location loc) { auto range = roundRange(Range{read_addr, read_addr + nBytes}); auto reserveBase = state->reserveBase; @@ -484,8 +493,8 @@ __device__ void readRange(ThreadState *state, uintptr_t read_addr, int nBytes, } // Handles tl.load(ptrs, mask) -__device__ void tensorLoad(ThreadState *state, const char *stackPtr, int nElems, - int bytesPerElem, Location loc) { +GSAN_DEVICE void tensorLoad(ThreadState *state, const char *stackPtr, + int nElems, int bytesPerElem, Location loc) { const uintptr_t *ptrsPtr = reinterpret_cast(stackPtr); const char *maskPtr = stackPtr + nElems * sizeof(uintptr_t); for (int i = 0; i < nElems; ++i) { @@ -496,17 +505,17 @@ __device__ void tensorLoad(ThreadState *state, const char *stackPtr, int nElems, } } -__device__ void initAtomicEventState(AtomicEventState *event) { +GSAN_DEVICE void initAtomicEventState(AtomicEventState *event) { event->threadState = nullptr; event->numCells = 0; for (auto &cell : event->cells) cell = nullptr; } -__device__ void acquireAtomicShadowRange(ThreadState *state, - AtomicEventState *event, - uintptr_t address, int nBytes, - Location loc) { +GSAN_DEVICE void acquireAtomicShadowRange(ThreadState *state, + AtomicEventState *event, + uintptr_t address, int nBytes, + Location loc) { auto range = roundRange(Range{address, address + nBytes}); auto reserveBase = state->reserveBase; uint8_t numCells = 0; @@ -534,7 +543,7 @@ __device__ void acquireAtomicShadowRange(ThreadState *state, } } -__device__ void releaseAtomicShadowRange(AtomicEventState *event) { +GSAN_DEVICE void releaseAtomicShadowRange(AtomicEventState *event) { if (event->threadState == nullptr) return; for (uint8_t i = 0; i < event->numCells; ++i) @@ -543,10 +552,11 @@ __device__ void releaseAtomicShadowRange(AtomicEventState *event) { initAtomicEventState(event); } -__device__ void beginAtomicAccess(GlobalState *globals, AtomicEventState *event, - bool pred, uintptr_t address, int nBytes, - uint32_t semRaw, uint32_t scopeRaw, - Location loc) { +GSAN_DEVICE void beginAtomicAccess(GlobalState *globals, + AtomicEventState *event, bool pred, + uintptr_t address, int nBytes, + uint32_t semRaw, uint32_t scopeRaw, + Location loc) { initAtomicEventState(event); if (!pred) return; @@ -573,9 +583,9 @@ __device__ void beginAtomicAccess(GlobalState *globals, AtomicEventState *event, } } -__device__ void endAtomicAccess(AtomicEventState *event, bool pred, - bool didWrite, uint32_t semRaw, - uint32_t scopeRaw, Location loc) { +GSAN_DEVICE void endAtomicAccess(AtomicEventState *event, bool pred, + bool didWrite, uint32_t semRaw, + uint32_t scopeRaw, Location loc) { if (!pred || event->threadState == nullptr) return; @@ -621,7 +631,7 @@ __device__ void endAtomicAccess(AtomicEventState *event, bool pred, } // namespace } // namespace gsan -extern "C" __device__ void +extern "C" GSAN_DEVICE void __triton_gsan_load_tensor(void *globalState, const char *stackPtr, int numElems, int bytesPerElem, const char *file, unsigned line) { auto loc = gsan::Location{file, line}; @@ -630,13 +640,13 @@ __triton_gsan_load_tensor(void *globalState, const char *stackPtr, int numElems, gsan::tensorLoad(threadState, stackPtr, numElems, bytesPerElem, loc); } -extern "C" __device__ void __triton_gsan_init(void *globalState, - const char *file, unsigned line) { +extern "C" GSAN_DEVICE void +__triton_gsan_init(void *globalState, const char *file, unsigned line) { auto loc = gsan::Location{file, line}; gsan::initThread(reinterpret_cast(globalState), loc); } -extern "C" __device__ void +extern "C" GSAN_DEVICE void __triton_gsan_store_tensor(void *globalState, const char *stackPtr, int numElems, int bytesPerElem, const char *file, unsigned line) { @@ -646,10 +656,9 @@ __triton_gsan_store_tensor(void *globalState, const char *stackPtr, gsan::tensorStore(threadState, stackPtr, numElems, bytesPerElem, loc); } -extern "C" __device__ void -__triton_gsan_atomic_begin_scalar(void *globalState, void *eventState, int pred, - uintptr_t address, int bytesPerElem, int sem, - int scope, const char *file, unsigned line) { +extern "C" GSAN_DEVICE void __triton_gsan_atomic_begin_scalar( + void *globalState, void *eventState, int pred, gsan::uintptr_t address, + int bytesPerElem, int sem, int scope, const char *file, unsigned line) { auto loc = gsan::Location{file, line}; gsan::beginAtomicAccess( reinterpret_cast(globalState), @@ -657,7 +666,7 @@ __triton_gsan_atomic_begin_scalar(void *globalState, void *eventState, int pred, address, bytesPerElem, sem, scope, loc); } -extern "C" __device__ void +extern "C" GSAN_DEVICE void __triton_gsan_atomic_end_scalar(void *eventState, int pred, int didWrite, int sem, int scope, const char *file, unsigned line) { diff --git a/python/triton/experimental/gsan/src/Hash.cuh b/python/triton/experimental/gsan/src/Hash.cuh index 3f0a55b434c2..2be887948b21 100644 --- a/python/triton/experimental/gsan/src/Hash.cuh +++ b/python/triton/experimental/gsan/src/Hash.cuh @@ -1,20 +1,22 @@ -#include +#pragma once + +#include "GSan.h" namespace gsan { namespace { -__device__ uint32_t rotl32(uint32_t x, int r) { +GSAN_DEVICE uint32_t rotl32(uint32_t x, int r) { return (x << r) | (x >> (32 - r)); } -__device__ uint32_t value_hash32(uint32_t x) { +GSAN_DEVICE uint32_t value_hash32(uint32_t x) { x *= 0xcc9e2d51u; x = rotl32(x, 15); x *= 0x1b873593u; return x; } -__device__ uint32_t hash_finalize32(uint32_t h) { +GSAN_DEVICE uint32_t hash_finalize32(uint32_t h) { h ^= h >> 16; h *= 0x85ebca6bu; h ^= h >> 13; @@ -23,14 +25,14 @@ __device__ uint32_t hash_finalize32(uint32_t h) { return h; } -__device__ uint32_t hash_combine32(uint32_t h, uint32_t v) { +GSAN_DEVICE uint32_t hash_combine32(uint32_t h, uint32_t v) { h ^= value_hash32(v); h = rotl32(h, 13); h = h * 5u + 0xe6546b64u; return h; } -__device__ uint32_t hash2x32(uint32_t a, uint32_t b, uint32_t seed) { +GSAN_DEVICE uint32_t hash2x32(uint32_t a, uint32_t b, uint32_t seed) { uint32_t h = seed; h = hash_combine32(h, a); h = hash_combine32(h, b); diff --git a/third_party/nvidia/CMakeLists.txt b/third_party/nvidia/CMakeLists.txt index 339f7d323193..ff638103d921 100644 --- a/third_party/nvidia/CMakeLists.txt +++ b/third_party/nvidia/CMakeLists.txt @@ -6,9 +6,7 @@ if(TRITON_BUILD_PYTHON_MODULE) set(GSAN_RUNTIME_SRC "${triton_SOURCE_DIR}/python/triton/experimental/gsan/src/GSanLibrary.cu") set(GSAN_RUNTIME_HDRS "${triton_SOURCE_DIR}/python/triton/experimental/gsan/src/GSan.h" - "${triton_SOURCE_DIR}/python/triton/experimental/gsan/src/Hash.cuh" - "${CMAKE_CURRENT_SOURCE_DIR}/clang_cuda_shims/assert.h" - "${CMAKE_CURRENT_SOURCE_DIR}/clang_cuda_shims/curand_mtgp32_kernel.h") + "${triton_SOURCE_DIR}/python/triton/experimental/gsan/src/Hash.cuh") set(GSAN_RUNTIME_IR "${CMAKE_CURRENT_SOURCE_DIR}/backend/lib/gsan.ll") find_program(TRITON_GSAN_CLANGXX NAMES clang++ @@ -18,34 +16,8 @@ if(TRITON_BUILD_PYTHON_MODULE) message(FATAL_ERROR "clang++ is required to build gsan.ll") endif() - set(GSAN_RUNTIME_PLATFORM_FLAGS - "--cuda-path=${CMAKE_CURRENT_SOURCE_DIR}/backend") - if(APPLE AND CMAKE_OSX_SYSROOT) - list(APPEND GSAN_RUNTIME_PLATFORM_FLAGS -isysroot "${CMAKE_OSX_SYSROOT}") - endif() - - set(GSAN_RUNTIME_TOOLCHAIN_FLAGS) - # Detect the right include directory for C++ stdlib. - # The way this is currently done is GCC specific. - # TODO: Support libc++-based systems. - set(GSAN_HOST_GNU_CXX "${CMAKE_CXX_COMPILER}") - if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - # Force g++ to be used for querying -print-file-name on libstdc++.so - # clang++ does not print an absolute path in this case. - find_program(GSAN_HOST_GNU_CXX NAMES g++ c++) - endif() - if(GSAN_HOST_GNU_CXX) - execute_process( - COMMAND "${GSAN_HOST_GNU_CXX}" -print-file-name=libstdc++.so - OUTPUT_VARIABLE LIBSTDCXX_PATH - OUTPUT_STRIP_TRAILING_WHITESPACE - ) - if(IS_ABSOLUTE "${LIBSTDCXX_PATH}") - get_filename_component(GCC_INSTALL_DIR "${LIBSTDCXX_PATH}" DIRECTORY) - list(APPEND GSAN_RUNTIME_TOOLCHAIN_FLAGS "--gcc-install-dir=${GCC_INSTALL_DIR}") - endif() - endif() - + # Keep the device runtime freestanding: do not pull in Clang's CUDA wrapper, + # vendored CUDA headers, host SDK headers, or the host C++ standard library. add_custom_command( OUTPUT "${GSAN_RUNTIME_IR}" COMMAND "${CMAKE_COMMAND}" -E make_directory @@ -54,16 +26,12 @@ if(TRITON_BUILD_PYTHON_MODULE) -x cuda -std=c++17 -O3 -S -emit-llvm --cuda-device-only + -nocudainc -nocudalib - --no-cuda-version-check + -nostdinc -fno-exceptions -fcuda-flush-denormals-to-zero --cuda-gpu-arch=sm_80 - -Wno-unknown-cuda-version - ${GSAN_RUNTIME_TOOLCHAIN_FLAGS} - ${GSAN_RUNTIME_PLATFORM_FLAGS} - -isystem "${CMAKE_CURRENT_SOURCE_DIR}/clang_cuda_shims" - -isystem "${CMAKE_CURRENT_SOURCE_DIR}/backend/include" "${GSAN_RUNTIME_SRC}" -o "${GSAN_RUNTIME_IR}" DEPENDS "${GSAN_RUNTIME_SRC}" ${GSAN_RUNTIME_HDRS} COMMENT "Building GSan runtime" diff --git a/third_party/nvidia/clang_cuda_shims/assert.h b/third_party/nvidia/clang_cuda_shims/assert.h deleted file mode 100644 index 7076dcc2eda5..000000000000 --- a/third_party/nvidia/clang_cuda_shims/assert.h +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#if defined(__APPLE__) && defined(__CUDA__) - -#undef assert - -#ifdef NDEBUG -#define assert(e) ((void)0) -#else -#ifdef __FILE_NAME__ -#define __ASSERT_FILE_NAME __FILE_NAME__ -#else -#define __ASSERT_FILE_NAME __FILE__ -#endif - -#ifdef __cplusplus -#define assert(e) \ - (static_cast(__builtin_expect(!!(e), 1) \ - ? 0 \ - : (__assert_fail(#e, __ASSERT_FILE_NAME, __LINE__, \ - __PRETTY_FUNCTION__), \ - 0))) -#else -#define assert(e) \ - ((void)(__builtin_expect(!!(e), 1) \ - ? 0 \ - : (__assert_fail(#e, __ASSERT_FILE_NAME, __LINE__, __func__), \ - 0))) -#endif - -#endif - -#include <_static_assert.h> - -#else - -#include_next - -#endif diff --git a/third_party/nvidia/clang_cuda_shims/curand_mtgp32_kernel.h b/third_party/nvidia/clang_cuda_shims/curand_mtgp32_kernel.h deleted file mode 100644 index d9d8dba6e3c3..000000000000 --- a/third_party/nvidia/clang_cuda_shims/curand_mtgp32_kernel.h +++ /dev/null @@ -1,6 +0,0 @@ -#pragma once - -// Clang's CUDA runtime wrapper force-includes this header to work around -// conflicting builtin-variable redeclarations in the real CURAND header. -// GSan does not use CURAND, so an empty shim is sufficient for device-only -// LLVM IR generation against the vendored CUDA headers.