Skip to content

Commit 38dbdf4

Browse files
CUDA: Optimize PAD_REFLECT_1D (ggml-org#15957)
* CUDA: Optimize PAD_REFLECT_1D feat: add more test cases for PAD_REFLECT_1D * use fast_div to improve performance * Apply suggestion from JohannesGaessler Co-authored-by: Johannes Gäßler <[email protected]> * Apply suggestion from JohannesGaessler Co-authored-by: Johannes Gäßler <[email protected]> * optimize * use a concise expression to further speedup the cuda kernel --------- Co-authored-by: Johannes Gäßler <[email protected]>
1 parent 368560a commit 38dbdf4

File tree

3 files changed

+76
-54
lines changed

3 files changed

+76
-54
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,14 @@ static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fa
652652
return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;
653653
}
654654

655+
// Calculate both division and modulo at once, returns <n/divisor, n%divisor>
656+
static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 fastdiv_values) {
657+
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
658+
const uint32_t div_val = fastdiv(n, fastdiv_values);
659+
const uint32_t mod_val = n - div_val * fastdiv_values.z;
660+
return make_uint2(div_val, mod_val);
661+
}
662+
655663
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
656664

657665
static __device__ __forceinline__ float get_alibi_slope(
Lines changed: 61 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,89 @@
11
#include "pad_reflect_1d.cuh"
22

3-
static __global__ void pad_reflect_1d_kernel_f32(
4-
const void * __restrict__ src0,
5-
void * __restrict__ dst,
6-
const int64_t ne0,
7-
const int64_t ne00,
8-
const int64_t ne01,
9-
const int64_t ne02,
10-
const int64_t ne03,
11-
const int64_t nb00,
12-
const int64_t nb01,
13-
const int64_t nb02,
14-
const int64_t nb03,
15-
const int64_t nb0,
16-
const int64_t nb1,
17-
const int64_t nb2,
18-
const int64_t nb3,
19-
const int p0,
20-
const int p1) {
21-
3+
static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void
4+
pad_reflect_1d_kernel_f32(
5+
const void * __restrict__ src0,
6+
void * __restrict__ dst,
7+
const int64_t ne0,
8+
const int64_t ne00,
9+
const uint3 ne01,
10+
const int64_t ne02,
11+
const int64_t ne03,
12+
const int64_t nb00,
13+
const int64_t nb01,
14+
const int64_t nb02,
15+
const int64_t nb03,
16+
const int64_t nb0,
17+
const int64_t nb1,
18+
const int64_t nb2,
19+
const int64_t nb3,
20+
const int p0,
21+
const int p1) {
2222
const int64_t i3 = blockIdx.z;
2323
const int64_t i2 = blockIdx.y;
24-
const int64_t i1 = blockIdx.x;
2524

26-
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
25+
const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01);
26+
const int64_t tile1 = div_mod_packed.y; // i1
27+
const int64_t tile0 = div_mod_packed.x; // nth i0 tile
28+
const int64_t i1 = tile1;
29+
const int64_t i0 = threadIdx.x + tile0 * blockDim.x;
30+
31+
// ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh)
32+
if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) {
2733
return;
2834
}
2935

30-
const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
31-
char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
32-
33-
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
34-
float value;
36+
const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
37+
char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
3538

36-
if (i0 < p0) {
37-
// Left padding - reflect
38-
value = *(const float *)(src0_ptr + (p0 - i0) * nb00);
39-
} else if (i0 < ne0 - p1) {
40-
// Middle - copy
41-
value = *(const float *)(src0_ptr + (i0 - p0) * nb00);
42-
} else {
43-
// Right padding - reflect
44-
int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;
45-
value = *(const float *)(src0_ptr + src_idx * nb00);
46-
}
39+
const int64_t rel_i0 = i0 - p0; // relative i0 in src0
40+
int64_t src_idx;
4741

48-
*(float *)(dst_ptr + i0 * nb0) = value;
42+
if (rel_i0 < 0) {
43+
// Left padding - reflect
44+
src_idx = -rel_i0;
45+
} else if (rel_i0 < ne00) {
46+
// Middle - copy
47+
src_idx = rel_i0;
48+
} else {
49+
// Right padding - reflect
50+
src_idx = 2 * ne00 - 2 - rel_i0;
4951
}
52+
const float value = *(const float *) (src0_ptr + src_idx * nb00);
53+
*(float *) (dst_ptr + i0 * nb0) = value;
5054
}
5155

5256
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
53-
const ggml_tensor * src0 = dst->src[0];
54-
cudaStream_t stream = ctx.stream();
57+
const ggml_tensor * src0 = dst->src[0];
58+
cudaStream_t stream = ctx.stream();
5559

5660
GGML_ASSERT(src0->type == GGML_TYPE_F32);
5761
GGML_ASSERT(dst->type == GGML_TYPE_F32);
5862

5963
const int32_t * opts = (const int32_t *) dst->op_params;
60-
const int p0 = opts[0];
61-
const int p1 = opts[1];
64+
const int p0 = opts[0];
65+
const int p1 = opts[1];
6266

63-
const int64_t ne00 = src0->ne[0];
64-
const int64_t ne01 = src0->ne[1];
65-
const int64_t ne02 = src0->ne[2];
66-
const int64_t ne03 = src0->ne[3];
67+
const int64_t ne00 = src0->ne[0];
68+
const int64_t ne01 = src0->ne[1];
69+
const uint3 ne01_packed = init_fastdiv_values(ne01);
70+
const int64_t ne02 = src0->ne[2];
71+
const int64_t ne03 = src0->ne[3];
6772

6873
const int64_t ne0 = dst->ne[0];
6974

75+
// sanity: padded length matches
7076
GGML_ASSERT(ne0 == ne00 + p0 + p1);
7177

72-
const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1);
73-
const dim3 grid_dims(ne01, ne02, ne03);
78+
constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x)
79+
const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0
80+
// grid.x covers i1 and all tiles of i0: [ne01 * tiles0]
81+
// grid.y covers i2: [ne02]
82+
// grid.z covers i3: [ne03]
83+
const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03);
84+
const dim3 block_dims((unsigned) bx, 1, 1);
7485

7586
pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
76-
src0->data, dst->data,
77-
ne0, ne00, ne01, ne02, ne03,
78-
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
79-
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
80-
p0, p1
81-
);
87+
src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
88+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1);
8289
}

tests/test-backend-ops.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6507,6 +6507,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
65076507
test_cases.emplace_back(new test_pad());
65086508
test_cases.emplace_back(new test_pad_ext());
65096509
test_cases.emplace_back(new test_pad_reflect_1d());
6510+
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
65106511
test_cases.emplace_back(new test_roll());
65116512
test_cases.emplace_back(new test_arange());
65126513
test_cases.emplace_back(new test_timestep_embedding());
@@ -6645,6 +6646,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
66456646
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
66466647
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
66476648

6649+
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {512, 34, 2, 1}));
6650+
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 1, 1}));
6651+
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 80, 4, 1}));
6652+
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 1, 1}));
6653+
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
6654+
66486655
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
66496656
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true));
66506657

0 commit comments

Comments
 (0)