@@ -17870,46 +17870,57 @@ static void ggml_compute_forward_flash_attn_ext_f16(
1787017870 }
1787117871
1787217872#if GGML_USE_IQK_MULMAT
17873- if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
17874- //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",
17875- // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]);
17876- // I keep changing my mind what is the best strategy to split the threads when processing
17877- // multiple heads. This is my current thinking, the commented out code below was the previous.
17878- int ntg = nth/simple_gcd(neq2*neq3, nth);
17879- int64_t neq1g = (neq1 + ntg - 1)/ntg;
17880- //int64_t work_per_slice = D*nek1*neq1;
17881- //int ntg = 1;
17882- //
17883- // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix
17884- // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of
17885- // the number of threads processing the (iq2, iq3) matrix.
17886- //
17887- //if (neq1 >= 8*nth) {
17888- // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
17889- // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
17890- // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
17891- //}
17892- int counter = 0;
17893- for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
17894- for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
17895- if (counter++ % (nth/ntg) == ith/ntg) {
17896- int iq1 = (ith%ntg)*neq1g;
17897- int this_neq1 = MIN(neq1g, neq1-iq1);
17898- if (!iqk_flash_attn_noalibi(k->type, v->type,
17899- Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
17900- (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
17901- (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
17902- (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
17903- (const void *)((const char *)mask->data + iq1*mask->nb[1]),
17904- scale, softcap,
17905- (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable;
17906- }
17907- }
17908- }
17909- return;
17910- IQK_Flash_Attn_NotAvailable:;
17911- printf("iqk_flash was rejected\n");
17912- }
17873+ if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias,
17874+ q->ne[3], q->ne[2], q->nb[3], q->nb[2],
17875+ k->ne[3], k->ne[2], k->nb[3], k->nb[2],
17876+ v->ne[3], v->ne[2], v->nb[3], v->nb[2],
17877+ dst->ne[2], dst->ne[1], dst->nb[1],
17878+ k->type, v->type,
17879+ Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
17880+ q->data, k->data, v->data, mask->data,
17881+ scale, softcap, (float *)dst->data,
17882+ params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return;
17883+
17884+ // if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
17885+ // //if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",
17886+ // // k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]);
17887+ // // I keep changing my mind what is the best strategy to split the threads when processing
17888+ // // multiple heads. This is my current thinking, the commented out code below was the previous.
17889+ // int ntg = nth/simple_gcd(neq2*neq3, nth);
17890+ // int64_t neq1g = (neq1 + ntg - 1)/ntg;
17891+ // //int64_t work_per_slice = D*nek1*neq1;
17892+ // //int ntg = 1;
17893+ // //
17894+ // // When neq1 is large, it is better to have more than one thread process one (iq2,iq3) matrix
17895+ // // But we also want each thread to process the same amount of rows, so neq1 must be a multiple of
17896+ // // the number of threads processing the (iq2, iq3) matrix.
17897+ // //
17898+ // //if (neq1 >= 8*nth) {
17899+ // // if (nth%8 == 0 && neq1%8 == 0 && work_per_slice >= (1 << 23)) ntg = 8;
17900+ // // else if (nth%4 == 0 && neq1%4 == 0 && work_per_slice >= (1 << 21)) ntg = 4;
17901+ // // else if (nth%2 == 0 && neq1%2 == 0 && work_per_slice >= (1 << 19)) ntg = 2;
17902+ // //}
17903+ // int counter = 0;
17904+ // for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
17905+ // for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
17906+ // if (counter++ % (nth/ntg) == ith/ntg) {
17907+ // int iq1 = (ith%ntg)*neq1g;
17908+ // int this_neq1 = MIN(neq1g, neq1-iq1);
17909+ // if (!iqk_flash_attn_noalibi(k->type, v->type,
17910+ // Dk, Dv, this_neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1], ne1*nb1/sizeof(float),
17911+ // (const float *)((const char *)q->data + iq2*q->nb[2] + iq3*q->nb[3] + iq1*q->nb[1]),
17912+ // (const void *)((const char *)k->data + iq2/rk2*k->nb[2] + iq3/rk3*k->nb[3]),
17913+ // (const void *)((const char *)v->data + iq2/rv2*v->nb[2] + iq3/rv3*v->nb[3]),
17914+ // (const void *)((const char *)mask->data + iq1*mask->nb[1]),
17915+ // scale, softcap,
17916+ // (float *)((char *) dst->data + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1))) goto IQK_Flash_Attn_NotAvailable;
17917+ // }
17918+ // }
17919+ // }
17920+ // return;
17921+ //IQK_Flash_Attn_NotAvailable:;
17922+ // printf("iqk_flash was rejected\n");
17923+ // }
1791317924#endif
1791417925
1791517926 const uint32_t n_head = neq2;
@@ -21534,6 +21545,27 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
2153421545 const int64_t D = MAX(Dk, Dv);
2153521546
2153621547 cur = 3*sizeof(float)*D*n_tasks; // 3x head size/thread
21548+ #if GGML_USE_IQK_MULMAT
21549+ const struct ggml_tensor * q = node->src[0];
21550+ const struct ggml_tensor * k = node->src[1];
21551+ if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
21552+ int nstep_k = k->ne[1]/32;
21553+ int gcd_k = simple_gcd(nstep_k, n_tasks);
21554+ if (gcd_k > 1) {
21555+ int nth_k = n_tasks/gcd_k;
21556+ int rk2 = q->ne[2]/k->ne[2];
21557+ if (rk2%nth_k == 0) {
21558+ size_t size = (Dv + 16)*rk2/nth_k*sizeof(float)*n_tasks;
21559+ if (ggml_is_quantized(k->type)) {
21560+ enum ggml_type vec_dot_type = type_traits[k->type].vec_dot_type;
21561+ size_t row_size = ggml_row_size(vec_dot_type, q->ne[0]);
21562+ size += q->ne[2]*row_size;
21563+ }
21564+ cur = MAX(cur, size);
21565+ }
21566+ }
21567+ }
21568+ #endif
2153721569 } break;
2153821570 case GGML_OP_FLASH_ATTN_BACK:
2153921571 {
0 commit comments