@@ -2175,67 +2175,92 @@ class tinyBLAS_PPC {
21752175                int  ith, int  nth)
21762176        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
21772177    }
2178- 
21792178    void  matmul (int64_t  m, int64_t  n) {
2180-        mnpack (0 , m, 0 , n);
2179+         int64_t  mc = 256 ; int64_t  nc = 256 ; int64_t  kc = 256 ;
2180+         if  ( m%mc == 0  && n%nc == 0  && k%kc == 0 ) {
2181+ 	   matmul_tiled (m, n, mc, nc, kc);
2182+         } else  {
2183+            mnpack (0 , m, 0 , n);
2184+         }
21812185    }
21822186
21832187  private: 
21842188
21852189    void  (tinyBLAS_PPC::*kernel)(int64_t , int64_t );
21862190
2191+     inline  void  save_acc (acc_t * ACC, int64_t  ii, int64_t  jj) {
2192+        vec_t  vec_C[4 ];
2193+        __builtin_mma_disassemble_acc (vec_C, ACC);
2194+        for  (int  I = 0 ; I < 4 ; I++) {
2195+           for  (int  J = 0 ; J < 4 ; J++) {
2196+              *((float *)(C+ii+((jj+J)*ldc)+I)) = *((float *)&vec_C[I]+J);
2197+           }
2198+        }
2199+     }
2200+ 
2201+     inline  void  add_save_acc (acc_t * ACC, int64_t  ii, int64_t  jj) {
2202+        vec_t  vec_C[4 ];
2203+        __builtin_mma_disassemble_acc (vec_C, ACC);
2204+        for  (int  I = 0 ; I < 4 ; I++) {
2205+           for  (int  J = 0 ; J < 4 ; J++) {
2206+              float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);//  += *((float*)&vec_C[I]+J);
2207+              *c_ptr += *((float *)&vec_C[I]+J);
2208+           }
2209+        }
2210+     }
2211+ 
21872212    inline  void  vector_permute_store_4 (vector float  *src, float  *vecOffset) {
21882213       vector float  t1, t2, t3, t4, t5, t6, t7, t8;
2189-             t1 = vec_mergeh (src[0 ], src[1 ]);
2190-             t2 = vec_mergeh (src[2 ], src[3 ]);
2191-             t3 = vec_mergel (src[0 ], src[1 ]);
2192-             t4 = vec_mergel (src[2 ], src[3 ]);
2193- 
2194-             t5 = vec_xxpermdi (t1, t2, 0 );
2195-             t6 = vec_xxpermdi (t1, t2, 3 );
2196-             t7 = vec_xxpermdi (t3, t4, 0 );
2197-             t8 = vec_xxpermdi (t3, t4, 3 );
2198- 
2199-             vec_xst (t5, 0 , vecOffset);
2200-             vec_xst (t6, 0 , vecOffset + 4 );
2201-             vec_xst (t7, 0 , vecOffset + 8 );
2202-             vec_xst (t8, 0 , vecOffset + 12 );
2203-         }
2214+        t1 = vec_mergeh (src[0 ], src[1 ]);
2215+        t2 = vec_mergeh (src[2 ], src[3 ]);
2216+        t3 = vec_mergel (src[0 ], src[1 ]);
2217+        t4 = vec_mergel (src[2 ], src[3 ]);
2218+ 
2219+        t5 = vec_xxpermdi (t1, t2, 0 );
2220+        t6 = vec_xxpermdi (t1, t2, 3 );
2221+        t7 = vec_xxpermdi (t3, t4, 0 );
2222+        t8 = vec_xxpermdi (t3, t4, 3 );
2223+ 
2224+        vec_xst (t5, 0 , vecOffset);
2225+        vec_xst (t6, 0 , vecOffset + 4 );
2226+        vec_xst (t7, 0 , vecOffset + 8 );
2227+        vec_xst (t8, 0 , vecOffset + 12 );
2228+     }
22042229
22052230    inline  void  vector_permute_store_8 (vector float  *src, float  *vecOffset) {
22062231       vector float  t1, t2, t3, t4, t5, t6, t7, t8;
2207-            t1 = vec_mergeh (src[0 ], src[1 ]);
2208-            t2 = vec_mergeh (src[2 ], src[3 ]);
2209-            t3 = vec_mergeh (src[4 ], src[5 ]);
2210-            t4 = vec_mergeh (src[6 ], src[7 ]);
2211- 
2212-            t5 = vec_xxpermdi (t1, t2, 0 );
2213-            t6 = vec_xxpermdi (t3, t4, 0 );
2214-            t7 = vec_xxpermdi (t1, t2, 3 );
2215-            t8 = vec_xxpermdi (t3, t4, 3 );
2216- 
2217-            vec_xst (t5, 0 , vecOffset);
2218-            vec_xst (t6, 0 , vecOffset + 4 );
2219-            vec_xst (t7, 0 , vecOffset + 8 );
2220-            vec_xst (t8, 0 , vecOffset + 12 );
2221- 
2222-            t1 = vec_mergel (src[0 ], src[1 ]);
2223-            t2 = vec_mergel (src[2 ], src[3 ]);
2224-            t3 = vec_mergel (src[4 ], src[5 ]);
2225-            t4 = vec_mergel (src[6 ], src[7 ]);
2226- 
2227-            t5 = vec_xxpermdi (t1, t2, 0 );
2228-            t6 = vec_xxpermdi (t3, t4, 0 );
2229-            t7 = vec_xxpermdi (t1, t2, 3 );
2230-            t8 = vec_xxpermdi (t3, t4, 3 );
2231- 
2232-            vec_xst (t5, 0 , vecOffset + 16 );
2233-            vec_xst (t6, 0 , vecOffset + 20 );
2234-            vec_xst (t7, 0 , vecOffset + 24 );
2235-            vec_xst (t8, 0 , vecOffset + 28 );
2236-     }
2237- 
2238-  void  packTranspose (const  float * a, int64_t  lda, int  rows, int  cols, float * vec) {
2232+        t1 = vec_mergeh (src[0 ], src[1 ]);
2233+        t2 = vec_mergeh (src[2 ], src[3 ]);
2234+        t3 = vec_mergeh (src[4 ], src[5 ]);
2235+        t4 = vec_mergeh (src[6 ], src[7 ]);
2236+ 
2237+        t5 = vec_xxpermdi (t1, t2, 0 );
2238+        t6 = vec_xxpermdi (t3, t4, 0 );
2239+        t7 = vec_xxpermdi (t1, t2, 3 );
2240+        t8 = vec_xxpermdi (t3, t4, 3 );
2241+ 
2242+        vec_xst (t5, 0 , vecOffset);
2243+        vec_xst (t6, 0 , vecOffset + 4 );
2244+        vec_xst (t7, 0 , vecOffset + 8 );
2245+        vec_xst (t8, 0 , vecOffset + 12 );
2246+ 
2247+        t1 = vec_mergel (src[0 ], src[1 ]);
2248+        t2 = vec_mergel (src[2 ], src[3 ]);
2249+        t3 = vec_mergel (src[4 ], src[5 ]);
2250+        t4 = vec_mergel (src[6 ], src[7 ]);
2251+ 
2252+        t5 = vec_xxpermdi (t1, t2, 0 );
2253+        t6 = vec_xxpermdi (t3, t4, 0 );
2254+        t7 = vec_xxpermdi (t1, t2, 3 );
2255+        t8 = vec_xxpermdi (t3, t4, 3 );
2256+ 
2257+        vec_xst (t5, 0 , vecOffset + 16 );
2258+        vec_xst (t6, 0 , vecOffset + 20 );
2259+        vec_xst (t7, 0 , vecOffset + 24 );
2260+        vec_xst (t8, 0 , vecOffset + 28 );
2261+     }
2262+ 
2263+      void  packTranspose (const  float * a, int64_t  lda, int  rows, int  cols, float * vec) {
22392264        int64_t  i, j;
22402265        float  * aoffsets[8 ];
22412266        float  *aoffset = NULL , *boffset = NULL ;
@@ -2247,7 +2272,6 @@ class tinyBLAS_PPC {
22472272        boffset = vec;
22482273        j = (rows >> 3 );
22492274        if  (j > 0 ) {
2250- 
22512275            do  {
22522276                aoffsets[0 ] = aoffset;
22532277                for  (int  it = 1 ; it< 8 ; it++)
@@ -2265,10 +2289,13 @@ class tinyBLAS_PPC {
22652289
22662290                        vector_permute_store_8 (c1, boffset);
22672291                        vector_permute_store_8 (c2, boffset+32 );
2268-                         for  (int  it = 0 ; it < 4 ; it++)
2269-                             aoffsets[it] = aoffsets[it] + 8 *lda;
22702292                        boffset += 64 ;
22712293                        i--;
2294+                         if  (i > 0 ) {
2295+                            for  (int  it = 0 ; it < 8 ; it++) {
2296+                                aoffsets[it] = aoffsets[it] + 8 ;
2297+                            }
2298+                         }
22722299                    } while (i > 0 );
22732300                }
22742301                if  (cols & 4 ) {
@@ -2401,6 +2428,83 @@ class tinyBLAS_PPC {
24012428        SAVE_ACC (&acc_3, ii+4 , jj+4 );
24022429    }
24032430
2431+     inline  void  MMA_16x8 (vec_t  *vec_A0, vec_t * vec_A1, vec_t  *vec_B, acc_t  * acc) {
2432+         for  (int  x = 0 ; x < 16 ; x += 2 ) {
2433+             __builtin_mma_xvf32gerpp (&acc[0 ], vec_A0[x + 0 ], vec_B[x]);
2434+             __builtin_mma_xvf32gerpp (&acc[1 ], vec_A0[x + 0 ], vec_B[x + 1 ]);
2435+             __builtin_mma_xvf32gerpp (&acc[2 ], vec_A0[x + 1 ], vec_B[x]);
2436+             __builtin_mma_xvf32gerpp (&acc[3 ], vec_A0[x + 1 ], vec_B[x + 1 ]);
2437+             __builtin_mma_xvf32gerpp (&acc[4 ], vec_A1[x + 0 ], vec_B[x]);
2438+             __builtin_mma_xvf32gerpp (&acc[5 ], vec_A1[x + 0 ], vec_B[x + 1 ]);
2439+             __builtin_mma_xvf32gerpp (&acc[6 ], vec_A1[x + 1 ], vec_B[x]);
2440+             __builtin_mma_xvf32gerpp (&acc[7 ], vec_A1[x + 1 ], vec_B[x + 1 ]);
2441+         }
2442+     }
2443+ 
2444+     void  KERNEL (int64_t  ii, int64_t  jj, int64_t  mc, int64_t  nc, int64_t  kc, vec_t  * vec_A, vec_t * vec_B, int64_t  kk) {
2445+         for  (int64_t  i = 0 ; i <mc;  i += 16 ) {
2446+             int  A_base_addr = (mc/8 )* (i/8 )*16 ;
2447+             for  (int64_t  j = 0 ; j < nc; j += 8 ) {
2448+                  int  B_base_addr = (nc/8 )* (j/8 )*16 ;
2449+                  acc_t  acc[8 ];
2450+                  vec_t  A0_block[16 ]; vec_t  A1_block[16 ];
2451+                  for  (int  x = 0 ; x < 8 ; x++)
2452+                      __builtin_mma_xxsetaccz (&acc[x]);
2453+                  for  (int64_t  l = 0 ; l < kc; l+=8 ) {
2454+                      int  A0_block_idx = A_base_addr + (l/8 )*16 ;
2455+                      int  A1_block_idx = A0_block_idx + (mc/ 8 ) * 16 ;
2456+                      int  B_block_idx = B_base_addr + (l/8 )*16 ;
2457+                      vec_t * A0_block = &vec_A[A0_block_idx];
2458+                      vec_t * A1_block = &vec_A[A1_block_idx];
2459+                      vec_t * B_block = &vec_B[B_block_idx];
2460+                      MMA_16x8 (A0_block, A1_block, B_block, acc);
2461+                  }
2462+                  if  ( kk == 0 ) {
2463+                     save_acc (&acc[0 ], ii + i, jj + j);
2464+                     save_acc (&acc[1 ], ii + i, jj + j + 4 );
2465+                     save_acc (&acc[2 ], ii + i + 4 , jj + j);
2466+                     save_acc (&acc[3 ], ii + i + 4 , jj + j + 4 );
2467+                     save_acc (&acc[4 ], ii + i + 8 , jj + j);
2468+                     save_acc (&acc[5 ], ii + i + 8 , jj + j + 4 );
2469+                     save_acc (&acc[6 ], ii + i + 12 , jj + j);
2470+                     save_acc (&acc[7 ], ii + i + 12 , jj + j + 4 );
2471+                  } else  {
2472+                     add_save_acc (&acc[0 ], ii + i, jj + j);
2473+                     add_save_acc (&acc[1 ], ii + i, jj + j + 4 );
2474+                     add_save_acc (&acc[2 ], ii + i + 4 , jj + j);
2475+                     add_save_acc (&acc[3 ], ii + i + 4 , jj + j + 4 );
2476+                     add_save_acc (&acc[4 ], ii + i + 8 , jj + j);
2477+                     add_save_acc (&acc[5 ], ii + i + 8 , jj + j + 4 );
2478+                     add_save_acc (&acc[6 ], ii + i + 12 , jj + j);
2479+                     add_save_acc (&acc[7 ], ii + i + 12 , jj + j + 4 );
2480+                  }
2481+             }
2482+         }
2483+     }
2484+ 
2485+     void  matmul_tiled (int64_t  m , int64_t  n, int64_t  mc, int64_t  nc, int64_t  kc) {
2486+         int64_t  ytiles = m  / mc;
2487+         int64_t  xtiles = n  / nc;
2488+         int64_t  tiles = xtiles * ytiles;
2489+         int64_t  duty = (tiles + nth - 1 ) / nth;
2490+         int64_t  start = duty * ith;
2491+         int64_t  end = start + duty;
2492+         if  (end > tiles) {
2493+            end = tiles;
2494+         }
2495+         for  (int64_t  job = start; job < end; ++job) {
2496+             int64_t  ii =  (job / xtiles) * mc;
2497+             int64_t  jj =  (job % xtiles) * nc;
2498+             for  (int64_t  kk = 0 ; kk < k; kk += kc) {
2499+                  vec_t  A_pack[kc*mc/4 ];
2500+                  vec_t  B_pack[kc*nc/4 ];
2501+                  packTranspose (A+(ii*lda)+kk, lda, kc, mc, (float *)A_pack);
2502+                  packTranspose (B+(jj*ldb)+kk, ldb, kc, nc, (float *)B_pack);
2503+                  KERNEL (ii, jj, mc, nc, kc, A_pack, B_pack, kk);
2504+             }
2505+         }
2506+     }
2507+ 
24042508    void  mnpack (int64_t  m0, int64_t  m, int64_t  n0, int64_t  n) {
24052509        int  m_rem = MIN (m - m0, 8 );
24062510        int  n_rem = MIN (n - n0, 8 );
0 commit comments