@@ -103,6 +103,8 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
103103struct ggml_backend_vk_context;
104104
105105#define MAX_PARAMETER_COUNT 8
106+ // Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
107+ #define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 2)
106108
107109struct vk_pipeline_struct {
108110    std::string name;
@@ -368,6 +370,7 @@ struct vk_device_struct {
368370    bool float_controls_rte_fp16;
369371    bool subgroup_add;
370372    bool subgroup_shuffle;
373+     bool multi_add;
371374
372375    bool integer_dot_product;
373376
@@ -449,6 +452,9 @@ struct vk_device_struct {
449452    vk_pipeline pipeline_div[2][2][2];
450453    vk_pipeline pipeline_div_norepeat[2][2][2];
451454
455+     // indexed by num_additional_fused_ops == num_adds - 1
456+     vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
457+ 
452458    vk_pipeline pipeline_add_id_f32;
453459
454460    vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
@@ -801,6 +807,14 @@ struct vk_op_binary_push_constants {
801807    float param1; float param2; int32_t param3;
802808};
803809
810+ struct vk_op_multi_add_push_constants {
811+     // shape for dst
812+     uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;
813+ 
814+     // strides for srcs+dst
815+     uint32_t nb[8][4];
816+ };
817+ 
804818struct vk_op_add_id_push_constants {
805819    uint32_t ne0;
806820    uint32_t ne1;
@@ -2994,6 +3008,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
29943008    CREATE_BINARY(div, _norepeat, {1})
29953009#undef CREATE_BINARY
29963010
3011+     if (device->multi_add) {
3012+         for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
3013+             ggml_vk_create_pipeline(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
3014+         }
3015+     }
3016+ 
29973017    ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
29983018
29993019    ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -3533,6 +3553,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
35333553
35343554        device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
35353555
3556+         device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
3557+                             device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
3558+                             vk12_features.runtimeDescriptorArray &&
3559+                             device->vendor_id != VK_VENDOR_ID_INTEL &&
3560+                             getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
3561+ 
35363562        if (device->subgroup_size_control) {
35373563            device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
35383564            device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
@@ -6887,6 +6913,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
68876913        switch (op) {
68886914        case GGML_OP_ADD:
68896915        {
6916+             if (ctx->num_additional_fused_ops > 0) {
6917+                 return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
6918+             }
68906919            auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
68916920            return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
68926921        }
@@ -7743,6 +7772,107 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
77437772    }, dryrun);
77447773}
77457774
7775+ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
7776+     const ggml_tensor *first_node = cgraph->nodes[node_idx];
7777+     const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
7778+ 
7779+     // Make a list of all the tensors used by the op.
7780+     // Last element of the list is the dest tensor.
7781+     const ggml_tensor *tensors[MAX_PARAMETER_COUNT];
7782+     uint32_t num_srcs = ctx->num_additional_fused_ops + 2;
7783+     uint32_t num_tensors = num_srcs + 1;
7784+     GGML_ASSERT(num_tensors <= MAX_PARAMETER_COUNT);
7785+ 
7786+     tensors[0] = first_node->src[0];
7787+     tensors[1] = first_node->src[1];
7788+     for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) {
7789+         // check whether the previous result is src[0] or src[1]
7790+         if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) {
7791+             tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1];
7792+         } else {
7793+             tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0];
7794+         }
7795+     }
7796+     tensors[num_srcs] = dst;
7797+ 
7798+     vk_op_multi_add_push_constants pc;
7799+     pc.ne20 = (uint32_t)dst->ne[0];
7800+     pc.ne21 = (uint32_t)dst->ne[1];
7801+     pc.ne22 = (uint32_t)dst->ne[2];
7802+     pc.ne23 = (uint32_t)dst->ne[3];
7803+ 
7804+     for (uint32_t i = 0; i < num_tensors; ++i) {
7805+         const ggml_tensor *t = tensors[i];
7806+         pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float);
7807+         pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float);
7808+         pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);
7809+         pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);
7810+     }
7811+ 
7812+     vk_pipeline pipeline = ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
7813+ 
7814+     if (pipeline == nullptr) {
7815+         std::cerr << "ggml_vulkan: Error: Missing multi_add";
7816+         GGML_ABORT("fatal error");
7817+     }
7818+ 
7819+     if (dryrun) {
7820+         ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
7821+         return;
7822+     }
7823+ 
7824+     ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT];
7825+     vk_buffer buf[MAX_PARAMETER_COUNT];
7826+     size_t offset[MAX_PARAMETER_COUNT];
7827+     bool uma[MAX_PARAMETER_COUNT];
7828+ 
7829+     for (uint32_t i = 0; i < num_tensors; ++i) {
7830+         buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
7831+         buf[i] = nullptr;
7832+         offset[i] = 0;
7833+         uma[i] = false;
7834+ 
7835+         if (ctx->device->uma) {
7836+             ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
7837+             uma[i] = buf[i] != nullptr;
7838+         }
7839+         if (!uma[i]) {
7840+             buf[i] = buf_ctx[i]->dev_buffer;
7841+             offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
7842+         }
7843+         GGML_ASSERT(buf[i] != nullptr);
7844+     }
7845+     // If any remaining descriptors are unused, just point them at src[0]
7846+     for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) {
7847+         buf[i] = buf[0];
7848+         offset[i] = 0;
7849+     }
7850+ 
7851+     std::array<uint32_t, 3> elements;
7852+ 
7853+     uint32_t ne = ggml_nelements(dst);
7854+     if (ne > 262144) {
7855+         elements = { 512, 512, CEIL_DIV(ne, 262144) };
7856+     } else if (ne > 512) {
7857+         elements = { 512, CEIL_DIV(ne, 512), 1 };
7858+     } else {
7859+         elements = { ne, 1, 1 };
7860+     }
7861+ 
7862+     ggml_vk_sync_buffers(subctx);
7863+     ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
7864+         {
7865+             vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE },
7866+             vk_subbuffer{ buf[1], offset[1], VK_WHOLE_SIZE },
7867+             vk_subbuffer{ buf[2], offset[2], VK_WHOLE_SIZE },
7868+             vk_subbuffer{ buf[3], offset[3], VK_WHOLE_SIZE },
7869+             vk_subbuffer{ buf[4], offset[4], VK_WHOLE_SIZE },
7870+             vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE },
7871+             vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE },
7872+             vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE },
7873+         }, pc, elements);
7874+ }
7875+ 
77467876static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
77477877    const uint32_t src0_type_size = ggml_type_size(src0->type);
77487878    const uint32_t src1_type_size = ggml_type_size(src1->type);
@@ -9703,8 +9833,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
97039833
97049834        break;
97059835    case GGML_OP_ADD:
9706-         ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
9707- 
9836+         if (ctx->num_additional_fused_ops) {
9837+             ggml_vk_multi_add(ctx, compute_ctx, cgraph, node_idx, dryrun);
9838+         } else {
9839+             ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
9840+         }
97089841        break;
97099842    case GGML_OP_SUB:
97109843        ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -10586,6 +10719,58 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
1058610719    return true;
1058710720}
1058810721
10722+ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
10723+ 
10724+     const ggml_tensor *first_node = cgraph->nodes[node_idx];
10725+     if (first_node->op != GGML_OP_ADD) {
10726+         return 0;
10727+     }
10728+ 
10729+     if (!ctx->device->multi_add) {
10730+         return 0;
10731+     }
10732+ 
10733+     int32_t num_adds = 1;
10734+     while (node_idx + num_adds < cgraph->n_nodes &&
10735+            cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD &&
10736+            num_adds < MAX_FUSED_ADDS) {
10737+         num_adds++;
10738+     }
10739+ 
10740+     // The shader currently requires same shapes (but different strides are allowed),
10741+     // everything f32, and no misalignment
10742+     for (int32_t i = 0; i < num_adds; ++i) {
10743+         const ggml_tensor *next_node = cgraph->nodes[node_idx + i];
10744+         if (!ggml_are_same_shape(first_node, next_node->src[0]) ||
10745+             !ggml_are_same_shape(first_node, next_node->src[1]) ||
10746+             next_node->type != GGML_TYPE_F32 ||
10747+             next_node->src[0]->type != GGML_TYPE_F32 ||
10748+             next_node->src[1]->type != GGML_TYPE_F32 ||
10749+             get_misalign_bytes(ctx, next_node) ||
10750+             get_misalign_bytes(ctx, next_node->src[0]) ||
10751+             get_misalign_bytes(ctx, next_node->src[1])) {
10752+             num_adds = i;
10753+         }
10754+     }
10755+ 
10756+     // Verify we can fuse these
10757+     ggml_op adds[MAX_FUSED_ADDS];
10758+     for (int32_t i = 0; i < num_adds; ++i) {
10759+         adds[i] = GGML_OP_ADD;
10760+     }
10761+ 
10762+     // decrease num_adds if they can't all be fused
10763+     while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, adds, num_adds)) {
10764+         num_adds--;
10765+     }
10766+ 
10767+     // a single add is not "fused", so just return zero
10768+     if (num_adds == 1) {
10769+         return 0;
10770+     }
10771+     return num_adds;
10772+ }
10773+ 
1058910774static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
1059010775    VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
1059110776    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -10599,8 +10784,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1059910784
1060010785    uint64_t total_mat_mul_bytes = 0;
1060110786    for (int i = 0; i < cgraph->n_nodes; i++) {
10602-         if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10603-             ctx->num_additional_fused_ops = 1;
10787+         if (!ctx->device->disable_fusion) {
10788+             uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
10789+             if (num_adds) {
10790+                 ctx->num_additional_fused_ops = num_adds - 1;
10791+             } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10792+                 ctx->num_additional_fused_ops = 1;
10793+             }
1060410794        }
1060510795        ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
1060610796        if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
@@ -10675,8 +10865,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1067510865            mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
1067610866        }
1067710867
10678-         if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10679-             ctx->num_additional_fused_ops = 1;
10868+         if (!ctx->device->disable_fusion) {
10869+             uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
10870+             if (num_adds) {
10871+                 ctx->num_additional_fused_ops = num_adds - 1;
10872+             } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
10873+                 ctx->num_additional_fused_ops = 1;
10874+             }
1068010875        }
1068110876
1068210877        // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
0 commit comments