@@ -49,6 +49,7 @@ static bool g_sycl_loaded = false;
4949int g_ggml_sycl_debug = 0 ;
5050int g_ggml_sycl_disable_optimize = 0 ;
5151int g_ggml_sycl_disable_graph = 0 ;
52+ int g_ggml_sycl_prioritize_dmmv = 0 ;
5253
5354static ggml_sycl_device_info ggml_sycl_init () {
5455 ggml_sycl_device_info info = {};
@@ -195,11 +196,13 @@ static void ggml_check_sycl() try {
195196 g_ggml_sycl_debug = get_sycl_env (" GGML_SYCL_DEBUG" , 0 );
196197 g_ggml_sycl_disable_optimize= get_sycl_env (" GGML_SYCL_DISABLE_OPT" , 1 );
197198 g_ggml_sycl_disable_graph = get_sycl_env (" GGML_SYCL_DISABLE_GRAPH" , 1 );
199+ g_ggml_sycl_prioritize_dmmv = get_sycl_env (" GGML_SYCL_PRIORITIZE_DMMV" , 0 );
198200 GGML_SYCL_DEBUG (" [SYCL] call ggml_check_sycl\n " );
199201 GGML_LOG_INFO (" Running with Environment Variables:\n " );
200202 GGML_LOG_INFO (" GGML_SYCL_DEBUG: %d\n " , g_ggml_sycl_debug);
201203 GGML_LOG_INFO (" GGML_SYCL_DISABLE_OPT: %d\n " , g_ggml_sycl_disable_optimize);
202204 GGML_LOG_INFO (" GGML_SYCL_DISABLE_GRAPH: %d\n " , g_ggml_sycl_disable_graph);
205+ GGML_LOG_INFO (" GGML_SYCL_PRIORITIZE_DMMV: %d\n " , g_ggml_sycl_prioritize_dmmv);
203206 GGML_LOG_INFO (" Build with Macros:\n " );
204207#if defined(GGML_SYCL_FORCE_MMQ)
205208 GGML_LOG_INFO (" GGML_SYCL_FORCE_MMQ: yes\n " );
@@ -2822,12 +2825,45 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
28222825 std::exit (1 );
28232826}
28242827
2828+ enum class mul_mat_algo {
2829+ DMMV = 0 ,
2830+ MMVQ = 1 ,
2831+ MUL_MAT_SYCL = 2 ,
2832+ };
2833+
28252834inline bool ggml_sycl_supports_mmq (enum ggml_type type) {
28262835 // TODO: accuracy issues in MMQ
28272836 GGML_UNUSED (type);
28282837 return false ;
28292838}
28302839
2840+ inline bool ggml_sycl_supports_reorder_mul_mat_sycl (enum ggml_type type) {
2841+ switch (type) {
2842+ case GGML_TYPE_Q4_0:
2843+ return true ;
2844+ default :
2845+ return false ;
2846+ }
2847+ }
2848+
2849+ inline bool ggml_sycl_supports_reorder_dmmv (enum ggml_type type) {
2850+ switch (type) {
2851+ case GGML_TYPE_Q4_0:
2852+ return true ;
2853+ default :
2854+ return false ;
2855+ }
2856+ }
2857+
2858+ inline bool ggml_sycl_supports_reorder_mmvq (enum ggml_type type) {
2859+ switch (type) {
2860+ case GGML_TYPE_Q4_0:
2861+ return true ;
2862+ default :
2863+ return false ;
2864+ }
2865+ }
2866+
28312867static bool ggml_sycl_supports_dmmv (enum ggml_type type) {
28322868 switch (type) {
28332869 case GGML_TYPE_Q4_0:
@@ -2856,7 +2892,7 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
28562892 GGML_ASSERT ((size % sizeof (block_q4_0) == 0 ));
28572893 GGML_ASSERT ((offset % sizeof (block_q4_0) == 0 ));
28582894 int offset_blks = offset / sizeof (block_q4_0);
2859- auto qs_ptr = (uint8_t *)data_device + offset_blks * QK4_0 / 2 ;;
2895+ auto qs_ptr = (uint8_t *)data_device + offset_blks * QK4_0 / 2 ;
28602896 auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2 ) + offset_blks;
28612897
28622898 stream->parallel_for (
@@ -2884,25 +2920,44 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
28842920 reorder_qw (data_device, ncols, nrows, size, 0 , stream);
28852921}
28862922
2887- /*
2888- * This function could be called when the OP (mul_mat) function support reorder optimizition.
2889- */
2890- static void opt_for_reorder (ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1,
2891- ggml_tensor * dst) {
2892- if (!g_ggml_sycl_disable_optimize && // allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
2893- ctx->opt_feature .reorder && // allow this device due to good perf, skip the devices with bad perf.
2894- dst->op == GGML_OP_MUL_MAT && // limit to some supported cases of Q4_0, to do for more cases.
2895- src0->type == GGML_TYPE_Q4_0 &&
2896- src1->ne [2 ]==1 && src1->ne [3 ]==1 ) {
2923+ static bool should_reorder_tensor (ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
2924+ return !g_ggml_sycl_disable_optimize && // allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
2925+ ctx.opt_feature .reorder && // allow this device due to good perf, skip the devices with bad perf.
2926+ dst->op == GGML_OP_MUL_MAT && // limit to some supported cases of Q4_0, to do for more cases.
2927+ dst->src [1 ]->ne [2 ]==1 && dst->src [1 ]->ne [3 ]==1 ;
2928+ }
28972929
2898- ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra ;
2899- if (!extra) return ; // only happen in CI/UT permute case.
2930+ static void opt_for_reorder (ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */ ,
2931+ ggml_tensor * dst, mul_mat_algo mm_algorithm) {
2932+ if (!should_reorder_tensor (*ctx, dst)) {
2933+ return ;
2934+ }
29002935
2901- if (extra->optimized_feature .reorder ) return ; // skip the tensor which is handled for reorder.
2936+ ggml_tensor_extra_gpu * extra = static_cast <ggml_tensor_extra_gpu *>(src0->extra );
2937+ if (!extra || extra->optimized_feature .reorder ) {
2938+ return ; // Skip permutations and already reordered tensors
2939+ }
29022940
2903- reorder_qw (src0, ctx->stream ());
2904- extra->optimized_feature .reorder = true ; // used to decode/dequan in next steps.
2941+ switch (mm_algorithm) {
2942+ case mul_mat_algo::DMMV:
2943+ if (!ggml_sycl_supports_reorder_dmmv (src0->type )) {
2944+ return ;
2945+ }
2946+ break ;
2947+ case mul_mat_algo::MMVQ:
2948+ if (!ggml_sycl_supports_reorder_mmvq (src0->type )) {
2949+ return ;
2950+ }
2951+ break ;
2952+ case mul_mat_algo::MUL_MAT_SYCL:
2953+ if (!ggml_sycl_supports_reorder_mul_mat_sycl (src0->type )) {
2954+ return ;
2955+ }
2956+ break ;
29052957 }
2958+
2959+ reorder_qw (src0, ctx->stream ());
2960+ extra->optimized_feature .reorder = true ; // Used to decode/dequan in next steps and avoid re-reordering
29062961}
29072962
29082963static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2911,7 +2966,8 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
29112966 int64_t min_compute_capability = INT_MAX;
29122967
29132968 if (split) {
2914- ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer ->buft ->context ;
2969+ ggml_backend_sycl_split_buffer_type_context * buft_ctx =
2970+ (ggml_backend_sycl_split_buffer_type_context *) src0->buffer ->buft ->context ;
29152971 auto & tensor_split = buft_ctx->tensor_split ;
29162972 for (int id = 0 ; id < ggml_sycl_info ().device_count ; ++id) {
29172973 // skip devices that are not going to do any work:
@@ -2924,7 +2980,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
29242980 }
29252981 }
29262982 } else {
2927- min_compute_capability = ggml_sycl_info ().devices [ctx.device ].cc ;
2983+ min_compute_capability = ggml_sycl_info ().devices [ctx.device ].cc ;
29282984 }
29292985
29302986 // check data types and tensor shapes for custom matrix multiplication kernels:
@@ -2946,9 +3002,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
29463002 use_mul_mat_q = use_mul_mat_q && (src1->ne [1 ] <= MMQ_MAX_BATCH_SIZE);
29473003#endif // SYCL_USE_XMX
29483004
3005+
29493006 // mmvq path is faster in the CUDA backend.
2950- if (ctx.stream ()->get_backend () == sycl::backend::ext_oneapi_cuda)
3007+ if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream ()->get_backend () == sycl::backend::ext_oneapi_cuda
3008+ // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
3009+ // is enabled takes precedence over DMMV, the current if-else implementation
3010+ // requires disabling DMMV if both conditions are met
3011+ || (should_reorder_tensor (ctx, dst) && ggml_sycl_supports_reorder_mmvq (src0->type )))) {
29513012 use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
3013+ }
29523014
29533015 if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted (src0) && ggml_is_permuted (src1) && src1->ne [1 ] == 1 ) {
29543016 // TODO: Refactor and cleanup of mul mat dispatching.
@@ -2967,17 +3029,23 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
29673029 // KQ + KQV multi-batch
29683030 ggml_sycl_mul_mat_batched_sycl (ctx, src0, src1, dst);
29693031 } else if (use_dequantize_mul_mat_vec) {
2970- opt_for_reorder (&ctx, src0, src1, dst); // the OP function in this branch support reorder.
2971- ggml_sycl_op_mul_mat ( ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false );
2972- // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream() );
3032+ constexpr bool convert_src1_to_q8_1 = false ;
3033+ opt_for_reorder (& ctx, src0, src1, dst, mul_mat_algo::DMMV );
3034+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1 );
29733035 } else if (use_mul_mat_vec_q) {
2974- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true );
3036+ constexpr bool convert_src1_to_q8_1 = true ;
3037+ opt_for_reorder (&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
3038+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
29753039 } else if (use_mul_mat_q) {
2976- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true );
3040+ constexpr bool convert_src1_to_q8_1 = true ;
3041+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
29773042 } else {
2978- opt_for_reorder (&ctx, src0, src1, dst); // the OP function in this branch support reorder.
2979- ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false );
3043+ constexpr bool convert_src1_to_q8_1 = false ;
3044+ // MUL_MAT_SYCL supports reorder
3045+ opt_for_reorder (&ctx, src0, src1, dst, mul_mat_algo::MUL_MAT_SYCL);
3046+ ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
29803047 }
3048+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
29813049}
29823050
29833051
0 commit comments