@@ -1195,8 +1195,8 @@ void ggml_vk_rope(
11951195 const std::shared_ptr<kp::Tensor>& inB,
11961196 const std::shared_ptr<kp::Tensor>& out,
11971197 uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1198- ggml_type src0t, int32_t n_dims, int32_t mode,
1199- float freq_base, float freq_scale,
1198+ ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx,
1199+ float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
12001200 int32_t ne01, int32_t ne02, int32_t ne03,
12011201 uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
12021202 int32_t ne0,
@@ -1224,15 +1224,15 @@ void ggml_vk_rope(
12241224
12251225 struct PushConstants {
12261226 uint32_t inAOff, inBOff, outOff;
1227- int32_t n_dims, mode;
1228- float freq_base, freq_scale;
1227+ int32_t n_dims, mode, n_orig_ctx ;
1228+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ;
12291229 uint32_t nb00, nb01, nb02, nb03;
12301230 int32_t ne0;
12311231 uint32_t nb0, nb1, nb2, nb3;
12321232 } pushConsts {
12331233 safe_divide (inAOff, type_size), safe_divide (inBOff, 4 ), safe_divide (outOff, type_size),
1234- n_dims, mode,
1235- freq_base, freq_scale,
1234+ n_dims, mode, n_orig_ctx,
1235+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
12361236 nb00, nb01, nb02, nb03,
12371237 ne0,
12381238 nb0, nb1, nb2, nb3
@@ -1545,13 +1545,23 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
15451545 GGML_ASSERT (ne10 == ne02);
15461546 GGML_ASSERT (src0t == dstt);
15471547 // const int n_past = ((int32_t *) dst->op_params)[0];
1548- const int n_dims = ((int32_t *) dst->op_params )[1 ];
1549- const int mode = ((int32_t *) dst->op_params )[2 ];
1550- float freq_base;
1551- float freq_scale;
1552- memcpy (&freq_base, (int32_t *) dst->op_params + 4 , sizeof (float ));
1553- memcpy (&freq_scale, (int32_t *) dst->op_params + 5 , sizeof (float ));
1554- ggml_vk_rope (seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, freq_base, freq_scale, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3);
1548+ const int n_dims = ((int32_t *) dst->op_params )[1 ];
1549+ const int mode = ((int32_t *) dst->op_params )[2 ];
1550+ // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
1551+ const int n_orig_ctx = ((int32_t *) dst->op_params )[4 ];
1552+
1553+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1554+ memcpy (&freq_base, (int32_t *) dst->op_params + 5 , sizeof (float ));
1555+ memcpy (&freq_scale, (int32_t *) dst->op_params + 6 , sizeof (float ));
1556+ memcpy (&ext_factor, (int32_t *) dst->op_params + 7 , sizeof (float ));
1557+ memcpy (&attn_factor, (int32_t *) dst->op_params + 8 , sizeof (float ));
1558+ memcpy (&beta_fast, (int32_t *) dst->op_params + 9 , sizeof (float ));
1559+ memcpy (&beta_slow, (int32_t *) dst->op_params + 10 , sizeof (float ));
1560+ ggml_vk_rope (
1561+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx,
1562+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1563+ ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
1564+ );
15551565 } break ;
15561566 case GGML_OP_DUP:
15571567 case GGML_OP_CPY:
0 commit comments