55
66#if defined(GGML_USE_ACCELERATE )
77# include <Accelerate/Accelerate.h>
8- #elif defined(GGML_USE_BLAS )
9- # if defined(GGML_BLAS_USE_MKL )
10- # include <mkl.h>
11- # else
12- # include <cblas.h>
13- # endif
8+ #elif defined(GGML_BLAS_USE_MKL )
9+ # include <mkl.h>
10+ #else
11+ # include <cblas.h>
1412#endif
1513
1614struct ggml_backend_blas_context {
@@ -21,7 +19,7 @@ struct ggml_backend_blas_context {
2119
2220// helper function to determine if it is better to use BLAS or not
2321// for large matrices, BLAS is faster
24- static bool ggml_compute_forward_mul_mat_use_blas (const struct ggml_tensor * dst ) {
22+ static bool ggml_backend_blas_use_blas (const struct ggml_tensor * dst ) {
2523 const struct ggml_tensor * src0 = dst -> src [0 ];
2624 const struct ggml_tensor * src1 = dst -> src [1 ];
2725
@@ -72,11 +70,8 @@ static void ggml_backend_blas_mul_mat(struct ggml_backend_blas_context * ctx, st
7270 const int64_t r2 = ne12 /ne02 ;
7371 const int64_t r3 = ne13 /ne03 ;
7472
75- // nb01 >= nb00 - src0 is not transposed
76- // compute by src0 rows
77-
7873 const int64_t ne_plane = ne01 * ne00 ;
79- const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne13 * ne12 * ne_plane * sizeof (float );
74+ const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03 * ne02 * ne_plane * sizeof (float );
8075
8176 if (ctx -> work_size < desired_wsize ) {
8277 free (ctx -> work_data );
@@ -87,21 +82,19 @@ static void ggml_backend_blas_mul_mat(struct ggml_backend_blas_context * ctx, st
8782 void * wdata = ctx -> work_data ;
8883
8984 // convert src0 to float
90- if (true) {
91- if (type != GGML_TYPE_F32 ) {
92- ggml_to_float_t const to_float = type_traits .to_float ;
85+ if (type != GGML_TYPE_F32 ) {
86+ ggml_to_float_t const to_float = type_traits .to_float ;
9387
94- for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
95- for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
96- const void * x = (char * ) src0 -> data + i02 * nb02 + i03 * nb03 ;
97- float * const wplane = (float * ) wdata + i03 * ne12 * ne_plane + i02 * ne_plane ;
88+ for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
89+ for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
90+ const void * x = (char * ) src0 -> data + i02 * nb02 + i03 * nb03 ;
91+ float * const wplane = (float * ) wdata + i03 * ne12 * ne_plane + i02 * ne_plane ;
9892
9993#ifdef GGML_USE_OPENMP
10094 #pragma omp parallel for num_threads(ctx->n_threads)
10195#endif
102- for (int64_t i01 = 0 ; i01 < ne01 ; i01 ++ ) {
103- to_float ((const char * ) x + i01 * nb01 , wplane + i01 * ne00 , ne00 );
104- }
96+ for (int64_t i01 = 0 ; i01 < ne01 ; i01 ++ ) {
97+ to_float ((const char * ) x + i01 * nb01 , wplane + i01 * ne00 , ne00 );
10598 }
10699 }
107100 }
@@ -129,6 +122,70 @@ static void ggml_backend_blas_mul_mat(struct ggml_backend_blas_context * ctx, st
129122 }
130123}
131124
125+ static void ggml_backend_blas_out_prod (struct ggml_backend_blas_context * ctx , struct ggml_tensor * dst ) {
126+ const struct ggml_tensor * src0 = dst -> src [0 ];
127+ const struct ggml_tensor * src1 = dst -> src [1 ];
128+
129+ GGML_TENSOR_BINARY_OP_LOCALS
130+
131+ GGML_ASSERT (ne0 == ne00 );
132+ GGML_ASSERT (ne1 == ne10 );
133+ GGML_ASSERT (ne2 == ne02 );
134+ GGML_ASSERT (ne02 == ne12 );
135+ GGML_ASSERT (ne3 == ne13 );
136+ GGML_ASSERT (ne03 == ne13 );
137+
138+ // we don't support permuted src0 or src1
139+ GGML_ASSERT (nb00 == sizeof (float ));
140+
141+ // dst cannot be transposed or permuted
142+ GGML_ASSERT (nb0 == sizeof (float ));
143+ // GGML_ASSERT(nb0 <= nb1);
144+ // GGML_ASSERT(nb1 <= nb2);
145+ // GGML_ASSERT(nb2 <= nb3);
146+
147+ // nb01 >= nb00 - src0 is not transposed
148+ // compute by src0 rows
149+
150+ // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
151+ // src0: (k,n)
152+ // src1: (k,m)
153+ // dst: (m,n)
154+ //
155+ // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
156+ // Also expressed as (major,minor)
157+ // a: (m,k): so src1 transposed
158+ // b: (k,n): so src0
159+ // c: (m,n)
160+ //
161+ // However, if ggml_is_transposed(src1) is true, then
162+ // src1->data already contains a transposed version, so sgemm mustn't
163+ // transpose it further.
164+
165+ int n = src0 -> ne [0 ];
166+ int k = src0 -> ne [1 ];
167+ int m = src1 -> ne [0 ];
168+
169+ int transposeA ;
170+ int lda ;
171+
172+ if (!ggml_is_transposed (src1 )) {
173+ transposeA = CblasTrans ;
174+ lda = m ;
175+ } else {
176+ transposeA = CblasNoTrans ;
177+ lda = k ;
178+ }
179+
180+ float * a = (float * ) ((char * ) src1 -> data );
181+ float * b = (float * ) ((char * ) src0 -> data );
182+ float * c = (float * ) ((char * ) dst -> data );
183+
184+ cblas_sgemm (CblasRowMajor , transposeA , CblasNoTrans , m , n , k , 1.0 , a , lda , b , n , 0.0 , c , n );
185+
186+ GGML_UNUSED (ctx );
187+ }
188+
132189// backend interface
133190
134191GGML_CALL static const char * ggml_backend_blas_name (ggml_backend_t backend ) {
@@ -138,6 +195,9 @@ GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) {
138195}
139196
140197GGML_CALL static void ggml_backend_blas_free (ggml_backend_t backend ) {
198+ struct ggml_backend_blas_context * ctx = (struct ggml_backend_blas_context * )backend -> context ;
199+ free (ctx -> work_data );
200+ free (ctx );
141201 free (backend );
142202}
143203
@@ -158,8 +218,9 @@ GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t
158218 ggml_backend_blas_mul_mat (ctx , node );
159219 break ;
160220
161- // TODO
162- //case GGML_OP_OUT_PROD:
221+ case GGML_OP_OUT_PROD :
222+ ggml_backend_blas_out_prod (ctx , node );
223+ break ;
163224
164225 case GGML_OP_NONE :
165226 case GGML_OP_RESHAPE :
@@ -180,7 +241,16 @@ GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t
180241}
181242
182243GGML_CALL static bool ggml_backend_blas_supports_op (ggml_backend_t backend , const struct ggml_tensor * op ) {
183- return op -> op == GGML_OP_MUL_MAT && ggml_compute_forward_mul_mat_use_blas (op );
244+ const struct ggml_tensor * src0 = op -> src [0 ];
245+ const struct ggml_tensor * src1 = op -> src [1 ];
246+
247+ return (op -> op == GGML_OP_MUL_MAT && ggml_backend_blas_use_blas (op )) ||
248+ (op -> op == GGML_OP_OUT_PROD && op -> src [0 ]-> type == GGML_TYPE_F32 &&
249+ op -> src [1 ]-> type == GGML_TYPE_F32 &&
250+ ggml_is_matrix (src0 ) &&
251+ ggml_is_matrix (src1 ) &&
252+ ggml_is_contiguous (src0 ) &&
253+ (ggml_is_contiguous (src1 ) || ggml_is_transposed (src1 )));
184254
185255 GGML_UNUSED (backend );
186256}
@@ -229,9 +299,9 @@ ggml_backend_t ggml_backend_blas_init(void) {
229299 return NULL ;
230300 }
231301
232- ctx -> n_threads = GGML_DEFAULT_N_THREADS ;
233- ctx -> work_data = NULL ;
234- ctx -> work_size = 0 ;
302+ ctx -> n_threads = GGML_DEFAULT_N_THREADS ;
303+ ctx -> work_data = NULL ;
304+ ctx -> work_size = 0 ;
235305
236306 * backend = (struct ggml_backend ) {
237307 /* .guid = */ ggml_backend_blas_guid (),
0 commit comments