2525
2626#include < iostream>
2727
28+ #include " common/base.h"
29+
30+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
31+ #include " common/mem.h"
32+ #endif
33+
2834template <typename T>
2935inline std::string str (T x) {
3036 return std::to_string (x);
3137}
3238
3339namespace marlin_dense {
3440
35- constexpr int ceildiv (int a, int b) { return (a + b - 1 ) / b; }
36-
3741#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
3842
39- // Instances of `Vec` are used to organize groups of >>registers<<, as needed
40- // for instance as inputs to tensor core operations. Consequently, all
41- // corresponding index accesses must be compile-time constants, which is why we
42- // extensively use `#pragma unroll` throughout the kernel code to guarantee
43- // this.
44- template <typename T, int n>
45- struct Vec {
46- T elems[n];
47- __device__ T& operator [](int i) { return elems[i]; }
48- };
49-
5043using I4 = Vec<int , 4 >;
51-
5244// Matrix fragments for tensor core instructions; their precise layout is
5345// documented here:
5446// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
@@ -57,43 +49,6 @@ using FragB = Vec<half2, 2>;
5749using FragC = Vec<float , 4 >;
5850using FragS = Vec<half2, 1 >; // quantization scales
5951
60- // Predicated asynchronous global->shared copy; used for inputs A where we apply
61- // predication to handle batchsizes that are not multiples of 16.
62- __device__ inline void cp_async4_pred (void * smem_ptr, const void * glob_ptr,
63- bool pred = true ) {
64- const int BYTES = 16 ;
65- uint32_t smem = static_cast <uint32_t >(__cvta_generic_to_shared (smem_ptr));
66- asm volatile (
67- " {\n "
68- " .reg .pred p;\n "
69- " setp.ne.b32 p, %0, 0;\n "
70- " @p cp.async.cg.shared.global [%1], [%2], %3;\n "
71- " }\n " ::" r" ((int )pred),
72- " r" (smem), " l" (glob_ptr), " n" (BYTES));
73- }
74-
75- // Asynchronous global->shared copy
76- __device__ inline void cp_async4 (void * smem_ptr, const void * glob_ptr) {
77- const int BYTES = 16 ;
78- uint32_t smem = static_cast <uint32_t >(__cvta_generic_to_shared (smem_ptr));
79- asm volatile (
80- " {\n "
81- " cp.async.cg.shared.global [%0], [%1], %2;\n "
82- " }\n " ::" r" (smem),
83- " l" (glob_ptr), " n" (BYTES));
84- }
85-
86- // Async copy fence.
87- __device__ inline void cp_async_fence () {
88- asm volatile (" cp.async.commit_group;\n " ::);
89- }
90-
91- // Wait until at most `n` async copy stages are still pending.
92- template <int n>
93- __device__ inline void cp_async_wait () {
94- asm volatile (" cp.async.wait_group %0;\n " ::" n" (n));
95- }
96-
9752// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
9853// output/accumulation.
9954__device__ inline void mma (const FragA& a_frag, const FragB& frag_b,
@@ -164,39 +119,6 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
164119 frag_b[1 ] = __hmul2 (frag_b[1 ], s);
165120}
166121
167- // Wait until barrier reaches `count`, then lock for current threadblock.
168- __device__ inline void barrier_acquire (int * lock, int count) {
169- if (threadIdx .x == 0 ) {
170- int state = -1 ;
171- do
172- // Guarantee that subsequent writes by this threadblock will be visible
173- // globally.
174- asm volatile (" ld.global.acquire.gpu.b32 %0, [%1];\n "
175- : " =r" (state)
176- : " l" (lock));
177- while (state != count);
178- }
179- __syncthreads ();
180- }
181-
182- // Release barrier and increment visitation count.
183- __device__ inline void barrier_release (int * lock, bool reset = false ) {
184- __syncthreads ();
185- if (threadIdx .x == 0 ) {
186- if (reset) {
187- lock[0 ] = 0 ;
188- return ;
189- }
190- int val = 1 ;
191- // Make sure that all writes since acquiring this barrier are visible
192- // globally, while releasing the barrier.
193- asm volatile (" fence.acq_rel.gpu;\n " );
194- asm volatile (" red.relaxed.gpu.global.add.s32 [%0], %1;\n "
195- :
196- : " l" (lock), " r" (val));
197- }
198- }
199-
200122template <const int threads, // number of threads in a threadblock
201123 const int thread_m_blocks, // number of 16x16 blocks in the m
202124 // dimension (batchsize) of the
0 commit comments