diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index 3f7f3a11dfa..37f5b3db713 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "../../src/llama-ext.h" #include "gguf.h" #include @@ -916,7 +917,7 @@ static void process_logits( } } -static bool compute_imatrix(llama_context * ctx, const common_params & params, const int32_t n_ctx) { +static bool compute_imatrix(llama_context * ctx, llama_context * ctx_mtp, const common_params & params, const int32_t n_ctx) { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -926,6 +927,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); } + const int32_t n_embd = ctx_mtp ? llama_model_n_embd(model) : 0; + auto tim1 = std::chrono::high_resolution_clock::now(); LOG_INF("%s: tokenizing the input ..\n", __func__); @@ -975,12 +978,28 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1); + // MTP batch: needs both tokens and embeddings + // llama_batch_init with embd>0 allocates embd but not token, so we init with embd=0 + // to get token, and manage embd separately via mtp_embd_buf + llama_batch batch_mtp = {}; + std::vector mtp_embd_buf; + std::vector mtp_pending_h; + if (ctx_mtp) { + const int mtp_n_tokens = std::min(n_batch, n_ctx*n_seq); + batch_mtp = llama_batch_init(mtp_n_tokens, 0, 1); + mtp_embd_buf.resize((size_t) mtp_n_tokens * n_embd, 0.0f); + mtp_pending_h.resize(n_seq * n_embd, 0.0f); + } + std::vector logits; if (params.compute_ppl && num_batches > 1) { logits.reserve((size_t)n_ctx * n_vocab); } LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); + if (ctx_mtp) { + LOG_INF("%s: MTP context active - will also collect imatrix for MTP layers\n", __func__); + } std::vector workers(std::thread::hardware_concurrency() - 1); @@ -994,6 +1013,10 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c // clear the KV cache llama_memory_clear(llama_get_memory(ctx), true); + if (ctx_mtp) { + llama_memory_clear(llama_get_memory(ctx_mtp), true); + std::fill(mtp_pending_h.begin(), mtp_pending_h.end(), 0.0f); + } for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -1027,9 +1050,64 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c if (llama_decode(ctx, batch)) { LOG_ERR("%s : failed to eval\n", __func__); llama_batch_free(batch); + if (ctx_mtp) { + llama_batch_free(batch_mtp); + } return false; } + // Also decode through MTP context to collect imatrix for MTP layers + if (ctx_mtp && batch_size >= 16) { + const size_t row_bytes = (size_t) n_embd * sizeof(float); + const float * h_tgt = llama_get_embeddings_pre_norm(ctx); + + const int mtp_n_tokens = n_seq_batch * batch_size; + + common_batch_clear(batch_mtp); + batch_mtp.n_tokens = mtp_n_tokens; + + for (int k = 0; k < mtp_n_tokens; ++k) { + batch_mtp.token[k] = batch.token[k]; + batch_mtp.pos[k] = batch.pos[k]; + batch_mtp.n_seq_id[k] = batch.n_seq_id[k]; + for (int s = 0; s < batch.n_seq_id[k]; ++s) { + batch_mtp.seq_id[k][s] = batch.seq_id[k][s]; + } + batch_mtp.logits[k] = true; + } + + // Fill embeddings: shift h_tgt right by one position. + // Position k gets the embedding from position k-1 (previous step's hidden state). + // Position 0 gets the pending_h from the previous batch (or zeros for the first batch). + // h_tgt layout: [seq0_batch_tokens..., seq1_batch_tokens..., ...] each of size batch_size*n_embd + for (int seq = 0; seq < n_seq_batch; ++seq) { + float * dst = mtp_embd_buf.data() + (size_t) seq * batch_size * n_embd; + const float * src = h_tgt + (size_t) seq * batch_size * n_embd; + // First token in this sequence gets the pending h from the previous batch + memcpy(dst, mtp_pending_h.data() + (size_t) seq * n_embd, row_bytes); + // Remaining tokens get the shifted h_tgt + if (batch_size > 1) { + memcpy(dst + n_embd, src, (size_t) (batch_size - 1) * row_bytes); + } + } + + // Point batch_mtp.embd to our buffer (all sequences contiguous) + batch_mtp.embd = mtp_embd_buf.data(); + + if (llama_decode(ctx_mtp, batch_mtp)) { + LOG_WRN("%s: MTP decode failed - skipping MTP imatrix collection\n", __func__); + } + + // Save last h for each sequence as pending for the next batch + for (int seq = 0; seq < n_seq_batch; ++seq) { + const float * last_h = h_tgt + (size_t) seq * batch_size * n_embd + (size_t) (batch_size - 1) * n_embd; + memcpy(mtp_pending_h.data() + (size_t) seq * n_embd, last_h, row_bytes); + } + + // Reset embd pointer since batch_mtp doesn't own the buffer + batch_mtp.embd = nullptr; + } + if (params.compute_ppl && num_batches > 1) { const auto * batch_logits = llama_get_logits(ctx); logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); @@ -1089,6 +1167,9 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c } llama_batch_free(batch); + if (ctx_mtp) { + llama_batch_free(batch_mtp); + } return true; } @@ -1291,6 +1372,24 @@ int main(int argc, char ** argv) { return 1; } + // Enable pre-norm embeddings collection on the main context + // so we can feed them to the MTP context + llama_set_embeddings_pre_norm(ctx, true, /*masked*/ false); + + // Try to create an MTP context for collecting imatrix data for MTP layers + llama_context * ctx_mtp = nullptr; + { + auto cparams_mtp = common_context_params_to_llama(params); + cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP; + cparams_mtp.cb_eval = params.cb_eval; + cparams_mtp.cb_eval_user_data = params.cb_eval_user_data; + ctx_mtp = llama_init_from_model(model, cparams_mtp); + if (ctx_mtp) { + llama_set_embeddings_pre_norm(ctx_mtp, true, /*masked*/ true); + LOG_INF("%s: created MTP context for imatrix collection\n", __func__); + } + } + const int n_ctx_train = llama_model_n_ctx_train(model); if (params.n_ctx > n_ctx_train) { LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", @@ -1303,7 +1402,7 @@ int main(int argc, char ** argv) { LOG_INF("%s\n", common_params_get_system_info(params).c_str()); } - if (!compute_imatrix(ctx, params, n_ctx)) { + if (!compute_imatrix(ctx, ctx_mtp, params, n_ctx)) { return 1; } @@ -1311,6 +1410,10 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); + if (ctx_mtp) { + llama_perf_context_print(ctx_mtp); + llama_free(ctx_mtp); + } llama_backend_free();