Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,11 @@ static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGM
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
GGML_OP_RESHAPE };

static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm_bias{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
GGML_OP_RESHAPE };

static constexpr std::initializer_list<ggml_op> topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY, GGML_OP_RESHAPE, GGML_OP_ADD,
GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
Expand Down Expand Up @@ -477,6 +482,33 @@ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softma
{ 9, 0, 8 }, // reshape->src[0] == div
};

//node #601 ( SOFT_MAX): ffn_moe_probs-8 ( 0K) [Vulka ] use=2,c=1: ffn_moe_logits-8 ( 0K) [Vulka ]
//node #602 ( RESHAPE): ffn_moe_probs-8 (res ( 0K) [Vulka ] use=1,c=1: ffn_moe_probs-8 ( 0K) [Vulka ]
//node #603 ( ADD): ffn_moe_probs_biased ( 0K) [Vulka ] use=1,c=1: ffn_moe_probs-8 ( 0K) [Vulka ] blk.8.exp_probs_b.bi ( 0K) [Vulka ]
//node #604 ( ARGSORT): ffn_moe_argsort-8 ( 0K) [Vulka ] use=1,c=1: ffn_moe_probs_biased ( 0K) [Vulka ]
//node #605 ( VIEW): ffn_moe_topk-8 ( 0K) [Vulka ] use=4,c=1: ffn_moe_argsort-8 ( 0K) [Vulka ]
//node #606 ( GET_ROWS): ffn_moe_weights-8 ( 0K) [Vulka ] use=1,c=1: ffn_moe_probs-8 (res ( 0K) [Vulka ] ffn_moe_topk-8 ( 0K) [Vulka ]
//node #607 ( RESHAPE): ffn_moe_weights-8 (r ( 0K) [Vulka ] use=2,c=1: ffn_moe_weights-8 ( 0K) [Vulka ]
//node #608 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1,c=1: ffn_moe_weights-8 (r ( 0K) [Vulka ]
//node #609 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1,c=1: ffn_moe_weights_sum- ( 0K) [Vulka ]
//node #610 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1,c=1: ffn_moe_weights-8 (r ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ]
//node #611 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1,c=1: ffn_moe_weights_norm ( 0K) [Vulka ]
//node #612 ( SCALE): ffn_moe_weights_scal ( 0K) [Vulka ] use=1,c=1: ffn_moe_weights_norm ( 0K) [Vulka ]
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_bias_edges {
{ 1, 0, 0 }, // reshape->src[0] == softmax
{ 2, 0, 0 }, // add->src[0] == softmax
{ 3, 0, 2 }, // argsort->src[0] == add
{ 4, 0, 3 }, // view->src[0] == argsort
{ 5, 0, 1 }, // get_rows->src[0] == reshape
{ 5, 1, 4 }, // get_rows->src[1] == view
{ 6, 0, 5 }, // reshape->src[0] == get_rows
{ 7, 0, 6 }, // sum_rows->src[0] == reshape
{ 8, 0, 7 }, // clamp->src[0] == sum_rows
{ 9, 0, 6 }, // div->src[0] == reshape
{ 9, 1, 8 }, // div->src[1] == clamp
{10, 0, 9 }, // reshape->src[0] == div
};

//node #436 ( UNARY): ffn_moe_probs-10 ( 256K) [Vulka ] use=2: ffn_moe_logits-10 ( 256K) [Vulka ]
//node #437 ( RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ]
//node #438 ( ADD): ffn_moe_probs_biased ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] blk.10.exp_probs_b.b ( 0K) [Vulka ]
Expand Down Expand Up @@ -529,6 +561,7 @@ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax
enum topk_moe_mode {
TOPK_MOE_EARLY_SOFTMAX,
TOPK_MOE_EARLY_SOFTMAX_NORM,
TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS,
TOPK_MOE_LATE_SOFTMAX,
TOPK_MOE_SIGMOID_NORM_BIAS,
TOPK_MOE_COUNT,
Expand Down Expand Up @@ -10613,9 +10646,9 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
topk_moe_mode mode = ctx->fused_topk_moe_mode;
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits;
ggml_tensor * bias = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS || mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits;
ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] :
ggml_tensor * ids = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS || mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] :
(mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] :
cgraph->nodes[node_idx + 3];

Expand Down Expand Up @@ -10651,7 +10684,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
}
if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) {
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS || mode == TOPK_MOE_SIGMOID_NORM_BIAS) {
ggml_tensor * clamp = cgraph->nodes[node_idx + 8];
GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
Expand All @@ -10665,8 +10698,8 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID :
mode == TOPK_MOE_LATE_SOFTMAX ? GATING_FUNC_SOFTMAX_WEIGHT :
GATING_FUNC_SOFTMAX;
pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS;
pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
pc.has_bias = mode == TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS || mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
if (ctx->fused_topk_moe_scale) {
GGML_ASSERT(weights->op == GGML_OP_SCALE);
pc.output_scale = ggml_get_op_params_f32(weights, 0);
Expand Down Expand Up @@ -13393,6 +13426,17 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
get_rows = cgraph->nodes[node_idx + 4];
argsort = cgraph->nodes[node_idx + 2];
break;
case TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 9];
get_rows = cgraph->nodes[node_idx + 5];
argsort = cgraph->nodes[node_idx + 3];
// bias is expected to be 1D
if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 ||
!ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) {
return false;
}
break;
case TOPK_MOE_SIGMOID_NORM_BIAS:
softmax = cgraph->nodes[node_idx + 0]; // really sigmoid
weights = cgraph->nodes[node_idx + 10];
Expand Down Expand Up @@ -13434,7 +13478,7 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];

if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) {
if (probs != selection_probs && !(mode == TOPK_MOE_SIGMOID_NORM_BIAS || mode == TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS)) {
return false;
}

Expand Down Expand Up @@ -13777,6 +13821,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->fused_ops_write_mask |= 1 << 3;
ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm_bias, { i + 4, i + 10 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_bias_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm_bias.size() - 1;
// view of argsort writes to memory
ctx->fused_ops_write_mask |= 1 << 4;
ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS;
fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM_BIAS";
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
Expand Down Expand Up @@ -13988,6 +14040,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
if (keep_pattern(topk_moe_early_softmax_norm)) {
continue;
}
if (keep_pattern(topk_moe_early_softmax_norm_bias)) {
continue;
}
if (keep_pattern(topk_moe_sigmoid_norm_bias)) {
continue;
}
Expand Down Expand Up @@ -14017,6 +14072,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
}
// Don't pull forward nodes from fusion patterns
if (match_pattern(topk_moe_early_softmax_norm, j) ||
match_pattern(topk_moe_early_softmax_norm_bias, j) ||
match_pattern(topk_moe_sigmoid_norm_bias, j) ||
match_pattern(topk_moe_early_softmax, j) ||
match_pattern(topk_moe_late_softmax, j)) {
Expand Down
Loading