diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 08fd044ca03..3a8cd7d83a5 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -441,6 +441,11 @@ static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGM GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, GGML_OP_RESHAPE }; +static constexpr std::initializer_list topk_moe_early_softmax_norm_bias{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, + GGML_OP_RESHAPE }; + static constexpr std::initializer_list topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, @@ -477,6 +482,33 @@ static constexpr std::initializer_list> topk_moe_early_softma { 9, 0, 8 }, // reshape->src[0] == div }; +//node #601 ( SOFT_MAX): ffn_moe_probs-8 ( 0K) [Vulka ] use=2,c=1: ffn_moe_logits-8 ( 0K) [Vulka ] +//node #602 ( RESHAPE): ffn_moe_probs-8 (res ( 0K) [Vulka ] use=1,c=1: ffn_moe_probs-8 ( 0K) [Vulka ] +//node #603 ( ADD): ffn_moe_probs_biased ( 0K) [Vulka ] use=1,c=1: ffn_moe_probs-8 ( 0K) [Vulka ] blk.8.exp_probs_b.bi ( 0K) [Vulka ] +//node #604 ( ARGSORT): ffn_moe_argsort-8 ( 0K) [Vulka ] use=1,c=1: ffn_moe_probs_biased ( 0K) [Vulka ] +//node #605 ( VIEW): ffn_moe_topk-8 ( 0K) [Vulka ] use=4,c=1: ffn_moe_argsort-8 ( 0K) [Vulka ] +//node #606 ( GET_ROWS): ffn_moe_weights-8 ( 0K) [Vulka ] use=1,c=1: ffn_moe_probs-8 (res ( 0K) [Vulka ] ffn_moe_topk-8 ( 0K) [Vulka ] +//node #607 ( RESHAPE): ffn_moe_weights-8 (r ( 0K) [Vulka ] use=2,c=1: ffn_moe_weights-8 ( 0K) [Vulka ] +//node #608 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1,c=1: ffn_moe_weights-8 (r ( 0K) [Vulka ] +//node #609 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1,c=1: ffn_moe_weights_sum- ( 0K) [Vulka ] +//node #610 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1,c=1: ffn_moe_weights-8 (r ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ] +//node #611 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1,c=1: ffn_moe_weights_norm ( 0K) [Vulka ] +//node #612 ( SCALE): ffn_moe_weights_scal ( 0K) [Vulka ] use=1,c=1: ffn_moe_weights_norm ( 0K) [Vulka ] +static constexpr std::initializer_list> topk_moe_early_softmax_norm_bias_edges { + { 1, 0, 0 }, // reshape->src[0] == softmax + { 2, 0, 0 }, // add->src[0] == softmax + { 3, 0, 2 }, // argsort->src[0] == add + { 4, 0, 3 }, // view->src[0] == argsort + { 5, 0, 1 }, // get_rows->src[0] == reshape + { 5, 1, 4 }, // get_rows->src[1] == view + { 6, 0, 5 }, // reshape->src[0] == get_rows + { 7, 0, 6 }, // sum_rows->src[0] == reshape + { 8, 0, 7 }, // clamp->src[0] == sum_rows + { 9, 0, 6 }, // div->src[0] == reshape + { 9, 1, 8 }, // div->src[1] == clamp + {10, 0, 9 }, // reshape->src[0] == div +}; + //node #436 ( UNARY): ffn_moe_probs-10 ( 256K) [Vulka ] use=2: ffn_moe_logits-10 ( 256K) [Vulka ] //node #437 ( RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] //node #438 ( ADD): ffn_moe_probs_biased ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] blk.10.exp_probs_b.b ( 0K) [Vulka ] @@ -529,6 +561,7 @@ static constexpr std::initializer_list> topk_moe_late_softmax enum topk_moe_mode { TOPK_MOE_EARLY_SOFTMAX, TOPK_MOE_EARLY_SOFTMAX_NORM, + TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS, TOPK_MOE_LATE_SOFTMAX, TOPK_MOE_SIGMOID_NORM_BIAS, TOPK_MOE_COUNT, @@ -10613,9 +10646,9 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) { topk_moe_mode mode = ctx->fused_topk_moe_mode; ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0]; - ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits; + ggml_tensor * bias = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS || mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits; ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; - ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] : + ggml_tensor * ids = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS || mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] : (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3]; @@ -10651,7 +10684,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, pc.clamp_min = ggml_get_op_params_f32(clamp, 0); pc.clamp_max = ggml_get_op_params_f32(clamp, 1); } - if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) { + if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS || mode == TOPK_MOE_SIGMOID_NORM_BIAS) { ggml_tensor * clamp = cgraph->nodes[node_idx + 8]; GGML_ASSERT(clamp->op == GGML_OP_CLAMP); pc.clamp_min = ggml_get_op_params_f32(clamp, 0); @@ -10665,8 +10698,8 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID : mode == TOPK_MOE_LATE_SOFTMAX ? GATING_FUNC_SOFTMAX_WEIGHT : GATING_FUNC_SOFTMAX; - pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS; - pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS; + pc.has_bias = mode == TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS || mode == TOPK_MOE_SIGMOID_NORM_BIAS; + pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS || mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS; if (ctx->fused_topk_moe_scale) { GGML_ASSERT(weights->op == GGML_OP_SCALE); pc.output_scale = ggml_get_op_params_f32(weights, 0); @@ -13393,6 +13426,17 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc get_rows = cgraph->nodes[node_idx + 4]; argsort = cgraph->nodes[node_idx + 2]; break; + case TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS: + softmax = cgraph->nodes[node_idx + 0]; + weights = cgraph->nodes[node_idx + 9]; + get_rows = cgraph->nodes[node_idx + 5]; + argsort = cgraph->nodes[node_idx + 3]; + // bias is expected to be 1D + if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 || + !ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) { + return false; + } + break; case TOPK_MOE_SIGMOID_NORM_BIAS: softmax = cgraph->nodes[node_idx + 0]; // really sigmoid weights = cgraph->nodes[node_idx + 10]; @@ -13434,7 +13478,7 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc probs = probs->src[0]; ggml_tensor * selection_probs = argsort->src[0]; - if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) { + if (probs != selection_probs && !(mode == TOPK_MOE_SIGMOID_NORM_BIAS || mode == TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS)) { return false; } @@ -13777,6 +13821,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 3; ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM; fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM"; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm_bias, { i + 4, i + 10 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_bias_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS)) { + ctx->num_additional_fused_ops = topk_moe_early_softmax_norm_bias.size() - 1; + // view of argsort writes to memory + ctx->fused_ops_write_mask |= 1 << 4; + ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS; + fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS"; } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) && ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) { @@ -13988,6 +14040,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (keep_pattern(topk_moe_early_softmax_norm)) { continue; } + if (keep_pattern(topk_moe_early_softmax_norm_bias)) { + continue; + } if (keep_pattern(topk_moe_sigmoid_norm_bias)) { continue; } @@ -14017,6 +14072,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } // Don't pull forward nodes from fusion patterns if (match_pattern(topk_moe_early_softmax_norm, j) || + match_pattern(topk_moe_early_softmax_norm_bias, j) || match_pattern(topk_moe_sigmoid_norm_bias, j) || match_pattern(topk_moe_early_softmax, j) || match_pattern(topk_moe_late_softmax, j)) {