Skip to content

Commit 0ed6745

Browse files
cern1710ggerganov
authored andcommitted
metal: optimise GGML_OP_SUM (ggml-org#16559)
* optimise GGML_OP_SUM * add non-contiguous tests by permuting the input * change tests to require full contiguity of OP_SUM * cuda : add check GGML_OP_SUM --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 5248535 commit 0ed6745

File tree

5 files changed

+71
-11
lines changed

5 files changed

+71
-11
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3633,9 +3633,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
36333633
case GGML_OP_CONV_2D_DW:
36343634
case GGML_OP_CONV_TRANSPOSE_2D:
36353635
case GGML_OP_POOL_2D:
3636-
case GGML_OP_SUM:
36373636
case GGML_OP_ACC:
36383637
return true;
3638+
case GGML_OP_SUM:
3639+
return ggml_is_contiguous_rows(op->src[0]);
36393640
case GGML_OP_ARGSORT:
36403641
// TODO: Support arbitrary column width
36413642
return op->src[0]->ne[0] <= 1024;

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
662662
case GGML_OP_LOG:
663663
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
664664
case GGML_OP_SUM:
665+
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
665666
case GGML_OP_SUM_ROWS:
666667
case GGML_OP_MEAN:
667668
case GGML_OP_SOFT_MAX:

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -866,12 +866,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
866866

867867
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
868868

869+
int nth = 32; // SIMD width
870+
871+
while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
872+
nth *= 2;
873+
}
874+
875+
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
876+
nth = std::min(nth, (int) n);
877+
878+
const int nsg = (nth + 31) / 32;
879+
869880
ggml_metal_encoder_set_pipeline(enc, pipeline);
870881
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
871882
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
872883
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
873884

874-
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
885+
ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
886+
887+
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
875888

876889
return 1;
877890
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32(
17271727
constant ggml_metal_kargs_sum & args,
17281728
device const float * src0,
17291729
device float * dst,
1730-
ushort tiitg[[thread_index_in_threadgroup]]) {
1730+
threadgroup float * shmem_f32 [[threadgroup(0)]],
1731+
uint3 tgpig[[threadgroup_position_in_grid]],
1732+
ushort3 tpitg[[thread_position_in_threadgroup]],
1733+
ushort sgitg[[simdgroup_index_in_threadgroup]],
1734+
ushort tiisg[[thread_index_in_simdgroup]],
1735+
ushort3 ntg[[threads_per_threadgroup]]) {
17311736

1732-
if (tiitg != 0) {
1737+
if (args.np == 0) {
17331738
return;
17341739
}
17351740

1736-
float acc = 0.0f;
1737-
for (ulong i = 0; i < args.np; ++i) {
1738-
acc += src0[i];
1741+
const uint nsg = (ntg.x + 31) / 32;
1742+
1743+
float sumf = 0;
1744+
1745+
for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
1746+
sumf += src0[i0];
17391747
}
17401748

1741-
dst[0] = acc;
1749+
sumf = simd_sum(sumf);
1750+
1751+
if (tiisg == 0) {
1752+
shmem_f32[sgitg] = sumf;
1753+
}
1754+
1755+
threadgroup_barrier(mem_flags::mem_threadgroup);
1756+
1757+
float total = 0;
1758+
1759+
if (sgitg == 0) {
1760+
float v = 0;
1761+
1762+
if (tpitg.x < nsg) {
1763+
v = shmem_f32[tpitg.x];
1764+
}
1765+
1766+
total = simd_sum(v);
1767+
1768+
if (tpitg.x == 0) {
1769+
dst[0] = total;
1770+
}
1771+
}
17421772
}
17431773

17441774
template <bool norm>

tests/test-backend-ops.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4732,20 +4732,31 @@ struct test_topk_moe: public test_case {
47324732
struct test_sum : public test_case {
47334733
const ggml_type type;
47344734
const std::array<int64_t, 4> ne;
4735+
const std::array<int64_t, 4> permute;
4736+
bool _use_permute;
47354737

47364738
std::string vars() override {
4737-
return VARS_TO_STR2(type, ne);
4739+
std::string v = VARS_TO_STR2(type, ne);
4740+
if (_use_permute) v += "," + VAR_TO_STR(permute);
4741+
return v;
47384742
}
47394743

47404744
test_sum(ggml_type type = GGML_TYPE_F32,
4741-
std::array<int64_t, 4> ne = {10, 5, 4, 3})
4742-
: type(type), ne(ne) {}
4745+
std::array<int64_t, 4> ne = {10, 5, 4, 3},
4746+
std::array<int64_t, 4> permute = {0, 0, 0, 0})
4747+
: type(type), ne(ne), permute(permute),
4748+
_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
47434749

47444750
ggml_tensor * build_graph(ggml_context * ctx) override {
47454751
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
47464752
ggml_set_param(a);
47474753
ggml_set_name(a, "a");
47484754

4755+
if (_use_permute) {
4756+
a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]);
4757+
ggml_set_name(a, "a_permuted");
4758+
}
4759+
47494760
ggml_tensor * out = ggml_sum(ctx, a);
47504761
ggml_set_name(out, "out");
47514762

@@ -6876,6 +6887,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
68766887

68776888
test_cases.emplace_back(new test_sum());
68786889
test_cases.emplace_back(new test_sum_rows());
6890+
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous
6891+
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));
6892+
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));
68796893
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
68806894
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
68816895
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
@@ -6886,6 +6900,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
68866900
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
68876901
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
68886902
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
6903+
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
68896904
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
68906905
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
68916906
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));

0 commit comments

Comments
 (0)