@@ -200,22 +200,72 @@ __global__ void __launch_bounds__(128) gemm_forward_4bit_cuda_m128n64k32(int spl
200
200
201
201
for (int i_0_3 = 0 ; i_0_3 < 4 ; ++i_0_3) {
202
202
for (int j_0_4 = 0 ; j_0_4 < 2 ; ++j_0_4) {
203
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
204
+ {
205
+ __asm__ __volatile__ (
206
+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
207
+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
208
+ : " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[0 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[1 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[2 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[3 ])
209
+ : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[0 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[1 ]),
210
+ " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[0 ]),
211
+ " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[0 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[1 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[2 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[3 ])
212
+ );
213
+ }
214
+
215
+ {
216
+ __asm__ __volatile__ (
217
+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
218
+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
219
+ : " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[0 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[1 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[2 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[3 ])
220
+ : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[2 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[3 ]),
221
+ " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[1 ]),
222
+ " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[0 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[1 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[2 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[3 ])
223
+ );
224
+ }
203
225
226
+ {
227
+ __asm__ __volatile__ (
228
+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
229
+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
230
+ : " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[0 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[1 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[2 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[3 ])
231
+ : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[0 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[1 ]),
232
+ " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[0 ]),
233
+ " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[0 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[1 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[2 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[3 ])
234
+ );
235
+ }
236
+ {
237
+ __asm__ __volatile__ (
238
+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
239
+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
240
+ : " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[0 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[1 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[2 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[3 ])
241
+ : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[2 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[3 ]),
242
+ " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[1 ]),
243
+ " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[0 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[1 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[2 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[3 ])
244
+ );
245
+ }
246
+ #else
204
247
{
205
248
__asm__ __volatile__ (
206
249
" mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
207
250
" {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n "
208
251
: " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[0 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[1 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[2 ]), " =f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[3 ])
209
- : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[0 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[1 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[2 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[3 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[0 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[0 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[1 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[2 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[3 ]));
252
+ : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[0 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[1 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[2 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[3 ]),
253
+ " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[0 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[1 ]),
254
+ " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[0 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[1 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[2 ]), " f" (((float *)(C_warp + ((i_0_3 * 16 ) + (j_0_4 * 8 ))))[3 ])
255
+ );
210
256
}
211
257
212
258
{
213
259
__asm__ __volatile__ (
214
260
" mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
215
261
" {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n "
216
262
: " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[0 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[1 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[2 ]), " =f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[3 ])
217
- : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[0 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[1 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[2 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[3 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[0 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[1 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[2 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[3 ]));
263
+ : " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[0 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[1 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[2 ]), " r" (((unsigned *)(A_shared_warp + (i_0_3 * 8 )))[3 ]),
264
+ " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[1 ]),
265
+ " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[0 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[1 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[2 ]), " f" (((float *)(C_warp + (((i_0_3 * 16 ) + (j_0_4 * 8 )) + 4 )))[3 ])
266
+ );
218
267
}
268
+ #endif
219
269
}
220
270
}
221
271
}
0 commit comments