Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -682,12 +682,16 @@ static __global__ void mul_mat_vec_q(
template <ggml_type type, int c_rows_per_block>
__launch_bounds__(get_mmvq_mmid_max_batch_for_device<type>()*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q_moe(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids,
float * __restrict__ dst,
const void * vx_ptr, const void * vy_ptr, const int32_t * ids_ptr,
float * dst_ptr,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x,
const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst,
const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst,
const uint32_t ncols_dst, const uint32_t ids_stride) {
const void * GGML_CUDA_RESTRICT vx = vx_ptr;
const void * GGML_CUDA_RESTRICT vy = vy_ptr;
const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr;
float * GGML_CUDA_RESTRICT dst = dst_ptr;

constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
Expand All @@ -707,6 +711,7 @@ static __global__ void mul_mat_vec_q_moe(
return;
}

ggml_cuda_pdl_sync();
const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride];
const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y);

Expand All @@ -726,6 +731,8 @@ static __global__ void mul_mat_vec_q_moe(
}
}

ggml_cuda_pdl_lc();

// Warp-level reduction only - no shared memory needed
#pragma unroll
for (int i = 0; i < c_rows_per_block; ++i) {
Expand Down Expand Up @@ -794,8 +801,9 @@ static void mul_mat_vec_q_moe_launch(
const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block;
const dim3 block_nums(nblocks_rows, nchannels_dst);
const dim3 block_dims(warp_size, ncols_dst);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);

mul_mat_vec_q_moe<type, rows_per_block><<<block_nums, block_dims, 0, stream>>>(
ggml_cuda_kernel_launch(mul_mat_vec_q_moe<type, rows_per_block>, launch_params,
vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x,
stride_row_x, stride_col_y, stride_col_dst,
stride_channel_x, stride_channel_y, stride_channel_dst,
Expand Down
Loading