Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ struct vk_device_struct {
vk_pipeline pipeline_timestep_embedding_f32;
vk_pipeline pipeline_conv_transpose_1d_f32;
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_turbo_wht;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
// [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128
Expand Down Expand Up @@ -3447,11 +3448,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_TURBO3_0, turbo3_0, FA_SCALAR, )
} else {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_TURBO3_0, turbo3_0, FA_SCALAR, _fp32)
}
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) {
Expand Down Expand Up @@ -4187,7 +4190,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TURBO3_0], "dequant_turbo3_0", dequant_turbo3_0_len, dequant_turbo3_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TURBO3_0], "dequant_turbo3_0", dequant_turbo3_0_len, dequant_turbo3_0_data, "main", 2, 5 * sizeof(uint32_t), {128, 1, 1}, {}, 1);

// TurboQuant WHT
ggml_vk_create_pipeline(device, device->pipeline_turbo_wht, "turbo_wht", turbo_wht_len, turbo_wht_data, "main", 2, 3 * sizeof(uint32_t), {128, 1, 1}, {}, 1);

// get_rows
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
Expand Down Expand Up @@ -4307,15 +4313,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_TURBO3_0], "cpy_f32_turbo3_0", cpy_f32_turbo3_0_rte_len, cpy_f32_turbo3_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
} else {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_TURBO3_0], "cpy_f32_turbo3_0", cpy_f32_turbo3_0_len, cpy_f32_turbo3_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
}

#define SET_ROWS(itype, rte) \
Expand Down Expand Up @@ -7278,6 +7282,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_TURBO3_0:
return ctx->device->pipeline_cpy_quant_f32[src->type];
default:
break;
Expand Down Expand Up @@ -10063,7 +10068,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_SET_ROWS:
{
uint32_t ne = ggml_nelements(src0);
if (ggml_is_quantized(dst->type)) {
if (dst->type == GGML_TYPE_TURBO3_0) {
ne = ne / 128;
} else if (ggml_is_quantized(dst->type)) {
// quants run 32 threads each doing QUANT_K elements
ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
} else {
Expand Down Expand Up @@ -10834,6 +10841,32 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
});
}

static void ggml_vk_turbo_wht(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
int direction, group_size;
memcpy(&direction, dst->op_params + 0, sizeof(int));
memcpy(&group_size, dst->op_params + sizeof(int), sizeof(int));
struct { uint32_t ne; uint32_t direction; uint32_t group_size; } pc = {
(uint32_t)ggml_nelements(src0), (uint32_t)direction, (uint32_t)group_size,
};
vk_pipeline pipeline = ctx->device->pipeline_turbo_wht;
GGML_ASSERT(pipeline != nullptr);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0, false);
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
// Spread workgroups across Y/Z to stay within maxComputeWorkGroupCount[0].
// elements[0] / group_size = wg0; each row of 512 workgroups uses one Y slice.
const uint32_t n_groups = pc.ne / (uint32_t)group_size;
std::array<uint32_t, 3> elements;
if (n_groups > 262144) {
elements = { 512 * (uint32_t)group_size, 512, CEIL_DIV(n_groups, 262144) };
} else if (n_groups > 512) {
elements = { 512 * (uint32_t)group_size, CEIL_DIV(n_groups, 512), 1 };
} else {
elements = { pc.ne, 1, 1 };
}
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc, elements);
}

static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
}
Expand Down Expand Up @@ -13015,6 +13048,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_SET_ROWS:
ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node);

break;
case GGML_OP_TURBO_WHT:
ggml_vk_turbo_wht(ctx, compute_ctx, src0, node);

break;
case GGML_OP_SILU_BACK:
ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node);
Expand Down Expand Up @@ -15338,6 +15375,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_TURBO3_0:
// supported in scalar and coopmat2 paths
break;
case GGML_TYPE_Q4_1:
Expand Down Expand Up @@ -15441,7 +15479,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_TURBO3_0:
return true;
default:
break;
Expand Down Expand Up @@ -15710,6 +15747,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
ggml_is_contiguous(op->src[1]) &&
ggml_is_contiguous(op));
}
case GGML_OP_TURBO_WHT:
return op->src[0]->type == GGML_TYPE_F32 && op->src[0]->ne[0] % 128 == 0;
default:
return false;
}
Expand Down
213 changes: 157 additions & 56 deletions ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
#version 450

#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_KHR_shader_subgroup_shuffle : enable
#include "rte.glsl"
#include "types.glsl"

#if defined(SET_ROWS) && QUANT_K == 1
#if defined(SET_ROWS) && defined(DATA_A_TURBO3_0)
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
const uint BLOCK_SIZE = 128;
#elif defined(SET_ROWS) && QUANT_K == 1
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
const uint BLOCK_SIZE = 512;
#else
Expand Down Expand Up @@ -185,62 +191,67 @@ void quantize(uint dst_idx, uint src_idx)
#endif

#if defined(DATA_A_TURBO3_0)
void quantize(uint dst_idx, uint src_idx)
{
const float centroids[8] = float[8](
-0.190685, -0.117832, -0.065717, -0.021460,
0.021460, 0.065717, 0.117832, 0.190685
);
const float midpoints[7] = float[7](
-0.154259, -0.091775, -0.043589, 0.0, 0.043589, 0.091775, 0.154259
);

// Compute L2 norm
float norm_sq = 0.0;
[[unroll]] for (int j = 0; j < 32; ++j) {
float v = data_s[src_idx + j];
norm_sq += v * v;
}
float norm = sqrt(norm_sq);
float inv_norm = (norm > 1e-10) ? (1.0 / norm) : 0.0;

// Clear output
[[unroll]] for (int j = 0; j < 8; ++j) data_q[dst_idx].qs[j] = uint8_t(0);
[[unroll]] for (int j = 0; j < 4; ++j) data_q[dst_idx].signs[j] = uint8_t(0);

// Accumulate centroid reconstruction norm for correction
float recon_norm_sq = 0.0;

// Quantize each element
[[unroll]] for (int j = 0; j < 32; ++j) {
float val = data_s[src_idx + j] * inv_norm;

// Find nearest centroid via midpoint comparison
uint idx = 0;
if (val < midpoints[0]) idx = 0;
else if (val < midpoints[1]) idx = 1;
else if (val < midpoints[2]) idx = 2;
else if (val < midpoints[3]) idx = 3;
else if (val < midpoints[4]) idx = 4;
else if (val < midpoints[5]) idx = 5;
else if (val < midpoints[6]) idx = 6;
else idx = 7;

recon_norm_sq += centroids[idx] * centroids[idx];

// Pack: low 2 bits to qs, high 1 bit to signs
uint low2 = idx & 0x3;
uint hi1 = (idx >> 2) & 0x1;
data_q[dst_idx].qs[j / 4] |= uint8_t(low2 << ((j % 4) * 2));
data_q[dst_idx].signs[j / 8] |= uint8_t(hi1 << (j % 8));
}
const float TS1[128] = float[128](
-1, 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, 1, 1, 1,
1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, -1,
-1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1,
1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, 1, 1, 1, -1, 1,
-1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, 1,
1, -1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, 1, -1,
-1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, 1, -1, 1, -1, 1,
1, -1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, 1, 1, -1, 1
);

const float TS2[128] = float[128](
1, 1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, 1, -1, -1, -1,
1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, 1, 1,
1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1, 1, -1,
1, -1, 1, 1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, 1, 1,
1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, 1, 1,
-1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1,
1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1,
-1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1
);

const float TINV = 0.08838834764831845; // 1 / sqrt(128)

const float TC[8] = float[8](
-0.190685, -0.117832, -0.065717, -0.021460,
0.021460, 0.065717, 0.117832, 0.190685
);

const float TM[7] = float[7](
-0.154259, -0.091775, -0.043589,
0.0,
0.043589, 0.091775, 0.154259
);

// Norm correction: scale so reconstruction matches original norm
float recon_norm = sqrt(recon_norm_sq);
float corrected_norm = (recon_norm > 1e-10) ? (norm / recon_norm) : norm;
data_q[dst_idx].norm = float16_t(corrected_norm);
#if defined(SET_ROWS)

shared float wht[128];
shared float sg_acc[16];
shared float gnrm;

void quantize_block(uint b, uint o) {
[[unroll]] for (int j = 0; j < 32; ++j) data_q[b].qs[j] = uint8_t(0);
[[unroll]] for (int j = 0; j < 16; ++j) data_q[b].signs[j] = uint8_t(0);
float rs = 0.0;
[[unroll]] for (int j = 0; j < 128; ++j) {
float v = wht[o + j];
uint i = v < TM[0] ? 0 : v < TM[1] ? 1 : v < TM[2] ? 2 : v < TM[3] ? 3 :
v < TM[4] ? 4 : v < TM[5] ? 5 : v < TM[6] ? 6 : 7;
rs += TC[i] * TC[i];
uint low2 = i & 0x3;
uint hi1 = (i >> 2) & 0x1;
data_q[b].qs[j / 4] |= uint8_t(low2 << ((j % 4) * 2));
data_q[b].signs[j / 8] |= uint8_t(hi1 << (j % 8));
}
float rn = sqrt(rs);
data_q[b].norm = float16_t((rn > 1e-10) ? (gnrm / rn) : gnrm);
}
#endif

#endif // defined(SET_ROWS)
#endif // defined(DATA_A_TURBO3_0)

#if defined(DATA_A_IQ4_NL)
uint best_index(float x) {
Expand Down Expand Up @@ -304,7 +315,97 @@ void quantize(uint dst_idx, uint src_idx)
}
#endif

#if defined(SET_ROWS)
#if defined(SET_ROWS) && defined(DATA_A_TURBO3_0)
void main() {
const uint t = gl_LocalInvocationID.x;
const uint g = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint gpr = p.ne00 / 128;

if (gpr == 0) return;
if (g >= p.ne / 128) return;

uint tmp = g;
const uint ig = tmp % gpr; tmp /= gpr;
const uint i01 = tmp % p.ne01; tmp /= p.ne01;
const uint i02 = tmp % p.ne12;
const uint i03 = tmp / p.ne12;

const uint sb = src0_idx(ig * 128, i01, i02, i03) + get_aoffset();
const uint i1 = data_i[src1_idx(i01, fastmod(i02, p.ne11), fastmod(i03, p.ne12), 0) + get_boffset()] DATA_I_SWIZZLE;
const uint db = dst_idx(ig, i1, i02, i03) + get_doffset();

// Step 1: load into shared memory
wht[t] = data_s[sb + t];
barrier();

// Step 2: L2 norm via subgroup reduction
float v2 = wht[t] * wht[t];
v2 = subgroupAdd(v2);
if (gl_SubgroupInvocationID == 0) sg_acc[gl_SubgroupID] = v2;
barrier();
if (t == 0) {
float total = 0.0;
for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc[w];
gnrm = sqrt(total);
}
barrier();

// Step 3: normalize, then apply forward WHT: signs1 -> butterfly -> signs2
wht[t] *= (gnrm > 1e-10) ? (1.0 / gnrm) : 0.0;
barrier();

wht[t] *= TS1[t];
barrier();

[[unroll]] for (uint h = 1; h < 128; h *= 2) {
if ((t % (2 * h)) < h) {
float a = wht[t];
float b = wht[t + h];
wht[t] = a + b;
wht[t + h] = a - b;
}
barrier();
}

// Step 5: apply signs2 + scaling
float rv = wht[t] * TINV * TS2[t];

// Step 6: quantize -- all 128 threads participate
uint idx = rv < TM[0] ? 0u : rv < TM[1] ? 1u : rv < TM[2] ? 2u : rv < TM[3] ? 3u :
rv < TM[4] ? 4u : rv < TM[5] ? 5u : rv < TM[6] ? 6u : 7u;

// Pack qs: 4 elements per byte via subgroup shuffle
uint sg_lane = gl_SubgroupInvocationID;
uint my_low2 = idx & 0x3u;
uint qs_byte = 0u;
[[unroll]] for (uint k = 0; k < 4; k++) {
uint contrib = subgroupShuffle(my_low2, (sg_lane & ~3u) + k);
qs_byte |= contrib << (k * 2u);
}
if (sg_lane % 4u == 0u) {
data_q[db].qs[t / 4u] = uint8_t(qs_byte);
}

// Pack signs: 8 elements per byte via subgroup ballot
uvec4 ballot = subgroupBallot(((idx >> 2u) & 1u) != 0u);
if (sg_lane % 8u == 0u) {
uint local_byte = sg_lane / 8u;
data_q[db].signs[t / 8u] = uint8_t((ballot.x >> (local_byte * 8u)) & 0xFFu);
}

// Step 7: reconstruction norm via subgroup reduction
float rc = TC[idx] * TC[idx];
rc = subgroupAdd(rc);
if (sg_lane == 0u) sg_acc[gl_SubgroupID] = rc;
barrier();
if (t == 0u) {
float total = 0.0;
for (uint w = 0; w < gl_NumSubgroups; w++) total += sg_acc[w];
float rn = sqrt(total);
data_q[db].norm = float16_t((rn > 1e-10) ? (gnrm / rn) : gnrm);
}
}
#elif defined(SET_ROWS)

void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
Expand Down
Loading