22
33#extension GL_EXT_control_flow_attributes : enable
44#extension GL_EXT_shader_16bit_storage : require
5+ #extension GL_KHR_cooperative_matrix : require
6+ #extension GL_KHR_memory_scope_semantics : require
7+ #extension GL_EXT_shader_explicit_arithmetic_types : require
58
69#ifdef FLOAT16
710#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
@@ -152,12 +155,10 @@ void main() {
152155 uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
153156#endif
154157
155- float sums[WMITER * TM * WNITER * TN];
156- FLOAT_TYPE cache_a[WMITER * TM];
157- FLOAT_TYPE cache_b[WNITER * TN];
158+ coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator> sums[WM * WN / 16 / 16];
158159
159- [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN ; i++) {
160- sums[i] = 0.0f ;
160+ [[unroll]] for (uint i = 0; i < WM * WN / 16 / 16 ; i++) {
161+ sums[i] = coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator>(0.0) ;
161162 }
162163
163164 [[unroll]] for (uint block = start_k; block < end_k; block += BK) {
@@ -446,27 +447,14 @@ void main() {
446447 pos_a += BK / LOAD_VEC_A;
447448 pos_b += BK / LOAD_VEC_B;
448449
449- for (uint i = 0; i < BK; i++) {
450- // Load from shared into cache
451- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
452- [[unroll]] for (uint j = 0; j < TM; j++) {
453- cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
454- }
455- }
456- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
457- [[unroll]] for (uint j = 0; j < TN; j++) {
458- cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
459- }
460- }
461-
462- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
463- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
464- [[unroll]] for (uint cc = 0; cc < TN; cc++) {
465- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
466- const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
467- sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]);
468- }
469- }
450+ [[unroll]] for (uint i = 0; i < WM; i += 16) {
451+ [[unroll]] for (uint j = 0; j < WN; j += 16) {
452+ [[unroll]] for (uint k = 0; k < BK; k += 16) {
453+ coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA> matA;
454+ coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseB> matB;
455+ coopMatLoad(matA, buf_a, (warp_r * WM + i) * (BK+1) + k, BK+1, gl_CooperativeMatrixLayoutRowMajor);
456+ coopMatLoad(matB, buf_b, (warp_c * WN + j) * (BK+1) + k, BK+1, gl_CooperativeMatrixLayoutColumnMajor);
457+ sums[(i / 16) * (WN / 16) + (j / 16)] = coopMatMulAdd(matA, matB, sums[(i / 16) * (WN / 16) + (j / 16)]);
470458 }
471459 }
472460 }
@@ -481,6 +469,19 @@ void main() {
481469 const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
482470#endif
483471
472+ #if 1
473+ #ifndef MUL_MAT_ID
474+ // XXX TODO this is missing bounds checking against p.M and p.N,
475+ // which probably requires spilling to shared memory and doing scalar stores.
476+ // But sums[] may not all fit in shared memory...
477+ [[unroll]] for (uint i = 0; i < WM; i += 16) {
478+ [[unroll]] for (uint j = 0; j < WN; j += 16) {
479+ coopMatStore(sums[(i / 16) * (WN / 16) + (j / 16)], data_d, offsets + (dc + j) * p.stride_d + dr + i, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
480+ }
481+ }
482+ #endif
483+ #else
484+
484485 [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
485486 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
486487
@@ -505,4 +506,5 @@ void main() {
505506 }
506507 }
507508 }
509+ #endif
508510}
0 commit comments