Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions src/tl_templates/cuda/copy_sm100.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,72 +8,168 @@ namespace tl {
// 256-bit load for longlong4
__device__ __forceinline__ longlong4 ld_global_256(const longlong4 *ptr) {
longlong4 ret;
#if (__CUDACC_VER_MAJOR__ > 12) || \
(__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)
asm volatile("ld.global.v4.s64 {%0, %1, %2, %3}, [%4];"
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
: "l"(ptr));
#else
// CUDA < 12.9 fallback: two 128-bit loads (may have performance regression)
const char *base = reinterpret_cast<const char *>(ptr);
asm volatile("ld.global.v2.s64 {%0, %1}, [%2];"
: "=l"(ret.x), "=l"(ret.y)
: "l"(base));
asm volatile("ld.global.v2.s64 {%0, %1}, [%2];"
: "=l"(ret.z), "=l"(ret.w)
: "l"(base + 16));
#endif
return ret;
}

// 256-bit load for ulonglong4
__device__ __forceinline__ ulonglong4 ld_global_256(const ulonglong4 *ptr) {
ulonglong4 ret;
#if (__CUDACC_VER_MAJOR__ > 12) || \
(__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)
asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
: "l"(ptr));
#else
// CUDA < 12.9 fallback: two 128-bit loads (may have performance regression)
const char *base = reinterpret_cast<const char *>(ptr);
asm volatile("ld.global.v2.u64 {%0, %1}, [%2];"
: "=l"(ret.x), "=l"(ret.y)
: "l"(base));
asm volatile("ld.global.v2.u64 {%0, %1}, [%2];"
: "=l"(ret.z), "=l"(ret.w)
: "l"(base + 16));
#endif
return ret;
}

// Generic 256-bit load for FP8 types (returns ulonglong4)
template <typename T>
__device__ __forceinline__ ulonglong4 ld_global_256(const T *ptr) {
ulonglong4 ret;
#if (__CUDACC_VER_MAJOR__ > 12) || \
(__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)
asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
: "l"(ptr));
#else
// CUDA < 12.9 fallback: two 128-bit loads (may have performance regression)
const char *base = reinterpret_cast<const char *>(ptr);
asm volatile("ld.global.v2.u64 {%0, %1}, [%2];"
: "=l"(ret.x), "=l"(ret.y)
: "l"(base));
asm volatile("ld.global.v2.u64 {%0, %1}, [%2];"
: "=l"(ret.z), "=l"(ret.w)
: "l"(base + 16));
#endif
return ret;
}

// 256-bit store for longlong4
__device__ __forceinline__ void st_global_256(longlong4 *ptr, longlong4 &val) {
#if (__CUDACC_VER_MAJOR__ > 12) || \
(__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)
asm volatile("st.global.v4.s64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
#else
// CUDA < 12.9 fallback: two 128-bit stores (may have performance regression)
char *base = reinterpret_cast<char *>(ptr);
asm volatile("st.global.v2.s64 [%0], {%1, %2};"
:
: "l"(base), "l"(val.x), "l"(val.y));
asm volatile("st.global.v2.s64 [%0], {%1, %2};"
:
: "l"(base + 16), "l"(val.z), "l"(val.w));
#endif
}

// 256-bit store for ulonglong4 with non-const reference
__device__ __forceinline__ void st_global_256(ulonglong4 *ptr,
ulonglong4 &val) {
#if (__CUDACC_VER_MAJOR__ > 12) || \
(__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
#else
// CUDA < 12.9 fallback: two 128-bit stores (may have performance regression)
char *base = reinterpret_cast<char *>(ptr);
asm volatile("st.global.v2.u64 [%0], {%1, %2};"
:
: "l"(base), "l"(val.x), "l"(val.y));
asm volatile("st.global.v2.u64 [%0], {%1, %2};"
:
: "l"(base + 16), "l"(val.z), "l"(val.w));
#endif
}

// 256-bit store for ulonglong4 with const reference
// must be const &val, otherwise the compiler will generate a temporary variable
// and compilation will fail if we have st_global_256(ptr, ld_global_256(ptr))
__device__ __forceinline__ void st_global_256(ulonglong4 *ptr,
const ulonglong4 &val) {
#if (__CUDACC_VER_MAJOR__ > 12) || \
(__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
#else
// CUDA < 12.9 fallback: two 128-bit stores (may have performance regression)
char *base = reinterpret_cast<char *>(ptr);
asm volatile("st.global.v2.u64 [%0], {%1, %2};"
:
: "l"(base), "l"(val.x), "l"(val.y));
asm volatile("st.global.v2.u64 [%0], {%1, %2};"
:
: "l"(base + 16), "l"(val.z), "l"(val.w));
#endif
}

// Generic 256-bit store for FP8 types
template <typename T>
__device__ __forceinline__ void st_global_256(T *ptr, const ulonglong4 &val) {
#if (__CUDACC_VER_MAJOR__ > 12) || \
(__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
#else
// CUDA < 12.9 fallback: two 128-bit stores (may have performance regression)
char *base = reinterpret_cast<char *>(ptr);
asm volatile("st.global.v2.u64 [%0], {%1, %2};"
:
: "l"(base), "l"(val.x), "l"(val.y));
asm volatile("st.global.v2.u64 [%0], {%1, %2};"
:
: "l"(base + 16), "l"(val.z), "l"(val.w));
#endif
}

// Generic 256-bit store for FP8 types with non-const reference
template <typename T>
__device__ __forceinline__ void st_global_256(T *ptr, T &val) {
ulonglong4 &val_u64 = *((ulonglong4 *)&val);
#if (__CUDACC_VER_MAJOR__ > 12) || \
(__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val_u64.x), "l"(val_u64.y), "l"(val_u64.z),
"l"(val_u64.w));
#else
// CUDA < 12.9 fallback: two 128-bit stores (may have performance regression)
char *base = reinterpret_cast<char *>(ptr);
asm volatile("st.global.v2.u64 [%0], {%1, %2};"
:
: "l"(base), "l"(val_u64.x), "l"(val_u64.y));
asm volatile("st.global.v2.u64 [%0], {%1, %2};"
:
: "l"(base + 16), "l"(val_u64.z), "l"(val_u64.w));
#endif
}

__device__ __forceinline__ unsigned long long
Expand Down
Loading