Conversation
|
I spent more time thinking on this and I decide to change the approach:
The user code will look like this with this implementation: // load model
llama_model * model_llm = llama_model_load_from_file("main-model.gguf", ...);
llama_model * model_mtp = model_llm; // by default, MTP is built-in
// optionally, may load MTP from another file
if (!llama_model_has_mtp(model_llm)) {
model_mtp = llama_model_load_from_file("mtp-model.gguf", ...);
}
llama_context * ctx_llm = llama_init_from_model(model_llm, ...);
llama_context * ctx_mtp = llama_init_from_model(model_mtp, ...);
// example generating one token with main LLM
common_batch_add(batch, last_token, ...);
llama_decode(batch);
llama_token curr_token = common_sampler_sample(...);
llama_mtp_start(ctx_llm, ctx_mtp);
// at this point, MTP state is populated to ctx_mtp->cross
std::vector<llama_token> mtp_tokens;
llama_token last_mtp_token = curr_token;
for (i = 0; i < n_draft_max; i++) {
common_batch_clear(batch);
common_batch_add(batch, last_mtp_token, ...);
llama_decode(ctx_mtp, batch); // use state from ctx_mtp->cross, and update it afterwards
last_mtp_token = common_sampler_sample(...);
mtp_tokens.push_back(last_mtp_token);
}While on the surface, this approach doesn't seem much different from #18039 or #15225, it does have some important advantages:
|
Elaborating this point a bit more: while working on #12648, I have noticed that the synchronization has significant impact on generation speed of a small model. The sesame CSM contains of 2 parts: a main LLM that generate embeddings and a smaller decoder that generates 12 audio tokens for a given embedding. It is pretty much close to a speculative model with So I hope the same idea can be implemented to MTP model at some point, as backend sampling is already there in llama.cpp. |
ggerganov
left a comment
There was a problem hiding this comment.
Overall looks good. I agree that pt.2 from #18039 (comment) should be tested and see if it works to make the Eagle3 context decoder-only.
include/llama.h
Outdated
| bool kv_unified; // use a unified buffer across the input sequences when computing the attention | ||
| // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix | ||
| // ref: https://github.com/ggml-org/llama.cpp/pull/14363 | ||
| bool is_mtp; // create context for Multi-Token Prediction (MTP) |
There was a problem hiding this comment.
Something to consider is moving llm_graph_type to the public interface and changing this to:
llama_graph_type gtype;If we do it like this, we can even rework the llama_encode() and llama_decode() into a single llama_process().
| int32_t llama_mtp_start(llama_context * ctx_llm, llama_context * ctx_mtp) { | ||
| ctx_llm->synchronize(); | ||
|
|
||
| return ctx_llm->cpy_mtp_state(*ctx_mtp); | ||
| } | ||
|
|
There was a problem hiding this comment.
Also not sure if this is the best way, but seems OK for now. Would look into generalizing somehow to not be too MTP specific. I.e. a more generic mechanism for sharing data between contexts.
There was a problem hiding this comment.
I'm thinking about another version of llama_cross that will encapsulate multiple llama_context inside it:
struct llama_cross {
llama_context * ctx_enc; // text encoder models like T5
llama_context * ctx_llm;
llama_context * ctx_mtp;
llama_context * ctx_mtmd;
};Such that when a llama_process() is called on one context, it will propagate the state to another context.
For now I cannot think of a better way to avoid having purpose-specific naming like mtmd, mtp because the data between 2 contexts can vary depending on the task. But I think we can iterate from this idea.
For the current PR, I think I can proceed with the llama_mtp_start because it will easy to adapt to whatever API we may come up with in the future.
|
This was just linked on Reddit today: https://z-lab.ai/projects/dflash/ https://github.com/z-lab/dflash and seems worth thinking about for any future MPT/Eagle API:
|
|
I was quite distracted with other things lately, going back to this PR now. Hopefully get GLM working first, then Eagle Deepseek should be the same as GLM, but need to modify the conversion script, which will need to be done in a dedicated PR |
|
Alright, I've also just restarted work on #18039. I think we will meet both implementations at some point and be able to support both MTP and Eagle. |
Nothing here is working, but I'm pushing this to discuss the direction of MTP (multi-token prediction) integration in
libllamaCurrently, this implementation is designed around
llama_cross, which make it a bit similar to knowllama_encode|decodeworks. I'm targeting GLM support in the first iteration because the weight already had thenextntensors.For the explanation of the API, see the comment below.
In theory, this should work the same way for deepseek3, xiaomi mimo, GLM, eagle3 (see this comment)