Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
b2a5a88
feat: CUDA port of TurboQuant3 KV cache compression (RTX 5090 / SM 12.0)
signalnine Mar 26, 2026
eb9a589
perf: enable MMA/TILE flash attention for turbo3 — 0.97x q8_0 prefill
signalnine Mar 26, 2026
8b36e47
perf: parallel k_set_rows_turbo3 + optimise KQ/V dequant — +31% decod…
signalnine Mar 27, 2026
4c91451
fix: VEC flash-attn Q/K stride mismatch in vec_dot_fattn_vec_KQ_turbo3_0
signalnine Mar 27, 2026
972c76e
fix: graceful fallback for turbo3 with non-128-aligned head dims (iss…
signalnine Mar 28, 2026
9cdb872
fix: graceful fallback for turbo3 on non-128-aligned head dims (issue…
signalnine Mar 28, 2026
75e2769
feat: 64-element WHT groups + MLA Q rotation fix (issue #13)
signalnine Mar 28, 2026
d0d37b3
feat: mixed turbo3/q8_0 KV cache types (-ctk turbo3 -ctv q8_0 and vic…
signalnine Mar 28, 2026
53f1298
fix: implement CPU turbo3 quantize (was a stub that zeroed qs/signs)
signalnine Mar 28, 2026
da6b0fd
feat: GGML_TYPE_TURBO2_0 — 2-bit TurboQuant KV cache (6.4x compression)
signalnine Mar 28, 2026
00ecbbe
fix: MLA inverse WHT group_size derived from K (not V) — fixes GLM-4.7
signalnine Mar 28, 2026
6fb85a6
feat: InnerQ per-channel equalization + turbo2 64-group fallback
signalnine Mar 28, 2026
a5efe54
perf: sparse V dequant — skip negligible attention weights in VEC kernel
signalnine Mar 28, 2026
4c4511c
fix: require head_dim % 128 for turbo KV — fall back to q8_0 otherwise
signalnine Mar 29, 2026
b74119a
feat: zero-pad non-128 heads for full 7-stage WHT (replaces q8_0 fall…
signalnine Mar 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ const std::vector<ggml_type> kv_cache_types = {
GGML_TYPE_IQ4_NL,
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
GGML_TYPE_TURBO2_0,
GGML_TYPE_TURBO3_0,
GGML_TYPE_TURBO4_0,
};
Expand Down
7 changes: 5 additions & 2 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ extern "C" {
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
GGML_TYPE_TURBO3_0 = 41, // TurboQuant 3-bit KV cache: 2-bit PolarQuant + 1-bit QJL
GGML_TYPE_TURBO4_0 = 42, // TurboQuant 4-bit KV cache: 3-bit PolarQuant + 1-bit QJL
GGML_TYPE_COUNT = 43,
GGML_TYPE_TURBO2_0 = 43, // TurboQuant 2-bit KV cache: 2-bit PolarQuant (no QJL)
GGML_TYPE_COUNT = 44,
};

// precision
Expand Down Expand Up @@ -2490,7 +2491,9 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_turbo_wht(
struct ggml_context * ctx,
struct ggml_tensor * a,
int direction);
int direction,
int group_size, // 0 = auto (64 or 128 from ne[0])
struct ggml_tensor * scale); // NULL = no InnerQ scaling

// custom operators

Expand Down
12 changes: 12 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,18 @@ typedef struct {
} block_turbo4_0; // 68 bytes total
static_assert(sizeof(block_turbo4_0) == 2*sizeof(ggml_half) + QK_TURBO4*3/8 + QK_TURBO4/8, "wrong turbo4_0 block size/padding");

// TurboQuant 2-bit: 2-bit PolarQuant indices only (no QJL)
// Per block: norm(fp16) + 2-bit indices (8 bytes) = 10 bytes per 32 values
// = 2.5 bits/value → 6.4× compression vs fp16
// 4 centroids (Lloyd-Max for N(0, 1/128)): {-0.133462, -0.039994, 0.039994, 0.133462}
#define QK_TURBO2 32 // Block size 32
#define QK_TURBO2_GROUP 128 // rotation group size = head_dim
typedef struct {
ggml_half norm; // 2 bytes: corrected L2 norm
uint8_t qs[QK_TURBO2 / 4]; // 8 bytes: 2-bit indices (4 per byte)
} block_turbo2_0; // 10 bytes total
static_assert(sizeof(block_turbo2_0) == sizeof(ggml_half) + QK_TURBO2/4, "wrong turbo2_0 block size/padding");

//
// Super-block quantization structures
//
Expand Down
61 changes: 61 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "ggml-cpu-impl.h"
#include "ggml-impl.h"
#include "quants.h"
#include "ggml-quants.h"
#include "ggml-threading.h"
#include "unary-ops.h"
#include "binary-ops.h"
Expand Down Expand Up @@ -204,6 +205,14 @@ typedef pthread_t ggml_thread_t;
#include <TargetConditionals.h>
#endif

// Forward declarations — defined below, after utility functions
static void ggml_vec_dot_turbo3_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc);
static void ggml_vec_dot_turbo2_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc);

static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = {
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp32,
Expand Down Expand Up @@ -393,6 +402,18 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
[GGML_TYPE_I32] = {
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32,
},
[GGML_TYPE_TURBO3_0] = {
.from_float = (ggml_from_float_t) quantize_row_turbo3_0_ref,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_turbo3_0_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
[GGML_TYPE_TURBO2_0] = {
.from_float = (ggml_from_float_t) quantize_row_turbo2_0_ref,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_turbo2_0_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
};

const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
Expand Down Expand Up @@ -3318,6 +3339,46 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g
return ggml_graph_compute(cgraph, &cplan);
}

// TurboQuant3 vec_dot: dequantize turbo3 block to f32, then dot with f32 operand.
// Used by CPU flash attention for models with D not supported by CUDA FA (e.g. D=192).
static void ggml_vec_dot_turbo3_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc) {
GGML_ASSERT(nrc == 1);
GGML_UNUSED(bs); GGML_UNUSED(bx); GGML_UNUSED(by); GGML_UNUSED(nrc);

// Dequantize turbo3 to f32 temp buffer, then dot
float tmp[4096]; // max head_dim
GGML_ASSERT(n <= 4096);
ggml_get_type_traits(GGML_TYPE_TURBO3_0)->to_float(vx, tmp, n);

const float * y = (const float *)vy;
float sum = 0.0f;
for (int i = 0; i < n; i++) {
sum += tmp[i] * y[i];
}
*s = sum;
}

// TurboQuant2 vec_dot: dequantize turbo2 block to f32, then dot with f32 operand.
static void ggml_vec_dot_turbo2_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc) {
GGML_ASSERT(nrc == 1);
GGML_UNUSED(bs); GGML_UNUSED(bx); GGML_UNUSED(by); GGML_UNUSED(nrc);

float tmp[4096];
GGML_ASSERT(n <= 4096);
ggml_get_type_traits(GGML_TYPE_TURBO2_0)->to_float(vx, tmp, n);

const float * y = (const float *)vy;
float sum = 0.0f;
for (int i = 0; i < n; i++) {
sum += tmp[i] * y[i];
}
*s = sum;
}

void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) {
memcpy(y, x, n * sizeof(float));
}
Expand Down
72 changes: 57 additions & 15 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4926,6 +4926,14 @@ static void ggml_compute_forward_set_rows_f32(

ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;

// For turbo types: communicate WHT group size to the quantize function via global
if (dst->type == GGML_TYPE_TURBO3_0 || dst->type == GGML_TYPE_TURBO4_0 || dst->type == GGML_TYPE_TURBO2_0) {
extern int turbo3_cpu_wht_group_size;
int gs = 0;
memcpy(&gs, dst->op_params, sizeof(int));
turbo3_cpu_wht_group_size = (gs == 64 || gs == 128) ? gs : 0;
}

for (int64_t i03 = 0; i03 < ne03; ++i03) {
for (int64_t i02 = 0; i02 < ne02; ++i02) {
for (int64_t i = ir0; i < ir1; ++i) {
Expand Down Expand Up @@ -10626,34 +10634,55 @@ static void ggml_compute_forward_turbo_wht_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src = dst->src[0];
const ggml_tensor * scale_tensor = dst->src[1]; // InnerQ scale_inv (may be NULL)
const float * src_data = (const float *) src->data;
float * dst_data = (float *) dst->data;
const float * scale_inv = scale_tensor ? (const float *) scale_tensor->data : NULL;

int direction;
memcpy(&direction, dst->op_params, sizeof(int));
int group_size;
memcpy(&direction, dst->op_params + 0, sizeof(int));
memcpy(&group_size, dst->op_params + sizeof(int), sizeof(int));

const float * s_first = (direction == 0) ? turbo_wht_s1 : turbo_wht_s2;
const float * s_second = (direction == 0) ? turbo_wht_s2 : turbo_wht_s1;
const int64_t head_dim = src->ne[0];
const int64_t n_heads = ggml_nelements(src) / head_dim;
const int64_t groups_per_head = head_dim / group_size;
const int tail_size = (int)(head_dim % group_size);
const int64_t n_groups = groups_per_head * n_heads;

const int64_t n_total = ggml_nelements(src);
const int64_t n_groups = n_total / 128;
const float inv_sqrt = 1.0f / sqrtf((float)group_size);

// Parallel over groups
const int64_t ith = params->ith;
const int64_t nth = params->nth;
const int64_t grp_start = (n_groups * ith) / nth;
const int64_t grp_end = (n_groups * (ith + 1)) / nth;

// Select sign arrays: for 64-group, use first 64 elements of the 128-element arrays
const float * s_first = (direction == 0) ? turbo_wht_s1 : turbo_wht_s2;
const float * s_second = (direction == 0) ? turbo_wht_s2 : turbo_wht_s1;

for (int64_t g = grp_start; g < grp_end; g++) {
float x[128];
const float * in = src_data + g * 128;
const int64_t head_idx = g / groups_per_head;
const int64_t grp_in_head = g % groups_per_head;
const int64_t base = head_idx * head_dim + grp_in_head * group_size;

float x[128]; // max group_size
const float * in = src_data + base;

// InnerQ forward: apply scale_inv BEFORE signs+WHT (for Q pre-rotation)
if (direction == 0 && scale_inv != NULL) {
for (int i = 0; i < group_size; i++) x[i] = in[i] * scale_inv[i % group_size];
} else {
for (int i = 0; i < group_size; i++) x[i] = in[i];
}

// Apply first signs
for (int i = 0; i < 128; i++) x[i] = in[i] * s_first[i];
for (int i = 0; i < group_size; i++) x[i] *= s_first[i];

// WHT butterfly (7 stages)
for (int h = 1; h < 128; h *= 2) {
for (int i = 0; i < 128; i += h * 2) {
// WHT butterfly (log2(group_size) stages)
for (int h = 1; h < group_size; h *= 2) {
for (int i = 0; i < group_size; i += h * 2) {
for (int j = i; j < i + h; j++) {
float a = x[j], b = x[j + h];
x[j] = a + b;
Expand All @@ -10663,10 +10692,23 @@ static void ggml_compute_forward_turbo_wht_f32(
}

// Normalize + second signs
const float inv_sqrt_128 = 0.08838834764831845f;
float * out = dst_data + g * 128;
for (int i = 0; i < 128; i++) {
out[i] = x[i] * inv_sqrt_128 * s_second[i];
float * out = dst_data + base;
for (int i = 0; i < group_size; i++) {
float val = x[i] * inv_sqrt * s_second[i];
// InnerQ inverse: apply scale_inv AFTER WHT+signs (for V un-rotation)
if (direction == 1 && scale_inv != NULL) {
val *= scale_inv[i % group_size];
}
out[i] = val;
}
}

// Copy tail elements unchanged (identity pass-through)
if (tail_size > 0 && ith == 0) {
const int64_t tail_offset = groups_per_head * group_size;
for (int64_t h = 0; h < n_heads; h++) {
const int64_t base = h * head_dim + tail_offset;
memcpy(dst_data + base, src_data + base, tail_size * sizeof(float));
}
}
}
Expand Down
8 changes: 7 additions & 1 deletion ggml/src/ggml-cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,13 @@ if (CUDAToolkit_FOUND)
template-instances/fattn-vec-instance-f16-f16.cu
template-instances/fattn-vec-instance-q4_0-q4_0.cu
template-instances/fattn-vec-instance-q8_0-q8_0.cu
template-instances/fattn-vec-instance-bf16-bf16.cu)
template-instances/fattn-vec-instance-bf16-bf16.cu
template-instances/fattn-vec-instance-turbo3_0-turbo3_0.cu
template-instances/fattn-vec-instance-turbo3_0-q8_0.cu
template-instances/fattn-vec-instance-q8_0-turbo3_0.cu
template-instances/fattn-vec-instance-turbo2_0-turbo2_0.cu
template-instances/fattn-vec-instance-turbo2_0-q8_0.cu
template-instances/fattn-vec-instance-q8_0-turbo2_0.cu)
endif()

ggml_add_backend_library(ggml-cuda
Expand Down
17 changes: 17 additions & 0 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "convert.cuh"
#include "dequantize.cuh"
#include "turbo-quant.cuh"

#include <cstdint>

Expand Down Expand Up @@ -756,6 +757,10 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_NVFP4:
return dequantize_row_nvfp4_cuda;
case GGML_TYPE_TURBO3_0:
return dequantize_block_cont_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
case GGML_TYPE_TURBO2_0:
return dequantize_block_cont_cuda<QK_TURBO2, QR_TURBO2, dequantize_turbo2_0>;
case GGML_TYPE_F32:
return convert_unary_cont_cuda<float>;
case GGML_TYPE_BF16:
Expand Down Expand Up @@ -809,6 +814,10 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_NVFP4:
return dequantize_row_nvfp4_cuda;
case GGML_TYPE_TURBO3_0:
return dequantize_block_cont_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
case GGML_TYPE_TURBO2_0:
return dequantize_block_cont_cuda<QK_TURBO2, QR_TURBO2, dequantize_turbo2_0>;
case GGML_TYPE_F16:
return convert_unary_cont_cuda<half>;
case GGML_TYPE_BF16:
Expand All @@ -832,6 +841,10 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_TURBO3_0:
return dequantize_block_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
case GGML_TYPE_TURBO2_0:
return dequantize_block_cuda<QK_TURBO2, QR_TURBO2, dequantize_turbo2_0>;
case GGML_TYPE_BF16:
return convert_unary_cuda<nv_bfloat16>;
default:
Expand Down Expand Up @@ -874,6 +887,10 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_TURBO3_0:
return dequantize_block_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
case GGML_TYPE_TURBO2_0:
return dequantize_block_cuda<QK_TURBO2, QR_TURBO2, dequantize_turbo2_0>;
case GGML_TYPE_BF16:
return convert_unary_cuda<nv_bfloat16, float>;
default:
Expand Down
18 changes: 18 additions & 0 deletions ggml/src/ggml-cuda/dequantize.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "common.cuh"
#include "turbo-quant.cuh"

static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx;
Expand Down Expand Up @@ -75,3 +76,20 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
v.x *= d;
v.y *= d;
}

// Turbo3: 3-bit PolarQuant (2-bit qs + 1-bit sign), block size 32
// iqs is the element index within the block (even), produces elements iqs and iqs+1
static __device__ __forceinline__ void dequantize_turbo3_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_turbo3_0 * x = (const block_turbo3_0 *) vx;
const float norm = __half2float(x[ib].norm);
v.x = turbo3_dequant_element(&x[ib], iqs + 0, norm);
v.y = turbo3_dequant_element(&x[ib], iqs + 1, norm);
}

// Turbo2: 2-bit PolarQuant (2-bit qs only, no sign), block size 32
static __device__ __forceinline__ void dequantize_turbo2_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_turbo2_0 * x = (const block_turbo2_0 *) vx;
const float norm = __half2float(x[ib].norm);
v.x = turbo2_dequant_element(&x[ib], iqs + 0, norm);
v.y = turbo2_dequant_element(&x[ib], iqs + 1, norm);
}
Loading