Skip to content

Commit 3d85a1d

Browse files
ikawrakowIwan Kawrakow
andauthored
Better FlashMLA (#243)
* This is a better FA for TG It should benefit MLA and GQA. Tested to work with DeepSeek-Lite MLA, not yet for GQA. For tg64@pp8192 it is ~13% faster than MLA without FA, and 57% faster that the main branch FA. * WIP * Cleanup --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent c67a37b commit 3d85a1d

File tree

7 files changed

+582
-179
lines changed

7 files changed

+582
-179
lines changed

ggml/src/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,8 @@ set (GGML_HEADERS_IQK iqk/iqk_config.h)
258258
if (GGML_IQK_MUL_MAT)
259259
message(STATUS "Using optimized iqk matrix multiplications")
260260
add_compile_definitions(GGML_USE_IQK_MULMAT)
261-
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp)
262-
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h)
261+
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp iqk/iqk_flash_attn.cpp)
262+
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h iqk/iqk_flash_impl.h)
263263
if (GGML_IQK_FA_ALL_QUANTS)
264264
message(STATUS "Including all IQK FA kernels")
265265
add_compile_definitions(GGML_IQK_FA_ALL_QUANTS)

ggml/src/ggml.c

Lines changed: 72 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17870,46 +17870,57 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1787017870
}
1787117871

1787217872
#if GGML_USE_IQK_MULMAT
17873-
if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
17874-
//if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",
17875-
// k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]);
17876-
// I keep changing my mind what is the best strategy to split the threads when processing
17877-
// multiple heads. This is my current thinking, the commented out code below was the previous.
17878-
int ntg = nth/simple_gcd(neq2*neq3, nth);
17879-
int64_t neq1g = (neq1 + ntg - 1)/ntg;
17880-
//int64_t work_per_slice = D*nek1*neq1;
17881-
//int ntg = 1;
17882-
//
17883-
// When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix
17884-
// But we also want each thread to process the same amount of rows, so neq1 must be a multiple of
17885-
// the number of threads processing the (iq2, iq3) matrix.
17886-
//
17887-
//if (neq1 >= 8*nth) {
17888-
// if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
17889-
// else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
17890-
// else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
17891-
//}
17892-
int counter = 0;
17893-
for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
17894-
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
17895-
if (counter++ % (nth/ntg) == ith/ntg) {
17896-
int iq1 = (ith%ntg)*neq1g;
17897-
int this_neq1 = MIN(neq1g, neq1-iq1);
17898-
if (!iqk_flash_attn_noalibi(k->type, v->type,
17899-
Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
17900-
(const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
17901-
(const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
17902-
(const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
17903-
(const void *)((const char *)mask->data + iq1*mask->nb[1]),
17904-
scale, softcap,
17905-
(float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable;
17906-
}
17907-
}
17908-
}
17909-
return;
17910-
IQK_Flash_Attn_NotAvailable:;
17911-
printf("iqk_flash was rejected\n");
17912-
}
17873+
if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias,
17874+
q->ne[3], q->ne[2], q->nb[3], q->nb[2],
17875+
k->ne[3], k->ne[2], k->nb[3], k->nb[2],
17876+
v->ne[3], v->ne[2], v->nb[3], v->nb[2],
17877+
dst->ne[2], dst->ne[1], dst->nb[1],
17878+
k->type, v->type,
17879+
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
17880+
q->data, k->data, v->data, mask->data,
17881+
scale, softcap, (float *)dst->data,
17882+
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return;
17883+
17884+
// if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
17885+
// //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",
17886+
// // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]);
17887+
// // I keep changing my mind what is the best strategy to split the threads when processing
17888+
// // multiple heads. This is my current thinking, the commented out code below was the previous.
17889+
// int ntg = nth/simple_gcd(neq2*neq3, nth);
17890+
// int64_t neq1g = (neq1 + ntg - 1)/ntg;
17891+
// //int64_t work_per_slice = D*nek1*neq1;
17892+
// //int ntg = 1;
17893+
// //
17894+
// // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix
17895+
// // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of
17896+
// // the number of threads processing the (iq2, iq3) matrix.
17897+
// //
17898+
// //if (neq1 >= 8*nth) {
17899+
// // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
17900+
// // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
17901+
// // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
17902+
// //}
17903+
// int counter = 0;
17904+
// for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
17905+
// for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
17906+
// if (counter++ % (nth/ntg) == ith/ntg) {
17907+
// int iq1 = (ith%ntg)*neq1g;
17908+
// int this_neq1 = MIN(neq1g, neq1-iq1);
17909+
// if (!iqk_flash_attn_noalibi(k->type, v->type,
17910+
// Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
17911+
// (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
17912+
// (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
17913+
// (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
17914+
// (const void *)((const char *)mask->data + iq1*mask->nb[1]),
17915+
// scale, softcap,
17916+
// (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable;
17917+
// }
17918+
// }
17919+
// }
17920+
// return;
17921+
//IQK_Flash_Attn_NotAvailable:;
17922+
// printf("iqk_flash was rejected\n");
17923+
// }
1791317924
#endif
1791417925

1791517926
const uint32_t n_head = neq2;
@@ -21534,6 +21545,27 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
2153421545
const int64_t D = MAX(Dk, Dv);
2153521546

2153621547
cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread
21548+
#if GGML_USE_IQK_MULMAT
21549+
const struct ggml_tensor * q = node->src[0];
21550+
const struct ggml_tensor * k = node->src[1];
21551+
if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
21552+
int nstep_k = k->ne[1]/32;
21553+
int gcd_k = simple_gcd(nstep_k, n_tasks);
21554+
if (gcd_k > 1) {
21555+
int nth_k = n_tasks/gcd_k;
21556+
int rk2 = q->ne[2]/k->ne[2];
21557+
if (rk2%nth_k == 0) {
21558+
size_t size = (Dv + 16)*rk2/nth_k*sizeof(float)*n_tasks;
21559+
if (ggml_is_quantized(k->type)) {
21560+
enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
21561+
size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
21562+
size += q->ne[2]*row_size;
21563+
}
21564+
cur = MAX(cur, size);
21565+
}
21566+
}
21567+
}
21568+
#endif
2153721569
} break;
2153821570
case GGML_OP_FLASH_ATTN_BACK:
2153921571
{

ggml/src/iqk/iqk_common.h

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
2+
// vi: set et ft=cpp fenc=utf-8 :vi
3+
//
4+
//
5+
// Copyright (C) 2024 Iwan Kawrakow
6+
// MIT license
7+
// SPDX-License-Identifier: MIT
8+
//
9+
10+
#include "iqk_config.h"
11+
12+
#if defined IQK_IMPLEMENT
13+
14+
#include <cstring>
15+
#include <type_traits>
16+
#include <vector>
17+
18+
#include "ggml-impl.h"
19+
#include "ggml-quants.h"
20+
#include "iqk_mul_mat.h"
21+
#include "iqk_quantize.h"
22+
23+
#define GGML_COMMON_IMPL_C
24+
#include "ggml-common.h"
25+
26+
#define FA_TIMING 0
27+
28+
#include <utility>
29+
#include <array>
30+
#if FA_TIMING
31+
#include <chrono>
32+
#include <mutex>
33+
struct Perf {
34+
using TimePoint = std::chrono::time_point<std::chrono::high_resolution_clock>;
35+
std::array<double, 5> times = {};
36+
std::mutex mutex;
37+
bool report;
38+
static auto cur_time() { return std::chrono::high_resolution_clock::now(); }
39+
inline void accum(int what, const TimePoint& t1) {
40+
auto t2 = cur_time();
41+
auto dt = delta(t1, t2);
42+
std::lock_guard<std::mutex> lock(mutex);
43+
times[what] += dt;
44+
}
45+
inline void accum_nolock(int what, const TimePoint& t1) {
46+
auto t2 = cur_time();
47+
auto dt = delta(t1, t2);
48+
times[what] += dt;
49+
}
50+
inline void add(const Perf& other) {
51+
std::lock_guard<std::mutex> lock(mutex);
52+
for (int i = 0; i < int(times.size()); ++i) times[i] += other.times[i];
53+
}
54+
Perf(bool r) : report(r) {}
55+
~Perf() {
56+
if (report) {
57+
double tot = 0;
58+
for (auto& t : times) tot += t;
59+
if (!tot) return;
60+
printf("======================= Timing: %g ms in total\n", tot);
61+
for (int i = 0; i < int(times.size()); ++i) {
62+
if (times[i]) {
63+
printf("%d: %g ms -> %g%c\n", i, times[i], 100*times[i]/tot, '%');
64+
}
65+
}
66+
}
67+
}
68+
static Perf& instance() {
69+
static Perf p(true);
70+
return p;
71+
}
72+
static double delta(const TimePoint& t1, const TimePoint& t2) {
73+
return 1e-6*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
74+
}
75+
};
76+
#endif
77+
78+
#ifdef __AVX2__
79+
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
80+
#endif
81+
82+
namespace {
83+
84+
typedef struct {
85+
int32_t i1;
86+
int32_t i2;
87+
} mmid_row_mapping;
88+
89+
struct DataInfo {
90+
float * s;
91+
const char * cy;
92+
size_t bs;
93+
size_t by;
94+
int cur_y = 0;
95+
int ne11;
96+
const mmid_row_mapping * row_mapping = nullptr;
97+
size_t bs2 = 0;
98+
99+
inline const char * src1_row(int iy) const {
100+
if (!row_mapping) return cy + (cur_y + iy)*by;
101+
int i11 = row_mapping[cur_y + iy].i1 % ne11;
102+
int i12 = row_mapping[cur_y + iy].i2;
103+
return cy + (i11 + i12*ne11)*by;
104+
}
105+
106+
inline void store(int ix, int iy, float result) const {
107+
*(dst_row(iy) + ix) = result;
108+
}
109+
#ifdef __AVX__
110+
inline void store(int ix, int iy, __m128 result) const {
111+
_mm_storeu_ps(dst_row(iy) + ix, result);
112+
}
113+
inline void store(int ix, int iy, __m256 result) const {
114+
_mm256_storeu_ps(dst_row(iy) + ix, result);
115+
}
116+
#endif
117+
#ifdef __AVX512F__
118+
inline void store(int ix, int iy, __m512 result) const {
119+
_mm512_storeu_ps(dst_row(iy) + ix, result);
120+
}
121+
#endif
122+
#ifdef __ARM_NEON
123+
inline void store(int ix, int iy, float32x4_t result) const {
124+
vst1q_f32(dst_row(iy) + ix, result);
125+
}
126+
#endif
127+
inline float * dst_row(int iy) const {
128+
if (!row_mapping) return s + (cur_y + iy)*bs;
129+
int i12 = row_mapping[cur_y + iy].i2;
130+
int i1 = row_mapping[cur_y + iy].i1;
131+
int i2 = i12;
132+
return s + i1*bs + i2*bs2;
133+
}
134+
};
135+
136+
typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x);
137+
138+
#endif

0 commit comments

Comments
 (0)