llama: avoid copying logits during prompt decode in MTP#23198
Conversation
|
@d-r-e no, MTP does negatively impact prompt processing, but under this PR the negative impact is halved. |
|
2xMI50 qwen 27b Q4_1 does see some improvement with this PR |
Why does it affect prompt processing? |
|
Made sure this PR is included and re-tested: pretty much the same as #22673 (comment) which probably had the PR already. Still almost 50% pp hit |
Due to the embeddings copy, most likely. |
* llama: avoid copying logits during prompt decode in MTP * review: update comment * llama-graph: call set_output for t_h_pre_norm
This comment has been minimized.
This comment has been minimized.
* llama: avoid copying logits during prompt decode in MTP * review: update comment * llama-graph: call set_output for t_h_pre_norm
* llama: avoid copying logits during prompt decode in MTP * review: update comment * llama-graph: call set_output for t_h_pre_norm
* llama: avoid copying logits during prompt decode in MTP * review: update comment * llama-graph: call set_output for t_h_pre_norm
* llama: avoid copying logits during prompt decode in MTP * review: update comment * llama-graph: call set_output for t_h_pre_norm
* llama: avoid copying logits during prompt decode in MTP * review: update comment * llama-graph: call set_output for t_h_pre_norm
|
This particular commit seems to regress the acceptance rate, I lose about 5% at n = 4 in coding tasks. Otherwise the prefill improvement is amazing |
|
@jtjstock this commit just stops copying stuff we don't ever read, in a way it's a free optimization and it should not affect anything except the prefill |
|
@am17an Just did a bunch of runs, one at latest(b9254): draft acceptance = 0.75134 ( 3913 accepted / 5208 generated) Same prompt. Same Seed. Temp 0. Same hardware minutes apart. Same compiler. The only oddity seems to be the outputs from the latest(b9254), they are slightly different each time, but highly similar, where as with mainline commits (3e12fbd + dependent 49c21f9, included this last one for a clean revert and build) reverted build they are identical each time. I'm running this on windows, the command used: The prompt: write fizzbuzz in 16 different programming languages |
* llama: avoid copying logits during prompt decode in MTP * review: update comment * llama-graph: call set_output for t_h_pre_norm
|
@am17an Also just tested the official build of b9254(llama-b9254-bin-win-cuda-13.1-x64), seeing the same variability in the output and similar degraded acceptance rate. Maybe it's specific to my config? It's running on 2x 5060ti 16GB Edit: did some more testing. Issue still present without quantized kv, but does disappear when I switch to -sm tensor, output from -sm tensor is internally consistent across runs. |
* llama: avoid copying logits during prompt decode in MTP * review: update comment * llama-graph: call set_output for t_h_pre_norm
* llama: avoid copying logits during prompt decode in MTP * review: update comment * llama-graph: call set_output for t_h_pre_norm
* llama: avoid copying logits during prompt decode in MTP * review: update comment * llama-graph: call set_output for t_h_pre_norm
* llama: avoid copying logits during prompt decode in MTP * review: update comment * llama-graph: call set_output for t_h_pre_norm
|
@am17an I think the issue I see, which still presents at head, is the MTP heads are being fed the raw pre-norm residual, but Qwen seems to do better with the post-norm. Before this PR, Qwen was effectively reading a post-norm because t_h_pre_norm aliased the buffer that the final norm was writing to. So it was better before because of a side effect. It is a very small change to fix. Move the final build_norm before res->t_h_pre_norm = cur in the main graph, and move res->t_h_pre_norm = cur after the build_norm in the mtp one, in qwen35moe.cpp and qwen35.cpp. Two lines moved in each file. I could be wrong, but it does result in a noticeable improvement to acceptance with the corresponding bump in token generation. I see acceptance move from ~75% to ~80% at n-max 4, and generation from ~203 t/s to ~225 t/s on tensor split. Prefill was of course unaffected. I'm not looking at the sm layer issue I noted before, as I'm not using layer split anymore, but it's still present when I tested. |
|
@jtjstock the model's MTP head was trained on the pre-norm hidden state. I don't understand how the post-norm residual would make it better. |
|
Try it?
…On Tue, Jun 2, 2026 at 12:32 AM Aman Gupta ***@***.***> wrote:
*am17an* left a comment (ggml-org/llama.cpp#23198)
<#23198 (comment)>
@jtjstock <https://github.com/jtjstock> the model's MTP head was trained
on the pre-norm residual. I don't understand how the post-norm residual
would make it better.
—
Reply to this email directly, view it on GitHub
<#23198?email_source=notifications&email_token=AEZEUVEXU7OY322LAT3LAED45ZKFJA5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINJZHA3TGNBRGYY2M4TFMFZW63VHNVSW45DJN5XKKZLWMVXHJLDGN5XXIZLSL5RWY2LDNM#issuecomment-4598734161>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AEZEUVGP5WQDWNYH62W6W3T45ZKFJAVCNFSM6AAAAACZBIXW32VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHM2DKOJYG4ZTIMJWGE>
.
Triage notifications, keep track of coding agent tasks and review pull
requests on the go with GitHub Mobile for iOS
<https://github.com/notifications/mobile/ios/AEZEUVGI5QZBVZFFJEDVV7D45ZKFJA5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINJZHA3TGNBRGYY2M4TFMFZW63VHNVSW45DJN5XKKZLWMVXHJKTGN5XXIZLSL5UW64Y>
and Android
<https://github.com/notifications/mobile/android/AEZEUVHEOD6EYGRQTSJJZ4T45ZKFJA5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINJZHA3TGNBRGYY2M4TFMFZW63VHNVSW45DJN5XKKZLWMVXHJLTGN5XXIZLSL5QW4ZDSN5UWI>.
Download it today!
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
|
No thanks. Please stop posting irrelevant comments |
|
Sorry I am bothering you, but I don't see any source that specifically says Qwen 3.5/3.6's MTP is trained on pre-norm hidden state, I do see a paper that says the training method is undisclosed(https://arxiv.org/html/2605.09992v1), and vLLM seems to use the post-norm hidden state in the branch qwen 3.5's mtp follows. Again, I do apologise for bothering you, you've done a tremendous amount of work on this and it is not my intention to waste your time. |
|
Okay I just checked vLLM and you're right, they do pass in the post-norm hidden state, my assumption was that is was a deepseek like MTP which passes the pre norm hidden state. Thanks for pointing this out |

Overview
Avoid copying the logits for every token in the batch when doing prompt processing for MTP since it only requires the pre-norm. This reduces memory traffic quite a bit and in turn increases PP speed with MTP.
Additional information
Requirements