@@ -341,7 +341,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
341341        assert (tensor->view_src ->buffer ->buft  == buffer->buft );
342342        return  GGML_STATUS_SUCCESS;
343343    }
344-     if  (tensor->type  == GGML_TYPE_Q4_0 && !g_ggml_sycl_disable_optimize) {
344+     if  (( tensor->type  == GGML_TYPE_Q4_0 || tensor-> type  == GGML_TYPE_Q4_K)  && !g_ggml_sycl_disable_optimize) {
345345        ggml_tensor_extra_gpu * extra = new  ggml_tensor_extra_gpu{};
346346        tensor->extra                  = extra;
347347        ctx->tensor_extras .push_back (extra);  // used to release it when destroy ctx.
@@ -2841,6 +2841,8 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
28412841    switch  (type) {
28422842        case  GGML_TYPE_Q4_0:
28432843            return  true ;
2844+         case  GGML_TYPE_Q4_K:
2845+             return  !g_ggml_sycl_prioritize_dmmv;
28442846        default :
28452847            return  false ;
28462848    }
@@ -2858,6 +2860,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
28582860inline  bool  ggml_sycl_supports_reorder_mmvq (enum  ggml_type type) {
28592861    switch  (type) {
28602862        case  GGML_TYPE_Q4_0:
2863+         case  GGML_TYPE_Q4_K:
28612864            return  true ;
28622865        default :
28632866            return  false ;
@@ -2883,16 +2886,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
28832886    }
28842887}
28852888
2886- static  void  reorder_qw ( char  * data_device, const  int  ncols, const  int  nrows,
2887-                 size_t  size,  size_t  offset,  dpct::queue_ptr stream) {
2888-     auto  tmp_buf = sycl::malloc_shared<char >(size, *stream);
2889+ static  void  reorder_qw_q4_0 ( uint8_t  *  data_device, const  int  ncols, const  int  nrows,  size_t  size,  size_t  offset ,
2890+                              dpct::queue_ptr stream) {
2891+     auto  *  tmp_buf = sycl::malloc_shared<uint8_t >(size, *stream);
28892892    SYCL_CHECK (
28902893        CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size)
28912894            .wait ()));
28922895    GGML_ASSERT ((size % sizeof (block_q4_0) == 0 ));
28932896    GGML_ASSERT ((offset % sizeof (block_q4_0) == 0 ));
28942897    int  offset_blks = offset / sizeof (block_q4_0);
2895-     auto  qs_ptr = ( uint8_t *) data_device + offset_blks * QK4_0 / 2 ;
2898+     auto  qs_ptr      =  data_device + offset_blks * QK4_0 / 2 ;
28962899    auto  d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2 ) + offset_blks;
28972900
28982901    stream->parallel_for (
@@ -2906,18 +2909,59 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
29062909                *(qs_ptr + ib * QK4_0 / 2  + j) = x[ib].qs [j];
29072910            }
29082911            *(d_ptr + ib) = x[ib].d ;
2909-         });
2912+         }).wait_and_throw ();
2913+ 
2914+     sycl::free (tmp_buf, *stream);
2915+ }
2916+ 
2917+ static  void  reorder_qw_q4_k (uint8_t  * data_device, size_t  size, size_t  offset, dpct::queue_ptr stream) {
2918+     GGML_ASSERT (size % sizeof (block_q4_K) == 0 );
2919+     GGML_ASSERT (offset % sizeof (block_q4_K) == 0 );
2920+ 
2921+     const  int  nblocks = size / sizeof (block_q4_K);
2922+ 
2923+     auto  * tmp_buf = sycl::malloc_shared<uint8_t >(size, *stream);
2924+     SYCL_CHECK (CHECK_TRY_ERROR ((*stream).memcpy (tmp_buf, data_device, size).wait ()));
2925+ 
2926+     auto  * qs_ptr     = data_device;
2927+     auto  * scales_ptr = qs_ptr + QK_K / 2  * nblocks;
2928+     auto  * dm_ptr     = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
2929+ 
2930+     stream->parallel_for (nblocks, [=](auto  i) {
2931+         const  block_q4_K * x  = (const  block_q4_K *) tmp_buf;
2932+         const  int           ib = i;
2933+ 
2934+         for  (int  j = 0 ; j < QK_K / 2 ; ++j) {
2935+             qs_ptr[ib * (QK_K / 2 ) + j] = x[ib].qs [j];
2936+         }
2937+ 
2938+         for  (int  j = 0 ; j < K_SCALE_SIZE; ++j) {
2939+             scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales [j];
2940+         }
2941+ 
2942+         dm_ptr[ib] = x[ib].dm ;
2943+     }).wait_and_throw ();
29102944
29112945    sycl::free (tmp_buf, *stream);
29122946}
29132947
29142948static  void  reorder_qw (const  ggml_tensor * src0, dpct::queue_ptr stream) {
2915-     char * data_device = (char *) src0->data ;
2949+     uint8_t  *  data_device = (uint8_t  *)  src0->data ;
29162950    size_t  ncols = src0->ne [0 ];
29172951    size_t  nrows = src0->ne [1 ];
29182952    size_t  size = ggml_nbytes (src0);
29192953
2920-     reorder_qw (data_device, ncols, nrows, size, 0 , stream);
2954+     switch  (src0->type ) {
2955+         case  GGML_TYPE_Q4_0:
2956+             reorder_qw_q4_0 (data_device, ncols, nrows, size, 0 , stream);
2957+             break ;
2958+         case  GGML_TYPE_Q4_K:
2959+             reorder_qw_q4_k (data_device, size, 0 , stream);
2960+             break ;
2961+         default :
2962+             GGML_ABORT (" reorder_qw() called with unsupported type" 
2963+             break ;
2964+     }
29212965}
29222966
29232967static  bool  should_reorder_tensor (ggml_backend_sycl_context& ctx, const  ggml_tensor * dst) {
@@ -2960,8 +3004,18 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor *
29603004    extra->optimized_feature .reorder  = true ;  //  Used to decode/dequan in next steps and avoid re-reordering
29613005}
29623006
2963- static  void  ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const  ggml_tensor * src0, const  ggml_tensor * src1, ggml_tensor * dst) {
29643007
3008+ static  bool  can_use_dequantize_mul_mat_vec (const  ggml_tensor * src0, const  ggml_tensor * src1, ggml_tensor * dst) {
3009+     return  ggml_sycl_supports_dmmv (src0->type ) && src1->type  == GGML_TYPE_F32 && dst->type  == GGML_TYPE_F32 &&
3010+            src0->ne [0 ] % GGML_SYCL_DMMV_X == 0  && src1->ne [1 ] == 1 ;
3011+ }
3012+ 
3013+ static  bool  can_use_mul_mat_vec_q (const  ggml_tensor * src0, const  ggml_tensor * src1, ggml_tensor * dst) {
3014+     return  ggml_is_quantized (src0->type ) && src1->type  == GGML_TYPE_F32 && dst->type  == GGML_TYPE_F32 &&
3015+            src1->ne [1 ] <= MMVQ_MAX_BATCH_SIZE;
3016+ }
3017+ 
3018+ static  void  ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const  ggml_tensor * src0, const  ggml_tensor * src1, ggml_tensor * dst) {
29653019    const  bool  split = ggml_backend_buffer_is_sycl_split (src0->buffer );
29663020    int64_t  min_compute_capability = INT_MAX;
29673021
@@ -2984,13 +3038,9 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
29843038    }
29853039
29863040    //  check data types and tensor shapes for custom matrix multiplication kernels:
2987-     bool  use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv (src0->type )
2988-         && src1->type  == GGML_TYPE_F32 && dst->type  == GGML_TYPE_F32
2989-         && src0->ne [0 ] % GGML_SYCL_DMMV_X == 0  && src1->ne [1 ] == 1 ;
3041+     bool  use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec (src0, src1, dst);
29903042
2991-     bool  use_mul_mat_vec_q =  ggml_is_quantized (src0->type )
2992-         && src1->type  == GGML_TYPE_F32 && dst->type  == GGML_TYPE_F32
2993-         && src1->ne [1 ] <= MMVQ_MAX_BATCH_SIZE;
3043+     bool  use_mul_mat_vec_q = can_use_mul_mat_vec_q (src0, src1, dst);
29943044
29953045    bool  use_mul_mat_q =  ggml_sycl_supports_mmq (src0->type )
29963046        && src1->type  == GGML_TYPE_F32 && dst->type  == GGML_TYPE_F32;
0 commit comments