Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2142,6 +2142,26 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
return use_mul_mat_vec_f;
}

static bool ggml_cuda_should_use_mmvq(ggml_type type, int cc, int64_t ncols_dst) {
if (ncols_dst > MMVQ_MAX_BATCH_SIZE) {
return false;
}

if (GGML_CUDA_CC_IS_RDNA4(cc)) {
switch (type) {
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
return ncols_dst <= 4;
default:
break;
}
}

return true;
}

static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
ggml_tensor * src0 = tensor->src[0];
ggml_tensor * src1 = tensor->src[1];
Expand All @@ -2151,11 +2171,11 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
src0->view_src;

const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
dst->type == GGML_TYPE_F32 && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]);

// fusion is not universally faster on Pascal
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
if (cc <= GGML_CUDA_CC_PASCAL) {
return false;
}
Expand Down Expand Up @@ -2212,6 +2232,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor

const int cc = ggml_cuda_info().devices[id].cc;
const int warp_size = ggml_cuda_info().devices[id].warp_size;
use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]);
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
Expand All @@ -2220,6 +2241,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
} else {
const int cc = ggml_cuda_info().devices[ctx.device].cc;
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]);
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
Expand Down
35 changes: 33 additions & 2 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,39 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
}
}

// For RDNA4 MMQ is consistently faster than dequantization + hipBLAS:
// https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301
if (GGML_CUDA_CC_IS_RDNA4(cc)){
switch (type) {
case GGML_TYPE_IQ2_S:
case GGML_TYPE_Q6_K:
return ne11 <= 128;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_MXFP4:
return true;
case GGML_TYPE_Q5_K:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_Q4_K:
return ne11 <= 256;
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
return ne11 <= 512;

default:
return false;

}

return false;
}
return true;
}

Expand Down
Loading