From 07670a22c63b1fa335d6ec1c4a1e4255a920848c Mon Sep 17 00:00:00 2001 From: samuel Date: Wed, 3 Sep 2025 13:25:21 -0300 Subject: [PATCH 1/4] feat: implemented sampling for MTP --- common/speculative.cpp | 50 ++++++----------------------------------- common/speculative.h | 8 +++---- include/llama.h | 4 ++-- src/llama-context.cpp | 51 +++++++++++++++++++++--------------------- src/llama-model.cpp | 16 ++++--------- 5 files changed, 43 insertions(+), 86 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index c1d9149ea13d2..8d849df94b8e9 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -370,56 +370,20 @@ llama_token mtp_speculative_gen_draft( int32_t n_past, int32_t last_tok_idx) { - llama_token token_data[] = { id_last }; - llama_pos pos_data[] = { n_past }; - int32_t n_seq_id_data[] = { 1 }; - llama_seq_id seq_id_data_internal[] = { 0 }; - llama_seq_id* seq_id_data[] = {seq_id_data_internal}; - int8_t logits_data[] = { (int8_t) (smpl != nullptr) }; - - llama_batch batch = { - /*.n_tokens = */ 1, - /*.token = */ token_data, - /*.embd = */ nullptr, - /*.pos = */ pos_data, - /*.n_seq_id = */ n_seq_id_data, - /*.seq_id = */ seq_id_data, - /*.logits = */ logits_data - }; - - return llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); - //LOG_INF("updating kv cache for n_past: %d\n", n_past); - - /* if (!smpl) { return -1; } - else { - common_sampler_sample(smpl, ctx, last_tok_idx, true); - const auto* cur_p = common_sampler_get_candidates(smpl); - //for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) { - // LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - // k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); - //} - - const llama_token id = cur_p->data[0].id; - return id; - } - */ - // LOG_INF("cur_p->size: %d\n", cur_p->size); + llama_batch batch = llama_batch_init(1, 0, 1); + common_batch_add(batch, id_last, n_past, {0}, true); + llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); - // add drafted token for each sequence + llama_token id = common_sampler_sample(smpl, ctx, last_tok_idx, true); - // skip accepting draft token -- since we're only drafting one token this can't affect future outputs - // smpl will accept the token if it doesn't get rejected by main model later - // common_sampler_accept(smpl, id, true); + llama_batch_free(batch); - //llama_tokens result; - //result.reserve(1); - //result.push_back(id); - //return result; + return id; } @@ -438,4 +402,4 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start = 0, size_t n_tokens = -1); diff --git a/include/llama.h b/include/llama.h index 015c777763bf6..e43cd83468d0f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1454,8 +1454,8 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); - LLAMA_API llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, - const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); + LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, + const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); #ifdef __cplusplus } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 1f04b72145b28..fb285a8d297c9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2995,7 +2995,7 @@ void llama_opt_epoch( callback_eval); } -llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, +void llama_build_and_execute_mtp_graph(struct llama_context * ctx, const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { const auto * model = llama_get_model(ctx); @@ -3033,6 +3033,12 @@ llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past); + if (!gf) { + LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__); + if (sched) ggml_backend_sched_free(sched); + return; + } + ggml_backend_sched_reset(sched); // clear the allocation of the previous graph ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it @@ -3044,29 +3050,24 @@ llama_token llama_build_and_execute_mtp_graph(struct llama_context * ctx, ggml_backend_sched_graph_compute(sched, gf); // execute the graph - //struct ggml_tensor * logits_mtp = res_mtp->get_logits(); - - //LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp); - - //if (logits_mtp) { - // ctx->set_logits_ith(logits_mtp, sched, last_tok_idx); - //} - struct ggml_tensor * token_id_tensor = ggml_get_tensor(res_mtp->get_ctx(), "mtp_argmax_result"); - - - llama_token token_id = 0; // The C++ variable to hold the result. - - // ggml_backend_tensor_get is the function for GPU->CPU copies. - // We are copying a single 32-bit integer. - ggml_backend_tensor_get( - token_id_tensor, - &token_id, // Pointer to our C++ variable - 0, // Starting offset in bytes - sizeof(llama_token) // Number of bytes to copy - ); + struct ggml_tensor * logits_mtp = res_mtp->get_logits(); + + if (logits_mtp) { + float * logits_dest = ctx->get_logits_ith(last_tok_idx); + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp); + if (backend_res) { + // ggml_backend_tensor_get is the function for GPU->CPU copies. + // We are copying a single 32-bit integer. + ggml_backend_tensor_get(logits_mtp, + logits_dest, // Pointer to our C++ variable + 0, // Starting offset in bytes + ggml_nbytes(logits_mtp)); // Number of bytes to copy + } else { + LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__); + } + } else { + LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__); + } ggml_backend_sched_free(sched); - - return token_id; -} - +} \ No newline at end of file diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f9921e4b6d448..dd4bf211b7e94 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13950,6 +13950,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization llama_token last_token_id, int n_past ) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13964,8 +13965,6 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { //llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr; auto * inp_attn = build_attn_inp_kv_unified(); - ggml_tensor * cur; - // get MTP embedding for last (conventionally sampled) token // ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); // LLAMA_LOG_INFO("step: '%d'\n", 5641); @@ -13979,7 +13978,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { //ggml_tensor * inp_token_id = ggml_new_i32(ctx0, last_token_id); //ggml_set_no_alloc(ctx0, true); - + ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id); ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); @@ -13994,9 +13993,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat - - cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj - + ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj // now proceed through last layer (skipped in main model) ggml_tensor * inpSA = cur; @@ -14096,14 +14093,9 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); - + res->t_logits = cur; - ggml_build_forward_expand(gf, res->t_logits); - - struct ggml_tensor * token_id_tensor = ggml_argmax(ctx0, cur); - ggml_set_name(token_id_tensor, "mtp_argmax_result"); - ggml_build_forward_expand(gf, token_id_tensor); } }; From 5a5bce85777041d841393b4396e28f8e3065bb10 Mon Sep 17 00:00:00 2001 From: samuel Date: Wed, 3 Sep 2025 17:56:14 -0300 Subject: [PATCH 2/4] fix: add sample acceptance --- common/speculative.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/common/speculative.cpp b/common/speculative.cpp index 8d849df94b8e9..5edd4aa815bae 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -381,6 +381,14 @@ llama_token mtp_speculative_gen_draft( llama_token id = common_sampler_sample(smpl, ctx, last_tok_idx, true); + const auto * cur_p = common_sampler_get_candidates(smpl); + for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + } + + common_sampler_accept(smpl, id, true); + llama_batch_free(batch); return id; From 8742ce0e39823eeb101bb5b6099ff4ca7be10c6e Mon Sep 17 00:00:00 2001 From: samuel Date: Sat, 6 Sep 2025 00:21:18 -0300 Subject: [PATCH 3/4] feat: apply logits + greedy sampler --- common/sampling.cpp | 4 ++++ common/sampling.h | 2 ++ common/speculative.cpp | 19 +++++++++++++------ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index a5824ebeedbaa..452cefee3b9ac 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -582,3 +582,7 @@ std::vector common_sampler_types_from_chars(const std::stri return samplers; } + +void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) { + llama_sampler_apply(gsmpl->chain, cur_p); +} \ No newline at end of file diff --git a/common/sampling.h b/common/sampling.h index 2064421db4e80..b424d7d6d70ca 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -105,3 +105,5 @@ std::vector common_sampler_types_from_chars(const std: llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, const char * grammar_data); + +void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p); \ No newline at end of file diff --git a/common/speculative.cpp b/common/speculative.cpp index 5edd4aa815bae..77ed75913d5c7 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -379,15 +379,22 @@ llama_token mtp_speculative_gen_draft( llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); - llama_token id = common_sampler_sample(smpl, ctx, last_tok_idx, true); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int n_vocab = llama_n_vocab(vocab); - const auto * cur_p = common_sampler_get_candidates(smpl); - for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) { - LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + llama_token_data_array * cur_p = common_sampler_get_candidates(smpl); + + cur_p->size = n_vocab; + for (int i = 0; i < n_vocab; ++i) { + cur_p->data[i].id = i; + cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i]; } + cur_p->sorted = false; + + common_sampler_apply_chain(smpl, cur_p); - common_sampler_accept(smpl, id, true); + const llama_token id = cur_p->data[0].id; llama_batch_free(batch); From 184087bc91aefbee42568b5b12a931efb0c9e989 Mon Sep 17 00:00:00 2001 From: samuel Date: Sun, 7 Sep 2025 19:04:31 -0300 Subject: [PATCH 4/4] mtp (wip): add first concept about multiple predictions --- common/speculative.cpp | 92 ++++++++++++++++++++++++++++------------- common/speculative.h | 5 ++- include/llama.h | 8 +++- src/llama-context.cpp | 59 ++++++++++++++++---------- src/llama-context.h | 2 + src/llama-model.cpp | 22 +++++++--- src/llama-model.h | 3 +- tools/server/server.cpp | 14 ++++--- 8 files changed, 138 insertions(+), 67 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 77ed75913d5c7..9c5433e32f3bc 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -363,57 +363,91 @@ llama_tokens common_speculative_gen_draft( } -llama_token mtp_speculative_gen_draft( +llama_tokens mtp_speculative_gen_draft( struct common_sampler* smpl, struct llama_context* ctx, llama_token id_last, int32_t n_past, - int32_t last_tok_idx) { + int32_t last_tok_idx, + int32_t n_mtp_draft) { + + llama_tokens draft_tokens; + draft_tokens.reserve(n_mtp_draft); + + llama_token current_token = id_last; + int32_t current_n_past = n_past; + + float* prev_embedding_data = llama_get_embeddings_ith(ctx, last_tok_idx); + LOG_DBG("\n--- MTP total draft %d ---\n", n_mtp_draft); + + // The same layer will draft multiple tokens before being validated + for (int i = 0; i < n_mtp_draft; ++i) { + if (prev_embedding_data == nullptr) { + LOG_DBG("ERROR: prev_embedding_data is null in iteration %d!\n", i); + break; + } + llama_batch batch = llama_batch_init(1, 0, 1); + common_batch_add(batch, current_token, current_n_past, {0}, true); - if (!smpl) { - return -1; - } + float* next_embedding_data = llama_build_and_execute_mtp_graph( + ctx, batch, prev_embedding_data, i + ); - llama_batch batch = llama_batch_init(1, 0, 1); - common_batch_add(batch, id_last, n_past, {0}, true); + if (next_embedding_data == nullptr) { + LOG_DBG("ERROR: next_embedding_data returned null from graph execution\n", i); + } - llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); + // Apply logits + greedy: The main model has not yet selected + // the token as correct, so we cannot apply all samples. + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int n_vocab = llama_n_vocab(vocab); + + llama_token_data_array* cur_p = common_sampler_get_candidates(smpl); + cur_p->size = n_vocab; + for (int j = 0; j < n_vocab; ++j) { + cur_p->data[j].id = j; + // Place the MTP logits in the first slot of the context's logit buffer. + // This temporary storage is then read by the sampler. + cur_p->data[j].logit = llama_get_logits_ith(ctx, 0)[j]; + } + cur_p->sorted = false; + common_sampler_apply_chain(smpl, cur_p); - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - const int n_vocab = llama_n_vocab(vocab); + const llama_token new_id = cur_p->data[0].id; - llama_token_data_array * cur_p = common_sampler_get_candidates(smpl); + draft_tokens.push_back(new_id); - cur_p->size = n_vocab; - for (int i = 0; i < n_vocab; ++i) { - cur_p->data[i].id = i; - cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i]; - } - cur_p->sorted = false; + current_token = new_id; + current_n_past++; + prev_embedding_data = next_embedding_data; - common_sampler_apply_chain(smpl, cur_p); + llama_batch_free(batch); - const llama_token id = cur_p->data[0].id; - - llama_batch_free(batch); + if (!next_embedding_data) { + break; + } + } - return id; + return draft_tokens; } void mtp_update_kv_cache(struct llama_context * ctx, std::vector& tokens, size_t batch_start, size_t n_tokens) { - mtp_kv_update_data token; - if (n_tokens < 0) { n_tokens = tokens.size(); } - for (int i = 0; i < std::min(tokens.size(), n_tokens); ++i) { - token = tokens[i]; - //fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start)); + for (size_t i = 0; i < std::min((size_t)tokens.size(), n_tokens); ++i) { + mtp_kv_update_data& token = tokens[i]; + + llama_batch batch = llama_batch_init(1, 0, 1); + common_batch_add(batch, token.id, token.n_past, {0}, true); + + // Broken for now + // mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start); - mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start); + llama_batch_free(batch); } tokens.clear(); diff --git a/common/speculative.h b/common/speculative.h index 230f8382bccfc..f24f5c27effcc 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -35,12 +35,13 @@ void common_speculative_add_replacement_tgt_dft( // sample up to n_draft tokens and add them to the batch using the draft model -llama_token mtp_speculative_gen_draft( +llama_tokens mtp_speculative_gen_draft( struct common_sampler* smpl, struct llama_context* ctx, llama_token id_last, int32_t n_past, - int32_t last_tok_idx); + int32_t last_tok_idx, + int32_t n_mtp_draft); // sample up to n_draft tokens and add them to the batch using the draft model llama_tokens common_speculative_gen_draft( diff --git a/include/llama.h b/include/llama.h index e43cd83468d0f..9c4698431f894 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1454,8 +1454,12 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); - LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx, - const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx); + LLAMA_API float* llama_build_and_execute_mtp_graph( + struct llama_context * ctx, + const llama_batch batch_inp, + float * prev_embedding_data, + int32_t mtp_head_idx + ); #ifdef __cplusplus } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index fb285a8d297c9..2c9e25e3e259f 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2995,16 +2995,19 @@ void llama_opt_epoch( callback_eval); } -void llama_build_and_execute_mtp_graph(struct llama_context * ctx, - const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) { - +float* llama_build_and_execute_mtp_graph( + struct llama_context * ctx, + const llama_batch batch_inp, + float * prev_embedding_data, + int32_t mtp_head_idx +) { const auto * model = llama_get_model(ctx); auto res_mtp = std::make_unique(ctx->graph_max_nodes()); std::unique_ptr mctx = ctx->mtp_memory_batch(batch_inp); std::vector idxs; - idxs.push_back(n_past); + idxs.push_back(batch_inp.pos[0]); llama_kv_cache_unified::slot_info sinfo = { /*.s0 =*/ 0, /*.s1 =*/ 0, @@ -3024,50 +3027,64 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx, auto params_mtp = std::make_unique(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get())); ggml_backend_sched_t sched = params_mtp->sched; - auto * last_embd = ctx->get_embeddings_ith(last_tok_idx); - //if (mctx && !mctx->set_n_kv()) { // LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); //} static_cast(mctx.get())->set_n_kv(); - auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past); + auto * gf = model->build_mtp_graph(*params_mtp, mtp_head_idx); if (!gf) { LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__); if (sched) ggml_backend_sched_free(sched); - return; + return nullptr; } ggml_backend_sched_reset(sched); // clear the allocation of the previous graph ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it + llama_token token_id = batch_inp.token[0]; ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input"); - ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors + ggml_backend_tensor_set(mtp_token_id_input, &token_id, 0, sizeof(token_id)); // copy data to the newly allocated graph tensors ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input"); - ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors + + if (mtp_prev_embedding_input) { + ggml_backend_tensor_set(mtp_prev_embedding_input, prev_embedding_data, 0, + ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors + } else { + LLAMA_LOG_WARN("%s: Could not find 'mtp_prev_embedding_input' tensor in the MTP graph.\n", __func__); + } ggml_backend_sched_graph_compute(sched, gf); // execute the graph struct ggml_tensor * logits_mtp = res_mtp->get_logits(); if (logits_mtp) { - float * logits_dest = ctx->get_logits_ith(last_tok_idx); - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp); - if (backend_res) { - // ggml_backend_tensor_get is the function for GPU->CPU copies. - // We are copying a single 32-bit integer. - ggml_backend_tensor_get(logits_mtp, - logits_dest, // Pointer to our C++ variable - 0, // Starting offset in bytes - ggml_nbytes(logits_mtp)); // Number of bytes to copy + float * logits_dest = llama_get_logits_ith(ctx, 0); + // ggml_backend_tensor_get is the function for GPU->CPU copies. + // We are copying a single 32-bit integer. + ggml_backend_tensor_get(logits_mtp, + logits_dest, // Pointer to our C++ variable + 0, // Starting offset in bytes + ggml_nbytes(logits_mtp)); // Number of bytes to copy } else { - LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__); + LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__); + } + + struct ggml_tensor * next_embedding_tensor = ggml_get_tensor(res_mtp->get_ctx(), "mtp_next_embedding_output"); + float * next_embedding_data_ptr = nullptr; + + if (next_embedding_tensor) { + if (ctx->mtp_embedding_buffer.size() < ggml_nbytes(next_embedding_tensor)) { + ctx->mtp_embedding_buffer.resize(ggml_nbytes(next_embedding_tensor)); } + ggml_backend_tensor_get(next_embedding_tensor, ctx->mtp_embedding_buffer.data(), 0, ggml_nbytes(next_embedding_tensor)); + next_embedding_data_ptr = reinterpret_cast(ctx->mtp_embedding_buffer.data()); } else { - LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__); + LLAMA_LOG_ERROR("%s: The MTP graph did not produce an output embedding tensor.\n", __func__); } ggml_backend_sched_free(sched); + return next_embedding_data_ptr; } \ No newline at end of file diff --git a/src/llama-context.h b/src/llama-context.h index e8ea3a4c9be39..91a9e763a46b3 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -207,6 +207,8 @@ struct llama_context { ggml_backend_sched_t create_temp_scheduler(size_t n_nodes); std::unique_ptr mtp_memory_batch(const llama_batch& batch_inp); + + std::vector mtp_embedding_buffer; private: llm_graph_params graph_params( diff --git a/src/llama-model.cpp b/src/llama-model.cpp index dd4bf211b7e94..91506cbb1197b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13947,15 +13947,20 @@ struct llm_build_glm4_moe : public llm_graph_context { struct llm_build_glm4_moe_mtp : public llm_graph_context { llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params, + int mtp_head_idx) : llm_graph_context(params) { // For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization - llama_token last_token_id, int n_past - ) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); // Assuming a single MTP layer at the end const int il = hparams.n_layer - 1; + + if (il < 0 || il >= (int)model.layers.size()) { + LLAMA_LOG_ERROR("FATAL ERROR: Calculated MTP layer index (%d) is out of bounds! The number of layers is %zu.\n", + il, model.layers.size()); + GGML_ABORT("fatal error"); + } const auto & mtp_layer = model.layers[il]; // ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1); @@ -14089,13 +14094,20 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context { cur = ggml_add(ctx0, routed_out, shared_out); cb(cur, "ffn_out", il); } + ggml_tensor* final_hidden_state = cur; cur = ggml_add(ctx0, cur, ffn_inp); cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); - + res->t_logits = cur; + + ggml_set_name(final_hidden_state, "mtp_next_embedding_output"); + ggml_set_output(final_hidden_state); + ggml_build_forward_expand(gf, res->t_logits); + ggml_build_forward_expand(gf, final_hidden_state); + } }; @@ -18690,13 +18702,13 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params, - llama_token last_token_id, int n_past) const { + int mtp_head_idx) const { std::unique_ptr llm; switch (arch) { case LLM_ARCH_GLM4_MOE: { - llm = std::make_unique(*this, params, last_token_id, n_past); + llm = std::make_unique(*this, params, mtp_head_idx); } break; default: GGML_ABORT("fatal error"); diff --git a/src/llama-model.h b/src/llama-model.h index b28a37488f78a..c27f1aea0aeaf 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -475,8 +475,7 @@ struct llama_model { // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; - ggml_cgraph * build_mtp_graph(const llm_graph_params & params, - llama_token last_token_id, int n_past) const; + ggml_cgraph * build_mtp_graph(const llm_graph_params& params, int mtp_head_idx) const; private: struct impl; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 34053cd040388..8290dfa42df92 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2138,12 +2138,14 @@ struct server_context { SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model)); slot.has_mtp = true; - // assume one speculative token (true of all well-known MTP models so far) - slot.batch_spec = llama_batch_init(2, 0, 1); + // TODO: Build token argument for MTP + const int n_mtp_draft_target = 5; + + slot.batch_spec = llama_batch_init(1 + n_mtp_draft_target, 0, 1); SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens); params_base.speculative.n_min = 0; - params_base.speculative.n_max = 1; + params_base.speculative.n_max = n_mtp_draft_target; SRV_INF("%s\n", "MTP needs embeddings on decode, enabling"); llama_set_embeddings(ctx, true); @@ -3637,9 +3639,9 @@ struct server_context { llama_tokens draft; if (slot.has_mtp) { - llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); - draft.reserve(1); - draft.push_back(draft_id); + int n_draft = std::min(n_draft_max, slot.params.speculative.n_max); + + draft = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx, n_draft); } else { struct common_speculative_params params_spec;