From 55727bd0824d185e1c9142de71ebb1f53c7c096c Mon Sep 17 00:00:00 2001 From: cmikeh2 Date: Wed, 28 Sep 2022 16:03:28 +0000 Subject: [PATCH] Add predicated global load --- csrc/includes/memory_access_utils.h | 255 ++++++++++++++++++++++++++++ 1 file changed, 255 insertions(+) diff --git a/csrc/includes/memory_access_utils.h b/csrc/includes/memory_access_utils.h index a3f858f1f870..6d7759e28233 100644 --- a/csrc/includes/memory_access_utils.h +++ b/csrc/includes/memory_access_utils.h @@ -25,6 +25,9 @@ enum class StorePolicy { template __device__ __forceinline__ void load_global(void* dst, const void* src); +template +__device__ __forceinline__ void load_global(void* dst, const void* src, bool do_access); + // Shared accesses have no cache policy template __device__ __forceinline__ void load_shared(void* dst, const void* src); @@ -98,6 +101,36 @@ __device__ __forceinline__ void load_global<16>(void* dst, const void* src) #endif } +template <> +__device__ __forceinline__ void load_global<16>(void* dst, const void* src, bool do_access) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %5, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\tmov.b32 %2, 0;\n" + "\tmov.b32 %3, 0;\n" + "\t@p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src), "r"((int)do_access)); +#else + const uint4* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + data[0].z = 0; + data[0].w = 0; + } +#endif +} + template <> __device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* dst, const void* src) { @@ -112,6 +145,38 @@ __device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* d #endif } +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* dst, + const void* src, + bool do_access) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %5, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\tmov.b32 %2, 0;\n" + "\tmov.b32 %3, 0;\n" + "\t@p ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src), "r"((int)do_access)); +#else + const uint4* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + data[0].z = 0; + data[0].w = 0; + } +#endif +} + template <> __device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void* dst, const void* src) @@ -127,6 +192,38 @@ __device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void #endif } +template <> +__device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void* dst, + const void* src, + bool do_access) +{ + uint4* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %5, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\tmov.b32 %2, 0;\n" + "\tmov.b32 %3, 0;\n" + "\t@p ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w) + : "l"(src), "r"((int)do_access)); +#else + const uint4* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + data[0].z = 0; + data[0].w = 0; + } +#endif +} + template <> __device__ __forceinline__ void load_global<8>(void* dst, const void* src) { @@ -141,6 +238,32 @@ __device__ __forceinline__ void load_global<8>(void* dst, const void* src) #endif } +template <> +__device__ __forceinline__ void load_global<8>(void* dst, const void* src, bool do_access) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %3, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\t@p ld.global.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src), "r"((int)do_access)); +#else + const uint2* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + } +#endif +} + template <> __device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* dst, const void* src) { @@ -155,6 +278,34 @@ __device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* ds #endif } +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* dst, + const void* src, + bool do_access) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %3, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\t@p ld.global.cg.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src), "r"((int)do_access)); +#else + const uint2* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + } +#endif +} + template <> __device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* dst, const void* src) @@ -170,6 +321,34 @@ __device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* #endif } +template <> +__device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* dst, + const void* src, + bool do_access) +{ + uint2* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %3, 0;\n" + "\tmov.b32 %0, 0;\n" + "\tmov.b32 %1, 0;\n" + "\t@p ld.global.cs.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y) + : "l"(src), "r"((int)do_access)); +#else + const uint2* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0].x = 0; + data[0].y = 0; + } +#endif +} + template <> __device__ __forceinline__ void load_global<4>(void* dst, const void* src) { @@ -182,6 +361,30 @@ __device__ __forceinline__ void load_global<4>(void* dst, const void* src) #endif } +template <> +__device__ __forceinline__ void load_global<4>(void* dst, const void* src, bool do_access) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.b32 %0, 0;\n" + "\t@p ld.global.u32 {%0}, [%1];\n" + "}\n" + : "=r"(data[0]) + : "l"(src), "r"((int)do_access)); +#else + const int32_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + template <> __device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* dst, const void* src) { @@ -194,6 +397,32 @@ __device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* ds #endif } +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* dst, + const void* src, + bool do_access) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.b32 %0, 0;\n" + "\t@p ld.global.cg.u32 {%0}, [%1];\n" + "}\n" + : "=r"(data[0]) + : "l"(src), "r"((int)do_access)); +#else + const int32_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + template <> __device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* dst, const void* src) @@ -207,6 +436,32 @@ __device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* #endif } +template <> +__device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* dst, + const void* src, + bool do_access) +{ + int32_t* data = reinterpret_cast(dst); +#ifdef PTX_AVAILABLE + asm volatile( + "{\n" + "\t.reg .pred p;\n" + "\tsetp.ne.b32 p, %2, 0;\n" + "\tmov.b32 %0, 0;\n" + "\t@p ld.global.cs.u32 {%0}, [%1];\n" + "}\n" + : "=r"(data[0]) + : "l"(src), "r"((int)do_access)); +#else + const int32_t* src_cast = reinterpret_cast(src); + if (do_access) { + data[0] = src_cast[0]; + } else { + data[0] = 0; + } +#endif +} + /////////// Load Shared /////////// namespace internal {