Skip to content

Commit 45491b7

Browse files
author
firecoperana
committed
llama : fix KV shift for qwen2vl #13870
1 parent 149b086 commit 45491b7

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

src/llama-build-context.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,18 @@ ggml_cgraph * llm_build_context::build_k_shift() {
9999

100100
GGML_ASSERT(kv_self.size == n_ctx);
101101

102+
const auto & rope_type_shift = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE
103+
// @ngxson : this is a workaround
104+
// for M-RoPE, we want to rotate the whole vector when doing KV shift
105+
// a normal RoPE should work, we just need to use the correct ordering
106+
// ref: https://github.com/ggml-org/llama.cpp/pull/13870
107+
? LLAMA_ROPE_TYPE_NEOX
108+
: hparams.rope_type;
109+
110+
const float yarn_attn_factor_shift = model.arch == LLM_ARCH_DEEPSEEK2
111+
? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
112+
: cparams.yarn_attn_factor;
113+
102114
lctx.inp_K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
103115
cb(lctx.inp_K_shift, "K_shift", -1);
104116
ggml_set_input(lctx.inp_K_shift);
@@ -127,15 +139,15 @@ ggml_cgraph * llm_build_context::build_k_shift() {
127139
}
128140
}
129141
tmp = ggml_rope_ext_inplace(ctx0, tmp,
130-
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
131-
ext_factor, attn_factor, beta_fast, beta_slow);
142+
lctx.inp_K_shift, rope_factors, n_rot, rope_type_shift, n_ctx_orig, freq_base, freq_scale,
143+
ext_factor, yarn_attn_factor_shift, beta_fast, beta_slow);
132144
cb(tmp, "K_shifted_f32", il);
133145
tmp = ggml_cpy(ctx0, tmp, k);
134146
} else {
135147
// we rotate only the first n_rot dimensions
136148
tmp = ggml_rope_ext_inplace(ctx0, k,
137-
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
138-
ext_factor, attn_factor, beta_fast, beta_slow);
149+
lctx.inp_K_shift, rope_factors, n_rot, rope_type_shift, n_ctx_orig, freq_base, freq_scale,
150+
ext_factor, yarn_attn_factor_shift, beta_fast, beta_slow);
139151
}
140152
cb(tmp, "K_shifted", il);
141153
ggml_build_forward_expand(gf, tmp);

0 commit comments

Comments
 (0)