Skip to content

Commit 39da6fe

Browse files
committed
cont: fix gate ordering
1 parent 70b8802 commit 39da6fe

File tree

6 files changed

+96
-29
lines changed

6 files changed

+96
-29
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,7 +2030,15 @@ static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up, const ggml
20302030
return false;
20312031
}
20322032

2033-
if (glu->src[0] != ffn_up && glu->src[1] != ffn_gate) {
2033+
if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
2034+
return false;
2035+
}
2036+
2037+
if (ggml_get_glu_op(glu) != GGML_GLU_OP_SWIGLU) {
2038+
return false;
2039+
}
2040+
2041+
if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
20342042
return false;
20352043
}
20362044

@@ -2938,11 +2946,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29382946
return false;
29392947
}
29402948

2941-
const ggml_tensor * ffn_up = cgraph->nodes[node_idx];
2942-
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx+1];
2949+
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
2950+
const ggml_tensor * ffn_up = cgraph->nodes[node_idx+1];
29432951
const ggml_tensor * glu = cgraph->nodes[node_idx+2];
29442952

2945-
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
2953+
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
29462954
return true;
29472955
}
29482956
}
@@ -3088,22 +3096,22 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
30883096

30893097
for (ggml_op op : {GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID}) {
30903098
if (ggml_cuda_can_fuse(cgraph, i, {op, op, GGML_OP_GLU}, {})) {
3091-
const ggml_tensor * up = cgraph->nodes[i];
3092-
const ggml_tensor * gate = cgraph->nodes[i+1]->src[0];
30933099
ggml_tensor * glu = cgraph->nodes[i+2];
3100+
ggml_tensor * gate = glu->src[0];
3101+
ggml_tensor * up = glu->src[1];
30943102

30953103
const ggml_tensor * src0 = up->src[0];
30963104
const ggml_tensor * src1 = up->src[1];
30973105
const ggml_tensor * ids = up->src[2];
30983106

30993107
if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
3100-
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, gate, glu);
3108+
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, gate->src[0], glu);
31013109
fused_mul_mat_vec = true;
31023110
break;
31033111
}
31043112

31053113
if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
3106-
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, gate, glu);
3114+
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, gate->src[0], glu);
31073115
fused_mul_mat_vec = true;
31083116
break;
31093117
}

ggml/src/ggml-cuda/mmvf.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ static __global__ void mul_mat_vec_f(
197197
}
198198

199199
if constexpr (has_gate) {
200-
dst[tid*stride_col_dst + row] = op(sumf[tid]) * sumf_gate[tid];
200+
dst[tid*stride_col_dst + row] = sumf[tid] * op(sumf_gate[tid]);
201201
} else {
202202
dst[tid*stride_col_dst + row] = sumf[tid];
203203
}

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ static __global__ void mul_mat_vec_q(
242242
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
243243
float result = tmp[j][threadIdx.x];
244244
if constexpr (has_gate) {
245-
result = op(result) * tmp_gate[j][threadIdx.x];
245+
result = result * op(tmp_gate[j][threadIdx.x]);
246246
}
247247
dst[j*stride_col_dst + threadIdx.x] = result;
248248
}

ggml/src/ggml-cuda/unary.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,4 @@ void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8080

8181
__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
8282
return x / (1.0f + expf(-x));
83-
}
83+
}

ggml/src/ggml-impl.h

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,11 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
567567

568568
// return true if the node's results are only used by N other nodes
569569
// and can be fused into their calculations.
570-
static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
570+
static inline bool ggml_node_has_n_uses_impl(
571+
const struct ggml_cgraph * cgraph,
572+
int node_idx,
573+
int32_t n_uses,
574+
bool allow_views) {
571575
const struct ggml_tensor * node = cgraph->nodes[node_idx];
572576

573577
// check the use count against how many we're replacing
@@ -579,7 +583,14 @@ static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int n
579583
// if node is a view, some other node might be using the intermediate result
580584
// via the view source.
581585
if (node->view_src) {
582-
return false;
586+
if (!allow_views) {
587+
return false;
588+
}
589+
590+
size_t src_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node->view_src);
591+
if (!ggml_bitset_get(cgraph->visited_hash_set.used, src_hash_pos) || cgraph->use_counts[src_hash_pos] != 1) {
592+
return false;
593+
}
583594
}
584595

585596
// If the user requested output for the node, can't fuse
@@ -590,35 +601,83 @@ static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int n
590601
return true;
591602
}
592603

604+
static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
605+
return ggml_node_has_n_uses_impl(cgraph, node_idx, n_uses, false);
606+
}
607+
593608
// Returns true if nodes with indices { node_idxs } are the sequence of ggml_ops in ops[]
594609
// and are fusable. Nodes are considered fusable according to this function if:
595-
// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
596-
// - all nodes except the last are a src of the following node.
597-
// - all nodes are the same shape.
610+
// - all nodes except the last have only one use and their consumers are inside the fusion set.
611+
// - dependencies between nodes follow the order provided in node_idxs.
598612
// TODO: Consider allowing GGML_OP_NONE nodes in between
599613
static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const int * node_idxs, const enum ggml_op * ops, int num_ops) {
614+
GGML_ASSERT(num_ops <= 32);
615+
616+
if (num_ops <= 0) {
617+
return false;
618+
}
619+
620+
struct ggml_tensor * nodes[32] = {0};
621+
600622
for (int i = 0; i < num_ops; ++i) {
601-
if (node_idxs[i] >= cgraph->n_nodes) {
623+
const int idx = node_idxs[i];
624+
if (idx >= cgraph->n_nodes) {
602625
return false;
603626
}
604627

605-
struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];
606-
if (node->op != ops[i]) {
607-
return false;
608-
}
609-
if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
628+
nodes[i] = cgraph->nodes[idx];
629+
if (nodes[i]->op != ops[i]) {
610630
return false;
611631
}
612-
if (i > 0) {
613-
struct ggml_tensor * prev = cgraph->nodes[node_idxs[i - 1]];
614-
if (node->src[0] != prev && node->src[1] != prev) {
632+
}
633+
634+
for (int i = 0; i < num_ops; ++i) {
635+
struct ggml_tensor * node = nodes[i];
636+
637+
if (i < num_ops - 1) {
638+
const bool allow_views = node->view_src != NULL;
639+
if (!ggml_node_has_n_uses_impl(cgraph, node_idxs[i], 1, allow_views)) {
615640
return false;
616641
}
617-
if (!ggml_are_same_shape(node, prev)) {
642+
}
643+
644+
for (int j = 0; j < GGML_MAX_SRC; ++j) {
645+
struct ggml_tensor * src = node->src[j];
646+
if (!src) {
647+
continue;
648+
}
649+
650+
int src_pos = -1;
651+
for (int k = 0; k < num_ops; ++k) {
652+
if (nodes[k] == src) {
653+
src_pos = k;
654+
break;
655+
}
656+
}
657+
658+
if (src_pos != -1 && src_pos >= i) {
618659
return false;
619660
}
620661
}
621662
}
663+
664+
for (int i = 0; i < num_ops - 1; ++i) {
665+
bool has_consumer = false;
666+
for (int k = i + 1; k < num_ops && !has_consumer; ++k) {
667+
struct ggml_tensor * consumer = nodes[k];
668+
for (int s = 0; s < GGML_MAX_SRC; ++s) {
669+
if (consumer->src[s] == nodes[i]) {
670+
has_consumer = true;
671+
break;
672+
}
673+
}
674+
}
675+
676+
if (!has_consumer) {
677+
return false;
678+
}
679+
}
680+
622681
return true;
623682
}
624683

tests/test-backend-ops.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4630,7 +4630,7 @@ struct test_fused_ffn_gate : public test_case {
46304630
ggml_tensor * ffn_up = ggml_mul_mat(ctx, up, cur);
46314631
ggml_tensor * ffn_gate = ggml_mul_mat(ctx, gate, cur);
46324632

4633-
ggml_tensor * out = ggml_glu_split(ctx, ffn_up, ffn_gate, glu_op);
4633+
ggml_tensor * out = ggml_glu_split(ctx, ffn_gate, ffn_up, glu_op);
46344634

46354635
ggml_set_name(out, "out");
46364636
return out;
@@ -4652,10 +4652,10 @@ struct test_fused_ffn_gate : public test_case {
46524652
ggml_tensor * cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k, this->b ? 1 : n_used, m);
46534653
ggml_set_name(cur, "cur");
46544654

4655-
ggml_tensor * ffn_gate = ggml_mul_mat_id(ctx, gates, cur, ids);
46564655
ggml_tensor * ffn_up = ggml_mul_mat_id(ctx, ups, cur, ids);
4656+
ggml_tensor * ffn_gate = ggml_mul_mat_id(ctx, gates, cur, ids);
46574657

4658-
ggml_tensor * out = ggml_glu_split(ctx, ffn_up, ffn_gate, glu_op);
4658+
ggml_tensor * out = ggml_glu_split(ctx, ffn_gate, ffn_up, glu_op);
46594659

46604660
ggml_set_name(out, "out");
46614661
return out;

0 commit comments

Comments
 (0)