Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,9 @@ static_assert(sizeof(block_turbo3_0) == sizeof(ggml_half) + QK_TURBO3/4 + QK_TUR

// TurboQuant 4-bit: 3-bit PolarQuant indices + 1-bit QJL signs
// TURBO4_USE_4BIT: switch between 4-bit PolarQuant (new) and 3-bit+QJL (legacy)
// Default: 4-bit on Metal (validated), legacy on CUDA (until ported)
// Default: 4-bit on all backends (Metal + CUDA validated)
#ifndef TURBO4_USE_4BIT
# if defined(GGML_METAL)
# define TURBO4_USE_4BIT 1
# else
# define TURBO4_USE_4BIT 0
# endif
# define TURBO4_USE_4BIT 1
#endif

#define QK_TURBO4 128
Expand Down
28 changes: 28 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ static void ggml_vec_dot_turbo3_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
static void ggml_vec_dot_turbo2_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc);
static void ggml_vec_dot_turbo4_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc);

static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = {
Expand Down Expand Up @@ -414,6 +417,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
[GGML_TYPE_TURBO4_0] = {
.from_float = (ggml_from_float_t) quantize_row_turbo4_0_ref,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_turbo4_0_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
};

const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
Expand Down Expand Up @@ -3379,6 +3388,25 @@ static void ggml_vec_dot_turbo2_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
*s = sum;
}

// TurboQuant4 vec_dot: dequantize turbo4 block to f32, then dot with f32 operand.
static void ggml_vec_dot_turbo4_0_f32(int n, float * GGML_RESTRICT s, size_t bs,
const void * GGML_RESTRICT vx, size_t bx,
const void * GGML_RESTRICT vy, size_t by, int nrc) {
GGML_ASSERT(nrc == 1);
GGML_UNUSED(bs); GGML_UNUSED(bx); GGML_UNUSED(by); GGML_UNUSED(nrc);

float tmp[4096];
GGML_ASSERT(n <= 4096);
ggml_get_type_traits(GGML_TYPE_TURBO4_0)->to_float(vx, tmp, n);

const float * y = (const float *)vy;
float sum = 0.0f;
for (int i = 0; i < n; i++) {
sum += tmp[i] * y[i];
}
*s = sum;
}

void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) {
memcpy(y, x, n * sizeof(float));
}
Expand Down
11 changes: 10 additions & 1 deletion ggml/src/ggml-cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,16 @@ if (CUDAToolkit_FOUND)
template-instances/fattn-vec-instance-q8_0-turbo3_0.cu
template-instances/fattn-vec-instance-turbo2_0-turbo2_0.cu
template-instances/fattn-vec-instance-turbo2_0-q8_0.cu
template-instances/fattn-vec-instance-q8_0-turbo2_0.cu)
template-instances/fattn-vec-instance-q8_0-turbo2_0.cu
template-instances/fattn-vec-instance-turbo3_0-turbo2_0.cu
template-instances/fattn-vec-instance-turbo2_0-turbo3_0.cu
template-instances/fattn-vec-instance-turbo4_0-turbo4_0.cu
template-instances/fattn-vec-instance-turbo4_0-q8_0.cu
template-instances/fattn-vec-instance-q8_0-turbo4_0.cu
template-instances/fattn-vec-instance-turbo4_0-turbo3_0.cu
template-instances/fattn-vec-instance-turbo3_0-turbo4_0.cu
template-instances/fattn-vec-instance-turbo4_0-turbo2_0.cu
template-instances/fattn-vec-instance-turbo2_0-turbo4_0.cu)
endif()

ggml_add_backend_library(ggml-cuda
Expand Down
8 changes: 8 additions & 0 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_block_cont_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
case GGML_TYPE_TURBO2_0:
return dequantize_block_cont_cuda<QK_TURBO2, QR_TURBO2, dequantize_turbo2_0>;
case GGML_TYPE_TURBO4_0:
return dequantize_block_cont_cuda<QK_TURBO4, QR_TURBO4, dequantize_turbo4_0>;
case GGML_TYPE_F32:
return convert_unary_cont_cuda<float>;
case GGML_TYPE_BF16:
Expand Down Expand Up @@ -818,6 +820,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_block_cont_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
case GGML_TYPE_TURBO2_0:
return dequantize_block_cont_cuda<QK_TURBO2, QR_TURBO2, dequantize_turbo2_0>;
case GGML_TYPE_TURBO4_0:
return dequantize_block_cont_cuda<QK_TURBO4, QR_TURBO4, dequantize_turbo4_0>;
case GGML_TYPE_F16:
return convert_unary_cont_cuda<half>;
case GGML_TYPE_BF16:
Expand Down Expand Up @@ -845,6 +849,8 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
return dequantize_block_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
case GGML_TYPE_TURBO2_0:
return dequantize_block_cuda<QK_TURBO2, QR_TURBO2, dequantize_turbo2_0>;
case GGML_TYPE_TURBO4_0:
return dequantize_block_cuda<QK_TURBO4, QR_TURBO4, dequantize_turbo4_0>;
case GGML_TYPE_BF16:
return convert_unary_cuda<nv_bfloat16>;
default:
Expand Down Expand Up @@ -891,6 +897,8 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
return dequantize_block_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
case GGML_TYPE_TURBO2_0:
return dequantize_block_cuda<QK_TURBO2, QR_TURBO2, dequantize_turbo2_0>;
case GGML_TYPE_TURBO4_0:
return dequantize_block_cuda<QK_TURBO4, QR_TURBO4, dequantize_turbo4_0>;
case GGML_TYPE_BF16:
return convert_unary_cuda<nv_bfloat16, float>;
default:
Expand Down
9 changes: 9 additions & 0 deletions ggml/src/ggml-cuda/dequantize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
v.y *= d;
}

// Turbo4: 4-bit PolarQuant (nibble packed), block size 128
// iqs is the element index within the block (even), produces elements iqs and iqs+1
static __device__ __forceinline__ void dequantize_turbo4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
const block_turbo4_0 * x = (const block_turbo4_0 *) vx;
const float norm = __half2float(x[ib].norm);
v.x = turbo4_dequant_element(&x[ib], iqs + 0, norm);
v.y = turbo4_dequant_element(&x[ib], iqs + 1, norm);
}

// Turbo3: 3-bit PolarQuant (2-bit qs + 1-bit sign), block size 32
// iqs is the element index within the block (even), produces elements iqs and iqs+1
static __device__ __forceinline__ void dequantize_turbo3_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
Expand Down
113 changes: 113 additions & 0 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,55 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_turbo2_0(
return sum;
}

// Turbo4 KQ dot product: dequantize K from turbo4 blocks, dot with Q (float2/half2)
// 4-bit nibble packed: qs[j/2] >> ((j%2)*4) & 0xF
template <int D, int nthreads>
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_turbo4_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {

const block_turbo4_0 * K_turbo = (const block_turbo4_0 *) K_c;
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);

constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;

float sum = 0.0f;

#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
const int k_KQ = k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne + k_KQ_1;

const int elem0 = k_KQ * 2; // always even
const int ib = elem0 / QK_TURBO4; // block index
const int j0 = elem0 % QK_TURBO4; // always even

const float norm = __half2float(K_turbo[ib].norm);
// Both j0 and j0+1 are adjacent nibbles: j0/2 == (j0+1)/2 when j0 is even
const uint8_t qs_byte = K_turbo[ib].qs[j0 / 2];

const uint8_t idx0 = (qs_byte >> 0) & 0xF; // low nibble = j0
const uint8_t idx1 = (qs_byte >> 4) & 0xF; // high nibble = j0+1

float2 kv;
kv.x = TURBO_CENTROIDS_4BIT[idx0] * norm;
kv.y = TURBO_CENTROIDS_4BIT[idx1] * norm;

#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 qv = ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1];
ggml_cuda_mad(sum, make_float2(kv.x, kv.y), __half22float2(qv));
#else
const float2 qv = ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1];
sum += kv.x * qv.x + kv.y * qv.y;
#endif // V_DOT2_F32_F16_AVAILABLE
}
}

return sum;
}

template <typename Tds, int ni>
static __device__ __forceinline__ void quantize_q8_1_to_shared(
const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
Expand Down Expand Up @@ -807,6 +856,66 @@ static __device__ __forceinline__ void dequantize_V_turbo2_0(const void * __rest
}
}

// Turbo4 V dequantize: extract `ne` float/half values at position i0.
// 4-bit nibble packed, block size 128.
template <typename T, int ne>
static __device__ __forceinline__ void dequantize_V_turbo4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
const block_turbo4_0 * x = (const block_turbo4_0 *) vx;

const int64_t ib = i0 / QK_TURBO4;
const int j0 = i0 % QK_TURBO4;
const float norm = __half2float(x[ib].norm);

static_assert(ne == 2 || ne == 4, "bad ne");

if constexpr (ne == 4) {
// j0 is always a multiple of 4 from the VEC kernel access pattern.
// 4 consecutive elements span 2 qs bytes: j0/2 and j0/2+1.
const uint8_t qs_byte0 = x[ib].qs[j0 / 2]; // elements j0, j0+1
const uint8_t qs_byte1 = x[ib].qs[j0 / 2 + 1]; // elements j0+2, j0+3

const uint8_t idx0 = (qs_byte0 >> 0) & 0xF;
const uint8_t idx1 = (qs_byte0 >> 4) & 0xF;
const uint8_t idx2 = (qs_byte1 >> 0) & 0xF;
const uint8_t idx3 = (qs_byte1 >> 4) & 0xF;

#ifdef FP16_AVAILABLE
if constexpr (std::is_same_v<T, half>) {
((half2 *) dst)[0] = make_half2(
__float2half(TURBO_CENTROIDS_4BIT[idx0] * norm),
__float2half(TURBO_CENTROIDS_4BIT[idx1] * norm));
((half2 *) dst)[1] = make_half2(
__float2half(TURBO_CENTROIDS_4BIT[idx2] * norm),
__float2half(TURBO_CENTROIDS_4BIT[idx3] * norm));
} else
#endif // FP16_AVAILABLE
if constexpr (std::is_same_v<T, float>) {
((float2 *) dst)[0] = make_float2(
TURBO_CENTROIDS_4BIT[idx0] * norm,
TURBO_CENTROIDS_4BIT[idx1] * norm);
((float2 *) dst)[1] = make_float2(
TURBO_CENTROIDS_4BIT[idx2] * norm,
TURBO_CENTROIDS_4BIT[idx3] * norm);
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
} else { // ne == 2
#ifdef FP16_AVAILABLE
if constexpr (std::is_same_v<T, half>) {
float v0 = turbo4_dequant_element(&x[ib], j0, norm);
float v1 = turbo4_dequant_element(&x[ib], j0+1, norm);
((half2 *) dst)[0] = make_half2(__float2half(v0), __float2half(v1));
} else
#endif // FP16_AVAILABLE
if constexpr (std::is_same_v<T, float>) {
((float *) dst)[0] = turbo4_dequant_element(&x[ib], j0, norm);
((float *) dst)[1] = turbo4_dequant_element(&x[ib], j0+1, norm);
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
}
}

template <ggml_type type_K, int D, int nthreads>
constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
if constexpr (type_K == GGML_TYPE_F16) {
Expand All @@ -827,6 +936,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
return vec_dot_fattn_vec_KQ_turbo3_0<D, nthreads>;
} else if constexpr (type_K == GGML_TYPE_TURBO2_0) {
return vec_dot_fattn_vec_KQ_turbo2_0<D, nthreads>;
} else if constexpr (type_K == GGML_TYPE_TURBO4_0) {
return vec_dot_fattn_vec_KQ_turbo4_0<D, nthreads>;
} else {
static_assert(type_K == -1, "bad type");
return nullptr;
Expand All @@ -853,6 +964,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
return dequantize_V_turbo3_0<T, ne>;
} else if constexpr (type_V == GGML_TYPE_TURBO2_0) {
return dequantize_V_turbo2_0<T, ne>;
} else if constexpr (type_V == GGML_TYPE_TURBO4_0) {
return dequantize_V_turbo4_0<T, ne>;
} else {
static_assert(type_V == -1, "bad type");
return nullptr;
Expand Down
24 changes: 24 additions & 0 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);

GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);

return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
}

Expand All @@ -85,6 +90,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);

GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);

return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
}

Expand All @@ -94,6 +104,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);

GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);

// TODO tune specifically for Volta
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
}
Expand All @@ -107,6 +122,10 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);

GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(640, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);

// TODO tune specifically for RDNA
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
}
Expand Down Expand Up @@ -1826,3 +1845,8 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);

// D=640: padded turbo KV cache for GLM-4.7 Flash (ncols2=16 only)
extern DECL_FATTN_MMA_F16_CASE(640, 512, 1, 16);
extern DECL_FATTN_MMA_F16_CASE(640, 512, 2, 16);
extern DECL_FATTN_MMA_F16_CASE(640, 512, 4, 16);
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/fattn-tile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
GGML_ASSERT(V->ne[0] == 512);
ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
} break;
case 640: {
GGML_ASSERT(V->ne[0] == 512);
ggml_cuda_flash_attn_ext_tile_case<640, 512>(ctx, dst);
} break;
default: {
GGML_ABORT("Unsupported head size");
} break;
Expand Down
Loading