From cfbbfb25137e7338e3a5d907f42e3a33d071d5c8 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sun, 15 Mar 2026 21:13:44 +0530 Subject: [PATCH 1/3] Increase per-thread work if the K-dimension is small With tensor parallelism, the K-dimension of the FFN-down matrices is split, which makes it quite small, especially for MOEs. For example, Qwen3-30b-A3B has a K-dimension of 768, and Qwen3235B-A22B has k-dimension of 1536. The current heuristic uses a group of 4 warps irrespective of K-dimension size, resulting in some of the threads being idle. This results in poor performance for these matrices. This change increases the number of output elements per block for such cases. --- ggml/src/ggml-cuda/mmvq.cu | 141 ++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 74 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 632246e43..49b850665 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -33,7 +33,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) } } -static constexpr __device__ int get_vdr_mmvq(ggml_type type) { +static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; @@ -173,7 +173,10 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d return 1; } -static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { +static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) { + if (small_k) { + return nwarps; + } if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { switch (ncols_dst) { case 1: @@ -193,7 +196,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -template +template __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, @@ -208,7 +211,7 @@ static __global__ void mul_mat_vec_q( constexpr int vdr = get_vdr_mmvq(type); constexpr mmvq_parameter_table_id table_id = get_device_table_id(); constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id); - constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); @@ -414,14 +417,16 @@ static __global__ void mul_mat_vec_q( template static std::pair calc_launch_params( const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, - const int warp_size, const mmvq_parameter_table_id table_id) { - const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); + const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) { + const int nwarps = calc_nwarps(type, ncols_dst, table_id); + const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); + const int64_t nblocks = (nrows_x + rpb - 1) / rpb; const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); - const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1); + const dim3 block_dims(warp_size, nwarps, 1); return {block_nums, block_dims}; } -template +template static void mul_mat_vec_q_switch_fusion( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, @@ -434,7 +439,7 @@ static void mul_mat_vec_q_switch_fusion( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -444,7 +449,7 @@ static void mul_mat_vec_q_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -474,9 +479,27 @@ static void mul_mat_vec_q_switch_ncols_dst( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const bool has_ids = ids != nullptr; + // When K is small, increase rows_per_block to match nwarps so each warp has more work to do + // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_iter_1warp = vdr * warp_size / qi; + if (has_ids && ncols_dst > 1) { // Multi-token MUL_MAT_ID path only - single-token goes through regular path below constexpr int c_ncols_dst = 1; + const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); + const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; + if (use_small_k) { + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id, true); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + return; + } std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, @@ -485,76 +508,46 @@ static void mul_mat_vec_q_switch_ncols_dst( return; } +#define MMVQ_SWITCH_CASE(C_NCOLS_DST) \ + { \ + constexpr int c_ncols_dst = C_NCOLS_DST; \ + const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); \ + const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; \ + if (use_small_k) { \ + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, \ + warp_size, table_id, true); \ + mul_mat_vec_q_switch_fusion( \ + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, \ + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, \ + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, \ + dims.first, dims.second, 0, ids_stride, stream); \ + } else { \ + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, \ + warp_size, table_id); \ + mul_mat_vec_q_switch_fusion( \ + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, \ + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, \ + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, \ + dims.first, dims.second, 0, ids_stride, stream); \ + } \ + } + switch (ncols_dst) { - case 1: { - constexpr int c_ncols_dst = 1; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); - } break; - case 2: { - constexpr int c_ncols_dst = 2; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); - } break; - case 3: { - constexpr int c_ncols_dst = 3; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); - } break; - case 4: { - constexpr int c_ncols_dst = 4; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); - } break; - case 5: { - constexpr int c_ncols_dst = 5; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); - } break; - case 6: { - constexpr int c_ncols_dst = 6; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); - } break; - case 7: { - constexpr int c_ncols_dst = 7; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); - } break; - case 8: { - constexpr int c_ncols_dst = 8; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); - } break; + case 1: MMVQ_SWITCH_CASE(1) break; + case 2: MMVQ_SWITCH_CASE(2) break; + case 3: MMVQ_SWITCH_CASE(3) break; + case 4: MMVQ_SWITCH_CASE(4) break; + case 5: MMVQ_SWITCH_CASE(5) break; + case 6: MMVQ_SWITCH_CASE(6) break; + case 7: MMVQ_SWITCH_CASE(7) break; + case 8: MMVQ_SWITCH_CASE(8) break; default: GGML_ABORT("fatal error"); break; } +#undef MMVQ_SWITCH_CASE + GGML_UNUSED(has_fusion); } static void mul_mat_vec_q_switch_type( From 6374ae0ea62a84864ee59ef4e0a4d4e3559e6d4f Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Wed, 18 Mar 2026 18:33:11 +0530 Subject: [PATCH 2/3] Limit this change to ncols_dst = 1 --- ggml/src/ggml-cuda/mmvq.cu | 143 ++++++++++++++++++++++--------------- 1 file changed, 87 insertions(+), 56 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 49b850665..dec460c12 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -174,13 +174,10 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d } static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) { - if (small_k) { - return nwarps; - } if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { switch (ncols_dst) { case 1: - return 1; + return small_k ? nwarps : 1; case 2: case 3: case 4: @@ -479,27 +476,9 @@ static void mul_mat_vec_q_switch_ncols_dst( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const bool has_ids = ids != nullptr; - // When K is small, increase rows_per_block to match nwarps so each warp has more work to do - // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int qi = ggml_cuda_type_traits::qi; - constexpr int vdr = get_vdr_mmvq(type); - const int blocks_per_row_x = ncols_x / qk; - const int blocks_per_iter_1warp = vdr * warp_size / qi; - if (has_ids && ncols_dst > 1) { // Multi-token MUL_MAT_ID path only - single-token goes through regular path below constexpr int c_ncols_dst = 1; - const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); - const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; - if (use_small_k) { - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id, true); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); - return; - } std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, @@ -508,46 +487,98 @@ static void mul_mat_vec_q_switch_ncols_dst( return; } -#define MMVQ_SWITCH_CASE(C_NCOLS_DST) \ - { \ - constexpr int c_ncols_dst = C_NCOLS_DST; \ - const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); \ - const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; \ - if (use_small_k) { \ - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, \ - warp_size, table_id, true); \ - mul_mat_vec_q_switch_fusion( \ - vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, \ - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, \ - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, \ - dims.first, dims.second, 0, ids_stride, stream); \ - } else { \ - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, \ - warp_size, table_id); \ - mul_mat_vec_q_switch_fusion( \ - vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, \ - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, \ - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, \ - dims.first, dims.second, 0, ids_stride, stream); \ - } \ - } - switch (ncols_dst) { - case 1: MMVQ_SWITCH_CASE(1) break; - case 2: MMVQ_SWITCH_CASE(2) break; - case 3: MMVQ_SWITCH_CASE(3) break; - case 4: MMVQ_SWITCH_CASE(4) break; - case 5: MMVQ_SWITCH_CASE(5) break; - case 6: MMVQ_SWITCH_CASE(6) break; - case 7: MMVQ_SWITCH_CASE(7) break; - case 8: MMVQ_SWITCH_CASE(8) break; + case 1: { + constexpr int c_ncols_dst = 1; + + // When K is small, increase rows_per_block to match nwarps so each warp has more work to do + // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_iter_1warp = vdr * warp_size / qi; + const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); + const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; + if (use_small_k) { + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, + warp_size, table_id, true); + mul_mat_vec_q_switch_fusion( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } else { + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, + warp_size, table_id); + mul_mat_vec_q_switch_fusion( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } + } break; + case 2: { + constexpr int c_ncols_dst = 2; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 3: { + constexpr int c_ncols_dst = 3; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 4: { + constexpr int c_ncols_dst = 4; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 5: { + constexpr int c_ncols_dst = 5; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 6: { + constexpr int c_ncols_dst = 6; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 7: { + constexpr int c_ncols_dst = 7; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 8: { + constexpr int c_ncols_dst = 8; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; default: GGML_ABORT("fatal error"); break; } -#undef MMVQ_SWITCH_CASE - GGML_UNUSED(has_fusion); } static void mul_mat_vec_q_switch_type( From fd9e3348263fd9f0c3ed2c606021c713e23a607f Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Wed, 18 Mar 2026 19:04:13 +0530 Subject: [PATCH 3/3] tab to space --- ggml/src/ggml-cuda/mmvq.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index dec460c12..024b3d8cf 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -492,13 +492,13 @@ static void mul_mat_vec_q_switch_ncols_dst( constexpr int c_ncols_dst = 1; // When K is small, increase rows_per_block to match nwarps so each warp has more work to do - // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int qi = ggml_cuda_type_traits::qi; - constexpr int vdr = get_vdr_mmvq(type); - const int blocks_per_row_x = ncols_x / qk; - const int blocks_per_iter_1warp = vdr * warp_size / qi; - const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); + // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_iter_1warp = vdr * warp_size / qi; + const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; if (use_small_k) { std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,