@@ -7,14 +7,14 @@ template <typename T, typename type_acc, int ncols_dst, int block_size>
77static __global__ void mul_mat_vec_f (
88 const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
99 const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
10- const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
11- const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
10+ const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
11+ const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
1212 const int row = blockIdx .x ;
1313 const int channel_dst = blockIdx .y ;
14- const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
14+ const int channel_x = ids ? ids[channel_dst] : fastdiv (( uint32_t ) channel_dst, channel_ratio) ;
1515 const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
1616 const int sample_dst = blockIdx .z ;
17- const int sample_x = sample_dst / sample_ratio;
17+ const int sample_x = fastdiv (( uint32_t ) sample_dst, sample_ratio) ;
1818 const int sample_y = sample_dst;
1919 const int tid = threadIdx .x ;
2020
@@ -47,8 +47,8 @@ static __global__ void mul_mat_vec_f(
4747#pragma unroll
4848 for (int j = 0 ; j < ncols_dst; ++j) {
4949 const float2 tmpy = y2[j*stride_col_y2 + col2];
50- sumf[j] += tmpx.x * tmpy.x ;
51- sumf[j] += tmpx.y * tmpy.y ;
50+ ggml_cuda_mad ( sumf[j], tmpx.x , tmpy.x ) ;
51+ ggml_cuda_mad ( sumf[j], tmpx.y , tmpy.y ) ;
5252 }
5353 }
5454 } else if constexpr (std::is_same_v<T, half>) {
@@ -61,8 +61,8 @@ static __global__ void mul_mat_vec_f(
6161#pragma unroll
6262 for (int j = 0 ; j < ncols_dst; ++j) {
6363 const float2 tmpy = y2[j*stride_col_y2 + col2];
64- sumf[j] += tmpx.x * tmpy.x ;
65- sumf[j] += tmpx.y * tmpy.y ;
64+ ggml_cuda_mad ( sumf[j], tmpx.x , tmpy.x ) ;
65+ ggml_cuda_mad ( sumf[j], tmpx.y , tmpy.y ) ;
6666 }
6767 }
6868 } else {
@@ -88,16 +88,32 @@ static __global__ void mul_mat_vec_f(
8888#endif // FP16_AVAILABLE
8989 }
9090 } else if constexpr (std::is_same_v<T, nv_bfloat16>) {
91+ // TODO: add support for ggml_cuda_mad for hip_bfloat162
92+ #if defined(GGML_USE_HIP)
9193 const int * x2 = (const int *) x;
9294 for (int col2 = tid; col2 < ncols2; col2 += block_size) {
9395 const int tmpx = x2[col2];
9496#pragma unroll
9597 for (int j = 0 ; j < ncols_dst; ++j) {
9698 const float2 tmpy = y2[j*stride_col_y2 + col2];
97- sumf[j] += ggml_cuda_cast<float >(reinterpret_cast <const nv_bfloat16 *>(&tmpx)[0 ]) * tmpy.x ;
98- sumf[j] += ggml_cuda_cast<float >(reinterpret_cast <const nv_bfloat16 *>(&tmpx)[1 ]) * tmpy.y ;
99+ const float tmpx0 = ggml_cuda_cast<float >(reinterpret_cast <const nv_bfloat16 *>(&tmpx)[0 ]);
100+ const float tmpx1 = ggml_cuda_cast<float >(reinterpret_cast <const nv_bfloat16 *>(&tmpx)[1 ]);
101+ ggml_cuda_mad (sumf[j], tmpx0, tmpy.x );
102+ ggml_cuda_mad (sumf[j], tmpx1, tmpy.y );
99103 }
100104 }
105+ #else
106+ const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
107+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
108+ const nv_bfloat162 tmpx = x2[col2];
109+ #pragma unroll
110+ for (int j = 0 ; j < ncols_dst; ++j) {
111+ const float2 tmpy = y2[j*stride_col_y2 + col2];
112+ ggml_cuda_mad (sumf[j], tmpx.x , tmpy.x );
113+ ggml_cuda_mad (sumf[j], tmpx.y , tmpy.y );
114+ }
115+ }
116+ #endif
101117 } else {
102118 static_assert (std::is_same_v<T, void >, " unsupported type" );
103119 }
@@ -140,8 +156,8 @@ static void launch_mul_mat_vec_f_cuda(
140156 GGML_ASSERT (stride_col_y % 2 == 0 );
141157 GGML_ASSERT (ids || nchannels_dst % nchannels_x == 0 );
142158 GGML_ASSERT ( nsamples_dst % nsamples_x == 0 );
143- const int64_t channel_ratio = nchannels_dst / nchannels_x;
144- const int64_t sample_ratio = nsamples_dst / nsamples_x;
159+ const uint3 channel_ratio_fd = ids ? make_uint3 ( 0 , 0 , 0 ) : init_fastdiv_values ( nchannels_dst / nchannels_x) ;
160+ const uint3 sample_ratio_fd = init_fastdiv_values ( nsamples_dst / nsamples_x) ;
145161
146162 const int device = ggml_cuda_get_device ();
147163 const int warp_size = ggml_cuda_info ().devices [device].warp_size ;
@@ -167,50 +183,50 @@ static void launch_mul_mat_vec_f_cuda(
167183 case 32 : {
168184 mul_mat_vec_f<T, type_acc, ncols_dst, 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
169185 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
170- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
171- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
186+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
187+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
172188 } break ;
173189 case 64 : {
174190 mul_mat_vec_f<T, type_acc, ncols_dst, 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
175191 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
176- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
177- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
192+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
193+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
178194 } break ;
179195 case 96 : {
180196 mul_mat_vec_f<T, type_acc, ncols_dst, 96 ><<<block_nums, block_dims, nbytes_shared, stream>>>
181197 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
182- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
183- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
198+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
199+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
184200 } break ;
185201 case 128 : {
186202 mul_mat_vec_f<T, type_acc, ncols_dst, 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
187203 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
188- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
189- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
204+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
205+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
190206 } break ;
191207 case 160 : {
192208 mul_mat_vec_f<T, type_acc, ncols_dst, 160 ><<<block_nums, block_dims, nbytes_shared, stream>>>
193209 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
194- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
195- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
210+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
211+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
196212 } break ;
197213 case 192 : {
198214 mul_mat_vec_f<T, type_acc, ncols_dst, 192 ><<<block_nums, block_dims, nbytes_shared, stream>>>
199215 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
200- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
201- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
216+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
217+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
202218 } break ;
203219 case 224 : {
204220 mul_mat_vec_f<T, type_acc, ncols_dst, 224 ><<<block_nums, block_dims, nbytes_shared, stream>>>
205221 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
206- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
207- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
222+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
223+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
208224 } break ;
209225 case 256 : {
210226 mul_mat_vec_f<T, type_acc, ncols_dst, 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
211227 (x, y, ids, dst, ncols/2 , nchannels_y, stride_row, stride_col_y/2 , stride_col_dst,
212- channel_ratio , stride_channel_x, stride_channel_y, stride_channel_dst,
213- sample_ratio , stride_sample_x, stride_sample_y, stride_sample_dst);
228+ channel_ratio_fd , stride_channel_x, stride_channel_y, stride_channel_dst,
229+ sample_ratio_fd , stride_sample_x, stride_sample_y, stride_sample_dst);
214230 } break ;
215231 default : {
216232 GGML_ABORT (" fatal error" );
0 commit comments