Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,19 +279,21 @@ static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0

// TurboQuant blocks

// 3.0625 bpw
#define TBQ_BLK_SIZE 128

// 3.125 bpw
typedef struct {
uint8_t qs[QK_K * 3 / 8];
uint8_t qs[TBQ_BLK_SIZE * 3 / 8];
ggml_half d;
} block_tbq3_0;
static_assert(sizeof(block_tbq3_0) == sizeof(ggml_half) + QK_K * 3 / 8, "wrong tbq3_0 block size/padding");
static_assert(sizeof(block_tbq3_0) == sizeof(ggml_half) + TBQ_BLK_SIZE * 3 / 8, "wrong tbq3_0 block size/padding");

// 4.0625 bpw
// 4.125 bpw
typedef struct {
uint8_t qs[QK_K / 2];
uint8_t qs[TBQ_BLK_SIZE / 2];
ggml_half d;
} block_tbq4_0;
static_assert(sizeof(block_tbq4_0) == sizeof(ggml_half) + QK_K / 2, "wrong tbq4_0 block size/padding");
static_assert(sizeof(block_tbq4_0) == sizeof(ggml_half) + TBQ_BLK_SIZE / 2, "wrong tbq4_0 block size/padding");

//
// Super-block quantization structures
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-cpu/quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy,
}

void quantize_row_tbq3_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(k % QK_K == 0);
assert(k % TBQ_BLK_SIZE == 0);
block_tbq3_0 * GGML_RESTRICT y = vy;
quantize_row_tbq3_0_ref(x, y, k);
}

void quantize_row_tbq4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(k % QK_K == 0);
assert(k % TBQ_BLK_SIZE == 0);
block_tbq4_0 * GGML_RESTRICT y = vy;
quantize_row_tbq4_0_ref(x, y, k);
}
Expand Down
214 changes: 214 additions & 0 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,212 @@ static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict
}
}


// 3-bit unpacking helper for TBQ3_0
static __device__ __forceinline__ int unpack_3bit_index(const uint8_t * qs, int elem) {
const int group = elem / 8;
const int bit_offset = (elem % 8) * 3;
const uint8_t * src = qs + group * 3;
uint32_t bits = (uint32_t)src[0] | ((uint32_t)src[1] << 8) | ((uint32_t)src[2] << 16);
return (bits >> bit_offset) & 0x7;
}

template<int D, int nthreads>
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_tbq3_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {

const block_tbq3_0 * K_tbq3_0 = (const block_tbq3_0 *) K_c;
GGML_UNUSED(Q_v);

const float cb[8] = {
-2.1520f, -1.3440f, -0.7560f, -0.2451f,
0.2451f, 0.7560f, 1.3440f, 2.1520f,
};
const float scale_down = 0.08838834764f; // 1/sqrt(128)

float sum = 0.0f;

#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);

const int elem_start = k_KQ * 4;
const int ib = elem_start / 128;
const int elem = elem_start % 128;

const uint8_t* qs = K_tbq3_0[ib].qs;

const int idx0 = unpack_3bit_index(qs, elem + 0);
const int idx1 = unpack_3bit_index(qs, elem + 1);
const int idx2 = unpack_3bit_index(qs, elem + 2);
const int idx3 = unpack_3bit_index(qs, elem + 3);

const float v0 = cb[idx0];
const float v1 = cb[idx1];
const float v2 = cb[idx2];
const float v3 = cb[idx3];

const int u = Q_q8[k_KQ_0/nthreads];
const int8_t * q8 = (const int8_t *) &u;

float sumi = v0 * q8[0] + v1 * q8[1] + v2 * q8[2] + v3 * q8[3];

const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
sum += __half2float(K_tbq3_0[ib].d) * sumi * Q_ds.x * scale_down;
}

return sum;
}

template<int D, int nthreads>
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_tbq4_0(
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {

const block_tbq4_0 * K_tbq4_0 = (const block_tbq4_0 *) K_c;
GGML_UNUSED(Q_v);

const float cb[16] = {
-2.7326f, -2.0690f, -1.6180f, -1.2562f,
-0.9424f, -0.6568f, -0.3881f, -0.1284f,
0.1284f, 0.3881f, 0.6568f, 0.9424f,
1.2562f, 1.6180f, 2.0690f, 2.7326f,
};
const float scale_down = 0.08838834764f;

float sum = 0.0f;

#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);

const int elem_start = k_KQ * 4;
const int ib = elem_start / 128;
const int elem = elem_start % 128;

const uint8_t* qs = K_tbq4_0[ib].qs;

const uint8_t byte0 = qs[(elem + 0) / 2];
const uint8_t byte1 = qs[(elem + 2) / 2];

const int idx0 = byte0 & 0x0F;
const int idx1 = byte0 >> 4;
const int idx2 = byte1 & 0x0F;
const int idx3 = byte1 >> 4;

const float v0 = cb[idx0];
const float v1 = cb[idx1];
const float v2 = cb[idx2];
const float v3 = cb[idx3];

const int u = Q_q8[k_KQ_0/nthreads];
const int8_t * q8 = (const int8_t *) &u;

float sumi = v0 * q8[0] + v1 * q8[1] + v2 * q8[2] + v3 * q8[3];

const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
sum += __half2float(K_tbq4_0[ib].d) * sumi * Q_ds.x * scale_down;
}

return sum;
}

template <typename T, int ne>
static __device__ __forceinline__ void dequantize_V_tbq3_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
const block_tbq3_0 * x = (const block_tbq3_0 *) vx;

const int64_t ib = i0 / 128;
const int elem_start = i0 % 128;

const float cb[8] = {
-2.1520f, -1.3440f, -0.7560f, -0.2451f,
0.2451f, 0.7560f, 1.3440f, 2.1520f,
};
const float scale_down = 0.08838834764f;

static_assert(ne == 2 || ne == 4, "bad ne");

float vals[4];
#pragma unroll
for (int l = 0; l < ne; ++l) {
int elem = elem_start + l;
int group = elem / 8;
int bit_offset = (elem % 8) * 3;
const uint8_t * src = x[ib].qs + group * 3;
uint32_t bits = (uint32_t)src[0] | ((uint32_t)src[1] << 8) | ((uint32_t)src[2] << 16);
int idx = (bits >> bit_offset) & 0x7;
vals[l] = cb[idx] * scale_down;
}

#ifdef FP16_AVAILABLE
if constexpr (std::is_same_v<T, half>) {
const half2 d = __half2half2(x[ib].d);

#pragma unroll
for (int l0 = 0; l0 < ne; l0 += 2) {
((half2 *) dst)[l0/2] = d * make_half2(vals[l0 + 0], vals[l0 + 1]);
}
} else
#endif // FP16_AVAILABLE
if constexpr (std::is_same_v<T, float>) {
const float d = __half2float(x[ib].d);

#pragma unroll
for (int l = 0; l < ne; ++l) {
((float *) dst)[l] = d * vals[l];
}
} else {
static_assert(std::is_same_v<T, void>, "bad type");
}
}

template <typename T, int ne>
static __device__ __forceinline__ void dequantize_V_tbq4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
const block_tbq4_0 * x = (const block_tbq4_0 *) vx;

const int64_t ib = i0 / 128;
const int elem_start = i0 % 128;

const float cb[16] = {
-2.7326f, -2.0690f, -1.6180f, -1.2562f,
-0.9424f, -0.6568f, -0.3881f, -0.1284f,
0.1284f, 0.3881f, 0.6568f, 0.9424f,
1.2562f, 1.6180f, 2.0690f, 2.7326f,
};
const float scale_down = 0.08838834764f;

static_assert(ne == 2 || ne == 4, "bad ne");

float vals[4];
#pragma unroll
for (int l = 0; l < ne; ++l) {
int elem = elem_start + l;
uint8_t byte = x[ib].qs[elem / 2];
int idx = (elem % 2 == 0) ? (byte & 0x0F) : (byte >> 4);
vals[l] = cb[idx] * scale_down;
}

#ifdef FP16_AVAILABLE
if constexpr (std::is_same_v<T, half>) {
const half2 d = __half2half2(x[ib].d);

#pragma unroll
for (int l0 = 0; l0 < ne; l0 += 2) {
((half2 *) dst)[l0/2] = d * make_half2(vals[l0 + 0], vals[l0 + 1]);
}
} else
#endif // FP16_AVAILABLE
if constexpr (std::is_same_v<T, float>) {
const float d = __half2float(x[ib].d);

#pragma unroll
for (int l = 0; l < ne; ++l) {
((float *) dst)[l] = d * vals[l];
}
} else {
static_assert(std::is_same_v<T, void>, "bad type");
}
}

template <ggml_type type_K, int D, int nthreads>
constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
if constexpr (type_K == GGML_TYPE_F16) {
Expand All @@ -593,6 +799,10 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
} else if constexpr (type_K == GGML_TYPE_BF16) {
return vec_dot_fattn_vec_KQ_bf16<D, nthreads>;
} else if constexpr (type_K == GGML_TYPE_TBQ3_0) {
return vec_dot_fattn_vec_KQ_tbq3_0<D, nthreads>;
} else if constexpr (type_K == GGML_TYPE_TBQ4_0) {
return vec_dot_fattn_vec_KQ_tbq4_0<D, nthreads>;
} else {
static_assert(type_K == -1, "bad type");
return nullptr;
Expand All @@ -615,6 +825,10 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
return dequantize_V_q8_0<T, ne>;
} else if constexpr (type_V == GGML_TYPE_BF16) {
return dequantize_V_bf16<float, ne>;
} else if constexpr (type_V == GGML_TYPE_TBQ3_0) {
return dequantize_V_tbq3_0<T, ne>;
} else if constexpr (type_V == GGML_TYPE_TBQ4_0) {
return dequantize_V_tbq4_0<T, ne>;
} else {
static_assert(type_V == -1, "bad type");
return nullptr;
Expand Down
27 changes: 27 additions & 0 deletions ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ3_0, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ4_0, GGML_TYPE_F16)

FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
Expand Down Expand Up @@ -285,6 +287,29 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ3_0, GGML_TYPE_TBQ3_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ4_0, GGML_TYPE_TBQ4_0)

FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TBQ3_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_TBQ3_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_TBQ3_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_TBQ3_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_TBQ3_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TBQ3_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_TBQ3_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ3_0, GGML_TYPE_TBQ3_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ4_0, GGML_TYPE_TBQ3_0)

FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_TBQ4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_TBQ4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_TBQ4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_TBQ4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_TBQ4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TBQ4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_TBQ4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ3_0, GGML_TYPE_TBQ4_0)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_TBQ4_0, GGML_TYPE_TBQ4_0)

#else
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
Expand Down Expand Up @@ -391,6 +416,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_BF16:
case GGML_TYPE_TBQ3_0:
case GGML_TYPE_TBQ4_0:
break;
default:
return BEST_FATTN_KERNEL_NONE;
Expand Down
Loading