From 9b2129b05504fcadd4f3e715651c3f7d56f38402 Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Wed, 14 Jan 2026 15:56:50 +0000 Subject: [PATCH 01/14] Boilerplate for q5_Kx8 REPACK on ARM and fallback Signed-off-by: Alberto Cabrera --- ggml/src/ggml-cpu/arch-fallback.h | 38 +++++++---- ggml/src/ggml-cpu/arch/arm/repack.cpp | 54 +++++++++++++++ ggml/src/ggml-cpu/repack.cpp | 98 +++++++++++++++++++++++---- ggml/src/ggml-cpu/repack.h | 25 +++++-- 4 files changed, 187 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 3f8946ac701..ba6a19b9870 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -38,9 +38,10 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -48,9 +49,10 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -70,12 +72,14 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 @@ -94,9 +98,10 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -104,9 +109,10 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -126,9 +132,10 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -136,9 +143,10 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -165,18 +173,20 @@ #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -202,9 +212,10 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -212,9 +223,10 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -242,9 +254,10 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -252,9 +265,10 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index b61220a189a..13e51172b46 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -786,6 +786,33 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q5_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if 0 && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + GGML_ABORT("ggml_gemv_q5_K_8x8_q8_K: ARM NEON DOTPROD implementation not yet available"); + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, @@ -2738,6 +2765,33 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_q5_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if 0 && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + GGML_ABORT("ggml_gemm_q5_K_8x8_q8_K: NEON+MATMUL_INT8 implementation not available yet"); + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index fbf7ed9432a..e95a7291d07 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -616,6 +616,16 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemv_q5_K_8x8_q8_K_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + GGML_ABORT("ggml_gemv_q5_K_8x8_q8_K_generic: not implemented yet"); +} + void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1212,6 +1222,15 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemm_q5_K_8x8_q8_K_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + GGML_ABORT("ggml_gemm_q5_K_8x8_q8_K_generic: not implemented yet"); +} void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -1622,7 +1641,10 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in out.scales[i] = in[src1].scales[src2]; } return out; +} +static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) { + GGML_ABORT("make_block_q5_Kx8: not implemented yet"); } static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { @@ -1718,6 +1740,38 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_K); + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 8; + + block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data; + const block_q5_K * src = (const block_q5_K *) data; + block_q5_K dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q5_Kx8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(interleave_block == 8); @@ -1936,6 +1990,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size); } @@ -1973,6 +2031,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -1981,8 +2043,8 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { @@ -2013,20 +2075,24 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); -} - template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { @@ -2403,10 +2469,9 @@ template (ne00, - (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, - src0_cur + src0_cur_start * nb01, - src1_col, 1, src0_cur_end - src0_cur_start); + gemv( + ne00, (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, + src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); } } #undef MMID_MATRIX_ROW @@ -2432,6 +2497,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_K_8x4_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K; + // instance for Q5_K + static const ggml::cpu::repack::tensor_traits q5_K_8x8_q8_K; + // instance for Q2 static const ggml::cpu::repack::tensor_traits q2_K_8x8_q8_K; @@ -2482,6 +2550,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q2_K_8x8_q8_K; } } + } else if (cur->type == GGML_TYPE_Q5_K) { + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (cur->ne[1] % 8 == 0) { + return &q5_K_8x8_q8_K; + } + } } else if (cur->type == GGML_TYPE_IQ4_NL) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index af98e703442..05e2425f1d5 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -44,6 +44,7 @@ struct block_q4_Kx8 { }; static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding"); + struct block_q2_Kx8 { ggml_half d[8]; // super-block scale for quantized scales ggml_half dmin[8]; // super-block scale for quantized mins @@ -52,6 +53,18 @@ struct block_q2_Kx8 { }; static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding"); + +struct block_q5_Kx8 { + ggml_half d[8]; // super-block scale for quantized scales + ggml_half dmin[8]; // super-block scale for quantized mins + uint8_t scales[96]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K * 8 / 8]; // high bits of 5-bit quants + uint8_t qs[QK_K * 8 / 2]; // 4--bit quants +}; + +static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5, + "wrong q5_K block size/padding"); + struct block_q8_Kx4 { float d[4]; // delta int8_t qs[QK_K * 4]; // quants @@ -82,20 +95,22 @@ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTR void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -111,17 +126,19 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); From 7d944e996338555eda1849faf1b87a677d8d3957 Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Wed, 14 Jan 2026 16:17:41 +0000 Subject: [PATCH 02/14] Implements make_block_q5_Kx8 by extending make_block_q4_Kx8 Signed-off-by: Alberto Cabrera --- ggml/src/ggml-cpu/repack.cpp | 87 +++++++++++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index e95a7291d07..2d8f5e6e5e4 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1644,7 +1644,92 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in } static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) { - GGML_ABORT("make_block_q5_Kx8: not implemented yet"); + block_q5_Kx8 out; + //Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < 8; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + const int end = QK_K * 4 / blck_size_interleave; + + // Interleave Q5_K quants by taking 8 bytes at a time + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + + // Repeat for low bits 8 bytes at a time as well, since + // the high bits are interleaved in Q5_K and the index is + // qh_idx = (qs_idx % 32); + // qh_val = qh[qs_idx] >> (qs_idx / 32); + for (int i = 0; i < end / 4; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t)); + memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t)); + } + + // The below logic is copied over from Q4_K + // The point is to unpack all the scales and mins for each sub block every time we load 12 bytes. + // Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) + // The output Q5_Kx8 structure has 96 bytes + // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure + // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures + uint8_t s[8], m[8]; + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = in[j].scales[i] & 63; + m[j] = in[j].scales[i + 4] & 63; + } + + out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); + } + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15); + m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4); + } + + out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); + } + + return out; } static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { From 5ea06c3a37e822f31268c596617e99c934f7bfbb Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Wed, 14 Jan 2026 18:25:17 +0000 Subject: [PATCH 03/14] q5_K repack gemm and gemv generics --- ggml/src/ggml-cpu/repack.cpp | 188 +++++++++++++++++++++++++++++++++-- 1 file changed, 179 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 2d8f5e6e5e4..f6e45d56d82 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -474,15 +474,8 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, assert (n % qk == 0); assert (nc % ncols_interleaved == 0); - UNUSED(s); UNUSED(bs); - UNUSED(vx); - UNUSED(vy); UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); float sumf[8]; float sum_minf[8]; @@ -623,7 +616,91 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n, const void * GGML_RESTRICT vy, int nr, int nc) { - GGML_ABORT("ggml_gemv_q5_K_8x8_q8_K_generic: not implemented yet"); + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + float sum_minf[8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16; + + const int qh_shift = (k / 4) * 2; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * 8 + i) % 32; + const int qh_chunk = qh_idx / 8; + const int qh_pos = qh_idx % 8; + const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } } void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1229,7 +1306,100 @@ void ggml_gemm_q5_K_8x8_q8_K_generic(int n, const void * GGML_RESTRICT vy, int nr, int nc) { - GGML_ABORT("ggml_gemm_q5_K_8x8_q8_K_generic: not implemented yet"); + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + constexpr uint32_t kmask1 = 0x3f3f3f3f; + constexpr uint32_t kmask2 = 0x0f0f0f0f; + constexpr uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][8]; + float sum_minf[4][8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16; + + const int qh_shift = (k / 4) * 2; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * 8 + i) % 32; + const int qh_chunk = qh_idx / 8; + const int qh_pos = qh_idx % 8; + const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int m = 0; m < 4; m++) { + const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } } void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { From f5341c60621fd260e75ba5838f829e4b4f0c7f7a Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Wed, 14 Jan 2026 20:15:10 +0000 Subject: [PATCH 04/14] Gemm and Gemv ARM implementations (i8mm) --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 386 +++++++++++++++++++++++++- ggml/src/ggml-cpu/repack.h | 2 +- 2 files changed, 376 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 13e51172b46..0b09b351920 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -25,9 +25,8 @@ #define UNUSED GGML_UNUSED #if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD)) -static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in, - int16x8_t * out_mins, - int8_t * out_scales) { +// Helper for decoding scales and mins of Q4_K and Q5_K block formats +static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) { constexpr uint32_t kmask1 = 0x3f3f3f3f; constexpr uint32_t kmask2 = 0x0f0f0f0f; constexpr uint32_t kmask3 = 0x03030303; @@ -561,7 +560,7 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int i = 0; i < 2; i++) { int8_t aux_q4sb[8]; const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); } @@ -701,7 +700,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, for (int i = 0; i < 2; i++) { int8_t aux_q4sb[8]; const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); } @@ -806,8 +805,164 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, UNUSED(ncols_interleaved); UNUSED(blocklen); -#if 0 && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - GGML_ABORT("ggml_gemv_q5_K_8x8_q8_K: ARM NEON DOTPROD implementation not yet available"); +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_pairs = ncols_interleaved / 2; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mhb = vdupq_n_u8(0x10); // high bit mask for 5th bit position + + // Bit masks for extracting high bits based on subblock index + // sb=0: bit 0 (lo), bit 1 (hi); sb=1: bit 2 (lo), bit 3 (hi); etc. + const uint8_t bit_masks[8] = { 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80 }; + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[ncols_interleaved / 4]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int i = 0; i < ncols_interleaved / 4; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d); + float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d); + float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3 + float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7 + float32x4_t sb_min_0 = vmulq_f32(q5_dmin_0, q8_d); + float32x4_t sb_min_1 = vmulq_f32(q5_dmin_1, q8_d); + + // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567 + int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + // 2 sb each iteration + int32x4_t acc_lo[col_pairs]; + int32x4_t acc_hi[col_pairs]; + + // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block + const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8)); + int16_t bsums_arr[8]; + vst1q_s16(bsums_arr, bsums); + for (int sb = 0; sb < QK_K / 64; sb++) { + for (int i = 0; i < col_pairs; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later + int16x8_t q5sb_scales[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q5sb[8]; + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb); + q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb)); + } + + const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K; + const uint8_t * qh_base = q5_ptr[b].qh; // qh is shared across all subblocks + + // Masks for extracting high bits for this subblock + const uint8x16_t lo_bit_mask = vdupq_n_u8(bit_masks[sb * 2]); + const uint8x16_t hi_bit_mask = vdupq_n_u8(bit_masks[sb * 2 + 1]); + + // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns + const int8_t * q8_base = q8_ptr[b].qs + sb * 64; + int8x16_t q8_qs[8]; + for (int i = 0; i < 8; i++) { + q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8)); + } + + // Q5s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < col_pairs; cp++) { + // Low bits + uint8x16_t qs_cp_0 = vld1q_u8(qs_base + 16 * cp); + uint8x16_t qs_cp_1 = vld1q_u8(qs_base + 16 * cp + 64); + uint8x16_t qs_cp_2 = vld1q_u8(qs_base + 16 * cp + 128); + uint8x16_t qs_cp_3 = vld1q_u8(qs_base + 16 * cp + 192); + + // High bits (Q5_K specific) + uint8x16_t qh_cp_0 = vld1q_u8(qh_base + 16 * cp); + uint8x16_t qh_cp_1 = vld1q_u8(qh_base + 16 * cp + 64); + uint8x16_t qh_cp_2 = vld1q_u8(qh_base + 16 * cp + 128); + uint8x16_t qh_cp_3 = vld1q_u8(qh_base + 16 * cp + 192); + + uint8x16_t hbit_lo_0 = vandq_u8(vtstq_u8(qh_cp_0, lo_bit_mask), mhb); + uint8x16_t hbit_lo_1 = vandq_u8(vtstq_u8(qh_cp_1, lo_bit_mask), mhb); + uint8x16_t hbit_lo_2 = vandq_u8(vtstq_u8(qh_cp_2, lo_bit_mask), mhb); + uint8x16_t hbit_lo_3 = vandq_u8(vtstq_u8(qh_cp_3, lo_bit_mask), mhb); + + uint8x16_t hbit_hi_0 = vandq_u8(vtstq_u8(qh_cp_0, hi_bit_mask), mhb); + uint8x16_t hbit_hi_1 = vandq_u8(vtstq_u8(qh_cp_1, hi_bit_mask), mhb); + uint8x16_t hbit_hi_2 = vandq_u8(vtstq_u8(qh_cp_2, hi_bit_mask), mhb); + uint8x16_t hbit_hi_3 = vandq_u8(vtstq_u8(qh_cp_3, hi_bit_mask), mhb); + + // Combine 4-bit values with high bits to get 5-bit values + uint8x16_t q5_lo_0 = vorrq_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0); + uint8x16_t q5_lo_1 = vorrq_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1); + uint8x16_t q5_lo_2 = vorrq_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2); + uint8x16_t q5_lo_3 = vorrq_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3); + + uint8x16_t q5_hi_0 = vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0); + uint8x16_t q5_hi_1 = vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1); + uint8x16_t q5_hi_2 = vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2); + uint8x16_t q5_hi_3 = vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3); + + acc_lo[cp] = ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(q5_lo_0), q8_qs[0]); // 0 .. 7 + acc_lo[cp] = ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(q5_lo_1), q8_qs[1]); // 8 ..15 + acc_lo[cp] = ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(q5_lo_2), q8_qs[2]); // 16..23 + acc_lo[cp] = ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(q5_lo_3), q8_qs[3]); // 24..31 + + acc_hi[cp] = ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(q5_hi_0), q8_qs[4]); // 32..39 + acc_hi[cp] = ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(q5_hi_1), q8_qs[5]); // 40..47 + acc_hi[cp] = ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(q5_hi_2), q8_qs[6]); // 48..55 + acc_hi[cp] = ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(q5_hi_3), q8_qs[7]); // 56..63 + } + + // Iterates over a pair of column pairs (4 columns) to use a single 128 register + // p = 0 -> 0123 p2 -> 4567 + for (int i = 0, p = 0; p < col_pairs; i++, p += 2) { + int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]); + int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]); + float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1; + + // 0123 or 4567 + float32x4_t sumf_0 = + vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1]))); + acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0); + + float32x4_t sumf_1 = + vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1]))); + acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1); + } + + // Multiply Acc bsum + mins + // Each pair of subblocks share the same bsums + // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)). + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]); + + // cols 0-3 bias + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0])); + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1])); + + // cols 4-7 bias + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0])); + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1])); + } // for sb + + acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0); + acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x return; #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); @@ -2458,7 +2613,7 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int i = 0; i < 2; i++) { int8_t aux_q4sb[8]; const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); } @@ -2622,7 +2777,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later for (int i = 0; i < 2; i++) { const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]); } // q8_ptr[b].qs has interleaved Q8 rows (01, 23) @@ -2786,8 +2941,217 @@ void ggml_gemm_q5_K_8x8_q8_K(int n, UNUSED(ncols_interleaved); UNUSED(blocklen); -#if 0 && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - GGML_ABORT("ggml_gemm_q5_K_8x8_q8_K: NEON+MATMUL_INT8 implementation not available yet"); +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + constexpr int q8_k_blocklen = 4; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mhb = vdupq_n_u8(0x10); // high bit mask for 5th bit position + + // Bit masks for extracting high bits based on subblock index + const uint8_t bit_masks[8] = { 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80 }; + + // 8 accumulators: 2 row pairs × 4 col pairs + float32x4_t acc_f32[blocklen]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int i = 0; i < blocklen; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + // bsums pairs belongs to the same q8_k subblock + const int16x8_t bsums[4]{ + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + int16_t bsums_arr[4][8]; + for (int q8_row = 0; q8_row < 4; q8_row++) { + vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); + } + + int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results + int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7] + int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ... + for (int i = 0; i < 8; i++) { + acc[i] = vdupq_n_s32(0); + bias_acc[i] = vdupq_n_s32(0); + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int8_t q5sb_scales[2][8]; + int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later + for (int i = 0; i < 2; i++) { + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]); + } + + // q8_ptr[b].qs has interleaved Q8 rows (01, 23) + const int8_t * q8_base = q8_ptr[b].qs + sb * 256; + + int8x16_t q8_qs_01[8]; + int8x16_t q8_qs_23[8]; + + // Load 32-byte per row pair, 1 subblock each time + for (int i = 0; i < 8; i++) { + const int offset = i * 32; // 16 for row 01, 16 for row 23 + q8_qs_01[i] = vld1q_s8(q8_base + offset); + q8_qs_23[i] = vld1q_s8(q8_base + offset + 16); + } + + const int8x16_t q8s[2][8] = { + { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], + q8_qs_01[7] }, + { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], + q8_qs_23[7] }, + }; + + // Masks for extracting high bits for this subblock + const uint8x16_t lo_bit_mask = vdupq_n_u8(bit_masks[sb * 2]); + const uint8x16_t hi_bit_mask = vdupq_n_u8(bit_masks[sb * 2 + 1]); + + // Q5s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < ncols_interleaved / 2; cp++) { + for (int i = 0; i < 4; i++) { + sb_acc[i] = vdupq_n_s32(0); + } + + uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39 + uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47 + uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55 + uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63 + + // This is the only part of the algorithm that differs with Q4_K + // High bits + uint8x16_t qh_cp_0 = vld1q_u8(q5_ptr[b].qh + 16 * cp); + uint8x16_t qh_cp_1 = vld1q_u8(q5_ptr[b].qh + 16 * cp + 64); + uint8x16_t qh_cp_2 = vld1q_u8(q5_ptr[b].qh + 16 * cp + 128); + uint8x16_t qh_cp_3 = vld1q_u8(q5_ptr[b].qh + 16 * cp + 192); + + uint8x16_t hbit_lo_0 = vandq_u8(vtstq_u8(qh_cp_0, lo_bit_mask), mhb); + uint8x16_t hbit_lo_1 = vandq_u8(vtstq_u8(qh_cp_1, lo_bit_mask), mhb); + uint8x16_t hbit_lo_2 = vandq_u8(vtstq_u8(qh_cp_2, lo_bit_mask), mhb); + uint8x16_t hbit_lo_3 = vandq_u8(vtstq_u8(qh_cp_3, lo_bit_mask), mhb); + + uint8x16_t hbit_hi_0 = vandq_u8(vtstq_u8(qh_cp_0, hi_bit_mask), mhb); + uint8x16_t hbit_hi_1 = vandq_u8(vtstq_u8(qh_cp_1, hi_bit_mask), mhb); + uint8x16_t hbit_hi_2 = vandq_u8(vtstq_u8(qh_cp_2, hi_bit_mask), mhb); + uint8x16_t hbit_hi_3 = vandq_u8(vtstq_u8(qh_cp_3, hi_bit_mask), mhb); + + // Combine 4-bit values with high bits to get 5-bit values + const int8x16_t q5_nibbles[2][4] = { + { + vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0)), + vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1)), + vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2)), + vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3)), + }, + { + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0)), + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1)), + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2)), + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3)), + } + }; + // From here, it's the same as Q4_K + + // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8 + // for each of the internal 32 qs subblock (blk) + for (int rp = 0; rp < 2; rp++) { + for (int blk = 0; blk < 2; blk++) { + const int8x16_t * q8 = &q8s[rp][4 * blk]; + const int8x16_t * q5 = q5_nibbles[blk]; + int32x4_t acc = sb_acc[2 * rp + blk]; + // mul add for each qs in the same subblock + for (int qs_offset = 0; qs_offset < 4; qs_offset++) { + acc = vmmlaq_s32(acc, q5[qs_offset], q8[qs_offset]); + } + sb_acc[2 * rp + blk] = acc; + } + } + + // Scales[i] corresponds to column i + const int scale_offset = cp * 2; + for (int blk = 0; blk < 2; blk++) { + const int32x4_t block_scale = { + (int32_t) q5sb_scales[blk][scale_offset], + (int32_t) q5sb_scales[blk][scale_offset], + (int32_t) q5sb_scales[blk][scale_offset + 1], + (int32_t) q5sb_scales[blk][scale_offset + 1], + }; + acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale); + } + } + + // Multiply Acc bsum + mins + for (int q8_row = 0; q8_row < 4; q8_row++) { + // Each pair of subblocks share the same bsums + // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)). + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]); + + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0])); + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1])); + } + } // for sb + + // Reorder of i8mm output with bias and output layout + for (int i = 0; i < 8; i++) { + int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i])); + acc[i] = vcombine_s32(aux.val[0], aux.val[1]); + } + int32x4_t reorder_acc[8] = { + vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])), + vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])), + vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])), + vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])), + vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])), + vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])), + vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])), + vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])), + }; + + for (int i = 0; i < q8_k_blocklen; i++) { + for (int j = 0; j < 2; j++) { + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]); + float32x4_t q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4))); + const float32x4_t dmins = vmulq_f32(q5_dmin, q8_d); + + float32x4_t q5_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4))); + const float32x4_t scale = vmulq_f32(q5_d, q8_d); + + acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins); + acc_f32[2 * i + j] = + vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale); + } + } + } // for b + + // With the previous reorder, the tile is already in the correct memory layout. + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y return; #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 05e2425f1d5..da87103157e 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -59,7 +59,7 @@ struct block_q5_Kx8 { ggml_half dmin[8]; // super-block scale for quantized mins uint8_t scales[96]; // scales and mins, quantized with 6 bits uint8_t qh[QK_K * 8 / 8]; // high bits of 5-bit quants - uint8_t qs[QK_K * 8 / 2]; // 4--bit quants + uint8_t qs[QK_K * 8 / 2]; // low bits of 5-bit quants (in groups of 4) }; static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5, From a8e2fdbd52a066b21a2003002ad1003a1968999a Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Wed, 14 Jan 2026 20:30:19 +0000 Subject: [PATCH 05/14] Improved qh manipulation looking at non-repack vec_dot implementation --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 56 +++++++++++++++------------ 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 0b09b351920..83bdbe39299 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -808,11 +808,8 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) constexpr int col_pairs = ncols_interleaved / 2; const uint8x16_t m4b = vdupq_n_u8(0x0f); - const uint8x16_t mhb = vdupq_n_u8(0x10); // high bit mask for 5th bit position - - // Bit masks for extracting high bits based on subblock index - // sb=0: bit 0 (lo), bit 1 (hi); sb=1: bit 2 (lo), bit 3 (hi); etc. - const uint8_t bit_masks[8] = { 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80 }; + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); // 1x8 tile = 2 x 4 float32x4_t acc_f32[ncols_interleaved / 4]; @@ -847,6 +844,17 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8)); int16_t bsums_arr[8]; vst1q_s16(bsums_arr, bsums); + + // Load qh once per block and shift after each subblock + const uint8_t * qh_base = q5_ptr[b].qh; + uint8x16_t qh[col_pairs][4]; + for (int cp = 0; cp < col_pairs; cp++) { + qh[cp][0] = vld1q_u8(qh_base + 16 * cp); + qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64); + qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128); + qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192); + } + for (int sb = 0; sb < QK_K / 64; sb++) { for (int i = 0; i < col_pairs; i++) { acc_lo[i] = vdupq_n_s32(0); @@ -864,11 +872,6 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, } const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K; - const uint8_t * qh_base = q5_ptr[b].qh; // qh is shared across all subblocks - - // Masks for extracting high bits for this subblock - const uint8x16_t lo_bit_mask = vdupq_n_u8(bit_masks[sb * 2]); - const uint8x16_t hi_bit_mask = vdupq_n_u8(bit_masks[sb * 2 + 1]); // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns const int8_t * q8_base = q8_ptr[b].qs + sb * 64; @@ -879,27 +882,22 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, // Q5s columns iterated in pairs (01, 23, 45, 67) for (int cp = 0; cp < col_pairs; cp++) { - // Low bits + // Low 4 bits from qs uint8x16_t qs_cp_0 = vld1q_u8(qs_base + 16 * cp); uint8x16_t qs_cp_1 = vld1q_u8(qs_base + 16 * cp + 64); uint8x16_t qs_cp_2 = vld1q_u8(qs_base + 16 * cp + 128); uint8x16_t qs_cp_3 = vld1q_u8(qs_base + 16 * cp + 192); - // High bits (Q5_K specific) - uint8x16_t qh_cp_0 = vld1q_u8(qh_base + 16 * cp); - uint8x16_t qh_cp_1 = vld1q_u8(qh_base + 16 * cp + 64); - uint8x16_t qh_cp_2 = vld1q_u8(qh_base + 16 * cp + 128); - uint8x16_t qh_cp_3 = vld1q_u8(qh_base + 16 * cp + 192); - - uint8x16_t hbit_lo_0 = vandq_u8(vtstq_u8(qh_cp_0, lo_bit_mask), mhb); - uint8x16_t hbit_lo_1 = vandq_u8(vtstq_u8(qh_cp_1, lo_bit_mask), mhb); - uint8x16_t hbit_lo_2 = vandq_u8(vtstq_u8(qh_cp_2, lo_bit_mask), mhb); - uint8x16_t hbit_lo_3 = vandq_u8(vtstq_u8(qh_cp_3, lo_bit_mask), mhb); + // Extract high bits (mimics q5_k non-repack vec_dot) + uint8x16_t hbit_lo_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mone), 4); + uint8x16_t hbit_lo_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mone), 4); + uint8x16_t hbit_lo_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mone), 4); + uint8x16_t hbit_lo_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mone), 4); - uint8x16_t hbit_hi_0 = vandq_u8(vtstq_u8(qh_cp_0, hi_bit_mask), mhb); - uint8x16_t hbit_hi_1 = vandq_u8(vtstq_u8(qh_cp_1, hi_bit_mask), mhb); - uint8x16_t hbit_hi_2 = vandq_u8(vtstq_u8(qh_cp_2, hi_bit_mask), mhb); - uint8x16_t hbit_hi_3 = vandq_u8(vtstq_u8(qh_cp_3, hi_bit_mask), mhb); + uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3); + uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3); + uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3); + uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3); // Combine 4-bit values with high bits to get 5-bit values uint8x16_t q5_lo_0 = vorrq_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0); @@ -923,6 +921,14 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, acc_hi[cp] = ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(q5_hi_3), q8_qs[7]); // 56..63 } + // Prepare next subblock + for (int cp = 0; cp < col_pairs; cp++) { + qh[cp][0] = vshrq_n_u8(qh[cp][0], 2); + qh[cp][1] = vshrq_n_u8(qh[cp][1], 2); + qh[cp][2] = vshrq_n_u8(qh[cp][2], 2); + qh[cp][3] = vshrq_n_u8(qh[cp][3], 2); + } + // Iterates over a pair of column pairs (4 columns) to use a single 128 register // p = 0 -> 0123 p2 -> 4567 for (int i = 0, p = 0; p < col_pairs; i++, p += 2) { From 960689d2f18becc540e3dbb10aea92ae4f0f0646 Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Wed, 14 Jan 2026 20:36:17 +0000 Subject: [PATCH 06/14] Full unroll --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 165 ++++++++++++++++++-------- 1 file changed, 118 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 83bdbe39299..5f008696476 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -880,53 +880,124 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8)); } - // Q5s columns iterated in pairs (01, 23, 45, 67) - for (int cp = 0; cp < col_pairs; cp++) { - // Low 4 bits from qs - uint8x16_t qs_cp_0 = vld1q_u8(qs_base + 16 * cp); - uint8x16_t qs_cp_1 = vld1q_u8(qs_base + 16 * cp + 64); - uint8x16_t qs_cp_2 = vld1q_u8(qs_base + 16 * cp + 128); - uint8x16_t qs_cp_3 = vld1q_u8(qs_base + 16 * cp + 192); - - // Extract high bits (mimics q5_k non-repack vec_dot) - uint8x16_t hbit_lo_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mone), 4); - uint8x16_t hbit_lo_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mone), 4); - uint8x16_t hbit_lo_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mone), 4); - uint8x16_t hbit_lo_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mone), 4); - - uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3); - uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3); - uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3); - uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3); - - // Combine 4-bit values with high bits to get 5-bit values - uint8x16_t q5_lo_0 = vorrq_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0); - uint8x16_t q5_lo_1 = vorrq_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1); - uint8x16_t q5_lo_2 = vorrq_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2); - uint8x16_t q5_lo_3 = vorrq_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3); - - uint8x16_t q5_hi_0 = vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0); - uint8x16_t q5_hi_1 = vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1); - uint8x16_t q5_hi_2 = vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2); - uint8x16_t q5_hi_3 = vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3); - - acc_lo[cp] = ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(q5_lo_0), q8_qs[0]); // 0 .. 7 - acc_lo[cp] = ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(q5_lo_1), q8_qs[1]); // 8 ..15 - acc_lo[cp] = ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(q5_lo_2), q8_qs[2]); // 16..23 - acc_lo[cp] = ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(q5_lo_3), q8_qs[3]); // 24..31 - - acc_hi[cp] = ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(q5_hi_0), q8_qs[4]); // 32..39 - acc_hi[cp] = ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(q5_hi_1), q8_qs[5]); // 40..47 - acc_hi[cp] = ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(q5_hi_2), q8_qs[6]); // 48..55 - acc_hi[cp] = ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(q5_hi_3), q8_qs[7]); // 56..63 - } - - // Prepare next subblock - for (int cp = 0; cp < col_pairs; cp++) { - qh[cp][0] = vshrq_n_u8(qh[cp][0], 2); - qh[cp][1] = vshrq_n_u8(qh[cp][1], 2); - qh[cp][2] = vshrq_n_u8(qh[cp][2], 2); - qh[cp][3] = vshrq_n_u8(qh[cp][3], 2); + // Q5s columns iterated in pairs (01, 23, 45, 67) - fully unrolled + { + // Column pair 0 + uint8x16_t qs_0 = vld1q_u8(qs_base); + uint8x16_t qs_1 = vld1q_u8(qs_base + 64); + uint8x16_t qs_2 = vld1q_u8(qs_base + 128); + uint8x16_t qs_3 = vld1q_u8(qs_base + 192); + + uint8x16_t hbit_lo_0 = vshlq_n_u8(vandq_u8(qh[0][0], mone), 4); + uint8x16_t hbit_lo_1 = vshlq_n_u8(vandq_u8(qh[0][1], mone), 4); + uint8x16_t hbit_lo_2 = vshlq_n_u8(vandq_u8(qh[0][2], mone), 4); + uint8x16_t hbit_lo_3 = vshlq_n_u8(vandq_u8(qh[0][3], mone), 4); + uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3); + uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3); + uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3); + uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3); + + // Shift qh early to overlap with dot product execution + qh[0][0] = vshrq_n_u8(qh[0][0], 2); + qh[0][1] = vshrq_n_u8(qh[0][1], 2); + qh[0][2] = vshrq_n_u8(qh[0][2], 2); + qh[0][3] = vshrq_n_u8(qh[0][3], 2); + + acc_lo[0] = ggml_vdotq_s32(acc_lo[0], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_0, m4b), hbit_lo_0)), q8_qs[0]); + acc_lo[0] = ggml_vdotq_s32(acc_lo[0], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_1, m4b), hbit_lo_1)), q8_qs[1]); + acc_lo[0] = ggml_vdotq_s32(acc_lo[0], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_2, m4b), hbit_lo_2)), q8_qs[2]); + acc_lo[0] = ggml_vdotq_s32(acc_lo[0], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_3, m4b), hbit_lo_3)), q8_qs[3]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), q8_qs[4]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), q8_qs[5]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), q8_qs[6]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), q8_qs[7]); + + // Column pair 1 + qs_0 = vld1q_u8(qs_base + 16); + qs_1 = vld1q_u8(qs_base + 80); + qs_2 = vld1q_u8(qs_base + 144); + qs_3 = vld1q_u8(qs_base + 208); + + hbit_lo_0 = vshlq_n_u8(vandq_u8(qh[1][0], mone), 4); + hbit_lo_1 = vshlq_n_u8(vandq_u8(qh[1][1], mone), 4); + hbit_lo_2 = vshlq_n_u8(vandq_u8(qh[1][2], mone), 4); + hbit_lo_3 = vshlq_n_u8(vandq_u8(qh[1][3], mone), 4); + hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3); + hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3); + hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3); + hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3); + + qh[1][0] = vshrq_n_u8(qh[1][0], 2); + qh[1][1] = vshrq_n_u8(qh[1][1], 2); + qh[1][2] = vshrq_n_u8(qh[1][2], 2); + qh[1][3] = vshrq_n_u8(qh[1][3], 2); + + acc_lo[1] = ggml_vdotq_s32(acc_lo[1], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_0, m4b), hbit_lo_0)), q8_qs[0]); + acc_lo[1] = ggml_vdotq_s32(acc_lo[1], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_1, m4b), hbit_lo_1)), q8_qs[1]); + acc_lo[1] = ggml_vdotq_s32(acc_lo[1], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_2, m4b), hbit_lo_2)), q8_qs[2]); + acc_lo[1] = ggml_vdotq_s32(acc_lo[1], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_3, m4b), hbit_lo_3)), q8_qs[3]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), q8_qs[4]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), q8_qs[5]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), q8_qs[6]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), q8_qs[7]); + + // Column pair 2 + qs_0 = vld1q_u8(qs_base + 32); + qs_1 = vld1q_u8(qs_base + 96); + qs_2 = vld1q_u8(qs_base + 160); + qs_3 = vld1q_u8(qs_base + 224); + + hbit_lo_0 = vshlq_n_u8(vandq_u8(qh[2][0], mone), 4); + hbit_lo_1 = vshlq_n_u8(vandq_u8(qh[2][1], mone), 4); + hbit_lo_2 = vshlq_n_u8(vandq_u8(qh[2][2], mone), 4); + hbit_lo_3 = vshlq_n_u8(vandq_u8(qh[2][3], mone), 4); + hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3); + hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3); + hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3); + hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3); + + qh[2][0] = vshrq_n_u8(qh[2][0], 2); + qh[2][1] = vshrq_n_u8(qh[2][1], 2); + qh[2][2] = vshrq_n_u8(qh[2][2], 2); + qh[2][3] = vshrq_n_u8(qh[2][3], 2); + + acc_lo[2] = ggml_vdotq_s32(acc_lo[2], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_0, m4b), hbit_lo_0)), q8_qs[0]); + acc_lo[2] = ggml_vdotq_s32(acc_lo[2], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_1, m4b), hbit_lo_1)), q8_qs[1]); + acc_lo[2] = ggml_vdotq_s32(acc_lo[2], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_2, m4b), hbit_lo_2)), q8_qs[2]); + acc_lo[2] = ggml_vdotq_s32(acc_lo[2], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_3, m4b), hbit_lo_3)), q8_qs[3]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), q8_qs[4]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), q8_qs[5]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), q8_qs[6]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), q8_qs[7]); + + // Column pair 3 + qs_0 = vld1q_u8(qs_base + 48); + qs_1 = vld1q_u8(qs_base + 112); + qs_2 = vld1q_u8(qs_base + 176); + qs_3 = vld1q_u8(qs_base + 240); + + hbit_lo_0 = vshlq_n_u8(vandq_u8(qh[3][0], mone), 4); + hbit_lo_1 = vshlq_n_u8(vandq_u8(qh[3][1], mone), 4); + hbit_lo_2 = vshlq_n_u8(vandq_u8(qh[3][2], mone), 4); + hbit_lo_3 = vshlq_n_u8(vandq_u8(qh[3][3], mone), 4); + hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3); + hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3); + hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3); + hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3); + + qh[3][0] = vshrq_n_u8(qh[3][0], 2); + qh[3][1] = vshrq_n_u8(qh[3][1], 2); + qh[3][2] = vshrq_n_u8(qh[3][2], 2); + qh[3][3] = vshrq_n_u8(qh[3][3], 2); + + acc_lo[3] = ggml_vdotq_s32(acc_lo[3], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_0, m4b), hbit_lo_0)), q8_qs[0]); + acc_lo[3] = ggml_vdotq_s32(acc_lo[3], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_1, m4b), hbit_lo_1)), q8_qs[1]); + acc_lo[3] = ggml_vdotq_s32(acc_lo[3], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_2, m4b), hbit_lo_2)), q8_qs[2]); + acc_lo[3] = ggml_vdotq_s32(acc_lo[3], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_3, m4b), hbit_lo_3)), q8_qs[3]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), q8_qs[4]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), q8_qs[5]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), q8_qs[6]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), q8_qs[7]); } // Iterates over a pair of column pairs (4 columns) to use a single 128 register From 1d8c0bd8e399043f4ff74253ab14fc9763916fae Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Wed, 14 Jan 2026 21:04:06 +0000 Subject: [PATCH 07/14] Apply Q5_K Gemv vand and vshl optimizations to gemm. Improve comments. Signed-off-by: Alberto Cabrera --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 61 ++++++++++++++------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 5f008696476..ed2f9c668e6 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -880,9 +880,9 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8)); } - // Q5s columns iterated in pairs (01, 23, 45, 67) - fully unrolled + // Q5s column pair loop unrolled { - // Column pair 0 + // Cols 01 uint8x16_t qs_0 = vld1q_u8(qs_base); uint8x16_t qs_1 = vld1q_u8(qs_base + 64); uint8x16_t qs_2 = vld1q_u8(qs_base + 128); @@ -897,7 +897,6 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3); uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3); - // Shift qh early to overlap with dot product execution qh[0][0] = vshrq_n_u8(qh[0][0], 2); qh[0][1] = vshrq_n_u8(qh[0][1], 2); qh[0][2] = vshrq_n_u8(qh[0][2], 2); @@ -912,7 +911,7 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), q8_qs[6]); acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), q8_qs[7]); - // Column pair 1 + // Cols 23 qs_0 = vld1q_u8(qs_base + 16); qs_1 = vld1q_u8(qs_base + 80); qs_2 = vld1q_u8(qs_base + 144); @@ -941,7 +940,7 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), q8_qs[6]); acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), q8_qs[7]); - // Column pair 2 + // Cols 45 qs_0 = vld1q_u8(qs_base + 32); qs_1 = vld1q_u8(qs_base + 96); qs_2 = vld1q_u8(qs_base + 160); @@ -970,7 +969,7 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), q8_qs[6]); acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), q8_qs[7]); - // Column pair 3 + // Cols 45 qs_0 = vld1q_u8(qs_base + 48); qs_1 = vld1q_u8(qs_base + 112); qs_2 = vld1q_u8(qs_base + 176); @@ -3020,11 +3019,10 @@ void ggml_gemm_q5_K_8x8_q8_K(int n, #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) constexpr int q8_k_blocklen = 4; + constexpr int col_pairs = ncols_interleaved / 2; const uint8x16_t m4b = vdupq_n_u8(0x0f); - const uint8x16_t mhb = vdupq_n_u8(0x10); // high bit mask for 5th bit position - - // Bit masks for extracting high bits based on subblock index - const uint8_t bit_masks[8] = { 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80 }; + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); // 8 accumulators: 2 row pairs × 4 col pairs float32x4_t acc_f32[blocklen]; @@ -3060,6 +3058,16 @@ void ggml_gemm_q5_K_8x8_q8_K(int n, bias_acc[i] = vdupq_n_s32(0); } + // Load qh once per block and shift after each subblock + const uint8_t * qh_base = q5_ptr[b].qh; + uint8x16_t qh[col_pairs][4]; + for (int cp = 0; cp < col_pairs; cp++) { + qh[cp][0] = vld1q_u8(qh_base + 16 * cp); + qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64); + qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128); + qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192); + } + for (int sb = 0; sb < QK_K / 64; sb++) { // Need scales for the low and high nibbles // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total @@ -3090,12 +3098,8 @@ void ggml_gemm_q5_K_8x8_q8_K(int n, q8_qs_23[7] }, }; - // Masks for extracting high bits for this subblock - const uint8x16_t lo_bit_mask = vdupq_n_u8(bit_masks[sb * 2]); - const uint8x16_t hi_bit_mask = vdupq_n_u8(bit_masks[sb * 2 + 1]); - // Q5s columns iterated in pairs (01, 23, 45, 67) - for (int cp = 0; cp < ncols_interleaved / 2; cp++) { + for (int cp = 0; cp < col_pairs; cp++) { for (int i = 0; i < 4; i++) { sb_acc[i] = vdupq_n_s32(0); } @@ -3107,20 +3111,19 @@ void ggml_gemm_q5_K_8x8_q8_K(int n, // This is the only part of the algorithm that differs with Q4_K // High bits - uint8x16_t qh_cp_0 = vld1q_u8(q5_ptr[b].qh + 16 * cp); - uint8x16_t qh_cp_1 = vld1q_u8(q5_ptr[b].qh + 16 * cp + 64); - uint8x16_t qh_cp_2 = vld1q_u8(q5_ptr[b].qh + 16 * cp + 128); - uint8x16_t qh_cp_3 = vld1q_u8(q5_ptr[b].qh + 16 * cp + 192); - - uint8x16_t hbit_lo_0 = vandq_u8(vtstq_u8(qh_cp_0, lo_bit_mask), mhb); - uint8x16_t hbit_lo_1 = vandq_u8(vtstq_u8(qh_cp_1, lo_bit_mask), mhb); - uint8x16_t hbit_lo_2 = vandq_u8(vtstq_u8(qh_cp_2, lo_bit_mask), mhb); - uint8x16_t hbit_lo_3 = vandq_u8(vtstq_u8(qh_cp_3, lo_bit_mask), mhb); - - uint8x16_t hbit_hi_0 = vandq_u8(vtstq_u8(qh_cp_0, hi_bit_mask), mhb); - uint8x16_t hbit_hi_1 = vandq_u8(vtstq_u8(qh_cp_1, hi_bit_mask), mhb); - uint8x16_t hbit_hi_2 = vandq_u8(vtstq_u8(qh_cp_2, hi_bit_mask), mhb); - uint8x16_t hbit_hi_3 = vandq_u8(vtstq_u8(qh_cp_3, hi_bit_mask), mhb); + uint8x16_t hbit_lo_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mone), 4); + uint8x16_t hbit_lo_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mone), 4); + uint8x16_t hbit_lo_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mone), 4); + uint8x16_t hbit_lo_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mone), 4); + uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3); + uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3); + uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3); + uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3); + + qh[cp][0] = vshrq_n_u8(qh[cp][0], 2); + qh[cp][1] = vshrq_n_u8(qh[cp][1], 2); + qh[cp][2] = vshrq_n_u8(qh[cp][2], 2); + qh[cp][3] = vshrq_n_u8(qh[cp][3], 2); // Combine 4-bit values with high bits to get 5-bit values const int8x16_t q5_nibbles[2][4] = { From f9582a669782508062ecf7dd68963ebc70f4d4c7 Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Wed, 14 Jan 2026 21:05:01 +0000 Subject: [PATCH 08/14] Fix wrong fallback definitions of Q5_K Signed-off-by: Alberto Cabrera --- ggml/src/ggml-cpu/arch-fallback.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index ba6a19b9870..7f2e4e0cab9 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -112,7 +112,7 @@ #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -215,7 +215,7 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 From 794e9ecdc8f9d5729ccc19f764d46320e2a727f3 Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Wed, 14 Jan 2026 21:21:14 +0000 Subject: [PATCH 09/14] Fixed comments. Reverted unnecessary formatting Signed-off-by: Alberto Cabrera --- ggml/src/ggml-cpu/repack.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index f6e45d56d82..19e021e59aa 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1840,7 +1840,7 @@ static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_in // Repeat for low bits 8 bytes at a time as well, since // the high bits are interleaved in Q5_K and the index is // qh_idx = (qs_idx % 32); - // qh_val = qh[qs_idx] >> (qs_idx / 32); + // qh_val = qh[qh_idx] >> (qs_idx / 32); for (int i = 0; i < end / 4; ++i) { int src_id = i % 8; int src_offset = (i / 8) * blck_size_interleave; @@ -2724,9 +2724,10 @@ template ( - ne00, (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, - src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); + gemv(ne00, + (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, + src0_cur + src0_cur_start * nb01, + src1_col, 1, src0_cur_end - src0_cur_start); } } #undef MMID_MATRIX_ROW From d65e2eae10b7a200a51c5a30a8e78723da519eea Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Thu, 15 Jan 2026 11:23:31 +0000 Subject: [PATCH 10/14] Fixed typo in generic definitions --- ggml/src/ggml-cpu/arch-fallback.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 7f2e4e0cab9..0a85a4cff30 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -41,7 +41,7 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -72,7 +72,7 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K -#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 @@ -101,7 +101,7 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -135,7 +135,7 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -176,7 +176,7 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -215,7 +215,7 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -257,7 +257,7 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 From a6e2281946ed99e6b205b64ff75555154e326962 Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Mon, 19 Jan 2026 17:17:34 +0000 Subject: [PATCH 11/14] Switching AND + Shift with Shift Insert. Better op interleaving. --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 96 ++++++++++++++------------- 1 file changed, 50 insertions(+), 46 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index ed2f9c668e6..217902c776f 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -3110,52 +3110,56 @@ void ggml_gemm_q5_K_8x8_q8_K(int n, uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63 // This is the only part of the algorithm that differs with Q4_K - // High bits - uint8x16_t hbit_lo_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mone), 4); - uint8x16_t hbit_lo_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mone), 4); - uint8x16_t hbit_lo_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mone), 4); - uint8x16_t hbit_lo_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mone), 4); - uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3); - uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3); - uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3); - uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3); - - qh[cp][0] = vshrq_n_u8(qh[cp][0], 2); - qh[cp][1] = vshrq_n_u8(qh[cp][1], 2); - qh[cp][2] = vshrq_n_u8(qh[cp][2], 2); - qh[cp][3] = vshrq_n_u8(qh[cp][3], 2); - - // Combine 4-bit values with high bits to get 5-bit values - const int8x16_t q5_nibbles[2][4] = { - { - vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0)), - vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1)), - vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2)), - vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3)), - }, - { - vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0)), - vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1)), - vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2)), - vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3)), - } - }; - // From here, it's the same as Q4_K - - // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8 - // for each of the internal 32 qs subblock (blk) - for (int rp = 0; rp < 2; rp++) { - for (int blk = 0; blk < 2; blk++) { - const int8x16_t * q8 = &q8s[rp][4 * blk]; - const int8x16_t * q5 = q5_nibbles[blk]; - int32x4_t acc = sb_acc[2 * rp + blk]; - // mul add for each qs in the same subblock - for (int qs_offset = 0; qs_offset < 4; qs_offset++) { - acc = vmmlaq_s32(acc, q5[qs_offset], q8[qs_offset]); - } - sb_acc[2 * rp + blk] = acc; - } - } + // Extract High bits and pack into 5 bit weights + uint8x16_t hbit_lo_0 = vandq_u8(qh[cp][0], mone); + uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3); + qh[cp][0] = vshrq_n_u8(qh[cp][0], 2); + // Same as Q4_K, i8mm to dequantize the weights. + const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4)); + int32x4_t acc_0 = sb_acc[0]; + acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]); + int32x4_t acc_2 = sb_acc[2]; + acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]); + const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0)); + int32x4_t acc_1 = sb_acc[1]; + acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]); + int32x4_t acc_3 = sb_acc[3]; + acc_3 = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]); + + // Repeat for the other 3 columns (8..15, 16..23, 24..31) + uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3); + uint8x16_t hbit_lo_1 = vandq_u8(qh[cp][1], mone); + qh[cp][1] = vshrq_n_u8(qh[cp][1], 2); + const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4)); + acc_0 = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]); + acc_2 = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]); + const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1)); + acc_1 = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]); + acc_3 = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]); + + uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3); + uint8x16_t hbit_lo_2 = vandq_u8(qh[cp][2], mone); + qh[cp][2] = vshrq_n_u8(qh[cp][2], 2); + const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4)); + acc_0 = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]); + acc_2 = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]); + const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2)); + acc_1 = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]); + acc_3 = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]); + + uint8x16_t hbit_lo_3 = vandq_u8(qh[cp][3], mone); + uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3); + qh[cp][3] = vshrq_n_u8(qh[cp][3], 2); + const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4)); + acc_0 = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]); + sb_acc[0] = acc_0; + acc_2 = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]); + sb_acc[2] = acc_2; + const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3)); + acc_1 = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]); + sb_acc[1] = acc_1; + acc_3 = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]); + sb_acc[3] = acc_3; // Scales[i] corresponds to column i const int scale_offset = cp * 2; From 339734dcab4765b52ddb0a07ec529e8f799e0cb1 Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Tue, 20 Jan 2026 15:53:07 +0000 Subject: [PATCH 12/14] Vectorize + unroll the block scales --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 217902c776f..7de94d3d791 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -3155,24 +3155,26 @@ void ggml_gemm_q5_K_8x8_q8_K(int n, sb_acc[0] = acc_0; acc_2 = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]); sb_acc[2] = acc_2; + + // Scales[i] corresponds to column i + const int scale_offset = cp * 2; + const int32_t s0 = q5sb_scales[0][scale_offset]; + const int32_t s1 = q5sb_scales[0][scale_offset + 1]; + const int32x4_t block_scale = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1)); + acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale); + const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3)); acc_1 = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]); sb_acc[1] = acc_1; acc_3 = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]); sb_acc[3] = acc_3; - // Scales[i] corresponds to column i - const int scale_offset = cp * 2; - for (int blk = 0; blk < 2; blk++) { - const int32x4_t block_scale = { - (int32_t) q5sb_scales[blk][scale_offset], - (int32_t) q5sb_scales[blk][scale_offset], - (int32_t) q5sb_scales[blk][scale_offset + 1], - (int32_t) q5sb_scales[blk][scale_offset + 1], - }; - acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale); - acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale); - } + const int32_t s2 = q5sb_scales[1][scale_offset]; + const int32_t s3 = q5sb_scales[1][scale_offset + 1]; + const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3)); + acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale2); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2); } // Multiply Acc bsum + mins From 365555deeb97dad9091bf672a81651144c3d91c7 Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Thu, 22 Jan 2026 15:24:07 +0000 Subject: [PATCH 13/14] Apply gemm optimizations to gemv --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 132 ++++++++++++++++---------- 1 file changed, 82 insertions(+), 50 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 7de94d3d791..a15ac73591c 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -808,8 +808,8 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) constexpr int col_pairs = ncols_interleaved / 2; const uint8x16_t m4b = vdupq_n_u8(0x0f); - const uint8x16_t mone = vdupq_n_u8(1); - const uint8x16_t mtwo = vdupq_n_u8(2); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); // 1x8 tile = 2 x 4 float32x4_t acc_f32[ncols_interleaved / 4]; @@ -888,10 +888,10 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, uint8x16_t qs_2 = vld1q_u8(qs_base + 128); uint8x16_t qs_3 = vld1q_u8(qs_base + 192); - uint8x16_t hbit_lo_0 = vshlq_n_u8(vandq_u8(qh[0][0], mone), 4); - uint8x16_t hbit_lo_1 = vshlq_n_u8(vandq_u8(qh[0][1], mone), 4); - uint8x16_t hbit_lo_2 = vshlq_n_u8(vandq_u8(qh[0][2], mone), 4); - uint8x16_t hbit_lo_3 = vshlq_n_u8(vandq_u8(qh[0][3], mone), 4); + uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone); + uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone); + uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone); + uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone); uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3); uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3); uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3); @@ -902,14 +902,22 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, qh[0][2] = vshrq_n_u8(qh[0][2], 2); qh[0][3] = vshrq_n_u8(qh[0][3], 2); - acc_lo[0] = ggml_vdotq_s32(acc_lo[0], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_0, m4b), hbit_lo_0)), q8_qs[0]); - acc_lo[0] = ggml_vdotq_s32(acc_lo[0], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_1, m4b), hbit_lo_1)), q8_qs[1]); - acc_lo[0] = ggml_vdotq_s32(acc_lo[0], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_2, m4b), hbit_lo_2)), q8_qs[2]); - acc_lo[0] = ggml_vdotq_s32(acc_lo[0], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_3, m4b), hbit_lo_3)), q8_qs[3]); - acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), q8_qs[4]); - acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), q8_qs[5]); - acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), q8_qs[6]); - acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), q8_qs[7]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); // Cols 23 qs_0 = vld1q_u8(qs_base + 16); @@ -917,10 +925,10 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, qs_2 = vld1q_u8(qs_base + 144); qs_3 = vld1q_u8(qs_base + 208); - hbit_lo_0 = vshlq_n_u8(vandq_u8(qh[1][0], mone), 4); - hbit_lo_1 = vshlq_n_u8(vandq_u8(qh[1][1], mone), 4); - hbit_lo_2 = vshlq_n_u8(vandq_u8(qh[1][2], mone), 4); - hbit_lo_3 = vshlq_n_u8(vandq_u8(qh[1][3], mone), 4); + hbit_lo_0 = vandq_u8(qh[1][0], mone); + hbit_lo_1 = vandq_u8(qh[1][1], mone); + hbit_lo_2 = vandq_u8(qh[1][2], mone); + hbit_lo_3 = vandq_u8(qh[1][3], mone); hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3); hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3); hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3); @@ -931,14 +939,22 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, qh[1][2] = vshrq_n_u8(qh[1][2], 2); qh[1][3] = vshrq_n_u8(qh[1][3], 2); - acc_lo[1] = ggml_vdotq_s32(acc_lo[1], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_0, m4b), hbit_lo_0)), q8_qs[0]); - acc_lo[1] = ggml_vdotq_s32(acc_lo[1], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_1, m4b), hbit_lo_1)), q8_qs[1]); - acc_lo[1] = ggml_vdotq_s32(acc_lo[1], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_2, m4b), hbit_lo_2)), q8_qs[2]); - acc_lo[1] = ggml_vdotq_s32(acc_lo[1], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_3, m4b), hbit_lo_3)), q8_qs[3]); - acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), q8_qs[4]); - acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), q8_qs[5]); - acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), q8_qs[6]); - acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), q8_qs[7]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); // Cols 45 qs_0 = vld1q_u8(qs_base + 32); @@ -946,10 +962,10 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, qs_2 = vld1q_u8(qs_base + 160); qs_3 = vld1q_u8(qs_base + 224); - hbit_lo_0 = vshlq_n_u8(vandq_u8(qh[2][0], mone), 4); - hbit_lo_1 = vshlq_n_u8(vandq_u8(qh[2][1], mone), 4); - hbit_lo_2 = vshlq_n_u8(vandq_u8(qh[2][2], mone), 4); - hbit_lo_3 = vshlq_n_u8(vandq_u8(qh[2][3], mone), 4); + hbit_lo_0 = vandq_u8(qh[2][0], mone); + hbit_lo_1 = vandq_u8(qh[2][1], mone); + hbit_lo_2 = vandq_u8(qh[2][2], mone); + hbit_lo_3 = vandq_u8(qh[2][3], mone); hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3); hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3); hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3); @@ -960,14 +976,22 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, qh[2][2] = vshrq_n_u8(qh[2][2], 2); qh[2][3] = vshrq_n_u8(qh[2][3], 2); - acc_lo[2] = ggml_vdotq_s32(acc_lo[2], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_0, m4b), hbit_lo_0)), q8_qs[0]); - acc_lo[2] = ggml_vdotq_s32(acc_lo[2], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_1, m4b), hbit_lo_1)), q8_qs[1]); - acc_lo[2] = ggml_vdotq_s32(acc_lo[2], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_2, m4b), hbit_lo_2)), q8_qs[2]); - acc_lo[2] = ggml_vdotq_s32(acc_lo[2], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_3, m4b), hbit_lo_3)), q8_qs[3]); - acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), q8_qs[4]); - acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), q8_qs[5]); - acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), q8_qs[6]); - acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), q8_qs[7]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); // Cols 45 qs_0 = vld1q_u8(qs_base + 48); @@ -975,10 +999,10 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, qs_2 = vld1q_u8(qs_base + 176); qs_3 = vld1q_u8(qs_base + 240); - hbit_lo_0 = vshlq_n_u8(vandq_u8(qh[3][0], mone), 4); - hbit_lo_1 = vshlq_n_u8(vandq_u8(qh[3][1], mone), 4); - hbit_lo_2 = vshlq_n_u8(vandq_u8(qh[3][2], mone), 4); - hbit_lo_3 = vshlq_n_u8(vandq_u8(qh[3][3], mone), 4); + hbit_lo_0 = vandq_u8(qh[3][0], mone); + hbit_lo_1 = vandq_u8(qh[3][1], mone); + hbit_lo_2 = vandq_u8(qh[3][2], mone); + hbit_lo_3 = vandq_u8(qh[3][3], mone); hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3); hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3); hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3); @@ -989,14 +1013,22 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, qh[3][2] = vshrq_n_u8(qh[3][2], 2); qh[3][3] = vshrq_n_u8(qh[3][3], 2); - acc_lo[3] = ggml_vdotq_s32(acc_lo[3], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_0, m4b), hbit_lo_0)), q8_qs[0]); - acc_lo[3] = ggml_vdotq_s32(acc_lo[3], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_1, m4b), hbit_lo_1)), q8_qs[1]); - acc_lo[3] = ggml_vdotq_s32(acc_lo[3], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_2, m4b), hbit_lo_2)), q8_qs[2]); - acc_lo[3] = ggml_vdotq_s32(acc_lo[3], vreinterpretq_s8_u8(vorrq_u8(vandq_u8(qs_3, m4b), hbit_lo_3)), q8_qs[3]); - acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), q8_qs[4]); - acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), q8_qs[5]); - acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), q8_qs[6]); - acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), q8_qs[7]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); } // Iterates over a pair of column pairs (4 columns) to use a single 128 register From 69b247789ecb4233184f5c37a59c0befb8be7dfa Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Thu, 22 Jan 2026 17:08:44 +0000 Subject: [PATCH 14/14] Improve bias calculation --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 34 ++++++++++++--------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index a15ac73591c..883d862901b 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -834,8 +834,6 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, float32x4_t sb_min_0 = vmulq_f32(q5_dmin_0, q8_d); float32x4_t sb_min_1 = vmulq_f32(q5_dmin_1, q8_d); - // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567 - int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; // 2 sb each iteration int32x4_t acc_lo[col_pairs]; int32x4_t acc_hi[col_pairs]; @@ -1031,12 +1029,20 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, q8_qs[7]); } + // Prepare bsum vectors for bias computation + // Each pair of subblocks share the same bsums + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]); + // Iterates over a pair of column pairs (4 columns) to use a single 128 register // p = 0 -> 0123 p2 -> 4567 for (int i = 0, p = 0; p < col_pairs; i++, p += 2) { int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]); int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]); + int16x4_t group_mins_lo = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]); + int16x4_t group_mins_hi = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]); float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1; + float32x4_t sb_min = p == 0 ? sb_min_0 : sb_min_1; // 0123 or 4567 float32x4_t sumf_0 = @@ -1046,25 +1052,15 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, float32x4_t sumf_1 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1]))); acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1); - } - - // Multiply Acc bsum + mins - // Each pair of subblocks share the same bsums - // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)). - int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]); - int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]); - - // cols 0-3 bias - bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0])); - bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1])); - // cols 4-7 bias - bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0])); - bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1])); + // FUSED BIAS: Compute and subtract bias immediately + // bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min + int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo); + bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi); + float32x4_t bias_f32 = vcvtq_f32_s32(bias); + acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32); + } } // for sb - - acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0); - acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1); } // for b int base = x * ncols_interleaved;