Skip to content

Commit 4f3ca60

Browse files
am17anyael-works
authored andcommitted
CUDA: use fastdiv + ggml_cuda_mad for mmvf (ggml-org#16557)
* CUDA: use fastdiv + ggml_cuda_mad for mmvf * use bf16 directly + fix formatting * Add exception for HIP code
1 parent e66ddac commit 4f3ca60

File tree

1 file changed

+44
-28
lines changed

1 file changed

+44
-28
lines changed

ggml/src/ggml-cuda/mmvf.cu

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ template <typename T, typename type_acc, int ncols_dst, int block_size>
77
static __global__ void mul_mat_vec_f(
88
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
99
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
10-
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
11-
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
10+
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
11+
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
1212
const int row = blockIdx.x;
1313
const int channel_dst = blockIdx.y;
14-
const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
14+
const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio);
1515
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
1616
const int sample_dst = blockIdx.z;
17-
const int sample_x = sample_dst / sample_ratio;
17+
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
1818
const int sample_y = sample_dst;
1919
const int tid = threadIdx.x;
2020

@@ -47,8 +47,8 @@ static __global__ void mul_mat_vec_f(
4747
#pragma unroll
4848
for (int j = 0; j < ncols_dst; ++j) {
4949
const float2 tmpy = y2[j*stride_col_y2 + col2];
50-
sumf[j] += tmpx.x*tmpy.x;
51-
sumf[j] += tmpx.y*tmpy.y;
50+
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
51+
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
5252
}
5353
}
5454
} else if constexpr (std::is_same_v<T, half>) {
@@ -61,8 +61,8 @@ static __global__ void mul_mat_vec_f(
6161
#pragma unroll
6262
for (int j = 0; j < ncols_dst; ++j) {
6363
const float2 tmpy = y2[j*stride_col_y2 + col2];
64-
sumf[j] += tmpx.x * tmpy.x;
65-
sumf[j] += tmpx.y * tmpy.y;
64+
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
65+
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
6666
}
6767
}
6868
} else {
@@ -88,16 +88,32 @@ static __global__ void mul_mat_vec_f(
8888
#endif // FP16_AVAILABLE
8989
}
9090
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
91+
//TODO: add support for ggml_cuda_mad for hip_bfloat162
92+
#if defined(GGML_USE_HIP)
9193
const int * x2 = (const int *) x;
9294
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
9395
const int tmpx = x2[col2];
9496
#pragma unroll
9597
for (int j = 0; j < ncols_dst; ++j) {
9698
const float2 tmpy = y2[j*stride_col_y2 + col2];
97-
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
98-
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
99+
const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
100+
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
101+
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
102+
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
99103
}
100104
}
105+
#else
106+
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
107+
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
108+
const nv_bfloat162 tmpx = x2[col2];
109+
#pragma unroll
110+
for (int j = 0; j < ncols_dst; ++j) {
111+
const float2 tmpy = y2[j*stride_col_y2 + col2];
112+
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
113+
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
114+
}
115+
}
116+
#endif
101117
} else {
102118
static_assert(std::is_same_v<T, void>, "unsupported type");
103119
}
@@ -140,8 +156,8 @@ static void launch_mul_mat_vec_f_cuda(
140156
GGML_ASSERT(stride_col_y % 2 == 0);
141157
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
142158
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
143-
const int64_t channel_ratio = nchannels_dst / nchannels_x;
144-
const int64_t sample_ratio = nsamples_dst / nsamples_x;
159+
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
160+
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
145161

146162
const int device = ggml_cuda_get_device();
147163
const int warp_size = ggml_cuda_info().devices[device].warp_size;
@@ -167,50 +183,50 @@ static void launch_mul_mat_vec_f_cuda(
167183
case 32: {
168184
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
169185
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
170-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
171-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
186+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
187+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
172188
} break;
173189
case 64: {
174190
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
175191
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
176-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
177-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
192+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
193+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
178194
} break;
179195
case 96: {
180196
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
181197
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
182-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
183-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
198+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
199+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
184200
} break;
185201
case 128: {
186202
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
187203
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
188-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
189-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
204+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
205+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
190206
} break;
191207
case 160: {
192208
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
193209
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
194-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
195-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
210+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
211+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
196212
} break;
197213
case 192: {
198214
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
199215
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
200-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
201-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
216+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
217+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
202218
} break;
203219
case 224: {
204220
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
205221
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
206-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
207-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
222+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
223+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
208224
} break;
209225
case 256: {
210226
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
211227
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
212-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
213-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
228+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
229+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
214230
} break;
215231
default: {
216232
GGML_ABORT("fatal error");

0 commit comments

Comments
 (0)