-
Couldn't load subscription status.
- Fork 13.5k
llama: use sliding window for phi3 #8627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4974,6 +4974,8 @@ static void llm_load_hparams( | |
| } break; | ||
| case LLM_ARCH_PHI3: | ||
| { | ||
| hparams.n_swa = 2048; | ||
| ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); | ||
| ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); | ||
|
|
||
| switch (hparams.n_layer) { | ||
|
|
@@ -10843,7 +10845,7 @@ struct llm_build_context { | |
| struct ggml_tensor * inp_pos = build_inp_pos(); | ||
|
|
||
| // KQ_mask (mask for 1 head, it will be broadcasted to all heads) | ||
| struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); | ||
| struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(); | ||
|
|
||
| for (int il = 0; il < n_layer; ++il) { | ||
| auto residual = inpL; | ||
|
|
@@ -10901,7 +10903,7 @@ struct llm_build_context { | |
|
|
||
| cur = llm_build_kv(ctx0, lctx, kv_self, gf, | ||
| model.layers[il].wo, model.layers[il].bo, | ||
| Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); | ||
| Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il); | ||
| } | ||
|
|
||
| if (il == n_layer - 1) { | ||
|
|
@@ -14108,18 +14110,23 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | |
| "causal attention is not supported by this model" | ||
| ); | ||
|
|
||
| if (lctx.inp_KQ_mask) { | ||
| if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) { | ||
| // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. | ||
| if (cparams.causal_attn && !lctx.is_encoding) { | ||
| const int64_t n_kv = kv_self.n; | ||
| const int64_t n_tokens = batch.n_tokens; | ||
|
|
||
| GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); | ||
|
|
||
| float * data = (float *) lctx.inp_KQ_mask->data; | ||
| float * data = nullptr; | ||
| float * data_swa = nullptr; | ||
|
|
||
| if (lctx.inp_KQ_mask) { | ||
| GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); | ||
| data_swa = (float *) lctx.inp_KQ_mask->data; | ||
| } | ||
|
|
||
| if (lctx.inp_KQ_mask_swa) { | ||
| GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer)); | ||
| data_swa = (float *) lctx.inp_KQ_mask_swa->data; | ||
| } | ||
|
|
||
|
|
@@ -14142,7 +14149,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | |
| f = 0.0f; | ||
| } | ||
| } | ||
| data[h*(n_kv*n_tokens) + j*n_kv + i] = f; | ||
|
|
||
| if (data) { | ||
| data[h*(n_kv*n_tokens) + j*n_kv + i] = f; | ||
| } | ||
|
|
||
| // may need to cut off old tokens for sliding window | ||
| if (data_swa) { | ||
|
|
@@ -14154,9 +14164,19 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | |
| } | ||
| } | ||
|
|
||
| for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { | ||
| for (int j = 0; j < n_kv; ++j) { | ||
| data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; | ||
| if (data) { | ||
| for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { | ||
| for (int j = 0; j < n_kv; ++j) { | ||
| data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (data_swa) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Overwriting it here may break gemma2 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I'm not familiar with gemma2, so I haven't test the PR on gemma2. I only test this PR with Phi3 on CPU. I do not understand when it should padded to I'm confused on the original code that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ngxson I agree with @FanShupei here, I think Not sure why this worked before though. Padding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both should be padded. The padding is necessary so that GPU kernels (such as the Metal Flash-Attention) not perform extra checks for out-of-bounds access when working on chunks of data |
||
| for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { | ||
| for (int j = 0; j < n_kv; ++j) { | ||
| data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.