Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4187,6 +4187,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TURBO3_0], "dequant_turbo3_0", dequant_turbo3_0_len, dequant_turbo3_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);

// get_rows
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
Expand All @@ -4212,6 +4213,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_TURBO3_0], "get_rows_turbo3_0", get_rows_turbo3_0_len, get_rows_turbo3_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
Expand All @@ -4237,6 +4239,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_TURBO3_0], "get_rows_turbo3_0_f32", get_rows_turbo3_0_f32_len, get_rows_turbo3_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
Expand Down Expand Up @@ -4304,13 +4307,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_TURBO3_0], "cpy_f32_turbo3_0", cpy_f32_turbo3_0_rte_len, cpy_f32_turbo3_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
} else {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_TURBO3_0], "cpy_f32_turbo3_0", cpy_f32_turbo3_0_len, cpy_f32_turbo3_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
}

#define SET_ROWS(itype, rte) \
Expand All @@ -4322,7 +4327,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_TURBO3_0], "set_rows_turbo3_0" #itype, set_rows_turbo3_0 ## itype ## rte ## _len, set_rows_turbo3_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);

if (device->float_controls_rte_fp16) {
SET_ROWS(_i32, _rte)
Expand All @@ -4340,6 +4346,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_TURBO3_0], "cpy_turbo3_0_f32", cpy_turbo3_0_f32_len, cpy_turbo3_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_TURBO3_0), 1, 1}, {}, 1);

auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
std::string s;
Expand Down Expand Up @@ -15391,6 +15398,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_MXFP4:
case GGML_TYPE_TURBO3_0:
case GGML_TYPE_I32:
return true;
default:
Expand All @@ -15409,6 +15417,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_TURBO3_0:
return true;
default:
return false;
Expand All @@ -15432,6 +15441,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_TURBO3_0:
return true;
default:
break;
Expand All @@ -15446,6 +15456,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_TURBO3_0:
return true;
default:
break;
Expand Down
58 changes: 58 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,64 @@ void quantize(uint dst_idx, uint src_idx)
}
#endif

#if defined(DATA_A_TURBO3_0)
void quantize(uint dst_idx, uint src_idx)
{
const float centroids[8] = float[8](
-0.190685, -0.117832, -0.065717, -0.021460,
0.021460, 0.065717, 0.117832, 0.190685
);
const float midpoints[7] = float[7](
-0.154259, -0.091775, -0.043589, 0.0, 0.043589, 0.091775, 0.154259
);

// Compute L2 norm
float norm_sq = 0.0;
[[unroll]] for (int j = 0; j < 32; ++j) {
float v = data_s[src_idx + j];
norm_sq += v * v;
}
float norm = sqrt(norm_sq);
float inv_norm = (norm > 1e-10) ? (1.0 / norm) : 0.0;

// Clear output
[[unroll]] for (int j = 0; j < 8; ++j) data_q[dst_idx].qs[j] = uint8_t(0);
[[unroll]] for (int j = 0; j < 4; ++j) data_q[dst_idx].signs[j] = uint8_t(0);

// Accumulate centroid reconstruction norm for correction
float recon_norm_sq = 0.0;

// Quantize each element
[[unroll]] for (int j = 0; j < 32; ++j) {
float val = data_s[src_idx + j] * inv_norm;

// Find nearest centroid via midpoint comparison
uint idx = 0;
if (val < midpoints[0]) idx = 0;
else if (val < midpoints[1]) idx = 1;
else if (val < midpoints[2]) idx = 2;
else if (val < midpoints[3]) idx = 3;
else if (val < midpoints[4]) idx = 4;
else if (val < midpoints[5]) idx = 5;
else if (val < midpoints[6]) idx = 6;
else idx = 7;

recon_norm_sq += centroids[idx] * centroids[idx];

// Pack: low 2 bits to qs, high 1 bit to signs
uint low2 = idx & 0x3;
uint hi1 = (idx >> 2) & 0x1;
data_q[dst_idx].qs[j / 4] |= uint8_t(low2 << ((j % 4) * 2));
data_q[dst_idx].signs[j / 8] |= uint8_t(hi1 << (j % 8));
}

// Norm correction: scale so reconstruction matches original norm
float recon_norm = sqrt(recon_norm_sq);
float corrected_norm = (recon_norm > 1e-10) ? (norm / recon_norm) : norm;
data_q[dst_idx].norm = float16_t(corrected_norm);
}
#endif

#if defined(DATA_A_IQ4_NL)
uint best_index(float x) {
if (x <= kvalues_iq4nl[0]) return 0;
Expand Down
36 changes: 36 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,39 @@ vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif

#if defined(DATA_A_TURBO3_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
// PolarQuant 3-bit centroids (Lloyd-Max for Gaussian)
const float centroids[8] = float[8](
-0.190685, -0.117832, -0.065717, -0.021460,
0.021460, 0.065717, 0.117832, 0.190685
);

// iqs is the element index within the block (0..31), we decode 2 consecutive elements
const uint j0 = iqs;
const uint j1 = iqs + 1;

// Extract 2-bit low indices from qs (4 per byte)
const uint low2_0 = (uint(data_a[a_offset + ib].qs[j0 / 4]) >> ((j0 % 4) * 2)) & 0x3;
const uint low2_1 = (uint(data_a[a_offset + ib].qs[j1 / 4]) >> ((j1 % 4) * 2)) & 0x3;

// Extract 1-bit high from signs (8 per byte)
const uint hi1_0 = (uint(data_a[a_offset + ib].signs[j0 / 8]) >> (j0 % 8)) & 0x1;
const uint hi1_1 = (uint(data_a[a_offset + ib].signs[j1 / 8]) >> (j1 % 8)) & 0x1;

// Combine to 3-bit index
const uint idx0 = low2_0 | (hi1_0 << 2);
const uint idx1 = low2_1 | (hi1_1 << 2);

return vec2(centroids[idx0], centroids[idx1]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
vec2 v0 = dequantize(ib, iqs, a_offset);
vec2 v1 = dequantize(ib, iqs + 2, a_offset);
return vec4(v0.x, v0.y, v1.x, v1.y);
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].norm), 0);
}
#endif
29 changes: 29 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,33 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
}
#endif

#if defined(DATA_A_TURBO3_0)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufTURBO3_0 {
block_turbo3_0 block;
};

float16_t dequantFuncTURBO3_0(const in decodeBufTURBO3_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float centroids[8] = float[8](
-0.190685, -0.117832, -0.065717, -0.021460,
0.021460, 0.065717, 0.117832, 0.190685
);
const float norm = float(bl.block.norm);
const uint j = coordInBlock[1];

// Extract 2-bit low index from qs (4 per byte)
const uint low2 = (uint(bl.block.qs[j / 4]) >> ((j % 4) * 2)) & 0x3;

// Extract 1-bit high from signs (8 per byte)
const uint hi1 = (uint(bl.block.signs[j / 8]) >> (j % 8)) & 0x1;

// Combine to 3-bit index
const uint idx = low2 | (hi1 << 2);

return float16_t(centroids[idx] * norm);
}
#endif

#if defined(DATA_A_Q4_0)
#define dequantFuncA dequantFuncQ4_0
#elif defined(DATA_A_Q4_1)
Expand Down Expand Up @@ -729,6 +756,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
#define dequantFuncA dequantFuncIQ4_NL
#elif defined(DATA_A_MXFP4)
#define dequantFuncA dequantFuncMXFP4
#elif defined(DATA_A_TURBO3_0)
#define dequantFuncA dequantFuncTURBO3_0
#elif defined(DATA_A_F32)
#define dequantFuncA dequantFuncF32
#endif
46 changes: 46 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_turbo3_0.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#version 450

#include "dequant_head.glsl"

layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer A {block_turbo3_0 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};

void main() {
const float centroids[8] = float[8](
-0.190685, -0.117832, -0.065717, -0.021460,
0.021460, 0.065717, 0.117832, 0.190685
);

const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;

const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}

const uint b_idx = 1024*i + 32*ir + 16*il;

const float norm = float(data_a[ib].norm);

const uint q_start = 16*il;

[[unroll]] for (uint l = 0; l < 16; ++l) {
const uint j = q_start + l;

// Extract 2-bit low index from qs (4 per byte)
const uint low2 = (uint(data_a[ib].qs[j / 4]) >> ((j % 4) * 2)) & 0x3;

// Extract 1-bit high from signs (8 per byte)
const uint hi1 = (uint(data_a[ib].signs[j / 8]) >> (j % 8)) & 0x1;

// Combine to 3-bit index
const uint idx = low2 | (hi1 << 2);

data_b[b_idx + l] = D_TYPE(centroids[idx] * norm);
}
}
17 changes: 17 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,23 @@ struct block_mxfp4
#define A_TYPE block_mxfp4
#endif

#define QUANT_K_TURBO3_0 32
#define QUANT_R_TURBO3_0 1

struct block_turbo3_0
{
float16_t norm;
uint8_t qs[8]; // 2-bit centroid indices (4 per byte)
uint8_t signs[4]; // 1-bit high bit of 3-bit index (8 per byte)
};

#if defined(DATA_A_TURBO3_0)
#define QUANT_K QUANT_K_TURBO3_0
#define QUANT_R QUANT_R_TURBO3_0
#define QUANT_AUXF 1
#define A_TYPE block_turbo3_0
#endif

#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
const int8_t kvalues_iq4nl_const[16] = {
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
Expand Down
5 changes: 3 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ const std::vector<std::string> type_names = {
"iq4_nl",
"mxfp4",
"bf16",
"turbo3_0",
};

enum MatMulIdType {
Expand Down Expand Up @@ -757,13 +758,13 @@ void process_shaders() {
string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}});
string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}});

for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl", "turbo3_0"}) {
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}

for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl", "turbo3_0"}) {
string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
Expand Down
Loading