diff --git a/common/sampling.cpp b/common/sampling.cpp index 9c04d35fd00..452cefee3b9 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -348,6 +348,11 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co llama_sampler_apply(chain, &cur_p); + /*for (int k = 0; k < (int)cur_p.size; ++k) { + LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f)\n", + k, 0, cur_p.data[k].id, cur_p.data[k].p); + }*/ + GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); const llama_token id = cur_p.data[cur_p.selected].id; @@ -577,3 +582,7 @@ std::vector common_sampler_types_from_chars(const std::stri return samplers; } + +void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) { + llama_sampler_apply(gsmpl->chain, cur_p); +} \ No newline at end of file diff --git a/common/sampling.h b/common/sampling.h index 2064421db4e..b424d7d6d70 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -105,3 +105,5 @@ std::vector common_sampler_types_from_chars(const std: llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, const char * grammar_data); + +void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p); \ No newline at end of file diff --git a/common/speculative.cpp b/common/speculative.cpp index 262b2c23e72..a7a40426821 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -5,6 +5,8 @@ #include "log.h" #include "common.h" #include "sampling.h" +#include "../src/llama-graph.h" +#include "../src/llama-context.h" #include #include @@ -359,3 +361,97 @@ llama_tokens common_speculative_gen_draft( } return result; } + + +llama_token mtp_speculative_gen_draft( + struct common_sampler* smpl, + struct llama_context* ctx, + llama_token id_last, + int32_t n_past, + int32_t last_tok_idx) { + + if (!smpl) { + return -1; + } + llama_batch mtp_batch = llama_batch_init(1, 0, 1); + const llama_pos draft_pos = n_past; + const llama_seq_id draft_seq_id = 0; + common_batch_add(mtp_batch, id_last, n_past, {0}, true); + + mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN; + + // Perform the MTP draft generation decode. This writes the MTP layer's + // KV state for the draft token into the cache. + llama_decode(ctx, mtp_batch); + llama_batch_free(mtp_batch); + + // CRITICAL: Purge the metadata for the draft token we just wrote. + // This makes the physical cell available again for the main model's validation pass, + // preventing a cache state corruption where two cells map to the same logical position. + llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1); + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int n_vocab = llama_n_vocab(vocab); + llama_token_data_array * cur_p = common_sampler_get_candidates(smpl); + cur_p->size = n_vocab; + for (int i = 0; i < n_vocab; ++i) { + cur_p->data[i].id = i; + cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // For a single-token batch, logits are always at index 0. + } + cur_p->sorted = false; + common_sampler_apply_chain(smpl, cur_p); + + return cur_p->data[0].id; +} + + +void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) { + if (batch.n_tokens == 0) { + return; + } + + LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens); + + llama_batch mtp_batch = batch; + if (is_prompt_warmup) { + mtp_batch.mtp_params.op_type = MTP_OP_WARMUP; + } else { + mtp_batch.mtp_params.op_type = MTP_OP_UPDATE_ACCEPTED; + } + + for (int i = 0; i < mtp_batch.n_tokens; ++i) { + mtp_batch.logits[i] = true; + } + llama_decode(ctx, mtp_batch); +} + +void mtp_accept_tokens( + struct llama_context * ctx, + const std::vector & ids, + int32_t n_past_base, + llama_seq_id seq_id +) { + if (ids.empty()) { + return; + } + + // Prepare a resized copy of the validation sinfo to match the number of accepted tokens. + // This sets up the context for a "forced sinfo" decode. + if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) { + return; + } + + // Build a new batch containing only the accepted tokens. + llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1); + for (size_t i = 0; i < ids.size(); ++i) { + common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true); + } + + mtp_update_kv_cache(ctx, accepted_batch, false); + + // Clean up the forced state to not affect subsequent, normal decode calls. + llama_mtp_cancel_sinfo_update(ctx); + + llama_batch_free(accepted_batch); +} diff --git a/common/speculative.h b/common/speculative.h index e69d7aaa1eb..8b81f4ac77d 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -12,6 +12,12 @@ struct common_speculative_params { float p_min = 0.75f; // min probability required to accept a token in the draft }; +struct mtp_kv_update_data { + llama_token id; + int32_t n_past; + int32_t tok_idx; +}; + struct common_speculative * common_speculative_init( struct llama_context * ctx_tgt, struct llama_context * ctx_dft @@ -27,9 +33,27 @@ void common_speculative_add_replacement_tgt_dft( struct common_speculative * spec, const char *source, const char *dest); + +// sample up to n_draft tokens and add them to the batch using the draft model +llama_token mtp_speculative_gen_draft( + struct common_sampler* smpl, + struct llama_context* ctx, + llama_token id_last, + int32_t n_past, + int32_t last_tok_idx); + // sample up to n_draft tokens and add them to the batch using the draft model llama_tokens common_speculative_gen_draft( - struct common_speculative * spec, - struct common_speculative_params params, - const llama_tokens & prompt, - llama_token id_last); + struct common_speculative * spec, + struct common_speculative_params params, + const llama_tokens & prompt, + llama_token id_last); + +void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup); + +void mtp_accept_tokens( + struct llama_context * ctx, + const std::vector & ids, + int32_t n_past_base, + llama_seq_id seq_id +); diff --git a/include/llama.h b/include/llama.h index 545e957e5f5..0b15d4bf1cd 100644 --- a/include/llama.h +++ b/include/llama.h @@ -221,6 +221,17 @@ extern "C" { // - if not: only the last token is output // ) // + typedef enum { + MTP_OP_NONE, + MTP_OP_WARMUP, + MTP_OP_UPDATE_ACCEPTED, + MTP_OP_DRAFT_GEN, + } llama_mtp_op_type; + + typedef struct llama_mtp_params { + llama_mtp_op_type op_type; + } llama_mtp_params; + typedef struct llama_batch { int32_t n_tokens; @@ -230,6 +241,7 @@ extern "C" { int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" + llama_mtp_params mtp_params; } llama_batch; enum llama_model_kv_override_type { @@ -495,6 +507,8 @@ extern "C" { LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab); + LLAMA_API int32_t llama_model_n_nextn_layer(const struct llama_model * model); + // Functions to access the model's GGUF metadata scalar values // - The functions return the length of the string on success, or -1 on failure // - The output string is always null-terminated and cleared on failure @@ -548,6 +562,8 @@ extern "C" { const char * fname_out, const llama_model_quantize_params * params); + + // // Adapters // @@ -1450,6 +1466,38 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); + // + // MTP + // + + LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state); + + /** + * @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo. + * This is used after speculative validation when only a subset of draft tokens are accepted. + * @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized. + * @return true on success. + */ + LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted); + + /** + * @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode. + * This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned. + * @return true on success. + */ + LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx); + + /** + * @brief Clears the forced sinfo state from the context. Must be called after a decode that used a prepared sinfo. + */ + LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx); + + /** + * @brief Removes KV cache metadata for a specified sequence and token range. + * This makes the physical cells logically available again without deleting the tensor data. + */ + LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); + #ifdef __cplusplus } #endif diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 18dcc6ddfe5..4b6fa3e6059 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -2240,12 +2240,13 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are currently ignored (reserved for future MTP support) // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // Changed to LLM_TENSOR_LAYER_REPEATING because we saved these under a blk with a non-negative id + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 8698d89acec..c01960c55ea 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -275,7 +275,9 @@ bool llama_batch_allocr::init( } } - if (!ok) { + // TEMPORARILY DISABLING THIS SANITY CHECK + // TODO: UNDO THIS IF IT WORKS + /*if (!ok) { LLAMA_LOG_ERROR( "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" @@ -284,7 +286,7 @@ bool llama_batch_allocr::init( __func__, s, s, p0, s, seq_pos_min(s)); return false; - } + }*/ } if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { @@ -832,13 +834,14 @@ struct llama_batch llama_batch_get_one( struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { llama_batch batch = { - /*n_tokens =*/ 0, - /*tokens =*/ nullptr, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*n_seq_id =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*.mtp_params =*/ { MTP_OP_NONE }, }; if (embd) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 26a5cf9c3f8..fb35d6c79de 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -6,6 +6,8 @@ #include "llama-memory.h" #include "llama-mmap.h" #include "llama-model.h" +#include "llama-graph.h" +#include "llama-kv-cache-unified.h" #include #include @@ -15,6 +17,11 @@ // // llama_context // +struct llama_context_kv_cache_data { + llama_kv_cache_unified::slot_info_vec_t last_main_model_sinfos; + llama_kv_cache_unified::slot_info_vec_t resized_sinfo_for_force; + const llama_kv_cache_unified::slot_info_vec_t * forced_sinfos = nullptr; +}; llama_context::llama_context( const llama_model & model, @@ -103,6 +110,8 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + kv_cache_data = new llama_context_kv_cache_data(); + { const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows; @@ -368,6 +377,7 @@ llama_context::llama_context( llama_context::~llama_context() { ggml_opt_free(opt_ctx); + delete static_cast(kv_cache_data); } void llama_context::synchronize() { @@ -522,6 +532,18 @@ float * llama_context::get_logits() { return logits; } +void llama_context::set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i) { + output_reorder(); + + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched_override, logit_override); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + int64_t j = output_ids[i]; + + ggml_backend_tensor_get_async(backend_res, logit_override, logits + j*model.vocab.n_tokens(), 0, model.vocab.n_tokens() * sizeof(float)); +} + float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; @@ -617,6 +639,10 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +ggml_tensor * llama_context::get_embeddings_tensor() { + return embd_tensor; +} + void llama_context::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { @@ -711,7 +737,8 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { +llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret, + const llama_mtp_params & mtp_params) { if (mctx && !mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; @@ -723,7 +750,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll // the new graph parameters // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters - const auto gparams = graph_params(res, ubatch, mctx, gtype); + const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); @@ -754,6 +781,13 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } } + if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation + if (!prepare_mtp_graph_inputs(res, ubatch, mtp_params)) { + ret = GGML_STATUS_FAILED; + return nullptr; + } + } + // set the input data for the input tensors { //const auto t_start_us = ggml_time_us(); @@ -771,7 +805,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } ret = GGML_STATUS_SUCCESS; - + if (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + ggml_tensor * sum_tensor = ggml_get_tensor(res->get_ctx(), "mtp_input_sum"); + } return res; } @@ -832,7 +868,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cparams.causal_attn = false; ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, { MTP_OP_NONE }); cparams.causal_attn = causal_attn_org; @@ -946,6 +982,8 @@ int llama_context::encode(const llama_batch & batch_inp) { int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + auto * kvd = static_cast(kv_cache_data); + if (!memory) { LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); return encode(batch_inp); @@ -1000,10 +1038,11 @@ int llama_context::decode(const llama_batch & batch_inp) { // handle any pending defrags/shifts kv_self_update(false); - llama_memory_context_ptr mctx; + std::unique_ptr mctx; while (true) { - mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + mctx = this->initialize_decode_context(batch_inp, output_all); + if (!mctx) { return -2; } @@ -1015,29 +1054,28 @@ int llama_context::decode(const llama_batch & batch_inp) { case LLAMA_MEMORY_STATUS_NO_UPDATE: { LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status()); - return -2; } case LLAMA_MEMORY_STATUS_FAILED_PREPARE: { + if (kvd->forced_sinfos) { + LLAMA_LOG_ERROR("%s: Mismatch between ubatches and sinfos during reuse.\n", __func__); + return -1; + } + if (!did_optimize) { did_optimize = true; - if (kv_self_update(true)) { LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens()); - continue; } } - LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens()); - return 1; } case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: { LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens()); - return -2; } } @@ -1052,10 +1090,9 @@ int llama_context::decode(const llama_batch & batch_inp) { }; int64_t n_outputs_prev = 0; - + do { const auto & ubatch = mctx->get_ubatch(); - // count the outputs in this ubatch { int32_t n_outputs_new = 0; @@ -1071,10 +1108,8 @@ int llama_context::decode(const llama_batch & batch_inp) { // needs to happen before the graph is built n_outputs = n_outputs_new; } - ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); - + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, batch_inp.mtp_params); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache llama_pos pos_min[LLAMA_MAX_SEQ]; @@ -1120,71 +1155,81 @@ int llama_context::decode(const llama_batch & batch_inp) { // extract logits if (t_logits && n_outputs > 0) { - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); - GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); - - float * logits_out = logits + n_outputs_prev*n_vocab; - - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); - ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + // MTP operations that are purely for updating the KV cache + // (MTP_OP_WARMUP and MTP_OP_UPDATE_ACCEPTED) also produce a logit tensor + // as a side effect of running the graph. If these logits are copied + // back to the main context buffer, they will overwrite the valid logits + // produced by the main model's pass, leading to incorrect sampling. + // This condition explicitly prevents that copy for cache-only operations. + if (batch_inp.mtp_params.op_type != MTP_OP_WARMUP && + batch_inp.mtp_params.op_type != MTP_OP_UPDATE_ACCEPTED) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + float * logits_out = logits + n_outputs_prev*n_vocab; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + } } } // extract embeddings if (t_embd && n_outputs > 0) { - ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); - GGML_ASSERT(backend_embd != nullptr); - - switch (cparams.pooling_type) { - case LLAMA_POOLING_TYPE_NONE: - { - // extract token embeddings - GGML_ASSERT(embd != nullptr); - float * embd_out = embd + n_outputs_prev*n_embd; - - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + if (batch_inp.mtp_params.op_type == MTP_OP_NONE) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + float * embd_out = embd + n_outputs_prev*n_embd; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings (cleared before processing each batch) + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - n_cls_out floats per sequence + auto & embd_seq_out = embd_seq; + const uint32_t n_cls_out = hparams.n_cls_out; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_cls_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); } - } break; - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_CLS: - case LLAMA_POOLING_TYPE_LAST: - { - // extract sequence embeddings (cleared before processing each batch) - auto & embd_seq_out = embd_seq; - - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; - - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_RANK: - { - // extract the rerank score - n_cls_out floats per sequence - auto & embd_seq_out = embd_seq; - - const uint32_t n_cls_out = hparams.n_cls_out; - - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; - - embd_seq_out[seq_id].resize(n_cls_out); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ABORT("unknown pooling type"); - } + } } } @@ -1249,7 +1294,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // overlap with device computation. ggml_backend_sched_reset(sched.get()); } - return 0; } @@ -1389,7 +1433,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE }); res->reset(); @@ -1409,8 +1453,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u llm_graph_params llama_context::graph_params( llm_graph_result * res, const llama_ubatch & ubatch, - const llama_memory_context_i * mctx, - llm_graph_type gtype) const { + const llama_memory_context_i * mctx, + llm_graph_type gtype, + const llama_mtp_params & mtp_params) const { return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, @@ -1423,12 +1468,28 @@ llm_graph_params llama_context::graph_params( /*.loras =*/ &loras, /*.mctx =*/ mctx, /*.cross =*/ &cross, + /*.mtp_params =*/ mtp_params, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, }; } +std::unique_ptr llama_context::mtp_memory_batch(const llama_batch& batch_inp) { + const auto& vocab = model.vocab; + const auto& hparams = model.hparams; + + const int64_t n_vocab = vocab.n_tokens(); + const int64_t n_embd = hparams.n_embd; + + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, false)) { + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + return nullptr; + } + + return memory->init_batch(*balloc, 1, false); +} + ggml_status llama_context::graph_compute( ggml_cgraph * gf, bool batched) { @@ -1456,8 +1517,10 @@ ggml_status llama_context::graph_compute( return status; } -llm_graph_cb llama_context::graph_get_cb() const { - return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { +llm_graph_cb llama_context::graph_get_cb(ggml_backend_sched * sched_override) const { + ggml_backend_sched * cb_sched = sched_override ? sched_override : sched.get(); + + return [=](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { if (il >= 0) { ggml_format_name(cur, "%s-%d", name, il); } else { @@ -1467,7 +1530,7 @@ llm_graph_cb llama_context::graph_get_cb() const { if (!cparams.offload_kqv) { if (strcmp(name, "kqv_merged_cont") == 0) { // all nodes between the KV store and the attention output are run on the CPU - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); + ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend_cpu); } } @@ -1480,7 +1543,7 @@ llm_graph_cb llama_context::graph_get_cb() const { for (const auto & backend : backends) { if (ggml_backend_get_device(backend.get()) == dev_layer) { if (ggml_backend_supports_op(backend.get(), cur)) { - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get()); + ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend.get()); } } } @@ -1489,6 +1552,10 @@ llm_graph_cb llama_context::graph_get_cb() const { }; } +ggml_backend_sched_t llama_context::create_temp_scheduler(size_t n_nodes) { + return ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), n_nodes, false, cparams.op_offload); +} + // // state save/load // @@ -2142,7 +2209,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE }); res->reset(); @@ -2233,6 +2300,7 @@ void llama_context::opt_epoch( llama_batch_free(batch); } + // // interface implementation // @@ -2274,6 +2342,8 @@ llama_context_params llama_context_default_params() { return result; } + + llama_context * llama_init_from_model( llama_model * model, llama_context_params params) { @@ -2412,6 +2482,7 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) { return ctx->get_logits_ith(i); } + float * llama_get_embeddings(llama_context * ctx) { ctx->synchronize(); @@ -2926,3 +2997,122 @@ void llama_opt_epoch( callback_train, callback_eval); } + +void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) { + ctx->draft_input_hidden_state = hidden_state; +} + +bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx) { + auto * kvd = static_cast(ctx->kv_cache_data); + const auto & last_sinfo = kvd->last_main_model_sinfos; + + if (last_sinfo.empty()) { + LLAMA_LOG_ERROR("%s: The main call sinfo is not available for warmup.\n", __func__); + return false; + } + + kvd->forced_sinfos = &last_sinfo; + return true; +} + + +bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted) { + auto * kvd = static_cast(ctx->kv_cache_data); + const auto & last_sinfo = kvd->last_main_model_sinfos; + + if (last_sinfo.empty() || last_sinfo[0].idxs.empty()) { + LLAMA_LOG_ERROR("%s: The sinfo for the last main call is not available.", __func__); + return false; + } + + kvd->resized_sinfo_for_force = last_sinfo; + + kvd->resized_sinfo_for_force[0].idxs[0].resize(n_accepted); + + kvd->forced_sinfos = &kvd->resized_sinfo_for_force; + + return true; +} + +void llama_mtp_cancel_sinfo_update(struct llama_context * ctx) { + auto * kvd = static_cast(ctx->kv_cache_data); + kvd->forced_sinfos = nullptr; +} + +void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (memory) { + static_cast(memory.get())->seq_rm(seq_id, p0, p1); + } +} + +void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + ctx->kv_cache_seq_rm(seq_id, p0, p1); +} + +/* + Initializes the memory context for a decode operation. + The logic follows a specific priority: + 1. Warmup: Always use a standard batch initialization. + 2. Forced S-Info (MTP Updates): If a specific KV cache layout is forced, use it. + 3. Default: Use a standard batch initialization, and if it's a main model pass, + save the resulting s-info for potential future reuse by MTP. +*/ +std::unique_ptr llama_context::initialize_decode_context(const llama_batch & batch_inp, const bool output_all) { + auto * kvd = static_cast(kv_cache_data); + std::unique_ptr mctx; + + if (cparams.warmup) { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + } else if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) { + LLAMA_LOG_DEBUG("%s: Forcing sinfos, bypassing find_slot.\n", __func__); + mctx = static_cast(memory.get())->init_batch_with_sinfos( + *balloc, cparams.n_ubatch, *kvd->forced_sinfos, true + ); + } else { + mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + + if (batch_inp.mtp_params.op_type == MTP_OP_NONE) { + if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) { + kvd->last_main_model_sinfos = static_cast(mctx.get())->get_sinfos(); + } else { + kvd->last_main_model_sinfos.clear(); + } + } + } + + return mctx; +} + + +bool llama_context::prepare_mtp_graph_inputs( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_mtp_params & mtp_params) { + + const char * target_tensor_name = "result_embd_pooled"; + ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name); + + const float * source_hidden_state = nullptr; + if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + source_hidden_state = this->embd; + } else { // MTP_OP_DRAFT_GEN + source_hidden_state = this->draft_input_hidden_state; + } + + if (source_hidden_state != nullptr && hidden_states_input != nullptr) { + const char * op_type; + if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + op_type = "MTP_UPDATE"; + } else { // MTP_OP_DRAFT_GEN + op_type = "DRAFT_GEN"; + } + + ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input)); + } else { + LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n", + __func__, target_tensor_name); + return false; + } + + return true; +} diff --git a/src/llama-context.h b/src/llama-context.h index 25c143d56df..4d77d5d81ae 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -20,6 +20,8 @@ class llama_io_write_i; struct llama_memory_i; struct llama_memory_context_i; +struct llama_context_kv_cache_data; + struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -27,6 +29,15 @@ struct llama_context { llama_context_params params); ~llama_context(); + + // The llama_context manages significant resources (GPU memory, file handles, PImpl data) + // and is fundamentally a non-copyable, non-movable object. Deleting these special + // member functions enforces this rule and is also technically required to allow the + // PImpl pattern (via unique_ptr or void*) with an incomplete type in the header. + llama_context(const llama_context &) = delete; + llama_context & operator=(const llama_context &) = delete; + llama_context(llama_context &&) = delete; + llama_context & operator=(llama_context &&) = delete; void synchronize(); @@ -59,6 +70,9 @@ struct llama_context { float * get_embeddings(); float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + ggml_tensor * get_embeddings_tensor(); + + const float * draft_input_hidden_state = nullptr; void attach_threadpool( ggml_threadpool_t threadpool, @@ -90,6 +104,8 @@ struct llama_context { int32_t il_start, int32_t il_end); + void kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1); + // process a single ubatch with a specific graph type // if memory_context is provided, it will be applied first to the context's memory // ret contains the status of the graph computation @@ -98,7 +114,8 @@ struct llama_context { const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, - ggml_status & ret); + ggml_status & ret, + const llama_mtp_params & mtp_params); int encode(const llama_batch & batch_inp); int decode(const llama_batch & batch_inp); @@ -199,14 +216,32 @@ struct llama_context { // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); + void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i); + + ggml_backend_sched_t create_temp_scheduler(size_t n_nodes); + + std::unique_ptr mtp_memory_batch(const llama_batch& batch_inp); + + // For MTP KV cache cell reuse + void * kv_cache_data; + private: llm_graph_params graph_params( llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx, - llm_graph_type gtype) const; + llm_graph_type gtype, + const llama_mtp_params & mtp_params) const; + + llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const; - llm_graph_cb graph_get_cb() const; + // Methods for MTP decode + std::unique_ptr initialize_decode_context(const llama_batch & batch_inp, const bool output_all); + + bool prepare_mtp_graph_inputs( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_mtp_params & mtp_params); // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); @@ -240,6 +275,7 @@ struct llama_context { // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE size_t embd_size = 0; // capacity (of floats) for embeddings float * embd = nullptr; + ggml_tensor * embd_tensor = nullptr; // sequence embeddings output (map of [n_embd] vectors) // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE @@ -308,3 +344,4 @@ struct llama_context { mutable int32_t n_reused = 0; // number of times the previous graph was reused }; + diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 053c72d6dc8..be7de40454e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1074,6 +1074,26 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { return cur; } + +ggml_tensor * llm_graph_context::build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const { + auto inp = std::make_unique(); + ggml_tensor * cur = nullptr; + + if (ubatch.token) { + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + ggml_set_name(inp->tokens, "mtp_inp_tokens"); + ggml_set_input(inp->tokens); + + cur = ggml_get_rows(ctx0, mtp_tok_embd, inp->tokens); + } else { + GGML_ABORT("fatal error: MTP update expects token IDs, not embeddings"); + } + + cb(cur, "mtp_inp_embd", -1); + res->add_input(std::move(inp)); + return cur; +} + ggml_tensor * llm_graph_context::build_inp_pos() const { auto inp = std::make_unique(hparams.n_pos_per_embd()); diff --git a/src/llama-graph.h b/src/llama-graph.h index 6ff49de3a1c..3c5feadfdc7 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -29,6 +29,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DRAFT, }; enum llm_ffn_op_type { @@ -94,6 +95,20 @@ class llm_graph_input_i { using llm_graph_input_ptr = std::unique_ptr; +class llm_graph_input_mtp_states : public llm_graph_input_i { +public: + llm_graph_input_mtp_states() = default; + virtual ~llm_graph_input_mtp_states() = default; + + void set_input(const llama_ubatch * /*ubatch*/) override {} + + bool can_reuse(const llm_graph_params & /*params*/) override { + return true; + } + + ggml_tensor * states = nullptr; +}; + class llm_graph_input_embd : public llm_graph_input_i { public: llm_graph_input_embd() = default; @@ -402,6 +417,7 @@ struct llm_graph_params { const llama_adapter_loras * loras; const llama_memory_context_i * mctx; const llama_cross * cross; + llama_mtp_params mtp_params; uint32_t n_outputs; @@ -450,6 +466,7 @@ struct llm_graph_params { cvec == other.cvec && loras == other.loras && cross == other.cross && + mtp_params.op_type == other.mtp_params.op_type && n_outputs == other.n_outputs; } }; @@ -664,6 +681,7 @@ struct llm_graph_context { // ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; + ggml_tensor * build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const; ggml_tensor * build_inp_pos() const; ggml_tensor * build_inp_attn_scale() const; ggml_tensor * build_inp_out_ids() const; @@ -818,3 +836,4 @@ struct llm_graph_context { // TODO: better name int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional); + diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index e539142e6b8..8d9b1f631f7 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -41,7 +41,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( } if (model.arch == LLM_ARCH_GLM4_MOE) { // GLM-4.5: Only process up to last layer, skip final NextN layer - n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers; + n_layer_cache = hparams.n_layer;// - hparams.nextn_predict_layers; } // create a context for each buffer type @@ -508,6 +508,34 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } +llama_memory_context_ptr llama_kv_cache_unified::init_batch_with_sinfos( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + const slot_info_vec_t & sinfos, + bool is_inplace_update) { + + if (sinfos.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + balloc.split_reset(); + std::vector ubatches; + while (true) { + auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); + if (ubatch.n_tokens == 0) { + break; + } + ubatches.push_back(std::move(ubatch)); + } + + if (ubatches.size() != sinfos.size()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, sinfos, std::move(ubatches), is_inplace_update); +} + llama_memory_context_ptr llama_kv_cache_unified::init_full() { return std::make_unique(this); } @@ -928,64 +956,81 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } assert(res.s1 >= res.s0); + if (!res.empty()) { + std::string idxs_str; + for (const auto& vec : res.idxs) { + if (!vec.empty()) { + if (vec.size() > 8) { + idxs_str += " [" + std::to_string(vec.front()) + "..." + std::to_string(vec.back()) + " (" + std::to_string(vec.size()) + " cells)]"; + } else { + idxs_str += " ["; + for(size_t i = 0; i < vec.size(); ++i) { + idxs_str += std::to_string(vec[i]) + (i == vec.size() - 1 ? "" : ", "); + } + idxs_str += "]"; + } + } + } + } return res; } -void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { - // keep track of the max sequence position that we would overwrite with this ubatch - // for non-SWA cache, this would be always empty - llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; - for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { - seq_pos_max_rm[s] = -1; - } - - assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); - - for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { - for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { - const uint32_t i = s*sinfo.size() + ii; - - auto & cells = v_cells[sinfo.strm[s]]; +void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update) { + // For "in-place" updates (MTP warmup/accept), we only update the tensor data. + // The cell metadata (logical position, sequence ID) has already been set + // by the main model's pass. We must skip all metadata modifications + // to prevent `pos_set` from asserting on an already-set cell. + if (!is_inplace_update) { + // keep track of the max sequence position that we would overwrite with this ubatch + // for non-SWA cache, this would be always empty + llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + seq_pos_max_rm[s] = -1; + } - const auto idx = sinfo.idxs[s][ii]; + assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); - if (!cells.is_empty(idx)) { - assert(cells.seq_count(idx) == 1); + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { + const uint32_t i = s*sinfo.size() + ii; - const llama_seq_id seq_id = cells.seq_get(idx); - const llama_pos pos = cells.pos_get(idx); + auto & cells = v_cells[sinfo.strm[s]]; - seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + const auto idx = sinfo.idxs[s][ii]; - cells.rm(idx); - } + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + cells.rm(idx); + } - cells.pos_set(idx, ubatch.pos[i]); + cells.pos_set(idx, ubatch.pos[i]); - for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { - cells.seq_add(idx, ubatch.seq_id[i][s]); + for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { + cells.seq_add(idx, ubatch.seq_id[i][s]); + } } } - } - // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence - // will be present in the cache. so we have to purge any position which is less than those we would overwrite - // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { - if (seq_pos_max_rm[s] == -1) { - continue; - } + // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence + // will be present in the cache. so we have to purge any position which is less than those we would overwrite + // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq_pos_max_rm[s] == -1) { + continue; + } - GGML_ASSERT(s < seq_to_stream.size()); + GGML_ASSERT(s < seq_to_stream.size()); - auto & cells = v_cells[seq_to_stream[s]]; + auto & cells = v_cells[seq_to_stream[s]]; - if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { - LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", - __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { - seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + } } } @@ -2290,7 +2335,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_kv_cache_unified::slot_info_vec_t sinfos, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) { + std::vector ubatches, + bool is_inplace_update) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)), is_inplace_update(is_inplace_update) { } llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; @@ -2315,13 +2361,18 @@ bool llama_kv_cache_unified_context::apply() { return true; } - kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]); + kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur], is_inplace_update); n_kv = kv->get_n_kv(); return true; } +void llama_kv_cache_unified_context::set_n_kv() { + n_kv = kv->get_n_kv(); +} + + llama_memory_status llama_kv_cache_unified_context::get_status() const { return status; } @@ -2384,6 +2435,10 @@ void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, con kv->set_input_pos_bucket(dst, ubatch); } +void llama_kv_cache_unified_context::set_sinfos(llama_kv_cache_unified::slot_info_vec_t new_sinfos) { + sinfos = new_sinfos; +} + uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { // the FA kernels require padding to avoid extra runtime boundary checks return cparams.flash_attn ? 256u : 32u; diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 342a675962e..f64f7faa5c0 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -116,6 +116,12 @@ class llama_kv_cache_unified : public llama_memory_i { llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) override; + + llama_memory_context_ptr init_batch_with_sinfos( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + const slot_info_vec_t & sinfos, + bool is_inplace_update); llama_memory_context_ptr init_full() override; @@ -181,7 +187,7 @@ class llama_kv_cache_unified : public llama_memory_i { slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] - void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); + void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update = false); // // input API @@ -321,7 +327,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { llama_kv_cache_unified_context( llama_kv_cache_unified * kv, slot_info_vec_t sinfos, - std::vector ubatches); + std::vector ubatches, + bool is_inplace_update = false); virtual ~llama_kv_cache_unified_context(); @@ -340,6 +347,7 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { // uint32_t get_n_kv() const; + void set_n_kv(); // TODO: temporary bool get_supports_set_rows() const; @@ -362,6 +370,10 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_sinfos(slot_info_vec_t new_sinfos); + + const slot_info_vec_t & get_sinfos() const { return sinfos; } + private: llama_memory_status status; @@ -396,4 +408,6 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { // a heuristic, to avoid attending the full cache if it is not yet utilized // as the cache gets filled, the benefit from this heuristic disappears int32_t n_kv; + + bool is_inplace_update = false; }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 58ca7df707e..ab7daee356a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4507,9 +4507,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // but only PROCESS up to last layer (skipping final NextN layer) in forward pass for (int i = 0; i < n_layer; ++i) { int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; + // flags |= TENSOR_SKIP; } auto & layer = layers[i]; @@ -4573,12 +4574,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + + // our input/output layer sanity check prevents us from loading the eh_proj layer! + // this is because eh_proj is labelled with a layer number in existing GGUFs, + // so we need to set bid == to successfully load the tensors, but our io layer sanity check requires bid == -1. + // this function is a hack that creates the nextn layers as LLM_TENSOR_LAYER_REPEATING instead. + /* auto create_tensor_override_io_sanity_check = + [&](llm_tensor type_enum, const char * suffix, int bid, const std::initializer_list& ne, int flags) -> ggml_tensor * { + + auto tn_orig = tn(type_enum, suffix, bid); + llm_tensor_info info_override = *tn_orig.info; + info_override.layer = LLM_TENSOR_LAYER_REPEATING; + + auto tn_override = tn_orig; + tn_override.info = &info_override; + + return create_tensor(tn_override, ne, flags); + };*/ + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags); layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); + + // layer.nextn.eh_proj = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i, { 2 * n_embd, n_embd }, flags); + // layer.nextn.embed_tokens = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i, { n_embd, n_vocab }, flags); + // layer.nextn.enorm = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_ENORM, "weight", i, { n_embd }, flags); + // layer.nextn.hnorm = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_HNORM, "weight", i, { n_embd }, flags); + // layer.nextn.shared_head_head = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i, { n_embd, n_vocab }, flags); + // layer.nextn.shared_head_norm = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i, { n_embd }, flags); } } } @@ -13763,159 +13789,294 @@ struct llm_build_glm4 : public llm_graph_context { struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); ggml_tensor * cur; - ggml_tensor * inpL; - inpL = build_inp_embd(model.tok_embd); + if (params.mtp_params.op_type != MTP_OP_NONE) { + ggml_tensor* hidden_states_from_main_model; - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); + if (params.mtp_params.op_type == MTP_OP_WARMUP || params.mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) { + hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); + ggml_set_input(hidden_states_from_main_model); - auto * inp_attn = build_attn_inp_kv_unified(); + auto inp_mtp = std::make_unique(); + inp_mtp->states = hidden_states_from_main_model; + res->add_input(std::move(inp_mtp)); + } else { + hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd); + ggml_set_name(hidden_states_from_main_model, "result_embd_pooled"); + ggml_set_input(hidden_states_from_main_model); - ggml_tensor * inp_out_ids = build_inp_out_ids(); + auto inp_mtp = std::make_unique(); + inp_mtp->states = hidden_states_from_main_model; + res->add_input(std::move(inp_mtp)); + } - // Only process up to last layer (skip final NextN layer) - // Final layer tensors are loaded but not processed in forward pass - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { - ggml_tensor * inpSA = inpL; + const int il_mtp = hparams.n_layer - 1; + const auto & mtp_layer = model.layers[il_mtp]; + res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head); - // Pre-attention norm - cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); + } else { + ggml_tensor * inpL = build_inp_embd(model.tok_embd); + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_unified(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // Only process up to last layer (skip final NextN layer) + // Final layer tensors are loaded but not processed in forward pass + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); - // self-attention - { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - cb(Qcur, "Qcur", il); + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + cb(Qcur, "Qcur", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - cb(Kcur, "Kcur", il); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + cb(Kcur, "Kcur", il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - cb(Vcur, "Vcur", il); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - // Apply Q/K norm if available (GLM-4.5 355B variant) - if (model.layers[il].attn_q_norm) { - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } - if (model.layers[il].attn_k_norm) { - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); - cb(Kcur, "Kcur_normed", il); + + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + // Post-attention norm + cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) + if (static_cast(il) < hparams.n_layer_dense_lead) { + // Dense FFN layer + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); - } + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } - if (il == n_transformer_layers - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); + cur = inpL; + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); - // Post-attention norm - cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "post_attn_norm", il); + // cb(cur, "result_norm", -1); + res->t_embd = cur; - // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) - if (static_cast(il) < hparams.n_layer_dense_lead) { - // Dense FFN layer - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - } else { - // Process routed experts using existing MoE infrastructure - ggml_tensor * routed_out = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, - n_expert, n_expert_used, - LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, - (llama_expert_gating_func_type) hparams.expert_gating_func, - il); - cb(routed_out, "ffn_moe_out", il); + // Use the main model header + res->t_logits = build_lora_mm(model.output, cur); + } - // Process shared expert on original input - ggml_tensor * shared_out = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(shared_out, "ffn_shexp_out", il); + ggml_build_forward_expand(gf, res->t_logits); - // Final output: routed_output + shared_output - cur = ggml_add(ctx0, routed_out, shared_out); - cb(cur, "ffn_out", il); + } + +private: + ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, + int64_t n_embd_head + ) { + ggml_tensor * embd_copy = ggml_dup(ctx0, prev_embeddings); + + const int il = hparams.n_layer - 1; + ggml_tensor * sum_node = ggml_sum(ctx0, embd_copy); + + ggml_set_name(sum_node, "mtp_input_sum"); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_unified(); + ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens); + + ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * hidden_state_norm = build_norm(embd_copy, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + + ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); + ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); + + // now proceed through last layer (skipped in main model) + ggml_tensor * inpSA = cur; + // Pre-attention norm for the MTP block + cur = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur); + if (mtp_layer.bq) Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur); + if (mtp_layer.bk) Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur); + if (mtp_layer.bv) Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (mtp_layer.attn_q_norm) { + Qcur = build_norm(Qcur, mtp_layer.attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (mtp_layer.attn_k_norm) { + Kcur = build_norm(Kcur, mtp_layer.attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); } - cur = ggml_add(ctx0, cur, ffn_inp); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - cur = build_cvec(cur, il); - cb(cur, "l_out", il); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - // input for next layer - inpL = cur; + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + mtp_layer.wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } - cur = inpL; - cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(cur, "result_norm", -1); - res->t_embd = cur; + cur = build_norm(ffn_inp, mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, il); - // lm_head - cur = build_lora_mm(model.output, cur); + // moe ffn for nextn block + { + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + mtp_layer.ffn_gate_inp, + mtp_layer.ffn_up_exps, + mtp_layer.ffn_gate_exps, + mtp_layer.ffn_down_exps, + mtp_layer.ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); - cb(cur, "result_output", -1); - res->t_logits = cur; + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + mtp_layer.ffn_up_shexp, NULL, NULL, + mtp_layer.ffn_gate_shexp, NULL, NULL, + mtp_layer.ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); - ggml_build_forward_expand(gf, cur); + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); + cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur); + + return cur; } }; @@ -18144,8 +18305,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { + std::unique_ptr llm; - switch (arch) { case LLM_ARCH_LLAMA: { @@ -18503,9 +18664,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { GGML_ABORT("fatal error"); } - // add on pooling layer - llm->build_pooling(cls, cls_b, cls_out, cls_out_b); - + if (params.mtp_params.op_type == MTP_OP_NONE) { + // add on pooling layer + llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + } return llm->res->get_gf(); } @@ -18587,6 +18749,10 @@ const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) return nullptr; } +int32_t llama_model_n_nextn_layer(const llama_model * model) { + return model->hparams.nextn_predict_layers; +} + // deprecated int32_t llama_n_ctx_train(const llama_model * model) { return llama_model_n_ctx_train(model); @@ -18820,3 +18986,4 @@ bool llama_model_is_diffusion(const llama_model * model) { const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } + diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a255d481a4d..a24532c6939 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1278,6 +1278,7 @@ struct server_task_result_apply_lora : server_task_result { } }; + struct server_slot { int id; int id_task = -1; @@ -1294,6 +1295,8 @@ struct server_slot { mtmd_context * mctx = nullptr; common_speculative * spec = nullptr; + bool has_mtp = false; + int32_t last_tok_idx = -1; std::vector lora; @@ -1391,7 +1394,7 @@ struct server_slot { } bool need_embd() const { - return server_task_type_need_embd(task_type); + return server_task_type_need_embd(task_type) || has_mtp; } bool need_logits() const { @@ -1401,9 +1404,14 @@ struct server_slot { // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens bool can_split() const { + //fprintf(stderr, "need_embd() %d\n", need_embd()); + //fprintf(stderr, "llama_get_memory(ctx) %d\n", llama_get_memory(ctx) != nullptr); + //fprintf(stderr, "POOLING_TYPE check %d\n", llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + return !need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) || + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE); // this seems to save embeddings for whole batch? } bool can_batch_with(server_slot & other_slot) const { @@ -1431,7 +1439,8 @@ struct server_slot { } bool can_speculate() const { - return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; + return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; + // return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt; } void add_token(const completion_token_output & token) { @@ -1566,6 +1575,7 @@ struct server_slot { } }; + struct server_metrics { int64_t t_start = 0; @@ -2122,6 +2132,22 @@ struct server_context { } } + // if model has MTP and no draft model is specified... + else if (llama_model_n_nextn_layer(model) > 0) { + SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model)); + slot.has_mtp = true; + + // assume one speculative token (true of all well-known MTP models so far) + slot.batch_spec = llama_batch_init(2, 0, 1); + SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens); + + params_base.speculative.n_min = 0; + params_base.speculative.n_max = 1; + + SRV_INF("%s\n", "MTP needs embeddings on decode, enabling"); + llama_set_embeddings(ctx, true); + } + SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); slot.params.sampling = params_base.sampling; @@ -3368,6 +3394,7 @@ struct server_context { const bool need_embd = server_task_type_need_embd(slot.task_type); common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); + slot.cache_tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; @@ -3482,6 +3509,21 @@ struct server_context { continue; // continue loop of n_batch } + if (slot_batched && slot_batched->has_mtp && + (slot_batched->state == SLOT_STATE_PROCESSING_PROMPT || slot_batched->state == SLOT_STATE_DONE_PROMPT)) { + + // Prepare the context to reuse the exact sinfo layout (including multiple u-batches) + // from the main model's prompt processing pass. This ensures the MTP layer's + // KV cache is perfectly aligned. + if (llama_mtp_prepare_sinfo_for_warmup(ctx)) { + mtp_update_kv_cache(ctx, batch_view, true); + // Clean up the forced state to not affect subsequent decodes. + llama_mtp_cancel_sinfo_update(ctx); + } else { + LOG_ERR("%s: Failed to prepare the MTP for warmup.", __func__); + } + } + // move the head of the batch forward with the number of tokens we just processed i_next = i + n_tokens; @@ -3516,11 +3558,14 @@ struct server_context { } const int tok_idx = slot.i_batch - i; - + // Sets the initial state for the first draft generation. + if (slot.has_mtp) { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, -1)); + } llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + slot.last_tok_idx = tok_idx; slot.i_batch = -1; - common_sampler_accept(slot.smpl, id, true); slot.n_decoded += 1; @@ -3590,23 +3635,33 @@ struct server_context { llama_token id = slot.sampled; - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; - params_spec.p_min = slot.params.speculative.p_min; + llama_tokens draft; + if (slot.has_mtp) { + llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + draft.reserve(1); + draft.push_back(draft_id); + } + else { + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; + + const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens(); - const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + } // ignore small drafts - if (slot.params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + if (slot.params.speculative.n_min > (int)draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min); continue; } // keep track of total number of drafted tokens tested slot.n_draft_total += draft.size(); + SLT_DBG(slot, "draft size = %d\n", draft.size()); // construct the speculation batch common_batch_clear(slot.batch_spec); @@ -3617,11 +3672,22 @@ struct server_context { } SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); - llama_decode(ctx, slot.batch_spec); // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + + if (slot.has_mtp) { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1)); + + if (!ids.empty()) { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1)); + } else { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, 0)); + } + + mtp_accept_tokens(ctx, ids, slot.n_past, slot.id); + } slot.n_past += ids.size(); slot.n_decoded += ids.size();