@@ -4802,14 +4802,17 @@ static void ggml_compute_forward_soft_max_f32(
48024802 memcpy (&scale, (float *) dst->op_params + 0 , sizeof (float ));
48034803 memcpy (&max_bias, (float *) dst->op_params + 1 , sizeof (float ));
48044804
4805- // TODO: handle transposed/permuted matrices
4806-
48074805 const int ith = params->ith ;
48084806 const int nth = params->nth ;
48094807
48104808 GGML_TENSOR_UNARY_OP_LOCALS
48114809
4812- // const int64_t ne11 = src1 ? src1->ne[1] : 1;
4810+ const int64_t nb11 = src1 ? src1->nb [1 ] : 1 ;
4811+ const int64_t nb12 = src1 ? src1->nb [2 ] : 1 ;
4812+ const int64_t nb13 = src1 ? src1->nb [3 ] : 1 ;
4813+
4814+ const int64_t ne12 = src1 ? src1->ne [2 ] : 1 ;
4815+ const int64_t ne13 = src1 ? src1->ne [3 ] : 1 ;
48134816
48144817 // TODO: is this supposed to be ceil instead of floor?
48154818 // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
@@ -4819,68 +4822,66 @@ static void ggml_compute_forward_soft_max_f32(
48194822 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
48204823 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
48214824
4822- const int nc = src0->ne [0 ];
4823- const int nr = ggml_nrows (src0);
4824-
4825- // rows per thread
4826- const int dr = (nr + nth - 1 )/nth;
4827-
4828- // row range for this thread
4829- const int ir0 = dr*ith;
4830- const int ir1 = MIN (ir0 + dr, nr);
4831-
4832- float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
4825+ float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
48334826
48344827 const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
48354828
4836- for (int i1 = ir0; i1 < ir1; i1++) {
4837- // ALiBi
4838- const uint32_t h = (i1/ne01)%ne02; // head
4839- const float slope = (max_bias > 0 .0f ) ? h < n_head_log2 ? powf (m0, h + 1 ) : powf (m1, 2 *(h - n_head_log2) + 1 ) : 1 .0f ;
4840-
4841- float * sp = (float *)((char *) src0->data + i1*src0->nb [1 ]);
4842- float * dp = (float *)((char *) dst->data + i1*dst->nb [1 ]);
4843-
4844- // broadcast the mask across rows
4845- ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data ) + (i1%ne01)*ne00 : NULL ;
4846- float * mp_f32 = src1 ? (float *)((char *) src1->data ) + (i1%ne01)*ne00 : NULL ;
4847-
4848- ggml_vec_cpy_f32 (nc, wp, sp);
4849- ggml_vec_scale_f32 (nc, wp, scale);
4850- if (mp_f32) {
4851- if (use_f16) {
4852- for (int i = 0 ; i < nc; ++i) {
4853- wp[i] += slope*GGML_CPU_FP16_TO_FP32 (mp_f16[i]);
4854- }
4855- } else {
4856- for (int i = 0 ; i < nc; ++i) {
4857- wp[i] += slope*mp_f32[i];
4829+ for (int64_t i03 = 0 ; i03 < ne03; i03++) {
4830+ for (int64_t i02 = 0 ; i02 < ne02; i02++) {
4831+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
4832+ const int64_t i11 = i01;
4833+ const int64_t i12 = i02%ne12;
4834+ const int64_t i13 = i03%ne13;
4835+
4836+ // ALiBi
4837+ const uint32_t h = i02; // head
4838+ const float slope = (max_bias > 0 .0f ) ? h < n_head_log2 ? powf (m0, h + 1 ) : powf (m1, 2 *(h - n_head_log2) + 1 ) : 1 .0f ;
4839+
4840+ float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
4841+ float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
4842+
4843+ // broadcast the mask across rows
4844+ ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL ;
4845+ float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL ;
4846+
4847+ ggml_vec_cpy_f32 (ne00, wp, sp);
4848+ ggml_vec_scale_f32 (ne00, wp, scale);
4849+ if (mp_f32) {
4850+ if (use_f16) {
4851+ for (int i = 0 ; i < ne00; ++i) {
4852+ wp[i] += slope*GGML_CPU_FP16_TO_FP32 (mp_f16[i]);
4853+ }
4854+ } else {
4855+ for (int i = 0 ; i < ne00; ++i) {
4856+ wp[i] += slope*mp_f32[i];
4857+ }
4858+ }
48584859 }
4859- }
4860- }
48614860
48624861#ifndef NDEBUG
4863- for (int i = 0 ; i < nc ; ++i) {
4864- // printf("p[%d] = %f\n", i, p[i]);
4865- assert (!isnan (wp[i]));
4866- }
4862+ for (int i = 0 ; i < ne00 ; ++i) {
4863+ // printf("p[%d] = %f\n", i, p[i]);
4864+ assert (!isnan (wp[i]));
4865+ }
48674866#endif
48684867
4869- float max = -INFINITY;
4870- ggml_vec_max_f32 (nc , &max, wp);
4868+ float max = -INFINITY;
4869+ ggml_vec_max_f32 (ne00 , &max, wp);
48714870
4872- ggml_float sum = ggml_vec_soft_max_f32 (nc , dp, wp, max);
4873- assert (sum > 0.0 );
4871+ ggml_float sum = ggml_vec_soft_max_f32 (ne00 , dp, wp, max);
4872+ assert (sum > 0.0 );
48744873
4875- sum = 1.0 /sum;
4876- ggml_vec_scale_f32 (nc , dp, sum);
4874+ sum = 1.0 /sum;
4875+ ggml_vec_scale_f32 (ne00 , dp, sum);
48774876
48784877#ifndef NDEBUG
4879- for (int i = 0 ; i < nc ; ++i) {
4880- assert (!isnan (dp[i]));
4881- assert (!isinf (dp[i]));
4882- }
4878+ for (int i = 0 ; i < ne00 ; ++i) {
4879+ assert (!isnan (dp[i]));
4880+ assert (!isinf (dp[i]));
4881+ }
48834882#endif
4883+ }
4884+ }
48844885 }
48854886}
48864887
@@ -7151,7 +7152,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
71517152 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
71527153 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
71537154
7154- ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu (k->type )->vec_dot_type ;
7155+ ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu (k->type )->vec_dot_type ;
71557156 ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu (k_vec_dot_type)->from_float ;
71567157 ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu (k->type )->vec_dot ;
71577158 ggml_to_float_t const v_to_float = ggml_get_type_traits (v->type )->to_float ;
@@ -7183,7 +7184,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
71837184 memset (VKQ32, 0 , DV*sizeof (float ));
71847185 }
71857186
7186- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb [1 ]) : NULL ;
7187+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb [1 ] + (iq3%mask-> ne [ 2 ])*mask-> nb [ 2 ] ) : NULL ;
71877188
71887189 // k indices
71897190 const int ik3 = iq3 / rk3;
0 commit comments