diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 113eace357d..6d35c8c84b9 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -283,13 +283,9 @@ static_assert(sizeof(block_turbo3_0) == sizeof(ggml_half) + QK_TURBO3/4 + QK_TUR // TurboQuant 4-bit: 3-bit PolarQuant indices + 1-bit QJL signs // TURBO4_USE_4BIT: switch between 4-bit PolarQuant (new) and 3-bit+QJL (legacy) -// Default: 4-bit on Metal (validated), legacy on CUDA (until ported) +// Default: 4-bit on all backends (Metal + CUDA validated) #ifndef TURBO4_USE_4BIT -# if defined(GGML_METAL) -# define TURBO4_USE_4BIT 1 -# else -# define TURBO4_USE_4BIT 0 -# endif +# define TURBO4_USE_4BIT 1 #endif #define QK_TURBO4 128 diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 65c7c9931af..0ee540cb1dc 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -212,6 +212,9 @@ static void ggml_vec_dot_turbo3_0_f32(int n, float * GGML_RESTRICT s, size_t bs, static void ggml_vec_dot_turbo2_0_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +static void ggml_vec_dot_turbo4_0_f32(int n, float * GGML_RESTRICT s, size_t bs, + const void * GGML_RESTRICT vx, size_t bx, + const void * GGML_RESTRICT vy, size_t by, int nrc); static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = { @@ -414,6 +417,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_F32, .nrows = 1, }, + [GGML_TYPE_TURBO4_0] = { + .from_float = (ggml_from_float_t) quantize_row_turbo4_0_ref, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_turbo4_0_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, }; const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) { @@ -3379,6 +3388,25 @@ static void ggml_vec_dot_turbo2_0_f32(int n, float * GGML_RESTRICT s, size_t bs, *s = sum; } +// TurboQuant4 vec_dot: dequantize turbo4 block to f32, then dot with f32 operand. +static void ggml_vec_dot_turbo4_0_f32(int n, float * GGML_RESTRICT s, size_t bs, + const void * GGML_RESTRICT vx, size_t bx, + const void * GGML_RESTRICT vy, size_t by, int nrc) { + GGML_ASSERT(nrc == 1); + GGML_UNUSED(bs); GGML_UNUSED(bx); GGML_UNUSED(by); GGML_UNUSED(nrc); + + float tmp[4096]; + GGML_ASSERT(n <= 4096); + ggml_get_type_traits(GGML_TYPE_TURBO4_0)->to_float(vx, tmp, n); + + const float * y = (const float *)vy; + float sum = 0.0f; + for (int i = 0; i < n; i++) { + sum += tmp[i] * y[i]; + } + *s = sum; +} + void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) { memcpy(y, x, n * sizeof(float)); } diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 94e48651388..f2a797224af 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -126,7 +126,16 @@ if (CUDAToolkit_FOUND) template-instances/fattn-vec-instance-q8_0-turbo3_0.cu template-instances/fattn-vec-instance-turbo2_0-turbo2_0.cu template-instances/fattn-vec-instance-turbo2_0-q8_0.cu - template-instances/fattn-vec-instance-q8_0-turbo2_0.cu) + template-instances/fattn-vec-instance-q8_0-turbo2_0.cu + template-instances/fattn-vec-instance-turbo3_0-turbo2_0.cu + template-instances/fattn-vec-instance-turbo2_0-turbo3_0.cu + template-instances/fattn-vec-instance-turbo4_0-turbo4_0.cu + template-instances/fattn-vec-instance-turbo4_0-q8_0.cu + template-instances/fattn-vec-instance-q8_0-turbo4_0.cu + template-instances/fattn-vec-instance-turbo4_0-turbo3_0.cu + template-instances/fattn-vec-instance-turbo3_0-turbo4_0.cu + template-instances/fattn-vec-instance-turbo4_0-turbo2_0.cu + template-instances/fattn-vec-instance-turbo2_0-turbo4_0.cu) endif() ggml_add_backend_library(ggml-cuda diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index e7269f95931..98d1ad5dad8 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -761,6 +761,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_block_cont_cuda; case GGML_TYPE_TURBO2_0: return dequantize_block_cont_cuda; + case GGML_TYPE_TURBO4_0: + return dequantize_block_cont_cuda; case GGML_TYPE_F32: return convert_unary_cont_cuda; case GGML_TYPE_BF16: @@ -818,6 +820,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_block_cont_cuda; case GGML_TYPE_TURBO2_0: return dequantize_block_cont_cuda; + case GGML_TYPE_TURBO4_0: + return dequantize_block_cont_cuda; case GGML_TYPE_F16: return convert_unary_cont_cuda; case GGML_TYPE_BF16: @@ -845,6 +849,8 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { return dequantize_block_cuda; case GGML_TYPE_TURBO2_0: return dequantize_block_cuda; + case GGML_TYPE_TURBO4_0: + return dequantize_block_cuda; case GGML_TYPE_BF16: return convert_unary_cuda; default: @@ -891,6 +897,8 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { return dequantize_block_cuda; case GGML_TYPE_TURBO2_0: return dequantize_block_cuda; + case GGML_TYPE_TURBO4_0: + return dequantize_block_cuda; case GGML_TYPE_BF16: return convert_unary_cuda; default: diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index 513d748522a..ceb2ad62e16 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -77,6 +77,15 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in v.y *= d; } +// Turbo4: 4-bit PolarQuant (nibble packed), block size 128 +// iqs is the element index within the block (even), produces elements iqs and iqs+1 +static __device__ __forceinline__ void dequantize_turbo4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_turbo4_0 * x = (const block_turbo4_0 *) vx; + const float norm = __half2float(x[ib].norm); + v.x = turbo4_dequant_element(&x[ib], iqs + 0, norm); + v.y = turbo4_dequant_element(&x[ib], iqs + 1, norm); +} + // Turbo3: 3-bit PolarQuant (2-bit qs + 1-bit sign), block size 32 // iqs is the element index within the block (even), produces elements iqs and iqs+1 static __device__ __forceinline__ void dequantize_turbo3_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 66e5ba620a1..06d97da0823 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -398,6 +398,55 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_turbo2_0( return sum; } +// Turbo4 KQ dot product: dequantize K from turbo4 blocks, dot with Q (float2/half2) +// 4-bit nibble packed: qs[j/2] >> ((j%2)*4) & 0xF +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_turbo4_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_turbo4_0 * K_turbo = (const block_turbo4_0 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { + const int k_KQ = k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne + k_KQ_1; + + const int elem0 = k_KQ * 2; // always even + const int ib = elem0 / QK_TURBO4; // block index + const int j0 = elem0 % QK_TURBO4; // always even + + const float norm = __half2float(K_turbo[ib].norm); + // Both j0 and j0+1 are adjacent nibbles: j0/2 == (j0+1)/2 when j0 is even + const uint8_t qs_byte = K_turbo[ib].qs[j0 / 2]; + + const uint8_t idx0 = (qs_byte >> 0) & 0xF; // low nibble = j0 + const uint8_t idx1 = (qs_byte >> 4) & 0xF; // high nibble = j0+1 + + float2 kv; + kv.x = TURBO_CENTROIDS_4BIT[idx0] * norm; + kv.y = TURBO_CENTROIDS_4BIT[idx1] * norm; + +#ifdef V_DOT2_F32_F16_AVAILABLE + const half2 qv = ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]; + ggml_cuda_mad(sum, make_float2(kv.x, kv.y), __half22float2(qv)); +#else + const float2 qv = ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]; + sum += kv.x * qv.x + kv.y * qv.y; +#endif // V_DOT2_F32_F16_AVAILABLE + } + } + + return sum; +} + template static __device__ __forceinline__ void quantize_q8_1_to_shared( const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) { @@ -807,6 +856,66 @@ static __device__ __forceinline__ void dequantize_V_turbo2_0(const void * __rest } } +// Turbo4 V dequantize: extract `ne` float/half values at position i0. +// 4-bit nibble packed, block size 128. +template +static __device__ __forceinline__ void dequantize_V_turbo4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_turbo4_0 * x = (const block_turbo4_0 *) vx; + + const int64_t ib = i0 / QK_TURBO4; + const int j0 = i0 % QK_TURBO4; + const float norm = __half2float(x[ib].norm); + + static_assert(ne == 2 || ne == 4, "bad ne"); + + if constexpr (ne == 4) { + // j0 is always a multiple of 4 from the VEC kernel access pattern. + // 4 consecutive elements span 2 qs bytes: j0/2 and j0/2+1. + const uint8_t qs_byte0 = x[ib].qs[j0 / 2]; // elements j0, j0+1 + const uint8_t qs_byte1 = x[ib].qs[j0 / 2 + 1]; // elements j0+2, j0+3 + + const uint8_t idx0 = (qs_byte0 >> 0) & 0xF; + const uint8_t idx1 = (qs_byte0 >> 4) & 0xF; + const uint8_t idx2 = (qs_byte1 >> 0) & 0xF; + const uint8_t idx3 = (qs_byte1 >> 4) & 0xF; + +#ifdef FP16_AVAILABLE + if constexpr (std::is_same_v) { + ((half2 *) dst)[0] = make_half2( + __float2half(TURBO_CENTROIDS_4BIT[idx0] * norm), + __float2half(TURBO_CENTROIDS_4BIT[idx1] * norm)); + ((half2 *) dst)[1] = make_half2( + __float2half(TURBO_CENTROIDS_4BIT[idx2] * norm), + __float2half(TURBO_CENTROIDS_4BIT[idx3] * norm)); + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + ((float2 *) dst)[0] = make_float2( + TURBO_CENTROIDS_4BIT[idx0] * norm, + TURBO_CENTROIDS_4BIT[idx1] * norm); + ((float2 *) dst)[1] = make_float2( + TURBO_CENTROIDS_4BIT[idx2] * norm, + TURBO_CENTROIDS_4BIT[idx3] * norm); + } else { + static_assert(std::is_same_v, "unsupported type"); + } + } else { // ne == 2 +#ifdef FP16_AVAILABLE + if constexpr (std::is_same_v) { + float v0 = turbo4_dequant_element(&x[ib], j0, norm); + float v1 = turbo4_dequant_element(&x[ib], j0+1, norm); + ((half2 *) dst)[0] = make_half2(__float2half(v0), __float2half(v1)); + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + ((float *) dst)[0] = turbo4_dequant_element(&x[ib], j0, norm); + ((float *) dst)[1] = turbo4_dequant_element(&x[ib], j0+1, norm); + } else { + static_assert(std::is_same_v, "unsupported type"); + } + } +} + template constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { if constexpr (type_K == GGML_TYPE_F16) { @@ -827,6 +936,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { return vec_dot_fattn_vec_KQ_turbo3_0; } else if constexpr (type_K == GGML_TYPE_TURBO2_0) { return vec_dot_fattn_vec_KQ_turbo2_0; + } else if constexpr (type_K == GGML_TYPE_TURBO4_0) { + return vec_dot_fattn_vec_KQ_turbo4_0; } else { static_assert(type_K == -1, "bad type"); return nullptr; @@ -853,6 +964,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() { return dequantize_V_turbo3_0; } else if constexpr (type_V == GGML_TYPE_TURBO2_0) { return dequantize_V_turbo2_0; + } else if constexpr (type_V == GGML_TYPE_TURBO4_0) { + return dequantize_V_turbo4_0; } else { static_assert(type_V == -1, "bad type"); return nullptr; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index fff70c8eb89..0905448b581 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -71,6 +71,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); } @@ -85,6 +90,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); } @@ -94,6 +104,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 64, 256, 1, 32, 160, 128, 64, 1, false); + // TODO tune specifically for Volta return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); } @@ -107,6 +122,10 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + // TODO tune specifically for RDNA return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); } @@ -1826,3 +1845,8 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); + +// D=640: padded turbo KV cache for GLM-4.7 Flash (ncols2=16 only) +extern DECL_FATTN_MMA_F16_CASE(640, 512, 1, 16); +extern DECL_FATTN_MMA_F16_CASE(640, 512, 2, 16); +extern DECL_FATTN_MMA_F16_CASE(640, 512, 4, 16); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 3fcb09b7a2b..88b71bc529a 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -42,6 +42,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == 512); ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); } break; + case 640: { + GGML_ASSERT(V->ne[0] == 512); + ggml_cuda_flash_attn_ext_tile_case<640, 512>(ctx, dst); + } break; default: { GGML_ABORT("Unsupported head size"); } break; diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index f3fa80ab23d..d81012ef7a7 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -72,6 +72,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(640, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(640, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(640, 512, 16, 256, 2, 64, 64) + return 0; } @@ -128,6 +132,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(640, 512, 4, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(640, 512, 8, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(640, 512, 16, 256, 2, 32, 64) + return 0; } @@ -192,6 +200,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(640, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(640, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(640, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(640, 512, 32, 512, 1, 128, 64) + return 0; } @@ -256,6 +269,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(640, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(640, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(640, 512, 16, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(640, 512, 32, 256, 2, 128, 64) + return 0; } @@ -1187,7 +1205,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm const int gqa_ratio = Q->ne[2] / K->ne[2]; // On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases. - // However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented. + // However, for DKQ == 576/640, DV == 512 only the kernel variant with GQA optimizations is implemented. const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc); const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; @@ -1256,3 +1274,4 @@ extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); extern DECL_FATTN_TILE_CASE(256, 256); extern DECL_FATTN_TILE_CASE(576, 512); +extern DECL_FATTN_TILE_CASE(640, 512); diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 78170cbab35..40270a192dd 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -76,15 +76,15 @@ static __global__ void flash_attn_ext_vec( constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); // Turbo3 uses the float Q path (like f16/bf16), not q8_1 integer path - constexpr bool K_is_unquantized = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16 || type_K == GGML_TYPE_TURBO3_0 || type_K == GGML_TYPE_TURBO2_0); - constexpr bool V_is_unquantized = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16 || type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0); + constexpr bool K_is_unquantized = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16 || type_K == GGML_TYPE_TURBO3_0 || type_K == GGML_TYPE_TURBO2_0 || type_K == GGML_TYPE_TURBO4_0); + constexpr bool V_is_unquantized = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16 || type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0); constexpr int nthreads_KQ = K_is_unquantized ? 128 / cpy_nb : nthreads_KQ_q; - constexpr int nthreads_V = V_is_unquantized ? ((type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0) ? nthreads_V_q : 128 / cpy_nb) : nthreads_V_q; + constexpr int nthreads_V = V_is_unquantized ? ((type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0) ? nthreads_V_q : 128 / cpy_nb) : nthreads_V_q; static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); - constexpr int V_rows_per_thread = V_is_unquantized ? ((type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0) ? 4 : 2*cpy_ne) : 4; + constexpr int V_rows_per_thread = V_is_unquantized ? ((type_V == GGML_TYPE_TURBO3_0 || type_V == GGML_TYPE_TURBO2_0 || type_V == GGML_TYPE_TURBO4_0) ? 4 : 2*cpy_ne) : 4; constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ(); @@ -659,3 +659,44 @@ extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO2_0, GGML_TYPE_Q8_0); extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); + +// Mixed turbo3/turbo2 KV cache types +extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0); +extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0); +extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0); + +extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0); +extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0); +extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0); + +// TurboQuant4 — turbo4 K + turbo4 V (KV cache uses same type) +extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0); +extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0); +extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0); + +// Mixed turbo4/q8_0 KV cache types +extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO4_0, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO4_0, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_Q8_0); + +extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); +extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); +extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); + +// Mixed turbo4/turbo3 KV cache types +extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO3_0); +extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO3_0); +extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO3_0); + +extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO4_0); +extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO4_0); +extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO4_0); + +// Mixed turbo4/turbo2 KV cache types +extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0); +extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0); +extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0); + +extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0); +extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0); +extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 7e8ff3d8fb1..b5ed581816a 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -191,6 +191,17 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); } } break; + case 640: { + // Padded turbo KV cache for GLM-4.7 Flash (K head_dim=576 zero-padded to 640). + // D=640 shared memory (Q storage = ncols*(DKQ/2+4)*4) exceeds hardware limit at ncols1>=4. + // Cap at ncols1=2 (ncols=32): Q=32*324*4=41KB + KV≈37KB = ~78KB total. + GGML_ASSERT(V->ne[0] == 512); + if (Q->ne[1] <= 1) { + ggml_cuda_flash_attn_ext_mma_f16_case<640, 512, 1, 16>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_mma_f16_case<640, 512, 2, 16>(ctx, dst); + } + } break; default: GGML_ABORT("fatal error"); break; @@ -294,6 +305,25 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0) + // Mixed turbo3/turbo2 KV cache types + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0) + + // TurboQuant4 KV cache types (always enabled) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0) + + // Mixed turbo4/q8_0 KV cache types + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0) + + // Mixed turbo4/turbo3 KV cache types + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO4_0) + + // Mixed turbo4/turbo2 KV cache types + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0) + GGML_ABORT("fatal error"); } @@ -355,6 +385,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } break; case 576: + case 640: if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; } @@ -368,12 +399,11 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const #ifndef GGML_CUDA_FA_ALL_QUANTS if (K->type != V->type) { - // Allow mixed turbo/q8_0 KV types - const bool turbo_q8_mix = (K->type == GGML_TYPE_TURBO3_0 && V->type == GGML_TYPE_Q8_0) || - (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_TURBO3_0) || - (K->type == GGML_TYPE_TURBO2_0 && V->type == GGML_TYPE_Q8_0) || - (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_TURBO2_0); - if (!turbo_q8_mix) { + // Allow mixed turbo KV types (any combination of turbo2, turbo3, q8_0) + auto is_turbo = [](ggml_type t) { + return t == GGML_TYPE_TURBO2_0 || t == GGML_TYPE_TURBO3_0 || t == GGML_TYPE_TURBO4_0 || t == GGML_TYPE_Q8_0; + }; + if (!is_turbo(K->type) || !is_turbo(V->type)) { return BEST_FATTN_KERNEL_NONE; } } @@ -405,6 +435,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } break; + case GGML_TYPE_TURBO4_0: + // turbo4 VEC kernel instantiated for D in {64, 128, 256}. + if (K->ne[0] % 64 != 0) { + return BEST_FATTN_KERNEL_NONE; + } + break; default: return BEST_FATTN_KERNEL_NONE; } @@ -443,7 +479,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { int gqa_ratio_eff = 1; - const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; + const int ncols2_max = (Q->ne[0] == 576 || Q->ne[0] == 640) ? 16 : 8; while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { gqa_ratio_eff *= 2; } @@ -457,7 +493,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) { + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576 && Q->ne[0] != 640) { if (can_use_vector_kernel && Q->ne[1] <= 2) { return BEST_FATTN_KERNEL_VEC; } @@ -479,7 +515,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } } int gqa_ratio_eff = 1; - const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; + const int ncols2_max = (Q->ne[0] == 576 || Q->ne[0] == 640) ? 16 : 8; while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { gqa_ratio_eff *= 2; } @@ -490,7 +526,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use MFMA flash attention for CDNA (MI100+): - if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) { + if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576 && Q->ne[0] != 640) { const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); // MMA vs tile crossover benchmarked on MI300X @ d32768: // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c89e64ef1c6..b5774abd09b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4839,14 +4839,18 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } break; case GGML_OP_SET_ROWS: { - // turbo types require head_dim divisible by 64 (supports 64 and 128 WHT groups) + // turbo types require head_dim divisible by appropriate group size if ((op->type == GGML_TYPE_TURBO3_0 || op->type == GGML_TYPE_TURBO2_0) && op->src[0]->ne[0] % 64 != 0) { return false; } + // turbo4 block size is 128, so head_dim must be divisible by 128 + if (op->type == GGML_TYPE_TURBO4_0 && op->src[0]->ne[0] % 128 != 0) { + return false; + } return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL || - op->type == GGML_TYPE_TURBO3_0 || op->type == GGML_TYPE_TURBO2_0) && + op->type == GGML_TYPE_TURBO3_0 || op->type == GGML_TYPE_TURBO2_0 || op->type == GGML_TYPE_TURBO4_0) && op->src[0]->type == GGML_TYPE_F32 && (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); } break; diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 569185c4a2b..4e82676ce16 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -928,6 +928,203 @@ static void set_rows_cuda_turbo2( } } +// ---- TurboQuant4 set_rows: 128-element groups with WHT rotation + 4-bit quantization ---- +// +// turbo4 block size IS the WHT group size (128), so 1 CUDA block = 1 turbo4 block. +// 128 threads per block, thread j handles element j. +// 4-bit centroids (16 values), nibble packed: qs[j/2] |= (idx & 0xF) << ((j%2)*4) + +template +__launch_bounds__(128) +static __global__ void k_set_rows_turbo4( + const float * __restrict__ src0, + const idx_t * __restrict__ src1, + block_turbo4_0 * __restrict__ dst, + const int64_t ne00, + const int64_t ne01, + const int64_t ne10, + const int64_t ne11, + const int64_t ne12, + const int64_t ne13, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t s10, + const int64_t s11, + const int64_t s12, + const int64_t s1, + const int64_t s2, + const int64_t s3) { + + // blockIdx.x = flat block index; threadIdx.x = element within block (0..127) + const int j = threadIdx.x; + + // Decode blockIdx.x → (i_blk, i01, i02, i03) + const int64_t n_blocks_per_row = ne00 / QK_TURBO4; + const int64_t g = blockIdx.x; + const int64_t i_blk = g % n_blocks_per_row; + int64_t tmp = g / n_blocks_per_row; + const int64_t i01 = tmp % ne01; + tmp = tmp / ne01; + const int64_t i02 = tmp % ne12; + const int64_t i03 = tmp / ne12; + + const int64_t i12 = i02; + const int64_t i11 = i01 % ne11; + const int64_t i10 = i01; + + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + const float * src_row = src0 + i01*s01 + i02*s02 + i03*s03; + block_turbo4_0 * dst_row_ptr = (block_turbo4_0 *)((char *)dst + dst_row*s1 + i02*s2 + i03*s3); + block_turbo4_0 * blk = dst_row_ptr + i_blk; + + // ---- Step 1: Load element j (coalesced) ---- + __shared__ float x[128]; + x[j] = src_row[i_blk * QK_TURBO4 + j]; + __syncthreads(); + + // ---- InnerQ: calibrate on original (unscaled) values ---- + if (d_innerq_calibrating) { + atomicAdd(&d_innerq_sq_accum[j], x[j] * x[j]); + if (j == 0) atomicAdd(&d_innerq_count, 1); + } + + // ---- InnerQ: apply channel scale (only when active) ---- + if (d_innerq_active) { + x[j] *= d_innerq_scale[j]; + } + __syncthreads(); + + // ---- Step 2: Parallel L2 norm ---- + constexpr int n_warps = 128 / WARP_SIZE; // = 4 + __shared__ float warp_accum[n_warps]; + float v = x[j]; + float v2 = v * v; + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) + v2 += __shfl_xor_sync(0xffffffff, v2, offset); + if (j % WARP_SIZE == 0) + warp_accum[j / WARP_SIZE] = v2; + __syncthreads(); + + __shared__ float s_norm_sq; + if (j == 0) { + float total = 0.0f; + for (int w = 0; w < n_warps; w++) total += warp_accum[w]; + s_norm_sq = total; + } + __syncthreads(); + const float grp_norm = sqrtf(s_norm_sq); + const float inv_norm = (grp_norm > 1e-10f) ? 1.0f / grp_norm : 0.0f; + + // ---- Step 3: Normalize ---- + x[j] *= inv_norm; + __syncthreads(); + + // ---- Step 4: Forward WHT (signs1 → butterfly → signs2, normalized) ---- + x[j] *= TURBO_WHT_SIGNS1[j]; + __syncthreads(); + +#define WHT_STAGE_SHARED_T4(h) \ + if (j % (2*(h)) < (h)) { float a = x[j], b = x[j+(h)]; x[j] = a+b; x[j+(h)] = a-b; } \ + __syncthreads(); + + WHT_STAGE_SHARED_T4(1) + WHT_STAGE_SHARED_T4(2) + WHT_STAGE_SHARED_T4(4) + WHT_STAGE_SHARED_T4(8) + WHT_STAGE_SHARED_T4(16) + WHT_STAGE_SHARED_T4(32) + WHT_STAGE_SHARED_T4(64) +#undef WHT_STAGE_SHARED_T4 + + constexpr float inv_sqrt_128 = 0.08838834764831845f; + x[j] = x[j] * inv_sqrt_128 * TURBO_WHT_SIGNS2[j]; + __syncthreads(); + + // ---- Step 5: Quantize element j to 4-bit centroid ---- + const float rv = x[j]; + const uint8_t idx = turbo_nearest_centroid_4bit(rv); + + // ---- Step 6: Pack qs (nibble packed, warp-cooperative) ---- + // 2 elements per byte, 4 bits each. + // Thread pairs (j, j+1) share a qs byte. + const int lane = j % WARP_SIZE; + const uint8_t my_nibble = idx & 0xF; + uint8_t qs_byte = 0; + // Gather nibble from partner thread + uint8_t partner_nibble = __shfl_sync(0xffffffff, my_nibble, lane ^ 1); + if (j % 2 == 0) { + qs_byte = my_nibble | (partner_nibble << 4); + blk->qs[j / 2] = qs_byte; + } + + // ---- Step 7: Reconstruction norm (parallel) ---- + const float c = TURBO_CENTROIDS_4BIT[idx]; + float rc = c * c; + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) + rc += __shfl_xor_sync(0xffffffff, rc, offset); + if (j % WARP_SIZE == 0) + warp_accum[j / WARP_SIZE] = rc; + __syncthreads(); + + __shared__ float s_recon_sq; + if (j == 0) { + float total = 0.0f; + for (int w = 0; w < n_warps; w++) total += warp_accum[w]; + s_recon_sq = total; + } + __syncthreads(); + const float recon_norm = sqrtf(s_recon_sq); + const float corrected_norm = (recon_norm > 1e-10f) ? grp_norm / recon_norm : grp_norm; + + // ---- Step 8: Write corrected norm and zero rnorm (one thread) ---- + if (j == 0) { + blk->norm = __float2half(corrected_norm); + blk->rnorm = __float2half(0.0f); + } + + GGML_UNUSED(ne10); + GGML_UNUSED(ne13); +} + +template +static void set_rows_cuda_turbo4( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, + const ggml_tensor * src1, + ggml_tensor * dst) { + + const float * src0_d = (const float *)src0->data; + const idx_t * src1_d = (const idx_t *)src1->data; + + GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(ne00 % QK_TURBO4 == 0); // must be block-aligned (128) + + cudaStream_t stream = ctx.stream(); + + // turbo4 block size = WHT group size = 128, always + const int64_t n_blocks = ne00 / QK_TURBO4; + + const int64_t s01 = nb01/sizeof(float); + const int64_t s02 = nb02/sizeof(float); + const int64_t s03 = nb03/sizeof(float); + const int64_t s10 = nb10/sizeof(idx_t); + const int64_t s11 = nb11/sizeof(idx_t); + const int64_t s12 = nb12/sizeof(idx_t); + + // InnerQ: check/finalize calibration before kernel launch + turbo_innerq_check_finalize(QK_TURBO4, ne00); + + if (n_blocks > 0) { + const int64_t ne_total = n_blocks * ne01 * ne02 * ne03; + k_set_rows_turbo4<<<(int)ne_total, 128, 0, stream>>>( + src0_d, src1_d, (block_turbo4_0 *)dst->data, + ne00, ne01, ne10, ne11, ne12, ne13, + s01, s02, s03, s10, s11, s12, + nb1, nb2, nb3); + } +} + template static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const src_t * src0_d = (const src_t *)src0->data; @@ -1032,6 +1229,8 @@ static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * s set_rows_cuda_turbo3(ctx, src0, src1, dst); } else if (dst->type == GGML_TYPE_TURBO2_0) { set_rows_cuda_turbo2(ctx, src0, src1, dst); + } else if (dst->type == GGML_TYPE_TURBO4_0) { + set_rows_cuda_turbo4(ctx, src0, src1, dst); } else { GGML_ABORT("unsupported type %s", ggml_type_name(dst->type)); } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu index fb26abeb0da..b643d1d2daa 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu @@ -3,3 +3,4 @@ #include "../fattn-mma-f16.cuh" DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); +DECL_FATTN_MMA_F16_CASE(640, 512, 1, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu index f011a208cd2..2580167449c 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu @@ -3,3 +3,4 @@ #include "../fattn-mma-f16.cuh" DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); +DECL_FATTN_MMA_F16_CASE(640, 512, 2, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu index f5fd0e2369c..585e13d42ff 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu @@ -3,3 +3,4 @@ #include "../fattn-mma-f16.cuh" DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); +DECL_FATTN_MMA_F16_CASE(640, 512, 4, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq640-dv512.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq640-dv512.cu new file mode 100644 index 00000000000..c68a841ad41 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq640-dv512.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(640, 512); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu new file mode 100644 index 00000000000..1646ef05dd1 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu @@ -0,0 +1,7 @@ +// Mixed KV: q8_0 K + turbo4 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo3_0.cu new file mode 100644 index 00000000000..334ad2098b5 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo3_0.cu @@ -0,0 +1,7 @@ +// Mixed KV: turbo2 K + turbo3 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo4_0.cu new file mode 100644 index 00000000000..56177a4cae9 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo4_0.cu @@ -0,0 +1,7 @@ +// Mixed KV: turbo2 K + turbo4 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo2_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo2_0.cu new file mode 100644 index 00000000000..f8a5932537e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo2_0.cu @@ -0,0 +1,7 @@ +// Mixed KV: turbo3 K + turbo2 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo4_0.cu new file mode 100644 index 00000000000..0634006716a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo4_0.cu @@ -0,0 +1,7 @@ +// Mixed KV: turbo3 K + turbo4 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-q8_0.cu new file mode 100644 index 00000000000..7e1e4c74340 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-q8_0.cu @@ -0,0 +1,7 @@ +// Mixed KV: turbo4 K + q8_0 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo2_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo2_0.cu new file mode 100644 index 00000000000..a0cffecf86a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo2_0.cu @@ -0,0 +1,7 @@ +// Mixed KV: turbo4 K + turbo2 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo3_0.cu new file mode 100644 index 00000000000..54e93bd9cd3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo3_0.cu @@ -0,0 +1,7 @@ +// Mixed KV: turbo4 K + turbo3 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO3_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO3_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO3_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo4_0.cu new file mode 100644 index 00000000000..b49c683291d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo4_0.cu @@ -0,0 +1,7 @@ +// TurboQuant4 CUDA flash attention vec kernel instantiation + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0); diff --git a/ggml/src/ggml-cuda/turbo-quant.cuh b/ggml/src/ggml-cuda/turbo-quant.cuh index 8590cae1145..7b848ee3287 100644 --- a/ggml/src/ggml-cuda/turbo-quant.cuh +++ b/ggml/src/ggml-cuda/turbo-quant.cuh @@ -16,6 +16,7 @@ // ---- Quantization ratios for dequantize_block template ---- #define QR_TURBO3 1 // Each dequantize call produces 2 consecutive elements (like q8_0) #define QR_TURBO2 1 // Each dequantize call produces 2 consecutive elements (like q8_0) +#define QR_TURBO4 1 // Each dequantize call produces 2 consecutive elements (like q8_0) // ---- 2-bit centroids (Lloyd-Max for N(0, 1/128)) ---- @@ -291,6 +292,65 @@ static bool turbo_innerq_is_active(void) { return innerq_enabled == 2; } +// ---- 4-bit centroids (Lloyd-Max for N(0, 1/128)) ---- + +static __constant__ float TURBO_CENTROIDS_4BIT[16] = { + -0.173926f, -0.117195f, -0.089527f, -0.068756f, + -0.051262f, -0.035597f, -0.020989f, -0.006938f, + 0.006938f, 0.020989f, 0.035597f, 0.051262f, + 0.068756f, 0.089527f, 0.117195f, 0.173926f +}; + +// ---- Midpoints for nearest 4-bit centroid lookup ---- + +static __constant__ float TURBO_MID_4BIT[15] = { + -0.145561f, -0.103361f, -0.079142f, -0.060009f, + -0.043430f, -0.028293f, -0.013964f, 0.000000f, + 0.013964f, 0.028293f, 0.043430f, 0.060009f, + 0.079142f, 0.103361f, 0.145561f +}; + +// ---- Nearest 4-bit centroid index ---- + +static __device__ __forceinline__ uint8_t turbo_nearest_centroid_4bit(float val) { + if (val < TURBO_MID_4BIT[ 0]) return 0; + else if (val < TURBO_MID_4BIT[ 1]) return 1; + else if (val < TURBO_MID_4BIT[ 2]) return 2; + else if (val < TURBO_MID_4BIT[ 3]) return 3; + else if (val < TURBO_MID_4BIT[ 4]) return 4; + else if (val < TURBO_MID_4BIT[ 5]) return 5; + else if (val < TURBO_MID_4BIT[ 6]) return 6; + else if (val < TURBO_MID_4BIT[ 7]) return 7; + else if (val < TURBO_MID_4BIT[ 8]) return 8; + else if (val < TURBO_MID_4BIT[ 9]) return 9; + else if (val < TURBO_MID_4BIT[10]) return 10; + else if (val < TURBO_MID_4BIT[11]) return 11; + else if (val < TURBO_MID_4BIT[12]) return 12; + else if (val < TURBO_MID_4BIT[13]) return 13; + else if (val < TURBO_MID_4BIT[14]) return 14; + else return 15; +} + +// ---- Per-block quantize for turbo4 (128 elements, expects already-rotated input) ---- + +static __device__ void quantize_f32_turbo4_0_block(const float * __restrict__ src, + block_turbo4_0 * __restrict__ dst) { + for (int j = 0; j < QK_TURBO4 / 2; j++) dst->qs[j] = 0; + + for (int j = 0; j < QK_TURBO4; j++) { + uint8_t idx = turbo_nearest_centroid_4bit(src[j]); + dst->qs[j / 2] |= (idx & 0xF) << ((j % 2) * 4); + } +} + +// ---- Inline dequant helper: extract one float from turbo4 block ---- + +static __device__ __forceinline__ float turbo4_dequant_element( + const block_turbo4_0 * __restrict__ x, int j, float norm) { + uint8_t idx = (x->qs[j / 2] >> ((j % 2) * 4)) & 0xF; + return TURBO_CENTROIDS_4BIT[idx] * norm; +} + // ---- Nearest 3-bit centroid index ---- static __device__ __forceinline__ uint8_t turbo_nearest_centroid_3bit(float val) { diff --git a/ggml/src/ggml-turbo-quant.c b/ggml/src/ggml-turbo-quant.c index 85933fb7815..e42258f8896 100644 --- a/ggml/src/ggml-turbo-quant.c +++ b/ggml/src/ggml-turbo-quant.c @@ -497,7 +497,9 @@ void quantize_row_turbo4_0_ref(const float * GGML_RESTRICT x, block_turbo4_0 * G #endif /* Pack */ +#if !TURBO4_USE_4BIT y[block].norm = GGML_FP32_TO_FP16(norm); +#endif #if TURBO4_USE_4BIT /* 4-bit PolarQuant: nibble pack into qs[64] */ @@ -506,7 +508,6 @@ void quantize_row_turbo4_0_ref(const float * GGML_RESTRICT x, block_turbo4_0 * G y[block].qs[i / 2] |= (uint8_t)((indices[i] & 0xF) << ((i % 2) * 4)); } y[block].rnorm = GGML_FP32_TO_FP16(0.0f); - (void)projected; #else /* Legacy 3-bit + QJL: pack 3-bit indices + QJL signs */ memset(y[block].qs, 0, d * 3 / 8); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 7dbfdf74fe4..28073f1843a 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2944,8 +2944,16 @@ llama_context * llama_init_from_model( if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { const uint32_t blck_size = ggml_blck_size(params.type_k); + const bool k_is_turbo = (params.type_k == GGML_TYPE_TURBO2_0 || + params.type_k == GGML_TYPE_TURBO3_0 || + params.type_k == GGML_TYPE_TURBO4_0); for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { - if (model->hparams.n_embd_head_k(il) % blck_size != 0) { + uint32_t head_k = model->hparams.n_embd_head_k(il); + // Turbo types zero-pad heads to next multiple of 128 in llama-kv-cache.cpp + if (k_is_turbo && head_k % 128 != 0) { + head_k = ((head_k + 127) / 128) * 128; + } + if (head_k % blck_size != 0) { LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il)); return nullptr; @@ -2955,8 +2963,17 @@ llama_context * llama_init_from_model( if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { const uint32_t blck_size = ggml_blck_size(params.type_v); + const bool v_is_turbo = (params.type_v == GGML_TYPE_TURBO2_0 || + params.type_v == GGML_TYPE_TURBO3_0 || + params.type_v == GGML_TYPE_TURBO4_0); + const bool is_mla = model->hparams.is_mla(); for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { - if (model->hparams.n_embd_head_v(il) % blck_size != 0) { + uint32_t head_v = model->hparams.n_embd_head_v(il); + // Turbo types zero-pad; MLA has no separate V cache (V = view of K) + if (v_is_turbo && !is_mla && head_v % 128 != 0) { + head_v = ((head_v + 127) / 128) * 128; + } + if (head_v % blck_size != 0) { LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n", __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il)); return nullptr; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1ac40635a92..4b1aad73b18 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2091,20 +2091,40 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v = mctx_cur->get_v(ctx0, il); // TurboQuant pre-rotate-queries: O(d log d) WHT rotation via custom op - // Q shape: (n_embd_head, n_head, n_tokens) — ne[0] must be divisible by 64 + // Q shape: (n_embd_head, n_head, n_tokens) + // For zero-padded models (head_dim not 128-aligned), pad Q to match padded K dim first. if (k->type == GGML_TYPE_TURBO3_0 || k->type == GGML_TYPE_TURBO4_0 || k->type == GGML_TYPE_TURBO2_0) { - if (q->ne[0] % 64 == 0) { - if (!ggml_is_contiguous(q)) { q = ggml_cont(ctx0, q); } - ggml_tensor * innerq_scale = mctx_cur->get_turbo_innerq_scale_inv(); - q = ggml_turbo_wht(ctx0, q, 0, 0, innerq_scale); // 0 = forward, 0 = auto group size from q->ne[0] + // Pad Q per-head to next multiple of 128 if needed + if (q->ne[0] % 128 != 0) { + const int64_t pad = ((q->ne[0] + 127) / 128) * 128 - q->ne[0]; + q = ggml_pad(ctx0, q, pad, 0, 0, 0); } + if (!ggml_is_contiguous(q)) { q = ggml_cont(ctx0, q); } + ggml_tensor * innerq_scale = mctx_cur->get_turbo_innerq_scale_inv(); + q = ggml_turbo_wht(ctx0, q, 0, 0, innerq_scale); // 0 = forward, 0 = auto group size from q->ne[0] } ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); - // Note: TurboQuant inverse WHT is now applied inside build_attn_mha - // (after FA output, before v_mla) to handle both MLA and non-MLA models. + // TurboQuant: if V was padded, the output has padded dimensions. + // Extract original V head_dim after inverse WHT (applied inside build_attn_mha). + if (k->type == GGML_TYPE_TURBO3_0 || k->type == GGML_TYPE_TURBO4_0 || k->type == GGML_TYPE_TURBO2_0) { + const int64_t orig_v_head = hparams.n_embd_head_v(il); + // cur is 2D: (n_embd_head * n_head, n_tokens) after build_attn_mha + const int64_t padded_v_head = v->ne[0]; + if (padded_v_head != orig_v_head) { + // Reshape to 4D, extract original head_dim, reshape back to 2D + const int64_t n_head_v = hparams.n_head_kv(il); + const int64_t n_tokens_cur = cur->ne[1]; + cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur); + // ggml_view_3d to extract first orig_v_head elements per head + cur = ggml_view_3d(ctx0, cur, orig_v_head, n_head_v, n_tokens_cur, + cur->nb[1], cur->nb[2], 0); + cur = ggml_cont(ctx0, cur); + cur = ggml_reshape_2d(ctx0, cur, orig_v_head * n_head_v, n_tokens_cur); + } + } if (wo) { cur = build_lora_mm(wo, cur); @@ -2187,17 +2207,38 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0); // TurboQuant: pre-rotate Q for K-only (MLA) attention + // For zero-padded models, pad Q to match padded K dim first. if (k->type == GGML_TYPE_TURBO3_0 || k->type == GGML_TYPE_TURBO4_0 || k->type == GGML_TYPE_TURBO2_0) { - if (q->ne[0] % 64 == 0) { - if (!ggml_is_contiguous(q)) { q = ggml_cont(ctx0, q); } - ggml_tensor * innerq_scale = mctx_cur->get_turbo_innerq_scale_inv(); - q = ggml_turbo_wht(ctx0, q, 0, 0, innerq_scale); // 0 = forward, 0 = auto group size + // Pad Q per-head to next multiple of 128 if needed + if (q->ne[0] % 128 != 0) { + const int64_t pad = ((q->ne[0] + 127) / 128) * 128 - q->ne[0]; + q = ggml_pad(ctx0, q, pad, 0, 0, 0); } + if (!ggml_is_contiguous(q)) { q = ggml_cont(ctx0, q); } + ggml_tensor * innerq_scale = mctx_cur->get_turbo_innerq_scale_inv(); + q = ggml_turbo_wht(ctx0, q, 0, 0, innerq_scale); // 0 = forward, 0 = auto group size } ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); + // TurboQuant: if V was padded (MLA: V is view of K, may have padded dim), + // extract original V head_dim after inverse WHT. + if (k->type == GGML_TYPE_TURBO3_0 || k->type == GGML_TYPE_TURBO4_0 || k->type == GGML_TYPE_TURBO2_0) { + const int64_t orig_v_head = v_cur->ne[0]; // original V head_dim from model + const int64_t padded_v_head = v->ne[0]; // padded V head_dim in cache + if (padded_v_head != orig_v_head) { + // cur is 2D: (padded_v_head * n_head, n_tokens) after build_attn_mha + const int64_t n_head_v = hparams.n_head_kv(il); + const int64_t n_tokens_cur = cur->ne[1]; + cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur); + cur = ggml_view_3d(ctx0, cur, orig_v_head, n_head_v, n_tokens_cur, + cur->nb[1], cur->nb[2], 0); + cur = ggml_cont(ctx0, cur); + cur = ggml_reshape_2d(ctx0, cur, orig_v_head * n_head_v, n_tokens_cur); + } + } + if (wo) { cur = build_lora_mm(wo, cur); if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { @@ -2262,20 +2303,34 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = mctx_cur->get_k(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il); - // TurboQuant Q rotation: rotate Q forward to match rotated K in cache + // TurboQuant: pre-rotate Q for ISWA attention (pad to 128-aligned if needed) if (k->type == GGML_TYPE_TURBO3_0 || k->type == GGML_TYPE_TURBO4_0 || k->type == GGML_TYPE_TURBO2_0) { - if (q->ne[0] % 64 == 0) { - if (!ggml_is_contiguous(q)) { q = ggml_cont(ctx0, q); } - ggml_tensor * innerq_scale = mctx_cur->get_turbo_innerq_scale_inv(); - q = ggml_turbo_wht(ctx0, q, 0, 0, innerq_scale); // 0 = forward, 0 = auto group size + if (q->ne[0] % 128 != 0) { + const int64_t pad = ((q->ne[0] + 127) / 128) * 128 - q->ne[0]; + q = ggml_pad(ctx0, q, pad, 0, 0, 0); } + if (!ggml_is_contiguous(q)) { q = ggml_cont(ctx0, q); } + ggml_tensor * innerq_scale = mctx_cur->get_turbo_innerq_scale_inv(); + q = ggml_turbo_wht(ctx0, q, 0, 0, innerq_scale); } ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); - // Note: TurboQuant inverse WHT is now applied inside build_attn_mha - // (after FA output, before v_mla) to handle both MLA and non-MLA models. + // TurboQuant: if V was padded, extract original V head_dim after inverse WHT + if (k->type == GGML_TYPE_TURBO3_0 || k->type == GGML_TYPE_TURBO4_0 || k->type == GGML_TYPE_TURBO2_0) { + const int64_t orig_v_head = hparams.n_embd_head_v(il); + const int64_t padded_v_head = v->ne[0]; + if (padded_v_head != orig_v_head) { + const int64_t n_head_v = hparams.n_head_kv(il); + const int64_t n_tokens_cur = cur->ne[1]; + cur = ggml_reshape_3d(ctx0, cur, padded_v_head, n_head_v, n_tokens_cur); + cur = ggml_view_3d(ctx0, cur, orig_v_head, n_head_v, n_tokens_cur, + cur->nb[1], cur->nb[2], 0); + cur = ggml_cont(ctx0, cur); + cur = ggml_reshape_2d(ctx0, cur, orig_v_head * n_head_v, n_tokens_cur); + } + } if (wo) { cur = build_lora_mm(wo, cur); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 7ce067a55c0..6b7302e8fb8 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -152,21 +152,11 @@ llama_kv_cache::llama_kv_cache( throw std::runtime_error("failed to create ggml context for kv cache"); } - // TurboQuant requires head_dim divisible by 128 for quality. - // 64-group WHT (head_dim=192/576) passes NIAH but has catastrophic PPL on some - // models (DeepSeek 192-dim: 344K vs 9.9 baseline). Fall back to q8_0. - const bool is_turbo_type = (type_k == GGML_TYPE_TURBO3_0 || type_k == GGML_TYPE_TURBO4_0 || type_k == GGML_TYPE_TURBO2_0 || - type_v == GGML_TYPE_TURBO3_0 || type_v == GGML_TYPE_TURBO4_0 || type_v == GGML_TYPE_TURBO2_0); + // TurboQuant zero-padding: for models with non-128-aligned head_dim (e.g. DeepSeek + // head_dim_k=192), pad each head to the next multiple of 128. The padded zeros don't + // affect dot products since WHT preserves inner products: + // = = + <0, 0> = const uint32_t n_embd_head_k = hparams.n_embd_head_k(il); - if (is_turbo_type && n_embd_head_k % 128 != 0) { - if (il == 0) { - LLAMA_LOG_WARN("%s: turbo KV cache requires head_dim divisible by 128, " - "but this model has n_embd_head_k=%u — falling back to q8_0\n", - __func__, n_embd_head_k); - } - type_k = GGML_TYPE_Q8_0; - type_v = GGML_TYPE_Q8_0; - } const bool has_k = true; @@ -201,8 +191,35 @@ llama_kv_cache::llama_kv_cache( } } } - ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, layer_type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr; - ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, layer_type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr; + // For turbo types, pad K head_dim to next multiple of 128 for full WHT groups + uint32_t n_embd_k_gqa_eff = n_embd_k_gqa; + const bool k_is_turbo = (layer_type_k == GGML_TYPE_TURBO3_0 || layer_type_k == GGML_TYPE_TURBO4_0 || layer_type_k == GGML_TYPE_TURBO2_0); + if (k_is_turbo && n_embd_head_k % 128 != 0) { + const uint32_t padded_head_k = ((n_embd_head_k + 127) / 128) * 128; + const uint32_t n_head_kv = n_embd_k_gqa / n_embd_head_k; + n_embd_k_gqa_eff = n_head_kv * padded_head_k; + if (il == 0) { + LLAMA_LOG_INFO("%s: turbo zero-padding K head_dim %u -> %u (cache %u -> %u)\n", + __func__, n_embd_head_k, padded_head_k, n_embd_k_gqa, n_embd_k_gqa_eff); + } + } + + // For turbo types, pad V head_dim to next multiple of 128 if needed + const uint32_t n_embd_head_v = hparams.n_embd_head_v(il); + uint32_t n_embd_v_gqa_eff = n_embd_v_gqa; + const bool v_is_turbo = (layer_type_v == GGML_TYPE_TURBO3_0 || layer_type_v == GGML_TYPE_TURBO4_0 || layer_type_v == GGML_TYPE_TURBO2_0); + if (v_is_turbo && !is_mla && n_embd_head_v % 128 != 0) { + const uint32_t padded_head_v = ((n_embd_head_v + 127) / 128) * 128; + const uint32_t n_head_kv = n_embd_v_gqa / n_embd_head_v; + n_embd_v_gqa_eff = n_head_kv * padded_head_v; + if (il == 0) { + LLAMA_LOG_INFO("%s: turbo zero-padding V head_dim %u -> %u (cache %u -> %u)\n", + __func__, n_embd_head_v, padded_head_v, n_embd_v_gqa, n_embd_v_gqa_eff); + } + } + + ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, layer_type_k, n_embd_k_gqa_eff, kv_size, n_stream) : nullptr; + ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, layer_type_v, n_embd_v_gqa_eff, kv_size, n_stream) : nullptr; has_k && ggml_format_name(k, "cache_k_l%d", il); has_v && ggml_format_name(v, "cache_v_l%d", il); @@ -211,8 +228,8 @@ llama_kv_cache::llama_kv_cache( std::vector v_stream; for (uint32_t s = 0; s < n_stream; ++s) { - k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr); - v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr); + k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa_eff, kv_size, k->nb[1], s*k->nb[2]) : nullptr); + v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa_eff, kv_size, v->nb[1], s*v->nb[2]) : nullptr); } map_layer_ids[il] = layers.size(); @@ -1143,13 +1160,24 @@ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_k const uint64_t kv_size = get_size(); const uint64_t n_embd_k_gqa = k->ne[0]; - assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il)); + // For turbo-padded caches, n_embd_k_gqa may be larger than hparams value + const bool k_is_turbo = (k->type == GGML_TYPE_TURBO3_0 || k->type == GGML_TYPE_TURBO4_0 || k->type == GGML_TYPE_TURBO2_0); + if (k_is_turbo) { + assert(n_embd_k_gqa >= hparams.n_embd_k_gqa(il)); + } else { + assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il)); + } + + // Use padded head_dim for turbo types so the full padded data is returned + const uint32_t head_k = hparams.n_embd_head_k(il); + const uint32_t head_k_eff = (k_is_turbo && head_k % 128 != 0) + ? ((head_k + 127) / 128) * 128 : head_k; const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; return ggml_view_4d(ctx, k, - hparams.n_embd_head_k(il), hparams.n_head_kv(il), n_kv, ns, - ggml_row_size(k->type, hparams.n_embd_head_k(il)), + head_k_eff, hparams.n_head_kv(il), n_kv, ns, + ggml_row_size(k->type, head_k_eff), ggml_row_size(k->type, n_embd_k_gqa), ggml_row_size(k->type, n_embd_k_gqa*kv_size), ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0); @@ -1163,27 +1191,33 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k const uint64_t kv_size = get_size(); const uint64_t n_embd_v_gqa = v->ne[0]; - // [TAG_V_CACHE_VARIABLE] + // [TAG_V_CACHE_VARIABLE] — for turbo-padded V, cache may be larger assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il)); + // Use padded head_dim for turbo types + const bool v_is_turbo = (v->type == GGML_TYPE_TURBO3_0 || v->type == GGML_TYPE_TURBO4_0 || v->type == GGML_TYPE_TURBO2_0); + const uint32_t head_v = hparams.n_embd_head_v(il); + const uint32_t head_v_eff = (v_is_turbo && head_v % 128 != 0) + ? ((head_v + 127) / 128) * 128 : head_v; + const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; if (!v_trans) { // note: v->nb[1] <= v->nb[2] return ggml_view_4d(ctx, v, - hparams.n_embd_head_v(il), hparams.n_head_kv(il), n_kv, ns, - ggml_row_size(v->type, hparams.n_embd_head_v(il)), // v->nb[1] - ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] - ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] + head_v_eff, hparams.n_head_kv(il), n_kv, ns, + ggml_row_size(v->type, head_v_eff), // v->nb[1] + ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] + ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0); } // note: v->nb[1] > v->nb[2] return ggml_view_4d(ctx, v, - n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v(il), ns, - ggml_row_size(v->type, kv_size*hparams.n_embd_head_v(il)), // v->nb[1] - ggml_row_size(v->type, kv_size), // v->nb[2] - ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] + n_kv, hparams.n_head_kv(il), head_v_eff, ns, + ggml_row_size(v->type, kv_size*head_v_eff), // v->nb[1] + ggml_row_size(v->type, kv_size), // v->nb[2] + ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0); } @@ -1194,11 +1228,22 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm ggml_tensor * k = layers[ikv].k; - const int64_t n_embd_head = k_cur->ne[0]; + int64_t n_embd_head = k_cur->ne[0]; const int64_t n_head = k_cur->ne[1]; const int64_t n_tokens = k_cur->ne[2]; - const int64_t n_embd_gqa = n_embd_head*n_head; + // Turbo zero-padding: pad each head to next multiple of 128 before merging dims. + // k_cur shape here is (n_embd_head, n_head, n_tokens). + // ggml_pad pads ne[0] with zeros — exactly what we need per-head. + const bool k_is_turbo = (k->type == GGML_TYPE_TURBO3_0 || k->type == GGML_TYPE_TURBO4_0 || k->type == GGML_TYPE_TURBO2_0); + const bool k_needs_pad = k_is_turbo && (n_embd_head % 128 != 0); + if (k_needs_pad) { + const int64_t pad_amount = ((n_embd_head + 127) / 128) * 128 - n_embd_head; + k_cur = ggml_pad(ctx, k_cur, pad_amount, 0, 0, 0); + n_embd_head = k_cur->ne[0]; // now 128-aligned + } + + int64_t n_embd_gqa = n_embd_head * n_head; // we can merge dims 0 and 1 // TODO: add ggml helper function for this? @@ -1221,10 +1266,10 @@ ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggm // store the current K values into the cache ggml_tensor * result = ggml_set_rows(ctx, k, k_cur, k_idxs); - // For turbo3: store WHT group size in op_params so the CUDA kernel knows - // whether to use 128-element or 64-element WHT groups. - if (k->type == GGML_TYPE_TURBO3_0 || k->type == GGML_TYPE_TURBO4_0 || k->type == GGML_TYPE_TURBO2_0) { - int32_t wht_group = (n_embd_head % 128 == 0) ? 128 : 64; + // For turbo: store WHT group size in op_params so the CUDA kernel knows. + // With zero-padding, all groups are always full 128-element WHT groups. + if (k_is_turbo) { + int32_t wht_group = 128; // always 128 with padding memcpy(result->op_params, &wht_group, sizeof(int32_t)); } @@ -1238,11 +1283,20 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm auto * v = layers[ikv].v; - const int64_t n_embd_head = v_cur->ne[0]; + int64_t n_embd_head = v_cur->ne[0]; const int64_t n_head = v_cur->ne[1]; const int64_t n_tokens = v_cur->ne[2]; - const int64_t n_embd_gqa = n_embd_head*n_head; + // Turbo zero-padding: pad V head_dim to next multiple of 128 + const bool v_is_turbo = (v->type == GGML_TYPE_TURBO3_0 || v->type == GGML_TYPE_TURBO4_0 || v->type == GGML_TYPE_TURBO2_0); + const bool v_needs_pad = v_is_turbo && (n_embd_head % 128 != 0); + if (v_needs_pad) { + const int64_t pad_amount = ((n_embd_head + 127) / 128) * 128 - n_embd_head; + v_cur = ggml_pad(ctx, v_cur, pad_amount, 0, 0, 0); + n_embd_head = v_cur->ne[0]; // now 128-aligned + } + + int64_t n_embd_gqa = n_embd_head * n_head; // we can merge dims 0 and 1 GGML_ASSERT(ggml_row_size(v_cur->type, n_embd_head) == v_cur->nb[1]); @@ -1264,8 +1318,9 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm } ggml_tensor * result = ggml_set_rows(ctx, v, v_cur, v_idxs); - if (v->type == GGML_TYPE_TURBO3_0 || v->type == GGML_TYPE_TURBO4_0 || v->type == GGML_TYPE_TURBO2_0) { - int32_t wht_group = (n_embd_head % 128 == 0) ? 128 : 64; + // With zero-padding, all groups are always full 128-element WHT groups + if (v_is_turbo) { + int32_t wht_group = 128; // always 128 with padding memcpy(result->op_params, &wht_group, sizeof(int32_t)); } return result;