diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index ba006d9b31a..5d4b10d34b9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1732,6 +1732,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_IM2COL); + GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); @@ -1739,7 +1741,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_meta char base[256]; char name[256]; - snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); + if (ne00*ne01 <= 1024) { + snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); + } else { + snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type)); + } snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 206af227a2c..e2ce56e9e28 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3635,16 +3635,26 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); - GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + if (KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); - const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); + } else { + const uint64_t n_threads = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), N); + const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); + ggml_metal_encoder_dispatch_threadgroups(enc, quotient * CHW, OH, OW, n_threads, 1, 1); + } return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index e772664ba91..4adf4614acb 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4696,59 +4696,59 @@ kernel void kernel_im2col( template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; -// TODO: obsolete -- remove -//typedef void (im2col_ext_t)( -// constant ggml_metal_kargs_im2col & args, -// device const float * x, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// uint3 tgpg[[threadgroups_per_grid]], -// uint3 tpitg[[thread_position_in_threadgroup]], -// uint3 ntg[[threads_per_threadgroup]]); -// -//template -//kernel void kernel_im2col_ext( -// constant ggml_metal_kargs_im2col & args, -// device const float * x, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW -// uint3 tpitg[[thread_position_in_threadgroup]], -// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] -// const int64_t KHW = (int64_t)args.KHW; -// -// const int64_t d = tgpig[0] / args.CHW; -// const int64_t chw = tgpig[0] % args.CHW; -// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) -// const int64_t HW = tgpig[0] % KHW; -// -// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; -// if (tpitg_0 >= args.N) { -// return; -// } -// -// const int64_t tpitg_1 = HW / args.KW; -// const int64_t tpitg_2 = HW % args.KW; -// -// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; -// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; -// -// const int64_t offset_dst = -// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + -// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); -// -// device T * pdst = (device T *) (dst); -// -// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { -// pdst[offset_dst] = 0.0f; -// } else { -// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; -// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; -// } -//} -// -//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; -//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; +// TODO: optimize +typedef void (im2col_ext_t)( + constant ggml_metal_kargs_im2col & args, + device const float * x, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col_ext( + constant ggml_metal_kargs_im2col & args, + device const float * x, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] + const int64_t KHW = (int64_t)args.KHW; + + const int64_t d = tgpig[0] / args.CHW; + const int64_t chw = tgpig[0] % args.CHW; + const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) + const int64_t HW = tgpig[0] % KHW; + + const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; + if (tpitg_0 >= args.N) { + return; + } + + const int64_t tpitg_1 = HW / args.KW; + const int64_t tpitg_2 = HW % args.KW; + + const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; + const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; + + const int64_t offset_dst = + (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + + (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { + pdst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; + pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; + } +} + +template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; +template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; template kernel void kernel_conv_2d( diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0176599459f..58c5fdd10db 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7812,6 +7812,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {2, 2, 1536, 729}, {2, 2, 1536, 4096}, 1, 1, 0, 0, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {128, 128, 1, 2}, {32, 33, 1, 2}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {128, 128, 2, 1}, {33, 34, 2, 1}, 1, 1, 1, 1, 1, 1, true)); // im2col 3D test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));