From fe5b11e79cbf24465fd4824a17bf8b66eb8c8edf Mon Sep 17 00:00:00 2001 From: marksverdhei Date: Sun, 12 Apr 2026 11:01:09 +0200 Subject: [PATCH] feat(turboq): 128-block TBQ with native flash attention kernels Switch TBQ3_0/TBQ4_0 from 256-element (QK_K) to 128-element blocks via TBQ_BLK_SIZE constant. This halves block size for better granularity at 3.125/4.125 bpw. Add native CUDA flash attention dot-product and dequantize kernels for both TBQ types, eliminating the dequantize-to-f16 intermediate step. In flash attention mode, K/V tensors are now consumed directly in their quantized format instead of being cast through codebook dequant first. Co-Authored-By: Claude Opus 4.6 --- ggml/src/ggml-common.h | 14 +- ggml/src/ggml-cpu/quants.c | 4 +- ggml/src/ggml-cuda/fattn-common.cuh | 214 ++++++++++++++++++++++ ggml/src/ggml-cuda/fattn.cu | 27 +++ ggml/src/ggml-cuda/turboq.cu | 269 +++++++++++++--------------- ggml/src/ggml-turboq.c | 78 ++++---- ggml/src/ggml.c | 4 +- src/llama-graph.cpp | 24 +-- 8 files changed, 434 insertions(+), 200 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 2b81164b249f..a5947e8d288d 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -279,19 +279,21 @@ static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 // TurboQuant blocks -// 3.0625 bpw +#define TBQ_BLK_SIZE 128 + +// 3.125 bpw typedef struct { - uint8_t qs[QK_K * 3 / 8]; + uint8_t qs[TBQ_BLK_SIZE * 3 / 8]; ggml_half d; } block_tbq3_0; -static_assert(sizeof(block_tbq3_0) == sizeof(ggml_half) + QK_K * 3 / 8, "wrong tbq3_0 block size/padding"); +static_assert(sizeof(block_tbq3_0) == sizeof(ggml_half) + TBQ_BLK_SIZE * 3 / 8, "wrong tbq3_0 block size/padding"); -// 4.0625 bpw +// 4.125 bpw typedef struct { - uint8_t qs[QK_K / 2]; + uint8_t qs[TBQ_BLK_SIZE / 2]; ggml_half d; } block_tbq4_0; -static_assert(sizeof(block_tbq4_0) == sizeof(ggml_half) + QK_K / 2, "wrong tbq4_0 block size/padding"); +static_assert(sizeof(block_tbq4_0) == sizeof(ggml_half) + TBQ_BLK_SIZE / 2, "wrong tbq4_0 block size/padding"); // // Super-block quantization structures diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 9b16f1972550..378eeba72107 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -113,13 +113,13 @@ void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, } void quantize_row_tbq3_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(k % QK_K == 0); + assert(k % TBQ_BLK_SIZE == 0); block_tbq3_0 * GGML_RESTRICT y = vy; quantize_row_tbq3_0_ref(x, y, k); } void quantize_row_tbq4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(k % QK_K == 0); + assert(k % TBQ_BLK_SIZE == 0); block_tbq4_0 * GGML_RESTRICT y = vy; quantize_row_tbq4_0_ref(x, y, k); } diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index beeb52389464..f29ebae75145 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -577,6 +577,212 @@ static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict } } + +// 3-bit unpacking helper for TBQ3_0 +static __device__ __forceinline__ int unpack_3bit_index(const uint8_t * qs, int elem) { + const int group = elem / 8; + const int bit_offset = (elem % 8) * 3; + const uint8_t * src = qs + group * 3; + uint32_t bits = (uint32_t)src[0] | ((uint32_t)src[1] << 8) | ((uint32_t)src[2] << 16); + return (bits >> bit_offset) & 0x7; +} + +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_tbq3_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_tbq3_0 * K_tbq3_0 = (const block_tbq3_0 *) K_c; + GGML_UNUSED(Q_v); + + const float cb[8] = { + -2.1520f, -1.3440f, -0.7560f, -0.2451f, + 0.2451f, 0.7560f, 1.3440f, 2.1520f, + }; + const float scale_down = 0.08838834764f; // 1/sqrt(128) + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int elem_start = k_KQ * 4; + const int ib = elem_start / 128; + const int elem = elem_start % 128; + + const uint8_t* qs = K_tbq3_0[ib].qs; + + const int idx0 = unpack_3bit_index(qs, elem + 0); + const int idx1 = unpack_3bit_index(qs, elem + 1); + const int idx2 = unpack_3bit_index(qs, elem + 2); + const int idx3 = unpack_3bit_index(qs, elem + 3); + + const float v0 = cb[idx0]; + const float v1 = cb[idx1]; + const float v2 = cb[idx2]; + const float v3 = cb[idx3]; + + const int u = Q_q8[k_KQ_0/nthreads]; + const int8_t * q8 = (const int8_t *) &u; + + float sumi = v0 * q8[0] + v1 * q8[1] + v2 * q8[2] + v3 * q8[3]; + + const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads]; + sum += __half2float(K_tbq3_0[ib].d) * sumi * Q_ds.x * scale_down; + } + + return sum; +} + +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_tbq4_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_tbq4_0 * K_tbq4_0 = (const block_tbq4_0 *) K_c; + GGML_UNUSED(Q_v); + + const float cb[16] = { + -2.7326f, -2.0690f, -1.6180f, -1.2562f, + -0.9424f, -0.6568f, -0.3881f, -0.1284f, + 0.1284f, 0.3881f, 0.6568f, 0.9424f, + 1.2562f, 1.6180f, 2.0690f, 2.7326f, + }; + const float scale_down = 0.08838834764f; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int elem_start = k_KQ * 4; + const int ib = elem_start / 128; + const int elem = elem_start % 128; + + const uint8_t* qs = K_tbq4_0[ib].qs; + + const uint8_t byte0 = qs[(elem + 0) / 2]; + const uint8_t byte1 = qs[(elem + 2) / 2]; + + const int idx0 = byte0 & 0x0F; + const int idx1 = byte0 >> 4; + const int idx2 = byte1 & 0x0F; + const int idx3 = byte1 >> 4; + + const float v0 = cb[idx0]; + const float v1 = cb[idx1]; + const float v2 = cb[idx2]; + const float v3 = cb[idx3]; + + const int u = Q_q8[k_KQ_0/nthreads]; + const int8_t * q8 = (const int8_t *) &u; + + float sumi = v0 * q8[0] + v1 * q8[1] + v2 * q8[2] + v3 * q8[3]; + + const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads]; + sum += __half2float(K_tbq4_0[ib].d) * sumi * Q_ds.x * scale_down; + } + + return sum; +} + +template +static __device__ __forceinline__ void dequantize_V_tbq3_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_tbq3_0 * x = (const block_tbq3_0 *) vx; + + const int64_t ib = i0 / 128; + const int elem_start = i0 % 128; + + const float cb[8] = { + -2.1520f, -1.3440f, -0.7560f, -0.2451f, + 0.2451f, 0.7560f, 1.3440f, 2.1520f, + }; + const float scale_down = 0.08838834764f; + + static_assert(ne == 2 || ne == 4, "bad ne"); + + float vals[4]; +#pragma unroll + for (int l = 0; l < ne; ++l) { + int elem = elem_start + l; + int group = elem / 8; + int bit_offset = (elem % 8) * 3; + const uint8_t * src = x[ib].qs + group * 3; + uint32_t bits = (uint32_t)src[0] | ((uint32_t)src[1] << 8) | ((uint32_t)src[2] << 16); + int idx = (bits >> bit_offset) & 0x7; + vals[l] = cb[idx] * scale_down; + } + +#ifdef FP16_AVAILABLE + if constexpr (std::is_same_v) { + const half2 d = __half2half2(x[ib].d); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(vals[l0 + 0], vals[l0 + 1]); + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const float d = __half2float(x[ib].d); + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * vals[l]; + } + } else { + static_assert(std::is_same_v, "bad type"); + } +} + +template +static __device__ __forceinline__ void dequantize_V_tbq4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_tbq4_0 * x = (const block_tbq4_0 *) vx; + + const int64_t ib = i0 / 128; + const int elem_start = i0 % 128; + + const float cb[16] = { + -2.7326f, -2.0690f, -1.6180f, -1.2562f, + -0.9424f, -0.6568f, -0.3881f, -0.1284f, + 0.1284f, 0.3881f, 0.6568f, 0.9424f, + 1.2562f, 1.6180f, 2.0690f, 2.7326f, + }; + const float scale_down = 0.08838834764f; + + static_assert(ne == 2 || ne == 4, "bad ne"); + + float vals[4]; +#pragma unroll + for (int l = 0; l < ne; ++l) { + int elem = elem_start + l; + uint8_t byte = x[ib].qs[elem / 2]; + int idx = (elem % 2 == 0) ? (byte & 0x0F) : (byte >> 4); + vals[l] = cb[idx] * scale_down; + } + +#ifdef FP16_AVAILABLE + if constexpr (std::is_same_v) { + const half2 d = __half2half2(x[ib].d); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(vals[l0 + 0], vals[l0 + 1]); + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const float d = __half2float(x[ib].d); + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * vals[l]; + } + } else { + static_assert(std::is_same_v, "bad type"); + } +} + template constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { if constexpr (type_K == GGML_TYPE_F16) { @@ -593,6 +799,10 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { return vec_dot_fattn_vec_KQ_q8_0; } else if constexpr (type_K == GGML_TYPE_BF16) { return vec_dot_fattn_vec_KQ_bf16; + } else if constexpr (type_K == GGML_TYPE_TBQ3_0) { + return vec_dot_fattn_vec_KQ_tbq3_0; + } else if constexpr (type_K == GGML_TYPE_TBQ4_0) { + return vec_dot_fattn_vec_KQ_tbq4_0; } else { static_assert(type_K == -1, "bad type"); return nullptr; @@ -615,6 +825,10 @@ constexpr __device__ dequantize_V_t get_dequantize_V() { return dequantize_V_q8_0; } else if constexpr (type_V == GGML_TYPE_BF16) { return dequantize_V_bf16; + } else if constexpr (type_V == GGML_TYPE_TBQ3_0) { + return dequantize_V_tbq3_0; + } else if constexpr (type_V == GGML_TYPE_TBQ4_0) { + return dequantize_V_tbq4_0; } else { static_assert(type_V == -1, "bad type"); return nullptr; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index ea6607cd337e..e42bfcca5749 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -237,6 +237,8 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ3_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ4_0, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -285,6 +287,29 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ3_0, GGML_TYPE_TBQ3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ4_0, GGML_TYPE_TBQ4_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TBQ3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_TBQ3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_TBQ3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_TBQ3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_TBQ3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TBQ3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_TBQ3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ3_0, GGML_TYPE_TBQ3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ4_0, GGML_TYPE_TBQ3_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TBQ4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_TBQ4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_TBQ4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_TBQ4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_TBQ4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TBQ4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_TBQ4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ3_0, GGML_TYPE_TBQ4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ4_0, GGML_TYPE_TBQ4_0) + #else FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -391,6 +416,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: case GGML_TYPE_BF16: + case GGML_TYPE_TBQ3_0: + case GGML_TYPE_TBQ4_0: break; default: return BEST_FATTN_KERNEL_NONE; diff --git a/ggml/src/ggml-cuda/turboq.cu b/ggml/src/ggml-cuda/turboq.cu index d349e968f9ce..f5a6db46efae 100644 --- a/ggml/src/ggml-cuda/turboq.cu +++ b/ggml/src/ggml-cuda/turboq.cu @@ -126,19 +126,17 @@ __global__ void dequantize_block_tbq3_0_kernel( const void * __restrict__ vx, dst_t * __restrict__ y, const float * __restrict__ Q_rot, int64_t nb) { - const int sub_block = blockIdx.x; - const int tbq_block = sub_block / 2; - const int sub_half = sub_block % 2; // 0 = first 128, 1 = second 128 + const int tbq_block = blockIdx.x; const int tid = threadIdx.x; // 0..127 if (tbq_block >= nb) return; const block_tbq3_0 * x = (const block_tbq3_0 *)vx; const float norm = __half2float(x[tbq_block].d); - const float scale_down = 0.0625f; // 1/sqrt(256) = 1/16 + const float scale_down = 0.08838834764f; // 1/sqrt(128) // Step 1: Unpack 3-bit index and look up codebook - const int elem = sub_half * 128 + tid; + const int elem = tid; const int idx = unpack_3bit_index(x[tbq_block].qs, elem); __shared__ float s_rotated[128]; @@ -153,7 +151,7 @@ __global__ void dequantize_block_tbq3_0_kernel( } // Step 3: Scale by norm and write - y[tbq_block * 256 + sub_half * 128 + tid] = (dst_t)(sum * norm); + y[tbq_block * 128 + tid] = (dst_t)(sum * norm); } template @@ -161,19 +159,17 @@ __global__ void dequantize_block_tbq4_0_kernel( const void * __restrict__ vx, dst_t * __restrict__ y, const float * __restrict__ Q_rot, int64_t nb) { - const int sub_block = blockIdx.x; - const int tbq_block = sub_block / 2; - const int sub_half = sub_block % 2; + const int tbq_block = blockIdx.x; const int tid = threadIdx.x; if (tbq_block >= nb) return; const block_tbq4_0 * x = (const block_tbq4_0 *)vx; const float norm = __half2float(x[tbq_block].d); - const float scale_down = 0.0625f; + const float scale_down = 0.08838834764f; // Step 1: Unpack 4-bit nibble and look up codebook - const int elem = sub_half * 128 + tid; + const int elem = tid; uint8_t idx; if (elem % 2 == 0) { idx = x[tbq_block].qs[elem / 2] & 0x0F; @@ -192,7 +188,7 @@ __global__ void dequantize_block_tbq4_0_kernel( } // Step 3: Scale and write - y[tbq_block * 256 + sub_half * 128 + tid] = (dst_t)(sum * norm); + y[tbq_block * 128 + tid] = (dst_t)(sum * norm); } // Ensure rotation matrix is initialized (lazy init, thread-safe via CUDA stream ordering) @@ -205,15 +201,15 @@ static void turboq_ensure_init(cudaStream_t stream) { template void dequantize_row_tbq3_0_cuda(const void * vx, dst_t * y, int64_t k, cudaStream_t stream) { turboq_ensure_init(stream); - const int64_t nb = k / 256; - dequantize_block_tbq3_0_kernel<<>>(vx, y, d_turboq_Q, nb); + const int64_t nb = k / 128; + dequantize_block_tbq3_0_kernel<<>>(vx, y, d_turboq_Q, nb); } template void dequantize_row_tbq4_0_cuda(const void * vx, dst_t * y, int64_t k, cudaStream_t stream) { turboq_ensure_init(stream); - const int64_t nb = k / 256; - dequantize_block_tbq4_0_kernel<<>>(vx, y, d_turboq_Q, nb); + const int64_t nb = k / 128; + dequantize_block_tbq4_0_kernel<<>>(vx, y, d_turboq_Q, nb); } // Explicit template instantiations @@ -233,39 +229,35 @@ template __global__ void dequantize_block_tbq3_0_codebook_kernel( const void * __restrict__ vx, dst_t * __restrict__ y, int64_t nb) { - const int sub_block = blockIdx.x; - const int tbq_block = sub_block / 2; - const int sub_half = sub_block % 2; + const int tbq_block = blockIdx.x; const int tid = threadIdx.x; if (tbq_block >= nb) return; const block_tbq3_0 * x = (const block_tbq3_0 *)vx; const float norm = __half2float(x[tbq_block].d); - const float scale_down = 0.0625f; + const float scale_down = 0.08838834764f; - const int elem = sub_half * 128 + tid; + const int elem = tid; const int idx = unpack_3bit_index(x[tbq_block].qs, elem); - y[tbq_block * 256 + elem] = (dst_t)(d_codebook_3bit[idx] * scale_down * norm); + y[tbq_block * 128 + elem] = (dst_t)(d_codebook_3bit[idx] * scale_down * norm); } template __global__ void dequantize_block_tbq4_0_codebook_kernel( const void * __restrict__ vx, dst_t * __restrict__ y, int64_t nb) { - const int sub_block = blockIdx.x; - const int tbq_block = sub_block / 2; - const int sub_half = sub_block % 2; + const int tbq_block = blockIdx.x; const int tid = threadIdx.x; if (tbq_block >= nb) return; const block_tbq4_0 * x = (const block_tbq4_0 *)vx; const float norm = __half2float(x[tbq_block].d); - const float scale_down = 0.0625f; + const float scale_down = 0.08838834764f; - const int elem = sub_half * 128 + tid; + const int elem = tid; uint8_t idx; if (elem % 2 == 0) { idx = x[tbq_block].qs[elem / 2] & 0x0F; @@ -273,19 +265,19 @@ __global__ void dequantize_block_tbq4_0_codebook_kernel( idx = (x[tbq_block].qs[elem / 2] >> 4) & 0x0F; } - y[tbq_block * 256 + elem] = (dst_t)(d_codebook_4bit[idx] * scale_down * norm); + y[tbq_block * 128 + elem] = (dst_t)(d_codebook_4bit[idx] * scale_down * norm); } template void dequantize_row_tbq3_0_codebook_cuda(const void * vx, dst_t * y, int64_t k, cudaStream_t stream) { - const int64_t nb = k / 256; - dequantize_block_tbq3_0_codebook_kernel<<>>(vx, y, nb); + const int64_t nb = k / 128; + dequantize_block_tbq3_0_codebook_kernel<<>>(vx, y, nb); } template void dequantize_row_tbq4_0_codebook_cuda(const void * vx, dst_t * y, int64_t k, cudaStream_t stream) { - const int64_t nb = k / 256; - dequantize_block_tbq4_0_codebook_kernel<<>>(vx, y, nb); + const int64_t nb = k / 128; + dequantize_block_tbq4_0_codebook_kernel<<>>(vx, y, nb); } template void dequantize_row_tbq3_0_codebook_cuda(const void * vx, float * y, int64_t k, cudaStream_t stream); @@ -306,7 +298,7 @@ __global__ void turboq_compute_norms_kernel( const int block_id = blockIdx.x; if (block_id >= num_blocks) return; - const int tid = threadIdx.x; // 0..255 + const int tid = threadIdx.x; // 0..127 const float * src = x + block_id * 256; float val = src[tid]; @@ -315,15 +307,18 @@ __global__ void turboq_compute_norms_kernel( // Warp-level reduction (full 32-lane warp) sum_sq = warp_reduce_sum(sum_sq); - __shared__ float warp_sums[8]; // 256/32 = 8 warps + __shared__ float warp_sums[4]; // 128/32 = 4 warps if (tid % 32 == 0) { warp_sums[tid / 32] = sum_sq; } __syncthreads(); - // Final reduction across warps (8-lane sub-warp) - if (tid < 8) { - float s = warp_reduce_sum<8>(warp_sums[tid]); + // Final reduction across warps + if (tid < 32) { + float s = (tid < 4) ? warp_sums[tid] : 0.0f; + for (int offset = 2; offset > 0; offset >>= 1) { + s += __shfl_down_sync(0xFFFFFFFF, s, offset); + } if (tid == 0) { float norm = sqrtf(s); norms[block_id] = (norm < 1e-10f) ? 1e-10f : norm; @@ -337,21 +332,19 @@ __global__ void quantize_f32_tbq3_0_kernel( const float * __restrict__ Q_rot, const float * __restrict__ norms, int64_t num_blocks) { - const int sub_block = blockIdx.x; - const int tbq_block = sub_block / 2; - const int sub_half = sub_block % 2; - const int tid = threadIdx.x; // 0..127 + const int tbq_block = blockIdx.x; + const int tid = threadIdx.x; // 0..127 if (tbq_block >= num_blocks) return; block_tbq3_0 * y = (block_tbq3_0 *)vy; - const float * src = x + tbq_block * 256; + const float * src = x + tbq_block * 128; const float norm = norms[tbq_block]; const float inv_norm = 1.0f / norm; - const float scale_up = 16.0f; // sqrt(256) + const float scale_up = 11.3137085f; // sqrt(128) // Step 1: Load and normalize - const int elem = sub_half * 128 + tid; + const int elem = tid; float unit_val = src[elem] * inv_norm; __shared__ float s_unit[128]; @@ -393,7 +386,7 @@ __global__ void quantize_f32_tbq3_0_kernel( // Step 5: Write packed bytes to output // Each sub-block writes 48 bytes (128 * 3 / 8) - const int qs_offset = sub_half * 48; // 128*3/8 + const int qs_offset = 0; // 128*3/8 if (pos_in_group < 3) { // 3 bytes per group, thread 0-2 in each group write one byte uint32_t packed = s_packed[group]; @@ -401,8 +394,8 @@ __global__ void quantize_f32_tbq3_0_kernel( (uint8_t)((packed >> (pos_in_group * 8)) & 0xFF); } - // First sub-block, first thread writes the norm - if (sub_half == 0 && tid == 0) { + // First thread writes the norm + if (tid == 0) { y[tbq_block].d = __float2half(norm); } } @@ -413,21 +406,19 @@ __global__ void quantize_f32_tbq4_0_kernel( const float * __restrict__ Q_rot, const float * __restrict__ norms, int64_t num_blocks) { - const int sub_block = blockIdx.x; - const int tbq_block = sub_block / 2; - const int sub_half = sub_block % 2; + const int tbq_block = blockIdx.x; const int tid = threadIdx.x; if (tbq_block >= num_blocks) return; block_tbq4_0 * y = (block_tbq4_0 *)vy; - const float * src = x + tbq_block * 256; + const float * src = x + tbq_block * 128; const float norm = norms[tbq_block]; const float inv_norm = 1.0f / norm; - const float scale_up = 16.0f; + const float scale_up = 11.3137085f; // Step 1: Load and normalize - const int elem = sub_half * 128 + tid; + const int elem = tid; float unit_val = src[elem] * inv_norm; __shared__ float s_unit[128]; @@ -464,7 +455,7 @@ __global__ void quantize_f32_tbq4_0_kernel( } // Write norm - if (sub_half == 0 && tid == 0) { + if (tid == 0) { y[tbq_block].d = __float2half(norm); } } @@ -496,14 +487,14 @@ void ggml_cpy_f32_tbq3_0_cuda( (void)ne00; (void)ne01; (void)ne02; (void)nb00; (void)nb01; (void)nb02; (void)nb03; (void)ne10; (void)ne11; (void)ne12; (void)nb10; (void)nb11; (void)nb12; (void)nb13; - GGML_ASSERT(ne % 256 == 0); + GGML_ASSERT(ne % 128 == 0); turboq_ensure_init(stream); - const int64_t num_blocks = ne / 256; + const int64_t num_blocks = ne / 128; float * norms = turboq_get_norms_buffer(num_blocks, stream); // Pass 1: compute norms - turboq_compute_norms_kernel<<>>( + turboq_compute_norms_kernel<<>>( (const float *)cx, norms, num_blocks); // Pass 2: quantize @@ -522,13 +513,13 @@ void ggml_cpy_f32_tbq4_0_cuda( (void)ne00; (void)ne01; (void)ne02; (void)nb00; (void)nb01; (void)nb02; (void)nb03; (void)ne10; (void)ne11; (void)ne12; (void)nb10; (void)nb11; (void)nb12; (void)nb13; - GGML_ASSERT(ne % 256 == 0); + GGML_ASSERT(ne % 128 == 0); turboq_ensure_init(stream); - const int64_t num_blocks = ne / 256; + const int64_t num_blocks = ne / 128; float * norms = turboq_get_norms_buffer(num_blocks, stream); - turboq_compute_norms_kernel<<>>( + turboq_compute_norms_kernel<<>>( (const float *)cx, norms, num_blocks); quantize_f32_tbq4_0_kernel<<>>( @@ -559,7 +550,7 @@ __global__ void set_rows_tbq3_0_kernel( // Each CUDA block handles one source row const int row_idx = blockIdx.x; - const int tid = threadIdx.x; // 0..255 + const int tid = threadIdx.x; // 0..127 // Decompose row_idx into (i01, i02, i03) const int64_t total_rows = ne01 * ne02 * ne03; @@ -581,27 +572,30 @@ __global__ void set_rows_tbq3_0_kernel( // Destination pointer char * dst_row = dst + dst_row_idx * nb1 + i02 * nb2 + i03 * nb3; - // ne00 should be 256 (one TBQ block per row) — assert in host code - // For simplicity, handle exactly 1 TBQ block (256 elements) - const int64_t num_blocks_per_row = ne00 / 256; + // ne00 should be 128 (one TBQ block per row) — assert in host code + // For simplicity, handle exactly 1 TBQ block (128 elements) + const int64_t num_blocks_per_row = ne00 / 128; for (int64_t blk = 0; blk < num_blocks_per_row; blk++) { - const float * blk_src = src_row + blk * 256; + const float * blk_src = src_row + blk * 128; block_tbq3_0 * blk_dst = (block_tbq3_0 *)dst_row + blk; - // Phase 1: Compute L2 norm (all 256 threads) + // Phase 1: Compute L2 norm (all 128 threads) float val = blk_src[tid]; float sum_sq = val * val; sum_sq = warp_reduce_sum(sum_sq); - __shared__ float warp_sums[8]; + __shared__ float warp_sums[4]; if (tid % 32 == 0) warp_sums[tid / 32] = sum_sq; __syncthreads(); __shared__ float s_norm; - if (tid < 8) { - float s = warp_reduce_sum<8>(warp_sums[tid]); + if (tid < 32) { + float s = (tid < 4) ? warp_sums[tid] : 0.0f; + for (int offset = 2; offset > 0; offset >>= 1) { + s += __shfl_down_sync(0xFFFFFFFF, s, offset); + } if (tid == 0) { float n = sqrtf(s); s_norm = (n < 1e-10f) ? 1e-10f : n; @@ -612,51 +606,45 @@ __global__ void set_rows_tbq3_0_kernel( float norm = s_norm; float inv_norm = 1.0f / norm; - // Phase 2: Normalize, rotate, quantize, pack (2 sub-blocks of 128) - // Each thread handles one element per sub-block - for (int sub = 0; sub < 2; sub++) { - int elem = sub * 128 + (tid % 128); - if (tid >= 128 && sub == 0) continue; // first 128 threads do sub 0 - if (tid < 128 && sub == 1) continue; // last 128 threads do sub 1 - int ltid = tid % 128; // local thread id within sub-block + // Phase 2: Normalize, rotate, quantize, pack (1 sub-block of 128) + int elem = tid; + int ltid = tid; - float unit_val = blk_src[elem] * inv_norm; + float unit_val = blk_src[elem] * inv_norm; - __shared__ float s_unit[256]; // use different halves for each sub-block - s_unit[elem] = unit_val; - __syncthreads(); + __shared__ float s_unit[128]; + s_unit[elem] = unit_val; + __syncthreads(); - // Forward rotation - float rotated = 0.0f; - for (int j = 0; j < 128; j++) { - rotated += Q_rot[ltid * 128 + j] * s_unit[sub * 128 + j]; - } + // Forward rotation + float rotated = 0.0f; + for (int j = 0; j < 128; j++) { + rotated += Q_rot[ltid * 128 + j] * s_unit[j]; + } - // Quantize - float scaled = rotated * 16.0f; - int idx = 7; - #pragma unroll - for (int b = 0; b < 7; b++) { - if (scaled < d_boundaries_3bit[b]) { idx = b; break; } - } + // Quantize + float scaled = rotated * 11.3137085f; + int idx = 7; + #pragma unroll + for (int b = 0; b < 7; b++) { + if (scaled < d_boundaries_3bit[b]) { idx = b; break; } + } - // Pack 3-bit - const int group = ltid / 8; - const int pos_in_group = ltid % 8; + // Pack 3-bit + const int group = ltid / 8; + const int pos_in_group = ltid % 8; - __shared__ uint32_t s_packed[32]; // 16 groups per sub-block × 2 - if (ltid < 16) s_packed[sub * 16 + ltid] = 0; - __syncthreads(); + __shared__ uint32_t s_packed[16]; + if (ltid < 16) s_packed[ltid] = 0; + __syncthreads(); - atomicOr(&s_packed[sub * 16 + group], ((uint32_t)idx) << (pos_in_group * 3)); - __syncthreads(); + atomicOr(&s_packed[group], ((uint32_t)idx) << (pos_in_group * 3)); + __syncthreads(); - const int qs_offset = sub * 48; - if (pos_in_group < 3) { - uint32_t packed = s_packed[sub * 16 + group]; - blk_dst->qs[qs_offset + group * 3 + pos_in_group] = - (uint8_t)((packed >> (pos_in_group * 8)) & 0xFF); - } + if (pos_in_group < 3) { + uint32_t packed = s_packed[group]; + blk_dst->qs[group * 3 + pos_in_group] = + (uint8_t)((packed >> (pos_in_group * 8)) & 0xFF); } if (tid == 0) { @@ -695,10 +683,10 @@ __global__ void set_rows_tbq4_0_kernel( char * dst_row = dst + dst_row_idx * nb1 + i02 * nb2 + i03 * nb3; - const int64_t num_blocks_per_row = ne00 / 256; + const int64_t num_blocks_per_row = ne00 / 128; for (int64_t blk = 0; blk < num_blocks_per_row; blk++) { - const float * blk_src = src_row + blk * 256; + const float * blk_src = src_row + blk * 128; block_tbq4_0 * blk_dst = (block_tbq4_0 *)dst_row + blk; // Phase 1: L2 norm @@ -707,13 +695,16 @@ __global__ void set_rows_tbq4_0_kernel( sum_sq = warp_reduce_sum(sum_sq); - __shared__ float warp_sums[8]; + __shared__ float warp_sums[4]; if (tid % 32 == 0) warp_sums[tid / 32] = sum_sq; __syncthreads(); __shared__ float s_norm; - if (tid < 8) { - float s = warp_reduce_sum<8>(warp_sums[tid]); + if (tid < 32) { + float s = (tid < 4) ? warp_sums[tid] : 0.0f; + for (int offset = 2; offset > 0; offset >>= 1) { + s += __shfl_down_sync(0xFFFFFFFF, s, offset); + } if (tid == 0) { float n = sqrtf(s); s_norm = (n < 1e-10f) ? 1e-10f : n; @@ -724,40 +715,36 @@ __global__ void set_rows_tbq4_0_kernel( float norm = s_norm; float inv_norm = 1.0f / norm; - // Phase 2: two sub-blocks - for (int sub = 0; sub < 2; sub++) { - int elem = sub * 128 + (tid % 128); - if (tid >= 128 && sub == 0) continue; - if (tid < 128 && sub == 1) continue; - int ltid = tid % 128; + // Phase 2 + int elem = tid; + int ltid = tid; - float unit_val = blk_src[elem] * inv_norm; + float unit_val = blk_src[elem] * inv_norm; - __shared__ float s_unit[256]; - s_unit[elem] = unit_val; - __syncthreads(); + __shared__ float s_unit[128]; + s_unit[elem] = unit_val; + __syncthreads(); - float rotated = 0.0f; - for (int j = 0; j < 128; j++) { - rotated += Q_rot[ltid * 128 + j] * s_unit[sub * 128 + j]; - } + float rotated = 0.0f; + for (int j = 0; j < 128; j++) { + rotated += Q_rot[ltid * 128 + j] * s_unit[j]; + } - float scaled = rotated * 16.0f; - int idx = 15; - #pragma unroll - for (int b = 0; b < 15; b++) { - if (scaled < d_boundaries_4bit[b]) { idx = b; break; } - } + float scaled = rotated * 11.3137085f; + int idx = 15; + #pragma unroll + for (int b = 0; b < 15; b++) { + if (scaled < d_boundaries_4bit[b]) { idx = b; break; } + } - // Pack 4-bit nibbles - __shared__ uint8_t s_indices[256]; - s_indices[elem] = (uint8_t)idx; - __syncthreads(); + // Pack 4-bit nibbles + __shared__ uint8_t s_indices[128]; + s_indices[elem] = (uint8_t)idx; + __syncthreads(); - if (ltid % 2 == 0) { - const int qs_idx = elem / 2; - blk_dst->qs[qs_idx] = s_indices[elem] | (s_indices[elem + 1] << 4); - } + if (ltid % 2 == 0) { + const int qs_idx = elem / 2; + blk_dst->qs[qs_idx] = s_indices[elem] | (s_indices[elem + 1] << 4); } if (tid == 0) { @@ -790,7 +777,7 @@ void ggml_set_rows_tbq3_0_cuda( // ne11/ne12 for index wrapping — for simple KV cache, these are 1 // We pass them as the last two args - set_rows_tbq3_0_kernel<<>>( + set_rows_tbq3_0_kernel<<>>( src0_d, src1_d, dst_d, d_turboq_Q, ne00, ne01, ne02, ne03, s01, s02, s03, s10, s11, s12, @@ -817,7 +804,7 @@ void ggml_set_rows_tbq4_0_cuda( const int64_t s11 = nb11 / sizeof(idx_t); const int64_t s12 = nb12 / sizeof(idx_t); - set_rows_tbq4_0_kernel<<>>( + set_rows_tbq4_0_kernel<<>>( src0_d, src1_d, dst_d, d_turboq_Q, ne00, ne01, ne02, ne03, s01, s02, s03, s10, s11, s12, diff --git a/ggml/src/ggml-turboq.c b/ggml/src/ggml-turboq.c index e8bdc46579dd..be8344194039 100644 --- a/ggml/src/ggml-turboq.c +++ b/ggml/src/ggml-turboq.c @@ -365,7 +365,7 @@ static float * turboq_get_scratch3(int64_t n) { #define TURBOQ_KV_DIM 128 static inline float turboq_block_scale_up(void) { - return sqrtf((float) QK_K); + return sqrtf((float) TBQ_BLK_SIZE); } static inline float turboq_block_scale_down(void) { @@ -448,19 +448,19 @@ static void unpack_3bit(uint8_t * indices, const uint8_t * src, int64_t n) { // --------------------------------------------------------------------------- void quantize_row_tbq3_0_ref(const float * GGML_RESTRICT x, block_tbq3_0 * GGML_RESTRICT y, int64_t k) { - assert(k % QK_K == 0); - const int64_t nb = k / QK_K; - float * unit = turboq_get_scratch(QK_K); - float * rotated = turboq_get_scratch2(QK_K); + assert(k % TBQ_BLK_SIZE == 0); + const int64_t nb = k / TBQ_BLK_SIZE; + float * unit = turboq_get_scratch(TBQ_BLK_SIZE); + float * rotated = turboq_get_scratch2(TBQ_BLK_SIZE); const uint64_t seed = turboq_seed_from_row(0); const float scale_up = turboq_block_scale_up(); - uint8_t indices[QK_K]; + uint8_t indices[TBQ_BLK_SIZE]; for (int64_t b = 0; b < nb; b++) { - const float * xb = x + b * QK_K; + const float * xb = x + b * TBQ_BLK_SIZE; float norm_sq = 0.0f; - for (int64_t j = 0; j < QK_K; ++j) { + for (int64_t j = 0; j < TBQ_BLK_SIZE; ++j) { norm_sq += xb[j] * xb[j]; } @@ -469,51 +469,51 @@ void quantize_row_tbq3_0_ref(const float * GGML_RESTRICT x, block_tbq3_0 * GGML_ norm = 1e-10f; } - for (int64_t j = 0; j < QK_K; ++j) { + for (int64_t j = 0; j < TBQ_BLK_SIZE; ++j) { unit[j] = xb[j] / norm; } turboq_rotate_block_forward(rotated, unit, seed); - for (int64_t j = 0; j < QK_K; j++) { + for (int64_t j = 0; j < TBQ_BLK_SIZE; j++) { float val = rotated[j] * scale_up; indices[j] = quantize_scalar_3bit(val); } - pack_3bit(y[b].qs, indices, QK_K); + pack_3bit(y[b].qs, indices, TBQ_BLK_SIZE); y[b].d = GGML_FP32_TO_FP16(norm); } } void dequantize_row_tbq3_0(const block_tbq3_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { - assert(k % QK_K == 0); - const int64_t nb = k / QK_K; - float * rotated = turboq_get_scratch(QK_K); - float * unit_approx = turboq_get_scratch2(QK_K); + assert(k % TBQ_BLK_SIZE == 0); + const int64_t nb = k / TBQ_BLK_SIZE; + float * rotated = turboq_get_scratch(TBQ_BLK_SIZE); + float * unit_approx = turboq_get_scratch2(TBQ_BLK_SIZE); const uint64_t seed = turboq_seed_from_row(0); const float scale_down = turboq_block_scale_down(); - uint8_t indices[QK_K]; + uint8_t indices[TBQ_BLK_SIZE]; for (int64_t b = 0; b < nb; b++) { const float norm = GGML_FP16_TO_FP32(x[b].d); - unpack_3bit(indices, x[b].qs, QK_K); - for (int64_t j = 0; j < QK_K; j++) { + unpack_3bit(indices, x[b].qs, TBQ_BLK_SIZE); + for (int64_t j = 0; j < TBQ_BLK_SIZE; j++) { rotated[j] = turboq_codebook_3bit[indices[j]] * scale_down; } turboq_rotate_block_inverse(unit_approx, rotated, seed); - for (int64_t j = 0; j < QK_K; ++j) { - y[b * QK_K + j] = unit_approx[j] * norm; + for (int64_t j = 0; j < TBQ_BLK_SIZE; ++j) { + y[b * TBQ_BLK_SIZE + j] = unit_approx[j] * norm; } } } size_t quantize_tbq3_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { (void)imatrix; - assert(n_per_row % QK_K == 0); + assert(n_per_row % TBQ_BLK_SIZE == 0); - const int64_t nb_per_row = n_per_row / QK_K; + const int64_t nb_per_row = n_per_row / TBQ_BLK_SIZE; const size_t row_size = nb_per_row * sizeof(block_tbq3_0); for (int64_t row = 0; row < nrows; row++) { @@ -529,18 +529,18 @@ size_t quantize_tbq3_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst // --------------------------------------------------------------------------- void quantize_row_tbq4_0_ref(const float * GGML_RESTRICT x, block_tbq4_0 * GGML_RESTRICT y, int64_t k) { - assert(k % QK_K == 0); - const int64_t nb = k / QK_K; - float * unit = turboq_get_scratch(QK_K); - float * rotated = turboq_get_scratch2(QK_K); + assert(k % TBQ_BLK_SIZE == 0); + const int64_t nb = k / TBQ_BLK_SIZE; + float * unit = turboq_get_scratch(TBQ_BLK_SIZE); + float * rotated = turboq_get_scratch2(TBQ_BLK_SIZE); const uint64_t seed = turboq_seed_from_row(0); const float scale_up = turboq_block_scale_up(); for (int64_t b = 0; b < nb; b++) { - const float * xb = x + b * QK_K; + const float * xb = x + b * TBQ_BLK_SIZE; float norm_sq = 0.0f; - for (int64_t j = 0; j < QK_K; ++j) { + for (int64_t j = 0; j < TBQ_BLK_SIZE; ++j) { norm_sq += xb[j] * xb[j]; } @@ -549,14 +549,14 @@ void quantize_row_tbq4_0_ref(const float * GGML_RESTRICT x, block_tbq4_0 * GGML_ norm = 1e-10f; } - for (int64_t j = 0; j < QK_K; ++j) { + for (int64_t j = 0; j < TBQ_BLK_SIZE; ++j) { unit[j] = xb[j] / norm; } turboq_rotate_block_forward(rotated, unit, seed); memset(y[b].qs, 0, sizeof(y[b].qs)); - for (int64_t j = 0; j < QK_K; j++) { + for (int64_t j = 0; j < TBQ_BLK_SIZE; j++) { float val = rotated[j] * scale_up; uint8_t idx = quantize_scalar_4bit(val); if (j % 2 == 0) { @@ -570,17 +570,17 @@ void quantize_row_tbq4_0_ref(const float * GGML_RESTRICT x, block_tbq4_0 * GGML_ } void dequantize_row_tbq4_0(const block_tbq4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { - assert(k % QK_K == 0); - const int64_t nb = k / QK_K; - float * rotated = turboq_get_scratch(QK_K); - float * unit_approx = turboq_get_scratch2(QK_K); + assert(k % TBQ_BLK_SIZE == 0); + const int64_t nb = k / TBQ_BLK_SIZE; + float * rotated = turboq_get_scratch(TBQ_BLK_SIZE); + float * unit_approx = turboq_get_scratch2(TBQ_BLK_SIZE); const uint64_t seed = turboq_seed_from_row(0); const float scale_down = turboq_block_scale_down(); for (int64_t b = 0; b < nb; b++) { const float norm = GGML_FP16_TO_FP32(x[b].d); - for (int64_t j = 0; j < QK_K; j++) { + for (int64_t j = 0; j < TBQ_BLK_SIZE; j++) { uint8_t idx; if (j % 2 == 0) { idx = x[b].qs[j / 2] & 0x0F; @@ -592,17 +592,17 @@ void dequantize_row_tbq4_0(const block_tbq4_0 * GGML_RESTRICT x, float * GGML_RE turboq_rotate_block_inverse(unit_approx, rotated, seed); - for (int64_t j = 0; j < QK_K; ++j) { - y[b * QK_K + j] = unit_approx[j] * norm; + for (int64_t j = 0; j < TBQ_BLK_SIZE; ++j) { + y[b * TBQ_BLK_SIZE + j] = unit_approx[j] * norm; } } } size_t quantize_tbq4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix) { (void)imatrix; - assert(n_per_row % QK_K == 0); + assert(n_per_row % TBQ_BLK_SIZE == 0); - const int64_t nb_per_row = n_per_row / QK_K; + const int64_t nb_per_row = n_per_row / TBQ_BLK_SIZE; const size_t row_size = nb_per_row * sizeof(block_tbq4_0); for (int64_t row = 0; row < nrows; row++) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index cf64bf320693..5d881c693984 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -924,7 +924,7 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { }, [GGML_TYPE_TBQ3_0] = { .type_name = "tbq3_0", - .blck_size = QK_K, + .blck_size = TBQ_BLK_SIZE, .type_size = sizeof(block_tbq3_0), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_tbq3_0, @@ -932,7 +932,7 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { }, [GGML_TYPE_TBQ4_0] = { .type_name = "tbq4_0", - .blck_size = QK_K, + .blck_size = TBQ_BLK_SIZE, .type_size = sizeof(block_tbq4_0), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_tbq4_0, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 6fcc301ed390..73e1281cdfdd 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2027,11 +2027,13 @@ ggml_tensor * llm_graph_context::build_attn_mha( GGML_ASSERT(n_head_kv > 0); GGML_ASSERT(n_embd_k_gqa % n_head_kv == 0); - // Codebook-only dequant: skip rotation (values stay in rotated domain) - k = ggml_cast(ctx0, k, tbq_attn_type); - int32_t codebook_flag = 1; - memcpy(k->op_params, &codebook_flag, sizeof(int32_t)); - cb(k, use_flash_attn ? "k_tbq_cb_f16" : "k_tbq_cb_f32", il); + if (!use_flash_attn) { + // Codebook-only dequant: skip rotation (values stay in rotated domain) + k = ggml_cast(ctx0, k, tbq_attn_type); + int32_t codebook_flag = 1; + memcpy(k->op_params, &codebook_flag, sizeof(int32_t)); + cb(k, "k_tbq_cb_f32", il); + } k = ggml_reshape_4d(ctx0, k, n_embd_k_gqa / n_head_kv, n_head_kv, k->ne[1], k->ne[2]); cb(k, "k_tbq_reshaped", il); @@ -2044,11 +2046,13 @@ ggml_tensor * llm_graph_context::build_attn_mha( GGML_ASSERT(n_head_kv > 0); GGML_ASSERT(n_embd_v_gqa % n_head_kv == 0); - // Codebook-only dequant: skip rotation - v = ggml_cast(ctx0, v, tbq_attn_type); - int32_t codebook_flag = 1; - memcpy(v->op_params, &codebook_flag, sizeof(int32_t)); - cb(v, use_flash_attn ? "v_tbq_cb_f16" : "v_tbq_cb_f32", il); + if (!use_flash_attn) { + // Codebook-only dequant: skip rotation + v = ggml_cast(ctx0, v, tbq_attn_type); + int32_t codebook_flag = 1; + memcpy(v->op_params, &codebook_flag, sizeof(int32_t)); + cb(v, "v_tbq_cb_f32", il); + } v = ggml_reshape_4d(ctx0, v, n_embd_v_gqa / n_head_kv, n_head_kv, v->ne[1], v->ne[2]); cb(v, "v_tbq_reshaped", il);