Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 105 additions & 2 deletions tools/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "../../src/llama-ext.h"
#include "gguf.h"

#include <algorithm>
Expand Down Expand Up @@ -916,7 +917,7 @@ static void process_logits(
}
}

static bool compute_imatrix(llama_context * ctx, const common_params & params, const int32_t n_ctx) {
static bool compute_imatrix(llama_context * ctx, llama_context * ctx_mtp, const common_params & params, const int32_t n_ctx) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);

Expand All @@ -926,6 +927,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
}

const int32_t n_embd = ctx_mtp ? llama_model_n_embd(model) : 0;

auto tim1 = std::chrono::high_resolution_clock::now();
LOG_INF("%s: tokenizing the input ..\n", __func__);

Expand Down Expand Up @@ -975,12 +978,28 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c

llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);

// MTP batch: needs both tokens and embeddings
// llama_batch_init with embd>0 allocates embd but not token, so we init with embd=0
// to get token, and manage embd separately via mtp_embd_buf
llama_batch batch_mtp = {};
std::vector<float> mtp_embd_buf;
std::vector<float> mtp_pending_h;
if (ctx_mtp) {
const int mtp_n_tokens = std::min(n_batch, n_ctx*n_seq);
batch_mtp = llama_batch_init(mtp_n_tokens, 0, 1);
mtp_embd_buf.resize((size_t) mtp_n_tokens * n_embd, 0.0f);
mtp_pending_h.resize(n_seq * n_embd, 0.0f);
}

std::vector<float> logits;
if (params.compute_ppl && num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab);
}

LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
if (ctx_mtp) {
LOG_INF("%s: MTP context active - will also collect imatrix for MTP layers\n", __func__);
}

std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);

Expand All @@ -994,6 +1013,10 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c

// clear the KV cache
llama_memory_clear(llama_get_memory(ctx), true);
if (ctx_mtp) {
llama_memory_clear(llama_get_memory(ctx_mtp), true);
std::fill(mtp_pending_h.begin(), mtp_pending_h.end(), 0.0f);
}

for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
Expand Down Expand Up @@ -1027,9 +1050,64 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
llama_batch_free(batch);
if (ctx_mtp) {
llama_batch_free(batch_mtp);
}
return false;
}

// Also decode through MTP context to collect imatrix for MTP layers
if (ctx_mtp && batch_size >= 16) {
const size_t row_bytes = (size_t) n_embd * sizeof(float);
const float * h_tgt = llama_get_embeddings_pre_norm(ctx);

const int mtp_n_tokens = n_seq_batch * batch_size;

common_batch_clear(batch_mtp);
batch_mtp.n_tokens = mtp_n_tokens;

for (int k = 0; k < mtp_n_tokens; ++k) {
batch_mtp.token[k] = batch.token[k];
batch_mtp.pos[k] = batch.pos[k];
batch_mtp.n_seq_id[k] = batch.n_seq_id[k];
for (int s = 0; s < batch.n_seq_id[k]; ++s) {
batch_mtp.seq_id[k][s] = batch.seq_id[k][s];
}
batch_mtp.logits[k] = true;
}

// Fill embeddings: shift h_tgt right by one position.
// Position k gets the embedding from position k-1 (previous step's hidden state).
// Position 0 gets the pending_h from the previous batch (or zeros for the first batch).
// h_tgt layout: [seq0_batch_tokens..., seq1_batch_tokens..., ...] each of size batch_size*n_embd
for (int seq = 0; seq < n_seq_batch; ++seq) {
float * dst = mtp_embd_buf.data() + (size_t) seq * batch_size * n_embd;
const float * src = h_tgt + (size_t) seq * batch_size * n_embd;
// First token in this sequence gets the pending h from the previous batch
memcpy(dst, mtp_pending_h.data() + (size_t) seq * n_embd, row_bytes);
// Remaining tokens get the shifted h_tgt
if (batch_size > 1) {
memcpy(dst + n_embd, src, (size_t) (batch_size - 1) * row_bytes);
}
}

// Point batch_mtp.embd to our buffer (all sequences contiguous)
batch_mtp.embd = mtp_embd_buf.data();

if (llama_decode(ctx_mtp, batch_mtp)) {
LOG_WRN("%s: MTP decode failed - skipping MTP imatrix collection\n", __func__);
}

// Save last h for each sequence as pending for the next batch
for (int seq = 0; seq < n_seq_batch; ++seq) {
const float * last_h = h_tgt + (size_t) seq * batch_size * n_embd + (size_t) (batch_size - 1) * n_embd;
memcpy(mtp_pending_h.data() + (size_t) seq * n_embd, last_h, row_bytes);
}

// Reset embd pointer since batch_mtp doesn't own the buffer
batch_mtp.embd = nullptr;
}

if (params.compute_ppl && num_batches > 1) {
const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
Expand Down Expand Up @@ -1089,6 +1167,9 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
}

llama_batch_free(batch);
if (ctx_mtp) {
llama_batch_free(batch_mtp);
}

return true;
}
Expand Down Expand Up @@ -1291,6 +1372,24 @@ int main(int argc, char ** argv) {
return 1;
}

// Enable pre-norm embeddings collection on the main context
// so we can feed them to the MTP context
llama_set_embeddings_pre_norm(ctx, true, /*masked*/ false);

// Try to create an MTP context for collecting imatrix data for MTP layers
llama_context * ctx_mtp = nullptr;
{
auto cparams_mtp = common_context_params_to_llama(params);
cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
cparams_mtp.cb_eval = params.cb_eval;
cparams_mtp.cb_eval_user_data = params.cb_eval_user_data;
ctx_mtp = llama_init_from_model(model, cparams_mtp);
if (ctx_mtp) {
llama_set_embeddings_pre_norm(ctx_mtp, true, /*masked*/ true);
LOG_INF("%s: created MTP context for imatrix collection\n", __func__);
}
}

const int n_ctx_train = llama_model_n_ctx_train(model);
if (params.n_ctx > n_ctx_train) {
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n",
Expand All @@ -1303,14 +1402,18 @@ int main(int argc, char ** argv) {
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
}

if (!compute_imatrix(ctx, params, n_ctx)) {
if (!compute_imatrix(ctx, ctx_mtp, params, n_ctx)) {
return 1;
}

g_collector.save_imatrix();

LOG("\n");
llama_perf_context_print(ctx);
if (ctx_mtp) {
llama_perf_context_print(ctx_mtp);
llama_free(ctx_mtp);
}

llama_backend_free();

Expand Down