Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
}
}

static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
Expand Down Expand Up @@ -173,11 +173,11 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d
return 1;
}

static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) {
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
switch (ncols_dst) {
case 1:
return 1;
return small_k ? nwarps : 1;
case 2:
case 3:
case 4:
Expand All @@ -193,7 +193,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
return 1;
}

template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false, bool small_k = false>
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
Expand All @@ -208,7 +208,7 @@ static __global__ void mul_mat_vec_q(
constexpr int vdr = get_vdr_mmvq(type);
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();

constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
Expand Down Expand Up @@ -414,14 +414,16 @@ static __global__ void mul_mat_vec_q(
template<ggml_type type>
static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
const int warp_size, const mmvq_parameter_table_id table_id) {
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) {
const int nwarps = calc_nwarps(type, ncols_dst, table_id);
const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
const int64_t nblocks = (nrows_x + rpb - 1) / rpb;
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1);
const dim3 block_dims(warp_size, nwarps, 1);
return {block_nums, block_dims};
}

template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false>
template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false, bool small_k = false>
static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
Expand All @@ -434,7 +436,7 @@ static void mul_mat_vec_q_switch_fusion(
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
Expand All @@ -444,7 +446,7 @@ static void mul_mat_vec_q_switch_fusion(

GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");

mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
Expand Down Expand Up @@ -488,11 +490,33 @@ static void mul_mat_vec_q_switch_ncols_dst(
switch (ncols_dst) {
case 1: {
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);

// When K is small, increase rows_per_block to match nwarps so each warp has more work to do
// Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_iter_1warp = vdr * warp_size / qi;
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
if (use_small_k) {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
warp_size, table_id, true);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, false, true>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
} else {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
}
} break;
case 2: {
constexpr int c_ncols_dst = 2;
Expand Down
Loading