Skip to content

llama: avoid copying logits during prompt decode in MTP#23198

Merged
am17an merged 3 commits into
ggml-org:masterfrom
am17an:mtp-pp-fix
May 17, 2026
Merged

llama: avoid copying logits during prompt decode in MTP#23198
am17an merged 3 commits into
ggml-org:masterfrom
am17an:mtp-pp-fix

Conversation

@am17an
Copy link
Copy Markdown
Contributor

@am17an am17an commented May 17, 2026

Overview

Avoid copying the logits for every token in the batch when doing prompt processing for MTP since it only requires the pre-norm. This reduces memory traffic quite a bit and in turn increases PP speed with MTP.

Additional information

Requirements

@am17an am17an requested review from a team, CISC and ggerganov as code owners May 17, 2026 10:22
Comment thread src/llama-context.cpp Outdated
Comment thread src/models/qwen35moe.cpp
Copy link
Copy Markdown
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick bench on RTX 5090 with Qwen3.6 27B Q4_K

Image

@am17an am17an merged commit 3e12fbd into ggml-org:master May 17, 2026
75 of 81 checks passed
@am17an am17an deleted the mtp-pp-fix branch May 17, 2026 15:30
@d-r-e
Copy link
Copy Markdown

d-r-e commented May 17, 2026

A quick bench on RTX 5090 with Qwen3.6 27B Q4_K

Image

Are the legend colors swapped?

@pwilkin
Copy link
Copy Markdown
Member

pwilkin commented May 17, 2026

@d-r-e no, MTP does negatively impact prompt processing, but under this PR the negative impact is halved.

@cb88
Copy link
Copy Markdown

cb88 commented May 17, 2026

2xMI50 qwen 27b Q4_1 does see some improvement with this PR
MI50 without MTP = 500t/s
with MTP = 250t/s
with MTP this PR = 300t/s

@tha80 tha80 mentioned this pull request May 17, 2026
11 tasks
@0cc4m
Copy link
Copy Markdown
Contributor

0cc4m commented May 17, 2026

@d-r-e no, MTP does negatively impact prompt processing, but under this PR the negative impact is halved.

Why does it affect prompt processing?

@Mithras
Copy link
Copy Markdown

Mithras commented May 17, 2026

Made sure this PR is included and re-tested:
unsloth/Qwen3.6-27B-MTP-GGUF/Qwen3.6-27B-UD-Q5_K_XL.gguf

| model     |             test |             t/s |     peak t/s |          ttfr (ms) |       est_ppt (ms) |      e2e_ttft (ms) |
|:----------|-----------------:|----------------:|-------------:|-------------------:|-------------------:|-------------------:|
| qwen36-27 |  pp2048 @ d16384 | 1843.93 ± 14.82 |              |   9098.06 ± 124.65 |   9097.48 ± 124.65 |   9098.06 ± 124.65 |
| qwen36-27 |   tg128 @ d16384 |    74.46 ± 3.24 | 84.00 ± 0.82 |                    |                    |                    |
| qwen36-27 |  pp2048 @ d65536 |  1449.72 ± 9.78 |              |  42344.98 ± 292.27 |  42344.40 ± 292.27 |  42344.98 ± 292.27 |
| qwen36-27 |   tg128 @ d65536 |    61.97 ± 2.35 | 68.33 ± 4.78 |                    |                    |                    |
| qwen36-27 | pp2048 @ d131072 |  1075.30 ± 2.36 |              | 112238.48 ± 281.78 | 112237.90 ± 281.78 | 112238.48 ± 281.78 |
| qwen36-27 |  tg128 @ d131072 |    48.40 ± 2.53 | 55.00 ± 0.00 |                    |                    |                    |

pretty much the same as #22673 (comment) which probably had the PR already. Still almost 50% pp hit

@pwilkin
Copy link
Copy Markdown
Member

pwilkin commented May 17, 2026

Why does it affect prompt processing?

Due to the embeddings copy, most likely.

DrBearJew referenced this pull request in DrBearJew/RoxxY May 17, 2026
* llama: avoid copying logits during prompt decode in MTP

* review: update comment

* llama-graph: call set_output for t_h_pre_norm
de-wim added a commit to de-wim/llama.cpp that referenced this pull request May 18, 2026
@cb88

This comment has been minimized.

eugenehp added a commit to eugenehp/llama-cpp-rs that referenced this pull request May 19, 2026
kgrama pushed a commit to kgrama/llama.cpp that referenced this pull request May 19, 2026
* llama: avoid copying logits during prompt decode in MTP

* review: update comment

* llama-graph: call set_output for t_h_pre_norm
xxmustafacooTR pushed a commit to xxPlayground/llama-cpp-turboquant that referenced this pull request May 19, 2026
* llama: avoid copying logits during prompt decode in MTP

* review: update comment

* llama-graph: call set_output for t_h_pre_norm
rsenthilkumar6 pushed a commit to rsenthilkumar6/llama.cpp that referenced this pull request May 19, 2026
* llama: avoid copying logits during prompt decode in MTP

* review: update comment

* llama-graph: call set_output for t_h_pre_norm
xbezdick pushed a commit to xbezdick/llama.cpp that referenced this pull request May 19, 2026
ArberSephirotheca pushed a commit to ArberSephirotheca/llama.cpp that referenced this pull request May 19, 2026
* llama: avoid copying logits during prompt decode in MTP

* review: update comment

* llama-graph: call set_output for t_h_pre_norm
fhnmor21 pushed a commit to fhnmor21/llama-cpp-turboquant that referenced this pull request May 19, 2026
* llama: avoid copying logits during prompt decode in MTP

* review: update comment

* llama-graph: call set_output for t_h_pre_norm
@jtjstock
Copy link
Copy Markdown

This particular commit seems to regress the acceptance rate, I lose about 5% at n = 4 in coding tasks. Otherwise the prefill improvement is amazing

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented May 21, 2026

@jtjstock this commit just stops copying stuff we don't ever read, in a way it's a free optimization and it should not affect anything except the prefill

@jtjstock
Copy link
Copy Markdown

jtjstock commented May 21, 2026

@am17an Just did a bunch of runs, one at latest(b9254): draft acceptance = 0.75134 ( 3913 accepted / 5208 generated)
and one with this and a dependent commit reverted: draft acceptance = 0.83896 ( 4376 accepted / 5216 generated)

Same prompt. Same Seed. Temp 0. Same hardware minutes apart. Same compiler. The only oddity seems to be the outputs from the latest(b9254), they are slightly different each time, but highly similar, where as with mainline commits (3e12fbd + dependent 49c21f9, included this last one for a clean revert and build) reverted build they are identical each time.

I'm running this on windows, the command used:
llama-server.exe ^
-m models/localweights/Qwen3.6-35B-A3B-MTP-IQ4_XS-Q8nextn-GGUF/Qwen3.6-35B-A3B-MTP-IQ4_XS-Q8nextn.gguf ^
--no-mmap --ctx-size 16384 --port 12345 ^
--flash-attn on --cache-type-k q8_0 --cache-type-v q8_0 ^
--spec-type draft-mtp --spec-draft-n-max 4 ^
-np 1 --temp 0 --seed 0 ^
-sm layer -ngl 999 --tensor-split 20,18 ^
-t 4 -tb 8 --no-warmup --metrics

The prompt: write fizzbuzz in 16 different programming languages

dbrain pushed a commit to dbrain/hbd-llama-cpp-turboquant that referenced this pull request May 21, 2026
* llama: avoid copying logits during prompt decode in MTP

* review: update comment

* llama-graph: call set_output for t_h_pre_norm
@jtjstock
Copy link
Copy Markdown

jtjstock commented May 21, 2026

@am17an Also just tested the official build of b9254(llama-b9254-bin-win-cuda-13.1-x64), seeing the same variability in the output and similar degraded acceptance rate.

Maybe it's specific to my config? It's running on 2x 5060ti 16GB

Edit: did some more testing. Issue still present without quantized kv, but does disappear when I switch to -sm tensor, output from -sm tensor is internally consistent across runs.

baramofme pushed a commit to baramofme/llama-cpp-turboquant that referenced this pull request May 23, 2026
* llama: avoid copying logits during prompt decode in MTP

* review: update comment

* llama-graph: call set_output for t_h_pre_norm
srossitto79 pushed a commit to srossitto79/llama.cpp that referenced this pull request May 23, 2026
* llama: avoid copying logits during prompt decode in MTP

* review: update comment

* llama-graph: call set_output for t_h_pre_norm
winstonma pushed a commit to winstonma/llama.cpp that referenced this pull request May 27, 2026
* llama: avoid copying logits during prompt decode in MTP

* review: update comment

* llama-graph: call set_output for t_h_pre_norm
fewtarius pushed a commit to fewtarius/llama.cpp that referenced this pull request May 30, 2026
* llama: avoid copying logits during prompt decode in MTP

* review: update comment

* llama-graph: call set_output for t_h_pre_norm
@jtjstock
Copy link
Copy Markdown

jtjstock commented Jun 2, 2026

@am17an I think the issue I see, which still presents at head, is the MTP heads are being fed the raw pre-norm residual, but Qwen seems to do better with the post-norm. Before this PR, Qwen was effectively reading a post-norm because t_h_pre_norm aliased the buffer that the final norm was writing to. So it was better before because of a side effect.

It is a very small change to fix. Move the final build_norm before res->t_h_pre_norm = cur in the main graph, and move res->t_h_pre_norm = cur after the build_norm in the mtp one, in qwen35moe.cpp and qwen35.cpp. Two lines moved in each file.

I could be wrong, but it does result in a noticeable improvement to acceptance with the corresponding bump in token generation. I see acceptance move from ~75% to ~80% at n-max 4, and generation from ~203 t/s to ~225 t/s on tensor split. Prefill was of course unaffected.

I'm not looking at the sm layer issue I noted before, as I'm not using layer split anymore, but it's still present when I tested.

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented Jun 2, 2026

@jtjstock the model's MTP head was trained on the pre-norm hidden state. I don't understand how the post-norm residual would make it better.

@jtjstock
Copy link
Copy Markdown

jtjstock commented Jun 2, 2026 via email

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented Jun 2, 2026

No thanks. Please stop posting irrelevant comments

@jtjstock
Copy link
Copy Markdown

jtjstock commented Jun 2, 2026

Sorry I am bothering you, but I don't see any source that specifically says Qwen 3.5/3.6's MTP is trained on pre-norm hidden state, I do see a paper that says the training method is undisclosed(https://arxiv.org/html/2605.09992v1), and vLLM seems to use the post-norm hidden state in the branch qwen 3.5's mtp follows.

Again, I do apologise for bothering you, you've done a tremendous amount of work on this and it is not my intention to waste your time.

@am17an
Copy link
Copy Markdown
Contributor Author

am17an commented Jun 2, 2026

Okay I just checked vLLM and you're right, they do pass in the post-norm hidden state, my assumption was that is was a deepseek like MTP which passes the pre norm hidden state. Thanks for pointing this out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants