@@ -46,8 +46,8 @@ static __global__ void mul_mat_vec(
4646#pragma unroll
4747 for (int j = 0 ; j < ncols_dst; ++j) {
4848 const float2 tmpy = y2[j*stride_col_y2 + col2];
49- sumf[j] += tmpx.x * tmpy.x ;
50- sumf[j] += tmpx.y * tmpy.y ;
49+ sumf[j] = __fmaf_rn ( tmpx.x , tmpy.x , sumf[j]) ;
50+ sumf[j] = __fmaf_rn ( tmpx.y , tmpy.y , sumf[j]) ;
5151 }
5252 }
5353 } else if constexpr (std::is_same<T, half>::value) {
@@ -60,8 +60,8 @@ static __global__ void mul_mat_vec(
6060#pragma unroll
6161 for (int j = 0 ; j < ncols_dst; ++j) {
6262 const float2 tmpy = y2[j*stride_col_y2 + col2];
63- sumf[j] += tmpx.x * tmpy.x ;
64- sumf[j] += tmpx.y * tmpy.y ;
63+ sumf[j] = __fmaf_rn ( tmpx.x , tmpy.x , sumf[j]) ;
64+ sumf[j] = __fmaf_rn ( tmpx.y , tmpy.y , sumf[j]) ;
6565 }
6666 }
6767 } else {
@@ -90,11 +90,13 @@ static __global__ void mul_mat_vec(
9090 const int * x2 = (const int *) x;
9191 for (int col2 = tid; col2 < ncols2; col2 += block_size) {
9292 const int tmpx = x2[col2];
93+ const float x_low = float (reinterpret_cast <const nv_bfloat16 *>(&tmpx)[0 ]);
94+ const float x_high = float (reinterpret_cast <const nv_bfloat16 *>(&tmpx)[1 ]);
9395#pragma unroll
9496 for (int j = 0 ; j < ncols_dst; ++j) {
9597 const float2 tmpy = y2[j*stride_col_y2 + col2];
96- sumf[j] += float ( reinterpret_cast < const nv_bfloat16 *>(&tmpx)[ 0 ]) * tmpy. x ;
97- sumf[j] += float ( reinterpret_cast < const nv_bfloat16 *>(&tmpx)[ 1 ]) * tmpy. y ;
98+ sumf[j] = __fmaf_rn (x_low, tmpy. x , sumf[j]) ;
99+ sumf[j] = __fmaf_rn (x_high, tmpy. y , sumf[j]) ;
98100 }
99101 }
100102 } else {
0 commit comments