Skip to content

Commit 0b2acfc

Browse files
committed
cont: fix gate ordering
1 parent 70b8802 commit 0b2acfc

File tree

5 files changed

+22
-14
lines changed

5 files changed

+22
-14
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,7 +2030,15 @@ static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up, const ggml
20302030
return false;
20312031
}
20322032

2033-
if (glu->src[0] != ffn_up && glu->src[1] != ffn_gate) {
2033+
if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
2034+
return false;
2035+
}
2036+
2037+
if (ggml_get_glu_op(glu) != GGML_GLU_OP_SWIGLU) {
2038+
return false;
2039+
}
2040+
2041+
if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
20342042
return false;
20352043
}
20362044

@@ -2938,11 +2946,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29382946
return false;
29392947
}
29402948

2941-
const ggml_tensor * ffn_up = cgraph->nodes[node_idx];
2942-
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx+1];
2949+
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
2950+
const ggml_tensor * ffn_up = cgraph->nodes[node_idx+1];
29432951
const ggml_tensor * glu = cgraph->nodes[node_idx+2];
29442952

2945-
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
2953+
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
29462954
return true;
29472955
}
29482956
}
@@ -3088,22 +3096,22 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
30883096

30893097
for (ggml_op op : {GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID}) {
30903098
if (ggml_cuda_can_fuse(cgraph, i, {op, op, GGML_OP_GLU}, {})) {
3091-
const ggml_tensor * up = cgraph->nodes[i];
3092-
const ggml_tensor * gate = cgraph->nodes[i+1]->src[0];
30933099
ggml_tensor * glu = cgraph->nodes[i+2];
3100+
ggml_tensor * gate = glu->src[0];
3101+
ggml_tensor * up = glu->src[1];
30943102

30953103
const ggml_tensor * src0 = up->src[0];
30963104
const ggml_tensor * src1 = up->src[1];
30973105
const ggml_tensor * ids = up->src[2];
30983106

30993107
if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
3100-
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, gate, glu);
3108+
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, gate->src[0], glu);
31013109
fused_mul_mat_vec = true;
31023110
break;
31033111
}
31043112

31053113
if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
3106-
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, gate, glu);
3114+
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, gate->src[0], glu);
31073115
fused_mul_mat_vec = true;
31083116
break;
31093117
}

ggml/src/ggml-cuda/mmvf.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ static __global__ void mul_mat_vec_f(
197197
}
198198

199199
if constexpr (has_gate) {
200-
dst[tid*stride_col_dst + row] = op(sumf[tid]) * sumf_gate[tid];
200+
dst[tid*stride_col_dst + row] = sumf[tid] * op(sumf_gate[tid]);
201201
} else {
202202
dst[tid*stride_col_dst + row] = sumf[tid];
203203
}

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ static __global__ void mul_mat_vec_q(
242242
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
243243
float result = tmp[j][threadIdx.x];
244244
if constexpr (has_gate) {
245-
result = op(result) * tmp_gate[j][threadIdx.x];
245+
result = result * op(tmp_gate[j][threadIdx.x]);
246246
}
247247
dst[j*stride_col_dst + threadIdx.x] = result;
248248
}

ggml/src/ggml-cuda/unary.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,4 @@ void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8080

8181
__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
8282
return x / (1.0f + expf(-x));
83-
}
83+
}

tests/test-backend-ops.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4630,7 +4630,7 @@ struct test_fused_ffn_gate : public test_case {
46304630
ggml_tensor * ffn_up = ggml_mul_mat(ctx, up, cur);
46314631
ggml_tensor * ffn_gate = ggml_mul_mat(ctx, gate, cur);
46324632

4633-
ggml_tensor * out = ggml_glu_split(ctx, ffn_up, ffn_gate, glu_op);
4633+
ggml_tensor * out = ggml_glu_split(ctx, ffn_gate, ffn_up, glu_op);
46344634

46354635
ggml_set_name(out, "out");
46364636
return out;
@@ -4652,10 +4652,10 @@ struct test_fused_ffn_gate : public test_case {
46524652
ggml_tensor * cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k, this->b ? 1 : n_used, m);
46534653
ggml_set_name(cur, "cur");
46544654

4655-
ggml_tensor * ffn_gate = ggml_mul_mat_id(ctx, gates, cur, ids);
46564655
ggml_tensor * ffn_up = ggml_mul_mat_id(ctx, ups, cur, ids);
4656+
ggml_tensor * ffn_gate = ggml_mul_mat_id(ctx, gates, cur, ids);
46574657

4658-
ggml_tensor * out = ggml_glu_split(ctx, ffn_up, ffn_gate, glu_op);
4658+
ggml_tensor * out = ggml_glu_split(ctx, ffn_gate, ffn_up, glu_op);
46594659

46604660
ggml_set_name(out, "out");
46614661
return out;

0 commit comments

Comments
 (0)