-
Notifications
You must be signed in to change notification settings - Fork 14.9k
ggml: aarch64: Implement SVE in Gemm q4_k 8x8 q8_k Kernel #19132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
0a0a010
c74d605
cde6298
3b9b4df
1d4d342
b392a2a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3038,6 +3038,317 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, | |||||||||||||||||||||||||||||||||||||||||
| UNUSED(ncols_interleaved); | ||||||||||||||||||||||||||||||||||||||||||
| UNUSED(blocklen); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) | ||||||||||||||||||||||||||||||||||||||||||
| if (svcntb() * 8 == 256) { | ||||||||||||||||||||||||||||||||||||||||||
| constexpr int q8_k_blocklen = 4; | ||||||||||||||||||||||||||||||||||||||||||
| const svuint8_t m4b_1 = svdup_n_u8(0x0f); | ||||||||||||||||||||||||||||||||||||||||||
| // 8 accumulators: 2 row pairs × 4 col pairs | ||||||||||||||||||||||||||||||||||||||||||
| svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67; | ||||||||||||||||||||||||||||||||||||||||||
| uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 }; | ||||||||||||||||||||||||||||||||||||||||||
| svbool_t pg = svptrue_pat_b32(SV_VL8); | ||||||||||||||||||||||||||||||||||||||||||
| svuint32_t idx = svld1(pg, idx_arr); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7}; | ||||||||||||||||||||||||||||||||||||||||||
| svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| for (int y = 0; y < nr / q8_k_blocklen; y++) { | ||||||||||||||||||||||||||||||||||||||||||
| const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| for (int x = 0; x < nc / ncols_interleaved; x++) { | ||||||||||||||||||||||||||||||||||||||||||
| const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| acc_f32_01 = svdup_n_f32(0); | ||||||||||||||||||||||||||||||||||||||||||
| acc_f32_23 = svdup_n_f32(0); | ||||||||||||||||||||||||||||||||||||||||||
| acc_f32_45 = svdup_n_f32(0); | ||||||||||||||||||||||||||||||||||||||||||
| acc_f32_67 = svdup_n_f32(0); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| for (int b = 0; b < nb; b++) { | ||||||||||||||||||||||||||||||||||||||||||
| // bsums pairs belongs to the same q8_k subblock | ||||||||||||||||||||||||||||||||||||||||||
| // 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum | ||||||||||||||||||||||||||||||||||||||||||
| const int16x8_t bsums[4]{ | ||||||||||||||||||||||||||||||||||||||||||
| vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), | ||||||||||||||||||||||||||||||||||||||||||
| vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), | ||||||||||||||||||||||||||||||||||||||||||
| vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), | ||||||||||||||||||||||||||||||||||||||||||
| vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), | ||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| int32_t bsums_arr32[4][8]; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| for (int q8_row = 0; q8_row < 4; q8_row++) { | ||||||||||||||||||||||||||||||||||||||||||
| int16x8_t v16 = bsums[q8_row]; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| // low 4 | ||||||||||||||||||||||||||||||||||||||||||
| int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16)); | ||||||||||||||||||||||||||||||||||||||||||
| vst1q_s32(&bsums_arr32[q8_row][0], v32_lo); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| // high 4 | ||||||||||||||||||||||||||||||||||||||||||
| int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16)); | ||||||||||||||||||||||||||||||||||||||||||
| vst1q_s32(&bsums_arr32[q8_row][4], v32_hi); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svint32_t sb_acc_0 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t sb_acc_2 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svint32_t acc_00 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t acc_11 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t acc_22 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t acc_33 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t acc_44 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t acc_55 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t acc_66 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t acc_77 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svint32_t bias_acc_00 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t bias_acc_22 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t bias_acc_44 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t bias_acc_66 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| for (int sb = 0; sb < QK_K / 64; sb++) { | ||||||||||||||||||||||||||||||||||||||||||
| // Need scales for the low and high nibbles | ||||||||||||||||||||||||||||||||||||||||||
| // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3; | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t q4sb_mins_0, q4sb_mins_1; | ||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||
| // 2-superblock I am working on | ||||||||||||||||||||||||||||||||||||||||||
| const int offset = sb * 24 + 0 * 12; | ||||||||||||||||||||||||||||||||||||||||||
| const uint8_t * scales_in = &q4_ptr[b].scales[offset]; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| const int offset1 = sb * 24 + 12; | ||||||||||||||||||||||||||||||||||||||||||
| const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1]; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| constexpr uint32_t kmask1 = 0x3f3f3f3f; | ||||||||||||||||||||||||||||||||||||||||||
| constexpr uint32_t kmask2 = 0x0f0f0f0f; | ||||||||||||||||||||||||||||||||||||||||||
| constexpr uint32_t kmask3 = 0x03030303; | ||||||||||||||||||||||||||||||||||||||||||
| constexpr uint8_t scales_size = 12; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| uint32_t sm[3]; | ||||||||||||||||||||||||||||||||||||||||||
| memcpy(sm, scales_in, scales_size); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| uint32_t sm1[3]; | ||||||||||||||||||||||||||||||||||||||||||
| memcpy(sm1, scales_in1, scales_size); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| const uint32_t mins_0_3 = sm[1] & kmask1; | ||||||||||||||||||||||||||||||||||||||||||
| const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| const uint32_t mins_0_3_1 = sm1[1] & kmask1; | ||||||||||||||||||||||||||||||||||||||||||
| const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7)); | ||||||||||||||||||||||||||||||||||||||||||
| svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1)); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /* reinterpret u32 → u8 */ | ||||||||||||||||||||||||||||||||||||||||||
| svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp); | ||||||||||||||||||||||||||||||||||||||||||
| svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1); | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+3140
to
+3141
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| /* widen u8 → u16->u32 (lower half only) */ | ||||||||||||||||||||||||||||||||||||||||||
| svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8)); | ||||||||||||||||||||||||||||||||||||||||||
| svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1)); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| q4sb_mins_0 = svreinterpret_s32_u32(mins_u16); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1); | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+3147
to
+3149
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| uint32_t scales_u32_0 = sm[0] & kmask1; | ||||||||||||||||||||||||||||||||||||||||||
| uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4); | ||||||||||||||||||||||||||||||||||||||||||
| uint32_t scales_u32_2 = sm1[0] & kmask1; | ||||||||||||||||||||||||||||||||||||||||||
| uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svuint32_t S01 = svdup_n_u32(scales_u32_0); | ||||||||||||||||||||||||||||||||||||||||||
| svuint32_t S23 = svdup_n_u32(scales_u32_1); | ||||||||||||||||||||||||||||||||||||||||||
| svuint32_t R01 = svdup_n_u32(scales_u32_2); | ||||||||||||||||||||||||||||||||||||||||||
| svuint32_t R23 = svdup_n_u32(scales_u32_3); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svint8_t S01_b = svreinterpret_s8_u32(S01); | ||||||||||||||||||||||||||||||||||||||||||
| svint8_t S23_b = svreinterpret_s8_u32(S23); | ||||||||||||||||||||||||||||||||||||||||||
| svint8_t R01_b = svreinterpret_s8_u32(R01); | ||||||||||||||||||||||||||||||||||||||||||
| svint8_t R23_b = svreinterpret_s8_u32(R23); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b))); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b))); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b))); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b))); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx); | ||||||||||||||||||||||||||||||||||||||||||
| block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx); | ||||||||||||||||||||||||||||||||||||||||||
| block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx); | ||||||||||||||||||||||||||||||||||||||||||
| block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| // Load 32-byte per row pair, 1 subblock each time | ||||||||||||||||||||||||||||||||||||||||||
| // predicate for activating higher lanes for 16 int8 elements | ||||||||||||||||||||||||||||||||||||||||||
| const svbool_t ph16 = svptrue_pat_b8(SV_VL16); | ||||||||||||||||||||||||||||||||||||||||||
| // predicate for activating lower lanes for 16 int8 elements | ||||||||||||||||||||||||||||||||||||||||||
| const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112)); | ||||||||||||||||||||||||||||||||||||||||||
| svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144)); | ||||||||||||||||||||||||||||||||||||||||||
| svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176)); | ||||||||||||||||||||||||||||||||||||||||||
| svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208)); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128)); | ||||||||||||||||||||||||||||||||||||||||||
| svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160)); | ||||||||||||||||||||||||||||||||||||||||||
| svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192)); | ||||||||||||||||||||||||||||||||||||||||||
| svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224)); | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+3184
to
+3193
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| // Q4s columns iterated in pairs (01, 23, 45, 67) | ||||||||||||||||||||||||||||||||||||||||||
| for (int cp = 0; cp < ncols_interleaved / 2; cp++) { | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| sb_acc_0 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
| sb_acc_2 = svdup_n_s32(0); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); | ||||||||||||||||||||||||||||||||||||||||||
| svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); | ||||||||||||||||||||||||||||||||||||||||||
| svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); | ||||||||||||||||||||||||||||||||||||||||||
| svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4)); | ||||||||||||||||||||||||||||||||||||||||||
| svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4)); | ||||||||||||||||||||||||||||||||||||||||||
| svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4)); | ||||||||||||||||||||||||||||||||||||||||||
| svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4)); | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+3206
to
+3209
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0); | ||||||||||||||||||||||||||||||||||||||||||
| sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4); | ||||||||||||||||||||||||||||||||||||||||||
| sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1); | ||||||||||||||||||||||||||||||||||||||||||
| sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5); | ||||||||||||||||||||||||||||||||||||||||||
| sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if(cp == 0) { | ||||||||||||||||||||||||||||||||||||||||||
| acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0); | ||||||||||||||||||||||||||||||||||||||||||
| acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| if(cp == 1) { | ||||||||||||||||||||||||||||||||||||||||||
| acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1); | ||||||||||||||||||||||||||||||||||||||||||
| acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| if(cp == 2) { | ||||||||||||||||||||||||||||||||||||||||||
| acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2); | ||||||||||||||||||||||||||||||||||||||||||
| acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| if(cp == 3) { | ||||||||||||||||||||||||||||||||||||||||||
| acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3); | ||||||||||||||||||||||||||||||||||||||||||
| acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0); | ||||||||||||||||||||||||||||||||||||||||||
| bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0); | ||||||||||||||||||||||||||||||||||||||||||
| bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0); | ||||||||||||||||||||||||||||||||||||||||||
| bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0); | ||||||||||||||||||||||||||||||||||||||||||
| bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1); | ||||||||||||||||||||||||||||||||||||||||||
| } // for sb | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4)); | ||||||||||||||||||||||||||||||||||||||||||
| acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4)); | ||||||||||||||||||||||||||||||||||||||||||
| acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4)); | ||||||||||||||||||||||||||||||||||||||||||
| acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4)); | ||||||||||||||||||||||||||||||||||||||||||
| acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4)); | ||||||||||||||||||||||||||||||||||||||||||
| acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4)); | ||||||||||||||||||||||||||||||||||||||||||
| acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4)); | ||||||||||||||||||||||||||||||||||||||||||
| acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4)); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1); | ||||||||||||||||||||||||||||||||||||||||||
| svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| // Broadcast q8 scalar | ||||||||||||||||||||||||||||||||||||||||||
| svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0))); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0))); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); | ||||||||||||||||||||||||||||||||||||||||||
| svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1); | ||||||||||||||||||||||||||||||||||||||||||
| acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| q8_d = svdup_f32(q8_ptr[b].d[1]); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); | ||||||||||||||||||||||||||||||||||||||||||
| dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1); | ||||||||||||||||||||||||||||||||||||||||||
| acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| q8_d = svdup_f32(q8_ptr[b].d[2]); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); | ||||||||||||||||||||||||||||||||||||||||||
| dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1); | ||||||||||||||||||||||||||||||||||||||||||
| acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| q8_d = svdup_f32(q8_ptr[b].d[3]); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); | ||||||||||||||||||||||||||||||||||||||||||
| dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1); | ||||||||||||||||||||||||||||||||||||||||||
| acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| } // for b | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| // With the previous reorder, the tile is already in the correct memory layout. | ||||||||||||||||||||||||||||||||||||||||||
| // Predicate for exactly 4 lanes | ||||||||||||||||||||||||||||||||||||||||||
| svbool_t pg4 = svptrue_pat_b32(SV_VL4); | ||||||||||||||||||||||||||||||||||||||||||
| for (int i = 0; i < q8_k_blocklen; i++) { | ||||||||||||||||||||||||||||||||||||||||||
| int row = y * q8_k_blocklen + i; | ||||||||||||||||||||||||||||||||||||||||||
| for (int j = 0; j < 2; j++) { | ||||||||||||||||||||||||||||||||||||||||||
| int col = x * ncols_interleaved + j * 4; | ||||||||||||||||||||||||||||||||||||||||||
| int offset = row * bs + col; | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| if (i == 0 && j == 0) { | ||||||||||||||||||||||||||||||||||||||||||
| // acc_f32_0 → lower half of acc_f32_01 | ||||||||||||||||||||||||||||||||||||||||||
| svst1_f32(pg4, s + offset, acc_f32_01); | ||||||||||||||||||||||||||||||||||||||||||
| } else if (i == 0 && j == 1) { | ||||||||||||||||||||||||||||||||||||||||||
| // acc_f32_1 → upper half of acc_f32_01 | ||||||||||||||||||||||||||||||||||||||||||
| svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4)); | ||||||||||||||||||||||||||||||||||||||||||
| } else if (i == 1 && j == 0) { | ||||||||||||||||||||||||||||||||||||||||||
| // acc_f32_2 | ||||||||||||||||||||||||||||||||||||||||||
| svst1_f32(pg4, s + offset, acc_f32_23); | ||||||||||||||||||||||||||||||||||||||||||
| } else if (i == 1 && j == 1) { | ||||||||||||||||||||||||||||||||||||||||||
| // acc_f32_3 | ||||||||||||||||||||||||||||||||||||||||||
| svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4)); | ||||||||||||||||||||||||||||||||||||||||||
| } else if (i == 2 && j == 0) { | ||||||||||||||||||||||||||||||||||||||||||
| // acc_f32_4 | ||||||||||||||||||||||||||||||||||||||||||
| svst1_f32(pg4, s + offset, acc_f32_45); | ||||||||||||||||||||||||||||||||||||||||||
| } else if (i == 2 && j == 1) { | ||||||||||||||||||||||||||||||||||||||||||
| // acc_f32_5 | ||||||||||||||||||||||||||||||||||||||||||
| svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4)); | ||||||||||||||||||||||||||||||||||||||||||
| } else if (i == 3 && j == 0) { | ||||||||||||||||||||||||||||||||||||||||||
| // acc_f32_6 | ||||||||||||||||||||||||||||||||||||||||||
| svst1_f32(pg4, s + offset, acc_f32_67); | ||||||||||||||||||||||||||||||||||||||||||
| } else if (i == 3 && j == 1) { | ||||||||||||||||||||||||||||||||||||||||||
| // acc_f32_7 | ||||||||||||||||||||||||||||||||||||||||||
| svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4)); | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| } // for x | ||||||||||||||||||||||||||||||||||||||||||
| } // for y | ||||||||||||||||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||
| #endif // SVE compile-time end | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) | ||||||||||||||||||||||||||||||||||||||||||
| constexpr int q8_k_blocklen = 4; | ||||||||||||||||||||||||||||||||||||||||||
| const uint8x16_t m4b = vdupq_n_u8(0x0f); | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.