diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8c4cf9ef1dbe..f83f5ffda22f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -76,6 +76,8 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #define YIELD() #endif +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" #include "ggml-impl.h" #include "ggml-backend-impl.h" @@ -999,6 +1001,7 @@ struct vk_mat_mat_push_constants { uint32_t k_split; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; uint32_t padded_N; + uint32_t deltas_offset; }; #define MAT_VEC_FUSION_FLAGS_BIAS0 0x1 @@ -1020,6 +1023,7 @@ struct vk_mat_vec_push_constants { uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; + uint32_t deltas_offset; }; struct vk_mat_vec_p021_push_constants { @@ -1054,6 +1058,7 @@ struct vk_mat_mat_id_push_constants { uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11; uint32_t padded_N; + uint32_t deltas_offset; }; struct vk_mat_vec_id_push_constants { uint32_t ncols; @@ -1068,6 +1073,7 @@ struct vk_mat_vec_id_push_constants { uint32_t ne11; uint32_t expert_i1; uint32_t nbi1; + uint32_t deltas_offset; }; struct vk_flash_attn_push_constants { @@ -1976,7 +1982,7 @@ static uint64_t vk_tensor_offset(const ggml_tensor * tensor) { static uint32_t get_misalign_bytes(const ggml_backend_vk_context * ctx, const ggml_tensor * t) { - return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));; + return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1)); } template void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { @@ -6674,6 +6680,8 @@ static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffe } } +static size_t ggml_vk_repack_size_tensor(const ggml_tensor * tensor); + static vk_subbuffer ggml_vk_tensor_subbuffer( const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) { @@ -6689,7 +6697,7 @@ static vk_subbuffer ggml_vk_tensor_subbuffer( } GGML_ASSERT(buffer != nullptr); - size_t size = ggml_nbytes(tensor); + size_t size = ggml_vk_repack_size_tensor(tensor); size_t misalign_bytes = offset & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); // The shader must support misaligned offsets when indexing into the buffer @@ -7284,6 +7292,134 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz ggml_vk_queue_command_pools_cleanup(dst->device); } +constexpr uint32_t VULKAN_REPACK_ALIGNMENT = 256; + +static void * ggml_vk_repack_scratch(size_t size) { + thread_local std::vector buf; + if (buf.size() < size) { + buf.resize(size); + } + return buf.data(); +} + +static size_t ggml_vk_get_num_blocks(const ggml_tensor * tensor) { + const size_t num_blocks_per_row = tensor->ne[0] / ggml_blck_size(tensor->type); + return num_blocks_per_row * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; +} + +struct vk_repack_type_info { + size_t quant_bytes; + size_t delta_bytes; + size_t delta_elem_size; +}; + +static const vk_repack_type_info * ggml_vk_get_repack_info(ggml_type type) { + static const vk_repack_type_info q4_0_info = { 16, 2, 2 }; + static const vk_repack_type_info q4_1_info = { 16, 4, 2 }; + static const vk_repack_type_info q8_0_info = { 32, 2, 2 }; + static const vk_repack_type_info iq4_nl_info = { 16, 2, 2 }; + static const vk_repack_type_info mxfp4_info = { 16, 1, 1 }; + + switch (type) { + case GGML_TYPE_Q4_0: return &q4_0_info; + case GGML_TYPE_Q4_1: return &q4_1_info; + case GGML_TYPE_Q8_0: return &q8_0_info; + case GGML_TYPE_IQ4_NL: return &iq4_nl_info; + case GGML_TYPE_MXFP4: return &mxfp4_info; + default: return nullptr; + } +} + +static size_t ggml_vk_repack_quants_region(const vk_repack_type_info * info, size_t n_blocks) { + return GGML_PAD(n_blocks * info->quant_bytes, VULKAN_REPACK_ALIGNMENT); +} + +static size_t ggml_vk_repack_size(const vk_repack_type_info * info, size_t n_blocks) { + return ggml_vk_repack_quants_region(info, n_blocks) + n_blocks * info->delta_bytes; +} + +static size_t ggml_vk_repack_size_tensor(const ggml_tensor * tensor) { + const auto * info = ggml_vk_get_repack_info(tensor->type); + if (info) { + return ggml_vk_repack_size(info, ggml_vk_get_num_blocks(tensor)); + } + return ggml_nbytes(tensor); +} + +static uint32_t ggml_vk_get_deltas_offset(const ggml_tensor * tensor) { + const auto * info = ggml_vk_get_repack_info(tensor->type); + if (!info) { + return 0; + } + return ggml_vk_repack_quants_region(info, ggml_vk_get_num_blocks(tensor)) / info->delta_elem_size; +} + +static void ggml_vk_repack_pack(const vk_repack_type_info * info, size_t n_blocks, + const void * data, void * quants_dst, void * deltas_dst) { + const size_t block_size = info->quant_bytes + info->delta_bytes; + uint8_t * dst_q = (uint8_t *)quants_dst; + uint8_t * dst_d = (uint8_t *)deltas_dst; + const uint8_t * src = (const uint8_t *)data; + + for (size_t i = 0; i < n_blocks; i++) { + memcpy(dst_q + info->quant_bytes * i, src + block_size * i + info->delta_bytes, info->quant_bytes); + memcpy(dst_d + info->delta_bytes * i, src + block_size * i, info->delta_bytes); + } +} + +static void ggml_vk_repack_unpack(const vk_repack_type_info * info, size_t n_blocks, + const void * quants_src, const void * deltas_src, void * data) { + const size_t block_size = info->quant_bytes + info->delta_bytes; + const uint8_t * src_q = (const uint8_t *)quants_src; + const uint8_t * src_d = (const uint8_t *)deltas_src; + uint8_t * dst = (uint8_t *)data; + + for (size_t i = 0; i < n_blocks; i++) { + memcpy(dst + block_size * i + info->delta_bytes, src_q + info->quant_bytes * i, info->quant_bytes); + memcpy(dst + block_size * i, src_d + info->delta_bytes * i, info->delta_bytes); + } +} + +static void ggml_vk_repack_write(vk_buffer & buf, const ggml_tensor * tensor, size_t offset, const void * data, size_t size) { + const auto * info = ggml_vk_get_repack_info(tensor->type); + GGML_ASSERT(info); + const size_t block_size = info->quant_bytes + info->delta_bytes; + const size_t first_block = offset / block_size; + const size_t n_blocks_chunk = size / block_size; + const size_t n_blocks_total = ggml_vk_get_num_blocks(tensor); + const size_t quants_region = ggml_vk_repack_quants_region(info, n_blocks_total); + const size_t scratch_size = n_blocks_chunk * info->quant_bytes + n_blocks_chunk * info->delta_bytes; + void * scratch = ggml_vk_repack_scratch(scratch_size); + uint8_t * scratch_q = (uint8_t *)scratch; + uint8_t * scratch_d = scratch_q + n_blocks_chunk * info->quant_bytes; + + ggml_vk_repack_pack(info, n_blocks_chunk, data, scratch_q, scratch_d); + + const size_t buf_base = vk_tensor_offset(tensor) + tensor->view_offs; + ggml_vk_buffer_write(buf, buf_base + first_block * info->quant_bytes, scratch_q, n_blocks_chunk * info->quant_bytes); + ggml_vk_buffer_write(buf, buf_base + quants_region + first_block * info->delta_bytes, scratch_d, n_blocks_chunk * info->delta_bytes); +} + +static void ggml_vk_repack_read(vk_buffer & buf, const ggml_tensor * tensor, size_t offset, void * data, size_t size) { + const auto * info = ggml_vk_get_repack_info(tensor->type); + GGML_ASSERT(info); + const size_t block_size = info->quant_bytes + info->delta_bytes; + const size_t first_block = offset / block_size; + const size_t n_blocks_chunk = size / block_size; + const size_t n_blocks_total = ggml_vk_get_num_blocks(tensor); + const size_t quants_region = ggml_vk_repack_quants_region(info, n_blocks_total); + const size_t scratch_size = n_blocks_chunk * info->quant_bytes + n_blocks_chunk * info->delta_bytes; + void * scratch = ggml_vk_repack_scratch(scratch_size); + uint8_t * scratch_q = (uint8_t *)scratch; + uint8_t * scratch_d = scratch_q + n_blocks_chunk * info->quant_bytes; + + const size_t buf_base = vk_tensor_offset(tensor) + tensor->view_offs; + ggml_vk_buffer_read(buf, buf_base + first_block * info->quant_bytes, scratch_q, n_blocks_chunk * info->quant_bytes); + ggml_vk_buffer_read(buf, buf_base + quants_region + first_block * info->delta_bytes, scratch_d, n_blocks_chunk * info->delta_bytes); + + ggml_vk_repack_unpack(info, n_blocks_chunk, scratch_q, scratch_d, data); +} + static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) { VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")"); @@ -7383,7 +7519,7 @@ static void ggml_vk_matmul( uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, - uint32_t padded_n) { + uint32_t padded_n, uint32_t deltas_offset) { VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); if (split_k == 1) { ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2])); @@ -7392,7 +7528,7 @@ static void ggml_vk_matmul( while (base_work_group_z < batch) { uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]); - const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n }; + const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n, deltas_offset }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z }); base_work_group_z += groups_z; } @@ -7415,7 +7551,7 @@ static void ggml_vk_matmul( while (base_work_group_z < batch) { uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]); - const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n, deltas_offset }; // Make sure enough workgroups get assigned for split k to work ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z }); base_work_group_z += groups_z; @@ -7470,13 +7606,13 @@ static void ggml_vk_matmul_id( uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11, - uint32_t padded_n) { + uint32_t padded_n, uint32_t deltas_offset) { VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), expert_count: (" << expert_count_buf.buffer->buffer << ", " << expert_count_buf.offset << ", " << expert_count_buf.size << "), " << "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")"); const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, - nei0, nei1, nbi1, ne11, padded_n }; + nei0, nei1, nbi1, ne11, padded_n, deltas_offset }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids, expert_count_buf }, pc, { m, nei1, n_as }); } @@ -7777,7 +7913,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, disable_split_k, pipeline); - const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); + const uint64_t qx_sz = ggml_vk_repack_size_tensor(src0); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); @@ -7925,6 +8061,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); } + const uint32_t deltas_offset = ggml_vk_get_deltas_offset(src0); + // compute ggml_vk_matmul( ctx, subctx, pipeline, @@ -7932,7 +8070,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub ggml_vk_subbuffer(ctx, d_D, d_buf_offset), { ctx->prealloc_split_k, 0, d_sz * split_k }, ne01, ne11, ne10, ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d, - split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n + split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n, deltas_offset ); // NOLINT if (x_non_contig || qx_needs_dequant) { @@ -8099,7 +8237,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t x_ne = ggml_nelements(src0); const uint64_t y_ne = ggml_nelements(src1); - const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qx_sz = ggml_vk_align_size(ggml_vk_repack_size_tensor(src0), ctx->device->properties.limits.minStorageBufferOffsetAlignment); const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); @@ -8225,6 +8363,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1])); + const uint32_t deltas_offset = ggml_vk_get_deltas_offset(src0); + uint32_t base_work_group_y = 0; while (base_work_group_y < ne12 * ne13) { @@ -8234,6 +8374,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& stride_batch_x, stride_batch_y, stride_batch_d, fusion_flags, base_work_group_y, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, + deltas_offset, }; ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { @@ -8610,7 +8751,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t y_ne = padded_n * ne10 * ne12 * ne13; const uint64_t d_ne = ggml_nelements(dst); - const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); + const uint64_t qx_sz = ggml_vk_repack_size_tensor(src0); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); @@ -8778,6 +8919,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); } + const uint32_t deltas_offset = ggml_vk_get_deltas_offset(src0); + // compute ggml_vk_matmul_id( ctx, subctx, pipeline, @@ -8785,7 +8928,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf, ne01, ne21, ne10, ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, - n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n + n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n, deltas_offset ); // NOLINT if (x_non_contig || qx_needs_dequant) { @@ -8877,7 +9020,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t x_ne = ggml_nelements(src0); const uint64_t y_ne = ggml_nelements(src1); - const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qx_sz = ggml_vk_align_size(ggml_vk_repack_size_tensor(src0), ctx->device->properties.limits.minStorageBufferOffsetAlignment); const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); @@ -9001,13 +9144,16 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1; } + const uint32_t deltas_offset = ggml_vk_get_deltas_offset(src0); + // Loop over the batch dimension for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) { const vk_mat_vec_id_push_constants pc = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21), fusion_flags, - (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1 + (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1, + deltas_offset, }; ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { @@ -12354,7 +12500,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ctx, subctx, p, ggml_vk_subbuffer(ctx, d_X), ggml_vk_subbuffer(ctx, d_Y), ggml_vk_subbuffer(ctx, d_D), ggml_vk_subbuffer(ctx, ctx->prealloc_split_k), m, n, k, k, k, m, k*m, k*n, m*n, - split_k, batch, batch, batch, 1, 1, n + split_k, batch, batch, batch, 1, 1, n, 0 ); } ggml_vk_ctx_end(subctx); @@ -12831,7 +12977,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, m, n, k, k, k, m, k*m, k*n, m*n, - split_k, batch, batch, batch, 1, 1, n + split_k, batch, batch, batch, 1, 1, n, 0 ); } } else { @@ -12840,7 +12986,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, m, n, k, k, k, m, k*m, k*n, m*n, - split_k, batch, batch, batch, 1, 1, n + split_k, batch, batch, batch, 1, 1, n, 0 ); } } @@ -13796,6 +13942,11 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml return; } + if (ggml_vk_get_repack_info(tensor->type)) { + ggml_vk_repack_write(buf, tensor, offset, data, size); + return; + } + ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } @@ -13810,6 +13961,11 @@ static void ggml_backend_vk_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, g return; } + if (ggml_vk_get_repack_info(tensor->type)) { + ggml_vk_repack_write(buf, tensor, offset, data, size); + return; + } + ggml_vk_buffer_write_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_data, stride_tensor, size, n_copies); } @@ -13823,6 +13979,11 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons vk_buffer buf = buf_ctx->dev_buffer; + if (ggml_vk_get_repack_info(tensor->type)) { + ggml_vk_repack_read(buf, tensor, offset, data, size); + return; + } + ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } @@ -13838,6 +13999,11 @@ static void ggml_backend_vk_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, c vk_buffer buf = buf_ctx->dev_buffer; + if (ggml_vk_get_repack_info(tensor->type)) { + ggml_vk_repack_read(buf, tensor, offset, data, size); + return; + } + ggml_vk_buffer_read_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_tensor, stride_data, size, n_copies); } @@ -13916,9 +14082,8 @@ static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_ } static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { - return ggml_nbytes(tensor); - UNUSED(buft); + return ggml_vk_repack_size_tensor(tensor); } ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) { @@ -14043,6 +14208,12 @@ static void ggml_backend_vk_set_tensor_2d_async(ggml_backend_t backend, ggml_ten } ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + if (ggml_vk_get_repack_info(tensor->type)) { + ggml_vk_repack_write(buf, tensor, offset, data, size); + return; + } vk_context cpy_ctx; @@ -14058,8 +14229,6 @@ static void ggml_backend_vk_set_tensor_2d_async(ggml_backend_t backend, ggml_ten cpy_ctx = ggml_vk_get_compute_ctx(ctx); } - vk_buffer buf = buf_ctx->dev_buffer; - auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; bool ret = ggml_vk_buffer_write_2d_async(cpy_ctx, buf, dst_offset, data, stride_data, stride_tensor, size, n_copies); @@ -14112,10 +14281,14 @@ static void ggml_backend_vk_get_tensor_2d_async(ggml_backend_t backend, const gg } ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; - vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); + if (ggml_vk_get_repack_info(tensor->type)) { + ggml_vk_repack_read(buf, tensor, offset, data, size); + return; + } - vk_buffer buf = buf_ctx->dev_buffer; + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; bool ret = ggml_vk_buffer_read_2d_async(compute_ctx, buf, src_offset, data, stride_tensor, stride_data, size, n_copies); @@ -16639,6 +16812,17 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * for (int i = 2; i < GGML_MAX_DIMS; i++) { srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1]; } + } else if (ggml_vk_get_repack_info(srci->type)) { + const auto * info = ggml_vk_get_repack_info(srci->type); + const size_t n_blocks = ggml_vk_get_num_blocks(srci); + const size_t quants_region = ggml_vk_repack_quants_region(info, n_blocks); + const size_t repacked_size = ggml_vk_repack_size(info, n_blocks); + void * data_repacked = ggml_vk_repack_scratch(repacked_size); + ggml_vk_buffer_read(buffer_gpu, offset, data_repacked, repacked_size); + ggml_vk_repack_unpack(info, n_blocks, + data_repacked, (uint8_t *)data_repacked + quants_region, + srci_clone->data); + memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); } else { if (offset + srci_size >= buffer_gpu->size) { srci_size = buffer_gpu->size - offset; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 88d07d2dfd50..ed62091136e1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -23,6 +23,16 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { #endif #if defined(DATA_A_Q4_0) +#if defined(A_TYPE_REPACKED) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_quants[(a_offset + ib) * 16 + iqs]); + return (vec2(vui & 0xF, vui >> 4) - 8.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_quants16[(a_offset + ib) * 8 + iqs/2]); + return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f); +} +#else vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); return (vec2(vui & 0xF, vui >> 4) - 8.0f); @@ -32,8 +42,19 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f); } #endif +#endif #if defined(DATA_A_Q4_1) +#if defined(A_TYPE_REPACKED) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_quants[(a_offset + ib) * 16 + iqs]); + return vec2(vui & 0xF, vui >> 4); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_quants16[(a_offset + ib) * 8 + iqs/2]); + return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12); +} +#else vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); return vec2(vui & 0xF, vui >> 4); @@ -43,6 +64,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12); } #endif +#endif #if defined(DATA_A_Q5_0) vec2 dequantize(uint ib, uint iqs, uint a_offset) { @@ -77,6 +99,16 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { #endif #if defined(DATA_A_Q8_0) +#if defined(A_TYPE_REPACKED) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const i8vec2 v = unpack8(int32_t(data_a_quants16[(a_offset + ib) * 16 + iqs/2])).xy; + return vec2(v); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const i8vec4 v = unpack8(int32_t(data_a_quants32[(a_offset + ib) * 8 + iqs/4])); + return vec4(v); +} +#else vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); } @@ -86,6 +118,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { return vec4(v0.x, v0.y, v1.x, v1.y); } #endif +#endif #if defined(DATA_A_Q1_0) vec2 dequantize(uint ib, uint iqs, uint a_offset) { @@ -428,6 +461,16 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { #endif #if defined(DATA_A_IQ4_NL) +#if defined(A_TYPE_REPACKED) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_quants[(a_offset + ib) * 16 + iqs]); + return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_quants16[(a_offset + ib) * 8 + iqs/2]); + return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]); +} +#else vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]); @@ -437,8 +480,20 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]); } #endif +#endif #if defined(DATA_A_MXFP4) +#if defined(A_TYPE_REPACKED) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_quants[(a_offset + ib) * 16 + iqs]); + return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5; +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_quants16[(a_offset + ib) * 8 + iqs/2]); + return vec4(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[(vui >> 4) & 0xF], + kvalues_mxfp4[(vui >> 8) & 0xF], kvalues_mxfp4[vui >> 12]) * 0.5; +} +#else vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5; @@ -449,6 +504,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { return vec4(v0.x, v0.y, v1.x, v1.y); } #endif +#endif #if defined(DATA_A_NVFP4) vec2 dequantize(uint ib, uint iqs, uint a_offset) { @@ -486,7 +542,11 @@ vec2 get_dm(uint ib, uint a_offset) { #if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) vec2 get_dm(uint ib, uint a_offset) { +#if (defined(DATA_A_Q4_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)) && defined(A_TYPE_REPACKED) + return vec2(float(data_a_deltas[a_offset + p.deltas_offset + ib]), 0); +#else return vec2(float(data_a[a_offset + ib].d), 0); +#endif } #endif @@ -499,7 +559,11 @@ vec2 get_dm(uint ib, uint a_offset) { #if defined(DATA_A_MXFP4) vec2 get_dm(uint ib, uint a_offset) { +#if defined(A_TYPE_REPACKED) + return vec2(e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + a_offset + ib])), 0); +#else return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0); +#endif } #endif @@ -511,8 +575,13 @@ vec2 get_dm(uint ib, uint a_offset) { #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) vec2 get_dm(uint ib, uint a_offset) { +#if defined(DATA_A_Q4_1) && defined(A_TYPE_REPACKED) + return vec2(float(data_a_deltas[p.deltas_offset + (a_offset + ib) * 2]), + float(data_a_deltas[p.deltas_offset + (a_offset + ib) * 2 + 1])); +#else const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm); return dm; +#endif } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index c582aba87dcd..1634aa2653f6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -25,36 +25,65 @@ float16_t dequantFuncQ1_0(const in decodeBufQ1_0 bl, const in uint blockCoords[2 return bit != 0u ? d : -d; } +#ifdef A_TYPE_REPACKED +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_0 { + uint32_t qs[4]; +}; +#else layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { block_q4_0_packed16 block; }; +#endif float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { - const float16_t d = bl.block.d; const uint idx = coordInBlock[1]; +#ifdef A_TYPE_REPACKED + const uint ib = pos_a + blockCoords[0] * (p.stride_a / QUANT_K) + blockCoords[1]; + const float16_t d = data_a_deltas[p.deltas_offset + ib]; + uint32_t qs = bl.qs[(idx & 0xC) >> 2]; const uint shift = (idx & 0x10) >> 2; + qs >>= ((idx & 3) * 8 + shift); +#else + const float16_t d = bl.block.d; uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]); + const uint shift = (idx & 0x10) >> 2; qs >>= shift; qs &= 0x0F0F; qs = unpack8(qs)[idx & 1]; +#endif + qs &= 0xF; float16_t ret = (float16_t(qs) - float16_t(8)) * d; return ret; } +#ifdef A_TYPE_REPACKED +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_1 { + uint32_t qs[4]; +}; +#else layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 { block_q4_1 block; }; +#endif float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { - const float16_t d = bl.block.d; - const float16_t m = bl.block.m; const uint idx = coordInBlock[1]; const uint iqs = idx & 0xF; const uint shift = (idx & 0x10) >> 2; +#ifdef A_TYPE_REPACKED + const uint ib = pos_a + blockCoords[0] * (p.stride_a / QUANT_K) + blockCoords[1]; + const float16_t d = data_a_deltas[p.deltas_offset + ib * 2]; + const float16_t m = data_a_deltas[p.deltas_offset + ib * 2 + 1]; + uint32_t qs = bl.qs[(idx & 0xC) >> 2]; + qs >>= ((iqs & 3) * 8 + shift); +#else + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; uint32_t qs = bl.block.qs[iqs]; qs >>= shift; +#endif qs &= 0xF; float16_t ret = float16_t(qs) * d + m; return ret; @@ -105,18 +134,28 @@ float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2 return ret; } +#ifdef A_TYPE_REPACKED +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ8_0 { + int32_t qs[8]; +}; +#else layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 { block_q8_0_packed16 block; }; +#endif float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { - const float16_t d = bl.block.d; const uint idx = coordInBlock[1]; const uint iqs = idx; - - // Load 16b and select the byte for this element +#ifdef A_TYPE_REPACKED + const uint ib = pos_a + blockCoords[0] * (p.stride_a / QUANT_K) + blockCoords[1]; + const float16_t d = data_a_deltas[p.deltas_offset + ib]; + int32_t qs = unpack8(bl.qs[(iqs & 0x1C) >> 2])[iqs & 3]; +#else + const float16_t d = bl.block.d; int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1]; +#endif float16_t ret = float16_t(qs) * d; return ret; } @@ -660,18 +699,32 @@ float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoor #endif #if defined(DATA_A_IQ4_NL) +#ifdef A_TYPE_REPACKED +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufIQ4_NL { + uint32_t qs[4]; +}; +#else layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL { block_iq4_nl block; }; +#endif float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { - const float16_t d = bl.block.d; const uint idx = coordInBlock[1]; +#ifdef A_TYPE_REPACKED + const uint ib = pos_a + blockCoords[0] * (p.stride_a / QUANT_K) + blockCoords[1]; + const float16_t d = data_a_deltas[p.deltas_offset + ib]; + uint32_t qs = bl.qs[(idx & 0xC) >> 2]; + const uint shift = (idx & 0x10) >> 2; + qs >>= ((idx & 3) * 8 + shift); +#else + const float16_t d = bl.block.d; const uint iqs = idx & 0xF; const uint shift = (idx & 0x10) >> 2; uint32_t qs = bl.block.qs[iqs]; qs >>= shift; +#endif qs &= 0xF; float16_t ret = float16_t(kvalues_iq4nl[qs]) * d; return ret; @@ -679,18 +732,31 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor #endif #if defined(DATA_A_MXFP4) +#ifdef A_TYPE_REPACKED +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufMXFP4 { + uint32_t qs[4]; +}; +#else layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 { block_mxfp4 block; }; +#endif float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { - const float d = e8m0_to_fp32(bl.block.e); const uint idx = coordInBlock[1]; const uint iqs = idx & 0xF; const uint shift = (idx & 0x10) >> 2; +#ifdef A_TYPE_REPACKED + const uint ib = pos_a + blockCoords[0] * (p.stride_a / QUANT_K) + blockCoords[1]; + const float d = e8m0_to_fp32(data_a_scales[p.deltas_offset + ib]); + uint32_t qs = bl.qs[(iqs & 0xC) >> 2]; + qs >>= ((iqs & 3) * 8 + shift); +#else + const float d = e8m0_to_fp32(bl.block.e); uint32_t qs = bl.block.qs[iqs]; qs >>= shift; +#endif qs &= 0xF; float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5); return ret; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index 4aeda68c7f2d..8addb350d809 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -38,6 +38,8 @@ layout (push_constant) uniform parameter uint broadcast2; uint broadcast3; #endif + + uint deltas_offset; } p; #ifdef MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl index e8d053cdd432..64976051b4c6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl @@ -15,6 +15,12 @@ layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16 #if defined(A_TYPE_PACKED32) layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif +#if defined(A_TYPE_REPACKED) +layout (binding = 0) readonly buffer A_QUANTS {uint8_t data_a_quants[];}; +layout (binding = 0) readonly buffer A_QUANTS16 {uint16_t data_a_quants16[];}; +layout (binding = 0) readonly buffer A_QUANTS32 {uint32_t data_a_quants32[];}; +layout (binding = 0) readonly buffer A_DELTAS {float16_t data_a_deltas[];}; +#endif layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; #ifdef B_TYPEV2 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index bc580aeeb834..1001b35ef9a1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -6,19 +6,32 @@ #if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) FLOAT_TYPE get_dm(uint ib) { +#if (defined(DATA_A_Q4_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)) && defined(A_TYPE_REPACKED) + return FLOAT_TYPE(data_a_deltas[p.deltas_offset + ib]); +#else return FLOAT_TYPE(data_a[ib].d); +#endif } #endif #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) FLOAT_TYPEV2 get_dm(uint ib) { +#if defined(DATA_A_Q4_1) && defined(A_TYPE_REPACKED) + return FLOAT_TYPEV2(data_a_deltas[p.deltas_offset + ib * 2], + data_a_deltas[p.deltas_offset + ib * 2 + 1]); +#else return FLOAT_TYPEV2(data_a_packed32[ib].dm); +#endif } #endif #if defined(DATA_A_MXFP4) FLOAT_TYPE get_dm(uint ib) { +#if defined(A_TYPE_REPACKED) + return FLOAT_TYPE(e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + ib]))); +#else return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); +#endif } #endif @@ -33,9 +46,13 @@ FLOAT_TYPEV2 get_dm(uint ib) { #if defined(DATA_A_Q4_0) // 2-byte loads for Q4_0 blocks (18 bytes) i32vec2 repack(uint ib, uint iqs) { +#if defined(DATA_A_Q4_0) && defined(A_TYPE_REPACKED) + const uint32_t vui = data_a_quants32[ib * 4 + iqs]; +#else const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], data_a_packed16[ib].qs[iqs * 2 + 1]); const uint32_t vui = pack32(quants); +#endif return i32vec2( vui & 0x0F0F0F0F, (vui >> 4) & 0x0F0F0F0F); } @@ -48,7 +65,11 @@ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const i #if defined(DATA_A_Q4_1) // 4-byte loads for Q4_1 blocks (20 bytes) i32vec2 repack(uint ib, uint iqs) { +#if defined(A_TYPE_REPACKED) + const uint32_t vui = data_a_quants32[ib * 4 + iqs]; +#else const uint32_t vui = data_a_packed32[ib].qs[iqs]; +#endif return i32vec2( vui & 0x0F0F0F0F, (vui >> 4) & 0x0F0F0F0F); } @@ -103,8 +124,12 @@ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const i #if defined(DATA_A_Q8_0) // 2-byte loads for Q8_0 blocks (34 bytes) int32_t repack(uint ib, uint iqs) { +#if defined(A_TYPE_REPACKED) + return int32_t(data_a_quants32[ib * 8 + iqs]); +#else return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ], data_a_packed16[ib].qs[iqs * 2 + 1])); +#endif } FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { @@ -115,10 +140,14 @@ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const i #if defined(DATA_A_MXFP4) // 1-byte loads for mxfp4 blocks (17 bytes) i32vec2 repack(uint ib, uint iqs) { +#if defined(A_TYPE_REPACKED) + const uint32_t qs = data_a_quants32[ib * 4 + iqs]; +#else const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], data_a[ib].qs[iqs * 4 + 1], data_a[ib].qs[iqs * 4 + 2], data_a[ib].qs[iqs * 4 + 3])); +#endif const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 89346e48e061..db464785ac66 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -62,6 +62,12 @@ layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16 #if defined(A_TYPE_PACKED32) layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif +#if defined(A_TYPE_REPACKED) +layout (binding = 0) readonly buffer A_QUANTS {uint8_t data_a_quants[];}; +layout (binding = 0) readonly buffer A_QUANTS16 {uint16_t data_a_quants16[];}; +layout (binding = 0) readonly buffer A_QUANTS32 {uint32_t data_a_quants32[];}; +layout (binding = 0) readonly buffer A_DELTAS {float16_t data_a_deltas[];}; +#endif layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; @@ -98,6 +104,9 @@ layout (push_constant) uniform parameter uint broadcast2; uint broadcast3; #endif + + uint padded_N; + uint deltas_offset; } p; layout (constant_id = 0) const uint BLOCK_SIZE = 64; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 497a18ff8a7c..c18636cdf7af 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -63,13 +63,27 @@ layout (push_constant) uniform parameter #endif // N dimension for the B matrix can be >= p.N uint padded_N; + uint deltas_offset; } p; +#ifdef A_TYPE_REPACKED +#if defined(DATA_A_Q8_0) +struct block_repacked_quants { uint16_t qs[16]; }; +#else +struct block_repacked_quants { uint16_t qs[8]; }; +#endif +layout (binding = 0) readonly buffer A {block_repacked_quants data_a[];}; +layout (binding = 0) readonly buffer A_DELTAS {float16_t data_a_deltas[];}; +layout (binding = 0) readonly buffer A_SCALES {uint8_t data_a_scales[];}; +#else layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#endif layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +uint pos_a; + #if QUANT_K > 1 #define DECODEFUNCA , dequantFuncA @@ -254,10 +268,10 @@ void main() { #endif #ifdef MUL_MAT_ID - uint pos_a = expert_idx * (p.batch_stride_a / QUANT_K); + pos_a = expert_idx * (p.batch_stride_a / QUANT_K); uint pos_b = 0; #else - uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K); + pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K); uint pos_b = batch_idx * p.batch_stride_b; uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches; #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 73595168984c..9e14d8fc2970 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -52,8 +52,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint ib = idx / 4; const uint iqs = idx & 0x03; +#if defined(A_TYPE_REPACKED) + const float d = float(data_a_deltas[p.deltas_offset + ib]); + const uint vui = data_a_quants32[ib * 4 + iqs]; +#else const float d = float(data_a_packed16[ib].d); const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); +#endif const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; @@ -68,8 +73,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint ib = idx / 4; const uint iqs = idx & 0x03; +#if defined(A_TYPE_REPACKED) + const vec2 dm = vec2(data_a_deltas[p.deltas_offset + ib * 2], + data_a_deltas[p.deltas_offset + ib * 2 + 1]); + const uint vui = data_a_quants32[ib * 4 + iqs]; +#else const vec2 dm = vec2(data_a_packed32[ib].dm); const uint vui = data_a_packed32[ib].qs[iqs]; +#endif const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y; const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y; @@ -123,10 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint ib = idx / 8; const uint iqs = idx & 0x07; +#if defined(A_TYPE_REPACKED) + const float d = float(data_a_deltas[p.deltas_offset + ib]); + const vec4 v = vec4(unpack8(int32_t(data_a_quants32[ib * 8 + iqs]))) * d; +#else const float d = float(data_a_packed16[ib].d); const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147 const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; +#endif buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); @@ -481,8 +497,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint ib = idx / 8; const uint iqs = idx & 0x07; +#if defined(A_TYPE_REPACKED) + const FLOAT_TYPE d = FLOAT_TYPE(data_a_deltas[p.deltas_offset + ib]); + const uint vui = uint(data_a_quants16[ib * 8 + iqs]); +#else const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); const uint vui = uint(data_a_packed16[ib].qs[iqs]); +#endif buf_a[buf_idx ] = d * FLOAT_TYPEV2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]); @@ -495,9 +516,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint ib = idx / 8; const uint iqs = (idx & 0x07) * 2; +#if defined(A_TYPE_REPACKED) + const float d = e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + ib])) * 0.5; + const uint vui16 = uint(data_a_quants16[ib * 8 + iqs/2]); + const uint vui = vui16 & 0xFF; + const uint vui2 = vui16 >> 8; +#else const float d = e8m0_to_fp32(data_a[ib].e) * 0.5; const uint vui = uint(data_a[ib].qs[iqs]); const uint vui2 = uint(data_a[ib].qs[iqs+1]); +#endif buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d, kvalues_mxfp4[vui2 & 0xF] * d); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index aae1c2e8ae9f..c47f669c0f61 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -30,6 +30,13 @@ layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16 #if defined(A_TYPE_PACKED32) layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif +#if defined(A_TYPE_REPACKED) +layout (binding = 0) readonly buffer A_QUANTS {uint8_t data_a_quants[];}; +layout (binding = 0) readonly buffer A_QUANTS16 {uint16_t data_a_quants16[];}; +layout (binding = 0) readonly buffer A_QUANTS32 {uint32_t data_a_quants32[];}; +layout (binding = 0) readonly buffer A_DELTAS {float16_t data_a_deltas[];}; +#endif + layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; @@ -65,6 +72,9 @@ layout (push_constant) uniform parameter uint broadcast2; uint broadcast3; #endif + + uint padded_N; + uint deltas_offset; } p; layout (constant_id = 0) const uint BLOCK_SIZE = 64; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index 59931b04b941..dcfd51519b3f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -11,19 +11,36 @@ // 4-byte loads for Q4_1 blocks (20 bytes) void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { #ifdef DATA_A_Q4_0 +#if defined(A_TYPE_REPACKED) + buf_a[buf_ib].qs[iqs] = data_a_quants32[ib * 4 + iqs]; +#else buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], data_a_packed16[ib].qs[iqs * 2 + 1])); +#endif if (iqs == 0) { +#if defined(A_TYPE_REPACKED) + buf_a[buf_ib].dm = FLOAT_TYPE(data_a_deltas[p.deltas_offset + ib]); +#else buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); +#endif } #else // DATA_A_Q4_1 +#if defined(A_TYPE_REPACKED) + buf_a[buf_ib].qs[iqs] = data_a_quants32[ib * 4 + iqs]; + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_deltas[p.deltas_offset + ib * 2], + data_a_deltas[p.deltas_offset + ib * 2 + 1]); + } +#else buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; if (iqs == 0) { buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm); } #endif +#endif } void block_a_to_registers(const uint reg_ib, const uint buf_ib) { @@ -115,12 +132,20 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { #if defined(DATA_A_Q8_0) // 2-byte loads for Q8_0 blocks (34 bytes) void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { +#if defined(A_TYPE_REPACKED) + buf_a[buf_ib].qs[iqs] = int32_t(data_a_quants32[ib * 8 + iqs]); + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE(data_a_deltas[p.deltas_offset + ib]); + } +#else buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2], data_a_packed16[ib].qs[iqs * 2 + 1])); if (iqs == 0) { buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); } +#endif } void block_a_to_registers(const uint reg_ib, const uint buf_ib) { @@ -147,10 +172,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { #if defined(DATA_A_MXFP4) // 1-byte loads for mxfp4 blocks (17 bytes) void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { +#if defined(A_TYPE_REPACKED) + const uint32_t qs = data_a_quants32[ib * 4 + iqs]; +#else const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], data_a[ib].qs[iqs * 4 + 1], data_a[ib].qs[iqs * 4 + 2], data_a[ib].qs[iqs * 4 + 3])); +#endif const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); @@ -159,7 +188,11 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])); if (iqs == 0) { +#if defined(A_TYPE_REPACKED) + buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(uint8_t(data_a_quants[p.deltas_offset + ib])) * 0.5); +#else buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5); +#endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index d99b2b5d802a..379c683da1bd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -564,6 +564,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c continue; } + std::map mm_base_dict = base_dict; + if (tname == "q4_0" || tname == "q4_1" || tname == "q8_0" || tname == "iq4_nl" || tname == "mxfp4") { + mm_base_dict["A_TYPE_REPACKED"] = "1"; + } + std::string data_a_key = "DATA_A_" + to_uppercase(tname); // For unaligned, load one at a time for f32/f16, or two at a time for quants std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant; @@ -579,19 +584,19 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(mm_base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(mm_base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(mm_base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(mm_base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) // Integer dot mmq performs better with f32 accumulators if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) { - string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(mm_base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); } #endif } @@ -665,33 +670,38 @@ void process_shaders() { std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}}; for (const auto& tname : type_names) { + std::map mmv_base_dict = base_dict; + if (tname == "q4_0" || tname == "q4_1" || tname == "q8_0" || tname == "iq4_nl" || tname == "mxfp4") { + mmv_base_dict["A_TYPE_REPACKED"] = "1"; + } + // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; - string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(mmv_base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(mmv_base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(mmv_base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); // mul mat vec with integer dot product #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") { - string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(mmv_base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(mmv_base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(mmv_base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(mmv_base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); } #endif