@@ -6868,43 +6868,52 @@ static void mul_mat_vec_q_cuda(
68686868
68696869 const int32_t config = ncols_y | (nwarps << 16 );
68706870
6871- switch (config) {
6872- case 0x00010001 :
6873- mul_mat_vec_q<1 , 1 , qk, qi, block_q_t , vdr, vec_dot>
6874- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6875- break ;
6876- case 0x00010002 :
6877- mul_mat_vec_q<1 , 2 , qk, qi, block_q_t , vdr, vec_dot>
6878- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6879- break ;
6880- case 0x00010003 :
6881- mul_mat_vec_q<1 , 3 , qk, qi, block_q_t , vdr, vec_dot>
6882- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6883- break ;
6884- case 0x00010004 :
6885- mul_mat_vec_q<1 , 4 , qk, qi, block_q_t , vdr, vec_dot>
6886- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6887- break ;
6888- case 0x00040001 :
6889- mul_mat_vec_q<4 , 1 , qk, qi, block_q_t , vdr, vec_dot>
6890- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6891- break ;
6892- case 0x00040002 :
6893- mul_mat_vec_q<4 , 2 , qk, qi, block_q_t , vdr, vec_dot>
6894- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6895- break ;
6896- case 0x00040003 :
6897- mul_mat_vec_q<4 , 3 , qk, qi, block_q_t , vdr, vec_dot>
6898- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6899- break ;
6900- case 0x00040004 :
6901- mul_mat_vec_q<4 , 4 , qk, qi, block_q_t , vdr, vec_dot>
6902- <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6903- break ;
6871+ switch (nwarps) {
6872+ case 1 : switch (ncols_y) {
6873+ case 1 :
6874+ mul_mat_vec_q<1 , 1 , qk, qi, block_q_t , vdr, vec_dot>
6875+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6876+ break ;
6877+ case 2 :
6878+ mul_mat_vec_q<1 , 2 , qk, qi, block_q_t , vdr, vec_dot>
6879+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6880+ break ;
6881+ case 3 :
6882+ mul_mat_vec_q<1 , 3 , qk, qi, block_q_t , vdr, vec_dot>
6883+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6884+ break ;
6885+ case 4 :
6886+ mul_mat_vec_q<1 , 4 , qk, qi, block_q_t , vdr, vec_dot>
6887+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6888+ break ;
6889+ default :
6890+ GGML_ASSERT (false );
6891+ break ;
6892+ } break ;
6893+ case 4 : switch (ncols_y) {
6894+ case 1 :
6895+ mul_mat_vec_q<4 , 1 , qk, qi, block_q_t , vdr, vec_dot>
6896+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6897+ break ;
6898+ case 2 :
6899+ mul_mat_vec_q<4 , 2 , qk, qi, block_q_t , vdr, vec_dot>
6900+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6901+ break ;
6902+ case 3 :
6903+ mul_mat_vec_q<4 , 3 , qk, qi, block_q_t , vdr, vec_dot>
6904+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6905+ break ;
6906+ case 4 :
6907+ mul_mat_vec_q<4 , 4 , qk, qi, block_q_t , vdr, vec_dot>
6908+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
6909+ break ;
6910+ default :
6911+ GGML_ASSERT (false );
6912+ break ;
6913+ } break ;
6914+
69046915 default :
69056916 GGML_ASSERT (false );
6906- // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
6907- // <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
69086917 break ;
69096918 }
69106919}
0 commit comments