Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 53 additions & 27 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,7 @@ struct vk_mat_mat_push_constants {
uint32_t M; uint32_t N; uint32_t K;
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
uint32_t base_work_group_z; uint32_t num_batches;
uint32_t k_split;
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t padded_N;
Expand All @@ -961,6 +962,7 @@ struct vk_mat_vec_push_constants {
uint32_t batch_stride_b;
uint32_t batch_stride_d;
uint32_t fusion_flags;
uint32_t base_work_group_y;
uint32_t ne02;
uint32_t ne12;
uint32_t broadcast2;
Expand Down Expand Up @@ -6766,8 +6768,16 @@ static void ggml_vk_matmul(
uint32_t padded_n) {
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
if (split_k == 1) {
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));

uint32_t base_work_group_z = 0;
while (base_work_group_z < batch) {
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);

const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
base_work_group_z += groups_z;
}
return;
}

Expand All @@ -6781,9 +6791,17 @@ static void ggml_vk_matmul(
uint32_t k_split = CEIL_DIV(k, split_k);
k_split = ROUNDUP_POW2(k_split, 256);

const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));

uint32_t base_work_group_z = 0;
while (base_work_group_z < batch) {
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);

const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
// Make sure enough workgroups get assigned for split k to work
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
base_work_group_z += groups_z;
}
ggml_vk_sync_buffers(ctx, subctx);
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
Expand Down Expand Up @@ -7179,7 +7197,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
}

// Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
if (qx_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
}
Expand Down Expand Up @@ -7477,7 +7494,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
}

vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
Expand Down Expand Up @@ -7572,22 +7588,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
}

// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
stride_batch_x, stride_batch_y, stride_batch_d,
fusion_flags,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
d_X,
d_Y,
d_D,
d_F0,
d_F1,
},
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));

uint32_t base_work_group_y = 0;
while (base_work_group_y < ne12 * ne13) {

uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
stride_batch_x, stride_batch_y, stride_batch_d,
fusion_flags, base_work_group_y,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
d_X,
d_Y,
d_D,
d_F0,
d_F1,
},
pc, { groups_x, groups_y, groups_z });
base_work_group_y += groups_y;
}

if (x_non_contig) {
ctx->prealloc_x_need_sync = true;
Expand Down Expand Up @@ -7825,10 +7848,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
src1->nb[2] <= src1->nb[1] &&
src1->nb[1] <= src1->nb[3] &&
src0->ne[3] == 1 &&
src1->ne[3] == 1) {
src1->ne[3] == 1 &&
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
!ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&
src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
// when ne12 and ne13 are one.
Expand Down Expand Up @@ -11543,7 +11571,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
}
}

ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);

Expand Down Expand Up @@ -12052,7 +12079,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
// y[i] = i % k;
}

ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);

Expand Down
5 changes: 3 additions & 2 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ layout (push_constant) uniform parameter
uint expert_i1;
uint nbi1;
#else
uint base_work_group_y;
uint ne02;
uint ne12;
uint broadcast2;
Expand All @@ -45,9 +46,9 @@ uint expert_id;

void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
#ifdef MUL_MAT_ID
const uint expert_i0 = gl_GlobalInvocationID.y;
const uint expert_i0 = gl_WorkGroupID.y;
#else
const uint batch_idx = gl_GlobalInvocationID.y;
const uint batch_idx = gl_WorkGroupID.y + p.base_work_group_y;
#endif

#ifndef MUL_MAT_ID
Expand Down
8 changes: 5 additions & 3 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ layout (push_constant) uniform parameter
uint nbi1;
uint ne11;
#else
uint base_work_group_z;
uint num_batches;
uint k_split;
uint ne02;
uint ne12;
Expand Down Expand Up @@ -139,7 +141,7 @@ void main() {
const uint ic = gl_WorkGroupID.y;

#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
const uint expert_idx = gl_WorkGroupID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
Expand All @@ -149,7 +151,7 @@ void main() {
#endif

#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;

const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
Expand Down Expand Up @@ -366,7 +368,7 @@ void main() {
const uint dc = ic * BN + warp_c * WN;

#ifndef MUL_MAT_ID
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
#endif

#ifdef COOPMAT
Expand Down
8 changes: 5 additions & 3 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ layout (push_constant) uniform parameter
uint nbi1;
uint ne11;
#else
uint base_work_group_z;
uint num_batches;
uint k_split;
uint ne02;
uint ne12;
Expand Down Expand Up @@ -197,7 +199,7 @@ void main() {
const uint ic = gl_WorkGroupID.y;

#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
const uint expert_idx = gl_WorkGroupID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
Expand All @@ -215,7 +217,7 @@ void main() {
#endif

#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;

const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
Expand Down Expand Up @@ -255,7 +257,7 @@ void main() {
#else
uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
uint pos_b = batch_idx * p.batch_stride_b;
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
#endif

uint stride_a = p.stride_a / QUANT_K;
Expand Down
Loading