@@ -2030,7 +2030,15 @@ static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up, const ggml
20302030 return false ;
20312031 }
20322032
2033- if (glu->src [0 ] != ffn_up && glu->src [1 ] != ffn_gate) {
2033+ if (glu->src [0 ] != ffn_gate && glu->src [1 ] != ffn_up) {
2034+ return false ;
2035+ }
2036+
2037+ if (ggml_get_glu_op (glu) != GGML_GLU_OP_SWIGLU) {
2038+ return false ;
2039+ }
2040+
2041+ if (const bool swapped = ggml_get_op_params_i32 (glu, 1 ); swapped) {
20342042 return false ;
20352043 }
20362044
@@ -2938,11 +2946,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29382946 return false ;
29392947 }
29402948
2941- const ggml_tensor * ffn_up = cgraph->nodes [node_idx];
2942- const ggml_tensor * ffn_gate = cgraph->nodes [node_idx+1 ];
2949+ const ggml_tensor * ffn_gate = cgraph->nodes [node_idx];
2950+ const ggml_tensor * ffn_up = cgraph->nodes [node_idx+1 ];
29432951 const ggml_tensor * glu = cgraph->nodes [node_idx+2 ];
29442952
2945- if (ggml_cuda_should_fuse_mul_mat (ffn_up, ffn_gate, glu)) {
2953+ if (ggml_cuda_should_fuse_mul_mat (ffn_up, ffn_gate, glu)) {
29462954 return true ;
29472955 }
29482956 }
@@ -3088,22 +3096,22 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
30883096
30893097 for (ggml_op op : {GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID}) {
30903098 if (ggml_cuda_can_fuse (cgraph, i, {op, op, GGML_OP_GLU}, {})) {
3091- const ggml_tensor * up = cgraph->nodes [i];
3092- const ggml_tensor * gate = cgraph->nodes [i+1 ]->src [0 ];
30933099 ggml_tensor * glu = cgraph->nodes [i+2 ];
3100+ ggml_tensor * gate = glu->src [0 ];
3101+ ggml_tensor * up = glu->src [1 ];
30943102
30953103 const ggml_tensor * src0 = up->src [0 ];
30963104 const ggml_tensor * src1 = up->src [1 ];
30973105 const ggml_tensor * ids = up->src [2 ];
30983106
30993107 if (ggml_cuda_should_fuse_mul_mat_vec_f (up)) {
3100- ggml_cuda_mul_mat_vec_f (*cuda_ctx, src0, src1, ids, glu, gate, glu);
3108+ ggml_cuda_mul_mat_vec_f (*cuda_ctx, src0, src1, ids, glu, gate-> src [ 0 ] , glu);
31013109 fused_mul_mat_vec = true ;
31023110 break ;
31033111 }
31043112
31053113 if (ggml_cuda_should_fuse_mul_mat_vec_q (up)) {
3106- ggml_cuda_mul_mat_vec_q (*cuda_ctx, src0, src1, ids, glu, gate, glu);
3114+ ggml_cuda_mul_mat_vec_q (*cuda_ctx, src0, src1, ids, glu, gate-> src [ 0 ] , glu);
31073115 fused_mul_mat_vec = true ;
31083116 break ;
31093117 }
0 commit comments