-
Notifications
You must be signed in to change notification settings - Fork 14.8k
ggml: aarch64: Implement SVE in Gemm q4_k 8x8 q8_k Kernel #19132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
ggml: aarch64: Implement SVE in Gemm q4_k 8x8 q8_k Kernel #19132
Conversation
|
cc @Alcpz |
|
Regarding CI Failure When I ran the same command on my system, it build correctly with no issue. Can we check or Rerun the CI pipeline. We have not made any changes in CMake or x86 code. I am attaching the logs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall I don't see any issues with the existing implementation, so all good from my perspective. Please, also try to run clang-format on your changes, there are some inconsistencies in the style.
| constexpr int q8_k_blocklen = 4; | ||
| const uint8x16_t m4b = vdupq_n_u8(0x0f); | ||
| #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) | ||
| if (svcntb()*8 == 256) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Format
| } | ||
|
|
||
| // q8_ptr[b].qs has interleaved Q8 rows (01, 23) | ||
| // const int8_t * q8_base = q8_ptr[b].qs + sb * 256; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is redundant commented code. Some comments could be improved a bit as well.
|
|
||
| for (int y = 0; y < nr / q8_k_blocklen; y++) { | ||
| const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); | ||
| const block_q8_Kx4 * GGML_RESTRICT q8_ptr_1 = (const block_q8_Kx4 *) vy + (y * nb); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the need for the same variable twice, I don't see them being used in a way that makes this necessary. Either clarify or cleanup.
| acc_f32_67 = svdup_n_f32(0); | ||
|
|
||
| for (int b = 0; b < nb; b++) { | ||
| // bsums pairs belongs to the same q8_k subblock // 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // bsums pairs belongs to the same q8_k subblock // 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum | |
| // bsums pairs belongs to the same q8_k subblock | |
| // 64 elements loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum |
The Server failures are due to changes in the CI. If you rebase on top of master you should get rid of those. I also saw some issues with the `x86 high performance job failing on other pipelines, but as you say, this is not caused here. |
c75f491 to
1d4d342
Compare
|
@Alcpz rebase and format related changes are pushed, Thank you ! |
|
Hi @ggerganov and @Alcpz Please review the code. I don't why @loci-dev is not showing the performance gain. My guess he might not utilizing the SVE code. I am to good help if you need anything else. Thank you. |
I've already gave my review, unfortunately, I don't have access to an SVE system so I can't really dig further into it. I'm trusting everything was tested accordingly since it's GEMM, so Perplexity should be enough to detect failures. Sorry I can't be of more help. |
This comment was marked as outdated.
This comment was marked as outdated.
…nce slowing the performance. So added code if SVE 256 is not present then use NEON code.
|
Hi @taronaeo Thanks for running the benchmark on Graviton 4 has SVE vector length of 128-bits, and current code is written for SVE vector length 256 bits. So, when you are running with this PR it does not utilize both NEON and SVE code and uses Now currently I have modified the code such way that if So now with these changes, you will see the similar performance for both NEON and this PR results. I am attaching my runtime results for For NEON
I hope this clears your doubt. Thank you. |
Great catch, I forgot about that. I can retest it on Graviton 3 where 256-bit SVE is available and update the benchmarks again :) |
|
I've tested this using an AWS
|
taronaeo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor code cleanup
| const uint32_t mins_0_3 = sm[1] & kmask1; | ||
| const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4); | ||
|
|
||
| const uint32_t mins_0_3_1 = sm1[1] & kmask1; | ||
| const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| const uint32_t mins_0_3 = sm[1] & kmask1; | |
| const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4); | |
| const uint32_t mins_0_3_1 = sm1[1] & kmask1; | |
| const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4); | |
| const uint32_t mins_0_3 = sm[1] & kmask1; | |
| const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4); | |
| const uint32_t mins_0_3_1 = sm1[1] & kmask1; | |
| const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4); |
| svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp); | ||
| svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp); | |
| svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1); | |
| svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp); | |
| svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1); |
| q4sb_mins_0 = svreinterpret_s32_u32(mins_u16); | ||
|
|
||
| q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| q4sb_mins_0 = svreinterpret_s32_u32(mins_u16); | |
| q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1); | |
| q4sb_mins_0 = svreinterpret_s32_u32(mins_u16); | |
| q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1); |
|
|
||
| svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112)); | ||
| svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144)); | ||
| svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176)); | ||
| svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208)); | ||
|
|
||
| svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128)); | ||
| svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160)); | ||
| svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192)); | ||
| svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112)); | |
| svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144)); | |
| svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176)); | |
| svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208)); | |
| svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128)); | |
| svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160)); | |
| svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192)); | |
| svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224)); | |
| svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112)); | |
| svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144)); | |
| svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176)); | |
| svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208)); | |
| svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128)); | |
| svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160)); | |
| svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192)); | |
| svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224)); |
| svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4)); | ||
| svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4)); | ||
| svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4)); | ||
| svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4)); | |
| svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4)); | |
| svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4)); | |
| svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4)); | |
| svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4)); | |
| svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4)); | |
| svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4)); | |
| svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4)); |
This PR introduces support for SVE (Scalable Vector Extensions) kernels for the q4_K_q8_K gemm using i8mm and vector instructions. ARM Neon support for this kernel added in PR #16739
Verifying Feature
----------------------------------------------------------------------------This PR contains the SVE implementation of the gemm used to compute the Q4_K quantization.
Kernel: ggml_gemm_q4_K_8x8_q8_K()By running a Q4_K_M quantized model of Llama-3.1-8B, I checked the generation output.
I also verified that the perplexity matches between the NEON and SVE implementations.
This correction does not appear to have any impact on accuracy.
The command used to measure the perplexity measure is
Performance Check
----------------------------------------------------------------------------This PR Improves the Prompt Eval time (TTFT) of LLM Inference by 17-20%, as compared to NEON (PR #16739).
The performance was measured on Graviton3E @ 64 core.
Performance is improved as follows. The value is tokens per second.
The command used to measure the performance is
This work is a contribution of @Vithulep and @abhijain1204fujitsu