Skip to content

llama : add MTP API#18886

Draft
ngxson wants to merge 6 commits intoggml-org:masterfrom
ngxson:xsn/mtp_model
Draft

llama : add MTP API#18886
ngxson wants to merge 6 commits intoggml-org:masterfrom
ngxson:xsn/mtp_model

Conversation

@ngxson
Copy link
Collaborator

@ngxson ngxson commented Jan 16, 2026

Nothing here is working, but I'm pushing this to discuss the direction of MTP (multi-token prediction) integration in libllama

Currently, this implementation is designed around llama_cross, which make it a bit similar to know llama_encode|decode works. I'm targeting GLM support in the first iteration because the weight already had the nextn tensors.

For the explanation of the API, see the comment below.

In theory, this should work the same way for deepseek3, xiaomi mimo, GLM, eagle3 (see this comment)

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 17, 2026

I spent more time thinking on this and I decide to change the approach:

  • We now using 2 different llama_context, one for MTP and one for the main LLM
  • Upon doing a llama_decode() on main LLM, a llama_mtp_start() call will follow to copy the state from main model to MTP model. I'm not 100% sure if this is needed, but adding for clarity for now.

The user code will look like this with this implementation:

// load model
llama_model * model_llm = llama_model_load_from_file("main-model.gguf", ...);
llama_model * model_mtp = model_llm; // by default, MTP is built-in
// optionally, may load MTP from another file
if (!llama_model_has_mtp(model_llm)) {
    model_mtp = llama_model_load_from_file("mtp-model.gguf", ...);
}

llama_context * ctx_llm = llama_init_from_model(model_llm, ...);
llama_context * ctx_mtp = llama_init_from_model(model_mtp, ...);

// example generating one token with main LLM

common_batch_add(batch, last_token, ...);
llama_decode(batch);
llama_token curr_token = common_sampler_sample(...);

llama_mtp_start(ctx_llm, ctx_mtp);
// at this point, MTP state is populated to ctx_mtp->cross

std::vector<llama_token> mtp_tokens;
llama_token last_mtp_token = curr_token;
for (i = 0; i < n_draft_max; i++) {
    common_batch_clear(batch);
    common_batch_add(batch, last_mtp_token, ...);
    llama_decode(ctx_mtp, batch); // use state from ctx_mtp->cross, and update it afterwards
    last_mtp_token = common_sampler_sample(...);
    mtp_tokens.push_back(last_mtp_token);
}

While on the surface, this approach doesn't seem much different from #18039 or #15225, it does have some important advantages:

  1. Because the internal MTP state is never leaked to the user code, the implementation is more expandable in the future without breaking the API.
  2. One of the main problem currently is that because the MTP model is very small. The synchronization CPU <--> backend between each token generated becomes a bottleneck in this case. In the long-term, we can optimize this by not letting common_sampler_sample to call llama_synchronize(). This allow generating MTP tokens to be fully asynchronous, and since sampling can be done on the backend, the backend will never need to pause until it finishes generating enough n_draft_max tokens.

@ngxson
Copy link
Collaborator Author

ngxson commented Jan 17, 2026

2. This allow generating MTP tokens to be fully asynchronous, and since sampling can be done on the backend, the backend will never need to pause until it finishes generating enough n_draft_max tokens.

Elaborating this point a bit more: while working on #12648, I have noticed that the synchronization has significant impact on generation speed of a small model.

The sesame CSM contains of 2 parts: a main LLM that generate embeddings and a smaller decoder that generates 12 audio tokens for a given embedding. It is pretty much close to a speculative model with n_draft_max = 12. My implementation was quite slow compared to pytorch version because on pytorch version, they can generate all 12 tokens without interrupting the backend (in pytorch version, they explicitly implemented backend sampling to by pass the CPU sampler).

So I hope the same idea can be implemented to MTP model at some point, as backend sampling is already there in llama.cpp.

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good. I agree that pt.2 from #18039 (comment) should be tested and see if it works to make the Eagle3 context decoder-only.

include/llama.h Outdated
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
bool is_mtp; // create context for Multi-Token Prediction (MTP)
Copy link
Member

@ggerganov ggerganov Jan 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something to consider is moving llm_graph_type to the public interface and changing this to:

llama_graph_type gtype;

If we do it like this, we can even rework the llama_encode() and llama_decode() into a single llama_process().

Comment on lines +3256 to +3261
int32_t llama_mtp_start(llama_context * ctx_llm, llama_context * ctx_mtp) {
ctx_llm->synchronize();

return ctx_llm->cpy_mtp_state(*ctx_mtp);
}

Copy link
Member

@ggerganov ggerganov Jan 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also not sure if this is the best way, but seems OK for now. Would look into generalizing somehow to not be too MTP specific. I.e. a more generic mechanism for sharing data between contexts.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking about another version of llama_cross that will encapsulate multiple llama_context inside it:

struct llama_cross {
    llama_context * ctx_enc; // text encoder models like T5
    llama_context * ctx_llm;
    llama_context * ctx_mtp;
    llama_context * ctx_mtmd;
};

Such that when a llama_process() is called on one context, it will propagate the state to another context.

For now I cannot think of a better way to avoid having purpose-specific naming like mtmd, mtp because the data between 2 contexts can vary depending on the task. But I think we can iterate from this idea.

For the current PR, I think I can proceed with the llama_mtp_start because it will easy to adapt to whatever API we may come up with in the future.

@jukofyork
Copy link
Collaborator

This was just linked on Reddit today:

https://z-lab.ai/projects/dflash/

https://github.com/z-lab/dflash

and seems worth thinking about for any future MPT/Eagle API:

We demonstrate that a free lunch does exist. Our key insight is that the large AR target model’s hidden features implicitly contain information about future tokens, a phenomenon also observed by [4].

Instead of asking a tiny diffusion model to reason from scratch, DFlash conditions the draft model on context features extracted from the target model. This fuses the deep reasoning capabilities of the large model with the parallel generation speed of the small diffusion drafter.

@ngxson
Copy link
Collaborator Author

ngxson commented Feb 5, 2026

I was quite distracted with other things lately, going back to this PR now. Hopefully get GLM working first, then Eagle

Deepseek should be the same as GLM, but need to modify the conversion script, which will need to be done in a dedicated PR

@ggerganov
Copy link
Member

Alright, I've also just restarted work on #18039. I think we will meet both implementations at some point and be able to support both MTP and Eagle.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants