diff --git a/src/tl_templates/cuda/atomic.h b/src/tl_templates/cuda/atomic.h index a573886b3..0bbc41711 100644 --- a/src/tl_templates/cuda/atomic.h +++ b/src/tl_templates/cuda/atomic.h @@ -12,7 +12,11 @@ using cutlass::bfloat16_t; using cutlass::half_t; #define TL_DEVICE __forceinline__ __device__ - +#define TL_NOT_IMPLEMENTED() \ + { \ + printf("%s not implemented\n", __PRETTY_FUNCTION__); \ + asm volatile("brkpt;\n"); \ + } template struct normalize_atomic_type { using type = T; }; @@ -63,8 +67,12 @@ TL_DEVICE void AtomicMax(T1 &ref, T2 val, } } } else { +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(*address); aref.fetch_max(cuda_cast(val), cuda::memory_order(memory_order)); +#else + TL_NOT_IMPLEMENTED(); +#endif } } @@ -89,9 +97,13 @@ TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val, } return static_cast(*reinterpret_cast(&old_val_ushort)); } else { +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(*address); return static_cast( aref.fetch_max(cuda_cast(val), cuda::memory_order(memory_order))); +#else + TL_NOT_IMPLEMENTED(); +#endif } } @@ -117,8 +129,13 @@ TL_DEVICE void AtomicMin(T1 &ref, T2 val, } } } else { +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(*address); - aref.fetch_min(cuda_cast(val), cuda::memory_order(memory_order)); + return static_cast( + aref.fetch_min(cuda_cast(val), cuda::memory_order(memory_order))); +#else + TL_NOT_IMPLEMENTED(); +#endif } } @@ -143,9 +160,13 @@ TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val, } return static_cast(*reinterpret_cast(&old_val_ushort)); } else { +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(*address); return static_cast( aref.fetch_min(cuda_cast(val), cuda::memory_order(memory_order))); +#else + TL_NOT_IMPLEMENTED(); +#endif } } @@ -216,8 +237,12 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val, } } } else { +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(*address); aref.fetch_add(cuda_cast(val), cuda::memory_order(memory_order)); +#else + TL_NOT_IMPLEMENTED(); +#endif } } @@ -290,9 +315,13 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val, } } } else { +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(*address); return static_cast( aref.fetch_add(cuda_cast(val), cuda::memory_order(memory_order))); +#else + TL_NOT_IMPLEMENTED(); +#endif } } @@ -618,13 +647,21 @@ AtomicAddx4Ret(float *ref, float *val, #endif template TL_DEVICE T AtomicLoad(T &ref, int memory_order) { +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(ref); return aref.load(cuda::memory_order(memory_order)); +#else + TL_NOT_IMPLEMENTED(); +#endif } template TL_DEVICE void AtomicStore(T1 &ref, T2 value, int memory_order) { using NT1 = typename normalize_atomic_type::type; +#if CUDART_VERSION >= 11080 cuda::atomic_ref aref(ref); aref.store(cuda_cast(value), cuda::memory_order(memory_order)); +#else + TL_NOT_IMPLEMENTED(); +#endif } diff --git a/src/tl_templates/cuda/debug.h b/src/tl_templates/cuda/debug.h index e8976874c..2724a814c 100644 --- a/src/tl_templates/cuda/debug.h +++ b/src/tl_templates/cuda/debug.h @@ -1,6 +1,9 @@ #pragma once +#if __CUDA_ARCH_LIST__ >= 890 #include "./cuda_fp8.h" +#endif + #include "common.h" #ifndef __CUDACC_RTC__ @@ -117,6 +120,7 @@ __device__ void debug_print_var(const char *msg, double var) { threadIdx.z, var); } +#if __CUDA_ARCH_LIST__ >= 890 // Specialization for fp8_e4_t type template <> __device__ void debug_print_var(const char *msg, fp8_e4_t var) { @@ -137,6 +141,8 @@ __device__ void debug_print_var(const char *msg, fp8_e5_t var) { threadIdx.z, (float)var); } +#endif + // Template declaration for device-side debug printing (buffer only) template __device__ void debug_print_buffer_value(const char *msg, const char *buf_name, @@ -242,6 +248,7 @@ __device__ void debug_print_buffer_value(const char *msg, } // Specialization for fp8_e4_t type +#if __CUDA_ARCH_LIST__ >= 890 template <> __device__ void debug_print_buffer_value(const char *msg, const char *buf_name, @@ -263,6 +270,8 @@ __device__ void debug_print_buffer_value(const char *msg, threadIdx.z, buf_name, index, (float)var); } +#endif + // Specialization for int16 type template <> __device__ void debug_print_buffer_value(const char *msg, diff --git a/src/tl_templates/cuda/gemm_mma.h b/src/tl_templates/cuda/gemm_mma.h index 712831732..25841a3b6 100644 --- a/src/tl_templates/cuda/gemm_mma.h +++ b/src/tl_templates/cuda/gemm_mma.h @@ -8,7 +8,6 @@ #include #include "common.h" -#include "cuda_fp8.h" #include "intrin.h" namespace cute::tl_mma {