Skip to content

Commit ce2b029

Browse files
ikawrakowIwan Kawrakow
andauthored
CUDA: faster FA TG for GQA models (#370)
* cuda: WIP MMA FA * Use MMA for TG also when quantized --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent b890e01 commit ce2b029

23 files changed

+2158
-11
lines changed

ggml/src/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,8 @@ if (GGML_CUDA)
321321
list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu")
322322
file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
323323
list(APPEND GGML_SOURCES_CUDA ${SRCS})
324+
file(GLOB SRCS "ggml-cuda/template-instances/fattn-mma*.cu")
325+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
324326
file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu")
325327
list(APPEND GGML_SOURCES_CUDA ${SRCS})
326328

ggml/src/ggml-cuda/common.cuh

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,14 @@
4646
#define CC_VOLTA 700
4747
#define CC_TURING 750
4848
#define CC_AMPERE 800
49+
#define CC_ADA_LOVELACE 890
4950
#define CC_OFFSET_AMD 1000000
51+
#define CC_OFFSET_MTHREADS 0x0100000
5052
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
5153
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
5254
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
55+
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < CC_OFFSET_MTHREADS)
56+
#define GGML_CUDA_CC_IS_AMD(cc) (cc >= CC_OFFSET_AMD)
5357

5458
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
5559

@@ -134,6 +138,49 @@ typedef float2 dfloat2;
134138
#define INT8_MMA_AVAILABLE
135139
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
136140

141+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
142+
#define CP_ASYNC_AVAILABLE
143+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
144+
145+
#ifdef __CUDA_ARCH_LIST__
146+
constexpr bool ggml_cuda_has_arch_impl(int) {
147+
return false;
148+
}
149+
150+
template<class ... Archs>
151+
constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
152+
return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
153+
}
154+
155+
constexpr bool ggml_cuda_has_arch(const int arch) {
156+
return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
157+
}
158+
159+
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
160+
if (cur == 0) {
161+
GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
162+
}
163+
return cur;
164+
}
165+
166+
template<class ... Archs>
167+
constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {
168+
if (first <= arch && first > cur) {
169+
return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);
170+
} else {
171+
return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);
172+
}
173+
}
174+
175+
constexpr int ggml_cuda_highest_compiled_arch(const int arch) {
176+
return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);
177+
}
178+
#else
179+
static int ggml_cuda_highest_compiled_arch(const int arch) {
180+
return arch;
181+
}
182+
#endif // __CUDA_ARCH_LIST__
183+
137184
static constexpr bool fast_fp16_available(const int cc) {
138185
return cc >= CC_PASCAL && cc != 610;
139186
}
@@ -146,6 +193,15 @@ static constexpr bool int8_mma_available(const int cc) {
146193
return cc < CC_OFFSET_AMD && cc >= CC_TURING;
147194
}
148195

196+
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
197+
static bool new_mma_available(const int cc) {
198+
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= CC_TURING;
199+
}
200+
201+
static bool cp_async_available(const int cc) {
202+
return cc < CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= CC_AMPERE;
203+
}
204+
149205
[[noreturn]]
150206
static __device__ void no_device_code(
151207
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {

ggml/src/ggml-cuda/cp-async.cuh

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Simplified API for asynchronous data loading.
2+
3+
#include "common.cuh"
4+
5+
// Copies data from global to shared memory, cg == cache global.
6+
// Both the src and dst pointers must be aligned to 16 bit.
7+
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
8+
// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
9+
// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
10+
template <int preload>
11+
static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
12+
static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
13+
#ifdef CP_ASYNC_AVAILABLE
14+
#if CUDART_VERSION >= 11040
15+
if (preload == 256) {
16+
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
17+
: : "r"(dst), "l"(src));
18+
} else if (preload == 128) {
19+
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
20+
: : "r"(dst), "l"(src));
21+
} else if (preload == 64) {
22+
asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
23+
: : "r"(dst), "l"(src));
24+
} else
25+
#endif // CUDART_VERSION >= 11040
26+
{
27+
asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
28+
: : "r"(dst), "l"(src));
29+
}
30+
#else
31+
GGML_UNUSED(dst);
32+
GGML_UNUSED(src);
33+
NO_DEVICE_CODE;
34+
#endif // CP_ASYNC_AVAILABLE
35+
}
36+
37+
// Makes each thread wait until its asynchronous data copies are done.
38+
// This does NOT provide any additional synchronization.
39+
// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
40+
static __device__ __forceinline__ void cp_async_wait_all() {
41+
#ifdef CP_ASYNC_AVAILABLE
42+
asm volatile("cp.async.wait_all;");
43+
#else
44+
NO_DEVICE_CODE;
45+
#endif // CP_ASYNC_AVAILABLE
46+
}

0 commit comments

Comments
 (0)