From be3a90c94959d4a07bf098cd2e3b42929ad892ff Mon Sep 17 00:00:00 2001 From: David Friehs Date: Sat, 14 Feb 2026 20:57:00 +0100 Subject: [PATCH 1/3] cuda: optimize iq2xxs/iq2xs/iq3xxs dequantization - load all 8 int8 for a grid position in one load - calculate signs via popcnt instead of fetching from ksigns table - broadcast signs to drop individual shift/mask --- ggml/src/ggml-cuda/mmq.cuh | 31 ++++++++++++++---------- ggml/src/ggml-cuda/vecdotq.cuh | 44 ++++++++++++++++++++++------------ 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index f80f98cda2c..e9400194f45 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2715,14 +2715,14 @@ template static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR2_XXS; ++l) { - const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]); - const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F]; + const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]]; + const uint signs = unpack_ksigns(aux32 >> (7 * l)); - const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); - const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0); - const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); - const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; @@ -2776,11 +2776,14 @@ template static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR2_XS; ++l) { - const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF]; + const uint signs = unpack_ksigns(q2[l] >> 9); - const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); + + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; @@ -2904,11 +2907,13 @@ template static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR3_XXS; ++l) { const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]); + const uint signs = unpack_ksigns(aux32 >> (7*l)); - const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F)); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); - const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 6baab1176ff..b68179adeb4 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -94,6 +94,15 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con #endif } +static __device__ __forceinline__ uint unpack_ksigns(const uint8_t v) { + // v is a 7 bit int, with the 8th sign being encodable as popcnt + // with xor we can "correct" the bit instead of having to mask + const uint p = __popc(v) & 1; + const uint s = v ^ p << 7; + // broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors + return s * 0x01010101; +} + // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q @@ -905,16 +914,16 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( int sumi = 0; #pragma unroll for (int k0 = 0; k0 < 8; k0 += 2) { - const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]); - const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F]; + const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]]; + const uint signs = unpack_ksigns(aux32 >> (7 * k0 / 2)); - const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); - const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0); const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0); sumi = ggml_cuda_dp4a(grid0, u0, sumi); - const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); - const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1); const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1); sumi = ggml_cuda_dp4a(grid1, u1, sumi); } @@ -942,13 +951,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( int sumi1 = 0; #pragma unroll for (int l0 = 0; l0 < 8; l0 += 2) { - const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9)); - - const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF]; + const uint signs = unpack_ksigns(q2[l0/2] >> 9); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); if (l0 < 4) { @@ -1028,13 +1039,16 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( #pragma unroll for (int l0 = 0; l0 < 8; l0 += 2) { const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]); + const uint signs = unpack_ksigns(aux32 >> (7*l0/2)); - const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F)); - - const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); + const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); sumi = ggml_cuda_dp4a(grid_l, u0, sumi); From dfc0e2c9a8483ec8e9ef7f581ce8f504545abbd8 Mon Sep 17 00:00:00 2001 From: David Friehs Date: Sun, 15 Feb 2026 07:02:05 +0100 Subject: [PATCH 2/3] cuda: iq2xxs: simplify sum scaling express `(sum * scale + sum / 2) / 4` as `(sum * (scale * 2 + 1)) / 8` express `((aux32 >> 28) * 2 + 1)` as `(aux32 >> 27 | 1)` saves 3 registers for mul_mat_vec_q (152 -> 149) according to nsight AFAICT no overflow can occur here as iq2xxs values are far too small --- ggml/src/ggml-cuda/mmq.cuh | 6 +++--- ggml/src/ggml-cuda/vecdotq.cuh | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index e9400194f45..5fe2820a6e7 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2733,12 +2733,12 @@ template static __device__ __forceinline__ void loa #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } - const int ls = aux32 >> 28; + const int ls = aux32 >> 27 | 1; // (scale * 2 + 1) const float d = bxi->d; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4 #else - x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4 #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index b68179adeb4..94677be60c8 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -928,8 +928,8 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( sumi = ggml_cuda_dp4a(grid1, u1, sumi); } - const int ls = aux32 >> 28; - sumi = (ls*sumi + sumi/2)/4; + const int ls = aux32 >> 27 | 1; // (scale * 2 + 1) + sumi = sumi * ls / 8; // (sumi * scale + sumi / 2) / 4 const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds); return d * sumi; } From b82a9807c2fe513bfc515eadc12df3dd421ba1a1 Mon Sep 17 00:00:00 2001 From: David Friehs Date: Sun, 15 Feb 2026 15:46:14 +0100 Subject: [PATCH 3/3] uint -> uint32_t error: identifier "uint" is undefined --- ggml/src/ggml-cuda/mmq.cuh | 6 +++--- ggml/src/ggml-cuda/vecdotq.cuh | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 5fe2820a6e7..255e59f6fc6 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2716,7 +2716,7 @@ template static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR2_XXS; ++l) { const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]]; - const uint signs = unpack_ksigns(aux32 >> (7 * l)); + const uint32_t signs = unpack_ksigns(aux32 >> (7 * l)); const int signs0 = __vcmpne4(signs & 0x08040201, 0); const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0); @@ -2777,7 +2777,7 @@ template static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR2_XS; ++l) { const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF]; - const uint signs = unpack_ksigns(q2[l] >> 9); + const uint32_t signs = unpack_ksigns(q2[l] >> 9); const int signs0 = __vcmpne4(signs & 0x08040201, 0); const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); @@ -2907,7 +2907,7 @@ template static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR3_XXS; ++l) { const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]); - const uint signs = unpack_ksigns(aux32 >> (7*l)); + const uint32_t signs = unpack_ksigns(aux32 >> (7*l)); const int signs0 = __vcmpne4(signs & 0x08040201, 0); const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 94677be60c8..ab803aca21b 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -94,11 +94,11 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con #endif } -static __device__ __forceinline__ uint unpack_ksigns(const uint8_t v) { +static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) { // v is a 7 bit int, with the 8th sign being encodable as popcnt // with xor we can "correct" the bit instead of having to mask - const uint p = __popc(v) & 1; - const uint s = v ^ p << 7; + const uint32_t p = __popc(v) & 1; + const uint32_t s = v ^ p << 7; // broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors return s * 0x01010101; } @@ -915,7 +915,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( #pragma unroll for (int k0 = 0; k0 < 8; k0 += 2) { const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]]; - const uint signs = unpack_ksigns(aux32 >> (7 * k0 / 2)); + const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2)); const int signs0 = __vcmpne4(signs & 0x08040201, 0); const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0); @@ -952,7 +952,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( #pragma unroll for (int l0 = 0; l0 < 8; l0 += 2) { const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF]; - const uint signs = unpack_ksigns(q2[l0/2] >> 9); + const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9); const int signs0 = __vcmpne4(signs & 0x08040201, 0); const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); @@ -1039,7 +1039,7 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( #pragma unroll for (int l0 = 0; l0 < 8; l0 += 2) { const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]); - const uint signs = unpack_ksigns(aux32 >> (7*l0/2)); + const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2)); const int signs0 = __vcmpne4(signs & 0x08040201, 0); const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);