@@ -1195,8 +1195,8 @@ void ggml_vk_rope(
11951195    const  std::shared_ptr<kp::Tensor>& inB,
11961196    const  std::shared_ptr<kp::Tensor>& out,
11971197    uint32_t  inAOff, uint32_t  inBOff, uint32_t  outOff,
1198-     ggml_type src0t, int32_t  n_dims, int32_t  mode,
1199-     float  freq_base, float  freq_scale,
1198+     ggml_type src0t, int32_t  n_dims, int32_t  mode,  int32_t  n_orig_ctx, 
1199+     float  freq_base, float  freq_scale,  float  ext_factor,  float  attn_factor,  float  beta_fast,  float  beta_slow, 
12001200    int32_t  ne01, int32_t  ne02, int32_t  ne03,
12011201    uint32_t  nb00, uint32_t  nb01, uint32_t  nb02, uint32_t  nb03,
12021202    int32_t  ne0,
@@ -1224,15 +1224,15 @@ void ggml_vk_rope(
12241224
12251225    struct  PushConstants  {
12261226        uint32_t  inAOff, inBOff, outOff;
1227-         int32_t  n_dims, mode;
1228-         float  freq_base, freq_scale;
1227+         int32_t  n_dims, mode, n_orig_ctx ;
1228+         float  freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ;
12291229        uint32_t  nb00, nb01, nb02, nb03;
12301230        int32_t  ne0;
12311231        uint32_t  nb0, nb1, nb2, nb3;
12321232    } pushConsts {
12331233        safe_divide (inAOff, type_size), safe_divide (inBOff, 4 ), safe_divide (outOff, type_size),
1234-         n_dims, mode,
1235-         freq_base, freq_scale,
1234+         n_dims, mode, n_orig_ctx, 
1235+         freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, 
12361236        nb00, nb01, nb02, nb03,
12371237        ne0,
12381238        nb0, nb1, nb2, nb3
@@ -1545,13 +1545,23 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
15451545                        GGML_ASSERT (ne10 == ne02);
15461546                        GGML_ASSERT (src0t == dstt);
15471547                        //  const int n_past = ((int32_t *) dst->op_params)[0];
1548-                         const  int  n_dims = ((int32_t  *) dst->op_params )[1 ];
1549-                         const  int  mode   = ((int32_t  *) dst->op_params )[2 ];
1550-                         float  freq_base;
1551-                         float  freq_scale;
1552-                         memcpy (&freq_base,  (int32_t  *) dst->op_params  + 4 , sizeof (float ));
1553-                         memcpy (&freq_scale, (int32_t  *) dst->op_params  + 5 , sizeof (float ));
1554-                         ggml_vk_rope (seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, freq_base, freq_scale, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3);
1548+                         const  int  n_dims     = ((int32_t  *) dst->op_params )[1 ];
1549+                         const  int  mode       = ((int32_t  *) dst->op_params )[2 ];
1550+                         //  skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
1551+                         const  int  n_orig_ctx = ((int32_t  *) dst->op_params )[4 ];
1552+ 
1553+                         float  freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1554+                         memcpy (&freq_base,   (int32_t  *) dst->op_params  +  5 , sizeof (float ));
1555+                         memcpy (&freq_scale,  (int32_t  *) dst->op_params  +  6 , sizeof (float ));
1556+                         memcpy (&ext_factor,  (int32_t  *) dst->op_params  +  7 , sizeof (float ));
1557+                         memcpy (&attn_factor, (int32_t  *) dst->op_params  +  8 , sizeof (float ));
1558+                         memcpy (&beta_fast,   (int32_t  *) dst->op_params  +  9 , sizeof (float ));
1559+                         memcpy (&beta_slow,   (int32_t  *) dst->op_params  + 10 , sizeof (float ));
1560+                         ggml_vk_rope (
1561+                             seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx,
1562+                             freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1563+                             ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
1564+                         );
15551565                    } break ;
15561566                case  GGML_OP_DUP:
15571567                case  GGML_OP_CPY:
0 commit comments