diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 3f01ff5bfb08..06953627994b 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -118,8 +118,11 @@ void ggml_cuda_mul_mat_q( const int64_t s03 = src0->nb[3] / ts_src0; const int64_t s3 = dst->nb[3] / ts_dst; - const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) - || GGML_CUDA_CC_IS_CDNA(cc); + // Stream-k decomposition is the wrong schedule for MoE expert matmuls + // (per-expert N is small, fixup overhead dominates). Disable when ids != nullptr. + const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || GGML_CUDA_CC_IS_CDNA(cc)) + && (ids == nullptr); // TODO: tighter pool buffer size vs q8 path const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4; diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index b1a319de9be9..b47362adcfec 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3463,7 +3463,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 -template +template #if defined(GGML_USE_HIP) #if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) @@ -3514,9 +3514,7 @@ static __global__ void mul_mat_q( } __syncthreads(); - // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: -#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA - { + if constexpr (!use_stream_k) { const int wt = blockIdx.z / nchannels_y; const int zt = blockIdx.z - wt*nchannels_y; const int jt = blockIdx.y; @@ -3569,7 +3567,6 @@ static __global__ void mul_mat_q( tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); return; } -#endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA constexpr int ITER_K = get_iter_k(type); @@ -3909,8 +3906,10 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps); - CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); - CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x; @@ -3925,7 +3924,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a if (!args.use_stream_k) { if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - mul_mat_q<<>> + constexpr bool use_stream_k = false; + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3933,7 +3933,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a args.ncols_max); } else { constexpr bool need_check = true; - mul_mat_q<<>> + constexpr bool use_stream_k = false; + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3954,7 +3955,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - mul_mat_q<<>> + constexpr bool use_stream_k = true; + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3971,7 +3973,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a args.ncols_max); } else { constexpr bool need_check = true; - mul_mat_q<<>> + constexpr bool use_stream_k = true; + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -4110,4 +4113,3 @@ void ggml_cuda_op_mul_mat_q( const int64_t src1_padded_row_size, cudaStream_t stream); bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts); -