Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ layout (push_constant) uniform parameter
uint nbi1;
uint ne11;
#else
uint base_work_group_z;
uint num_batches;
uint k_split;
uint ne02;
uint ne12;
Expand Down Expand Up @@ -108,7 +110,7 @@ void main() {
const uint ic = gl_WorkGroupID.y;

#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
const uint expert_idx = gl_WorkGroupID.z;
if (ic * BN >= data_expert_count[expert_idx]) {
return;
}
Expand All @@ -118,7 +120,7 @@ void main() {
#endif

#ifndef MUL_MAT_ID
const uint batch_idx = gl_GlobalInvocationID.z;
const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;

const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
Expand Down Expand Up @@ -276,7 +278,7 @@ void main() {
const uint dc = ic * BN + warp_c * WN;

#ifndef MUL_MAT_ID
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
#endif

[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
Expand Down
Loading