@@ -659,6 +659,11 @@ struct vk_op_push_constants {
659659    float param2;
660660};
661661
662+ struct vk_op_glu_push_constants {
663+     uint32_t ne00;
664+     uint32_t mode;  // 0: default, 1: swapped, 2: split
665+ };
666+ 
662667struct vk_op_unary_push_constants {
663668    uint32_t ne;
664669    uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -2733,8 +2738,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
27332738#undef CREATE_UNARY
27342739
27352740#define CREATE_GLU(name)  \
2736-     ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2 , sizeof(vk_op_push_constants ), {1, 1, 1}, { device->subgroup_size }, 1);  \
2737-     ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2 , sizeof(vk_op_push_constants ), {1, 1, 1}, { device->subgroup_size }, 1);
2741+     ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3 , sizeof(vk_op_glu_push_constants ), {1, 1, 1}, { device->subgroup_size }, 1);  \
2742+     ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3 , sizeof(vk_op_glu_push_constants ), {1, 1, 1}, { device->subgroup_size }, 1);
27382743
27392744    CREATE_GLU(geglu)
27402745    CREATE_GLU(reglu)
@@ -6947,7 +6952,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
69476952        }
69486953    }
69496954
6950-     if (op == GGML_OP_SOFT_MAX) {
6955+     if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU ) {
69516956        // Empty src1 is possible in soft_max, but the shader needs a buffer
69526957        vk_subbuffer subbuf_y;
69536958        if (use_src1) {
@@ -7539,12 +7544,23 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
75397544    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
75407545}
75417546
7542- static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7543-     GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7547+ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7548+     const bool swapped = (bool)dst->op_params[1];
7549+     const bool split = src1 != nullptr;
7550+ 
7551+     GGML_ASSERT(ggml_is_contiguous(src0));
7552+ 
7553+     if (!split) {
7554+         GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7555+     } else {
7556+         GGML_ASSERT(src0->ne[0] == src1->ne[0]);
7557+         GGML_ASSERT(src0->ne[0] == dst->ne[0]);
7558+         GGML_ASSERT(src0->type == src1->type);
7559+     }
75447560
7545-     const uint32_t swapped  = (uint32_t)dst->op_params[1] ;
7561+     const uint32_t mode  = split ? 2 : (swapped ? 1 : 0) ;
75467562
7547-     ggml_vk_op_f32<vk_op_push_constants >(ctx, subctx, src0, nullptr , nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], swapped, 0.0f, 0.0f  }, dryrun);
7563+     ggml_vk_op_f32<vk_op_glu_push_constants >(ctx, subctx, src0, src1 , nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], mode  }, dryrun);
75487564}
75497565
75507566static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -9003,7 +9019,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
90039019        case GGML_GLU_OP_GEGLU:
90049020        case GGML_GLU_OP_REGLU:
90059021        case GGML_GLU_OP_SWIGLU:
9006-             ggml_vk_glu(ctx, compute_ctx, src0, node, dryrun);
9022+             ggml_vk_glu(ctx, compute_ctx, src0, src1,  node, dryrun);
90079023            break;
90089024        default:
90099025            return false;
@@ -10725,7 +10741,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1072510741            GGML_ABORT("fatal error");
1072610742        }
1072710743    } else if (tensor->op == GGML_OP_GLU) {
10728-         tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
10744+         if (src_clone[1] == nullptr) {
10745+             tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
10746+         } else {
10747+             tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
10748+         }
1072910749    } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
1073010750        if (src1 == nullptr) {
1073110751            tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
0 commit comments