Skip to content

Commit e8bbcaa

Browse files
committed
explicitly invoke __fmaf_rn
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 5c579cf commit e8bbcaa

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

ggml/src/ggml-cuda/mmv.cu

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ static __global__ void mul_mat_vec(
4646
#pragma unroll
4747
for (int j = 0; j < ncols_dst; ++j) {
4848
const float2 tmpy = y2[j*stride_col_y2 + col2];
49-
sumf[j] += tmpx.x*tmpy.x;
50-
sumf[j] += tmpx.y*tmpy.y;
49+
sumf[j] = __fmaf_rn(tmpx.x, tmpy.x, sumf[j]);
50+
sumf[j] = __fmaf_rn(tmpx.y, tmpy.y, sumf[j]);
5151
}
5252
}
5353
} else if constexpr (std::is_same<T, half>::value) {
@@ -60,8 +60,8 @@ static __global__ void mul_mat_vec(
6060
#pragma unroll
6161
for (int j = 0; j < ncols_dst; ++j) {
6262
const float2 tmpy = y2[j*stride_col_y2 + col2];
63-
sumf[j] += tmpx.x * tmpy.x;
64-
sumf[j] += tmpx.y * tmpy.y;
63+
sumf[j] = __fmaf_rn(tmpx.x, tmpy.x, sumf[j]);
64+
sumf[j] = __fmaf_rn(tmpx.y, tmpy.y, sumf[j]);
6565
}
6666
}
6767
} else {
@@ -90,11 +90,13 @@ static __global__ void mul_mat_vec(
9090
const int * x2 = (const int *) x;
9191
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
9292
const int tmpx = x2[col2];
93+
const float x_low = float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
94+
const float x_high = float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
9395
#pragma unroll
9496
for (int j = 0; j < ncols_dst; ++j) {
9597
const float2 tmpy = y2[j*stride_col_y2 + col2];
96-
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
97-
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
98+
sumf[j] = __fmaf_rn(x_low, tmpy.x, sumf[j]);
99+
sumf[j] = __fmaf_rn(x_high, tmpy.y, sumf[j]);
98100
}
99101
}
100102
} else {

0 commit comments

Comments
 (0)