diff --git a/common/arg.cpp b/common/arg.cpp index 84b3c8f962da..07a4e4bb2aea 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2764,6 +2764,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.parse_special = true; } ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--mtp"}, + string_format("also activate the MTP/NextN draft head during imatrix collection so its tensors " + "(blk..nextn.eh_proj etc.) receive activations. No-op if the model has no MTP layers. " + "(default: %s)", params.imat_mtp ? "true" : "false"), + [](common_params & params) { + params.imat_mtp = true; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); add_opt(common_arg( {"-pps"}, string_format("is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false"), diff --git a/common/common.h b/common/common.h index 4cca9d715680..b37d4ae9599c 100644 --- a/common/common.h +++ b/common/common.h @@ -681,6 +681,7 @@ struct common_params { bool compute_ppl = true; // whether to compute perplexity bool show_statistics = false; // show imatrix statistics per tensor bool parse_special = false; // whether to parse special tokens during imatrix tokenization + bool imat_mtp = false; // also activate the MTP/NextN draft head so its tensors get imatrix data // cvector-generator params int n_pca_batch = 100; diff --git a/include/llama.h b/include/llama.h index 75095b22d08f..ba940554780b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -562,6 +562,10 @@ extern "C" { LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); + // Number of MTP / NextN draft head layers bundled with the model (0 if none). + // Used by callers (e.g. imatrix) that need to know whether the model carries an MTP head. + LLAMA_API int32_t llama_model_n_nextn (const struct llama_model * model); + // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8bf20a716eba..3cadf13400ec 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2182,6 +2182,10 @@ int32_t llama_model_n_swa(const llama_model * model) { return model->hparams.n_swa; } +int32_t llama_model_n_nextn(const llama_model * model) { + return (int32_t) model->hparams.nextn_predict_layers; +} + uint32_t llama_model_n_cls_out(const struct llama_model * model) { return model->hparams.n_cls_out; diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index 3f7f3a11dfa3..45e405be3dfb 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "../../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP) #include "gguf.h" #include @@ -916,7 +917,53 @@ static void process_logits( } } -static bool compute_imatrix(llama_context * ctx, const common_params & params, const int32_t n_ctx) { +// Run a forward pass through the MTP/NextN draft head so its weights +// (blk..nextn.eh_proj etc.) receive activations and get recorded by +// the imatrix collector. Mirrors common_speculative_state_draft_mtp::process(): +// the MTP head at position p is fed the next-token id (tokens[p+1]) paired +// with the trunk's pre-norm hidden state h[p]. The last position of the +// chunk has no next-token target and is dropped. +static bool compute_imatrix_mtp( + llama_context * ctx_tgt, + llama_context * ctx_mtp, + const llama_token * tokens, // n_tokens consecutive tokens (covers this batch) + int32_t n_tokens, + int32_t pos_first, // absolute position of tokens[0] in the chunk + int32_t n_embd, + llama_seq_id seq_id, + llama_batch & mtp_batch) { // pre-allocated, embd-capable batch (token+embd both alloc'd) + if (n_tokens < 2) { + return true; // need at least one (h[p], token[p+1]) pair + } + const int32_t n_pairs = n_tokens - 1; + + const size_t row_bytes = (size_t) n_embd * sizeof(float); + + common_batch_clear(mtp_batch); + + for (int32_t k = 0; k < n_pairs; ++k) { + // MTP position p+1 carries the next-token id and h[p] from the trunk. + common_batch_add(mtp_batch, tokens[k + 1], pos_first + k + 1, { seq_id }, false); + } + + // Fill h[p] rows from the trunk's pre-norm output. + for (int32_t k = 0; k < n_pairs; ++k) { + const float * h = llama_get_embeddings_pre_norm_ith(ctx_tgt, k); + if (h == nullptr) { + LOG_ERR("%s: trunk did not produce pre-norm embedding at row %d (was output enabled?)\n", __func__, k); + return false; + } + std::memcpy(mtp_batch.embd + (size_t) k * n_embd, h, row_bytes); + } + + if (llama_decode(ctx_mtp, mtp_batch) != 0) { + LOG_ERR("%s: llama_decode(ctx_mtp) failed\n", __func__); + return false; + } + return true; +} + +static bool compute_imatrix(llama_context * ctx, llama_context * ctx_mtp, const common_params & params, const int32_t n_ctx) { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -975,15 +1022,37 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1); + // Optional MTP/NextN draft head batch. Only used when ctx_mtp != nullptr. + // llama_batch_init() only allocates one of token/embd; MTP needs both, so we + // patch in a token buffer alongside (same trick as common/speculative.cpp). + llama_batch mtp_batch = {}; + bool mtp_enabled = (ctx_mtp != nullptr); + const int n_embd = llama_model_n_embd(model); + if (mtp_enabled) { + if (n_seq != 1) { + LOG_WRN("%s: --mtp is only supported with n_seq=1 (one sequence per batch); disabling MTP collection\n", __func__); + mtp_enabled = false; + } else { + mtp_batch = llama_batch_init(std::min(n_batch, n_ctx), n_embd, 1); + mtp_batch.token = (llama_token *) malloc(sizeof(llama_token) * std::min(n_batch, n_ctx)); + } + } + std::vector logits; if (params.compute_ppl && num_batches > 1) { logits.reserve((size_t)n_ctx * n_vocab); } - LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); + LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d%s\n", + __func__, n_chunk, n_ctx, n_batch, n_seq, mtp_enabled ? " (mtp head active)" : ""); std::vector workers(std::thread::hardware_concurrency() - 1); + if (mtp_enabled) { + // Trunk must expose the pre-norm hidden state so we can feed it into the MTP head. + llama_set_embeddings_pre_norm(ctx, true); + } + for (int i = 0; i < n_chunk; i += n_seq) { const int start = i * n_ctx; const int end = start + n_ctx; @@ -994,6 +1063,9 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c // clear the KV cache llama_memory_clear(llama_get_memory(ctx), true); + if (mtp_enabled) { + llama_memory_clear(llama_get_memory(ctx_mtp), true); + } for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -1027,6 +1099,10 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c if (llama_decode(ctx, batch)) { LOG_ERR("%s : failed to eval\n", __func__); llama_batch_free(batch); + if (mtp_enabled) { + free(mtp_batch.token); + llama_batch_free(mtp_batch); + } return false; } @@ -1034,6 +1110,20 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c const auto * batch_logits = llama_get_logits(ctx); logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); } + + if (mtp_enabled) { + // The sub-batch covers absolute positions [batch_start, batch_start + batch_size). + // tokens.data() + batch_start gives the matching token ids. + const int32_t pos_first = j * n_batch; + if (!compute_imatrix_mtp(ctx, ctx_mtp, + tokens.data() + batch_start, batch_size, + pos_first, n_embd, /*seq_id=*/0, mtp_batch)) { + llama_batch_free(batch); + free(mtp_batch.token); + llama_batch_free(mtp_batch); + return false; + } + } } @@ -1089,6 +1179,10 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c } llama_batch_free(batch); + if (mtp_enabled) { + free(mtp_batch.token); + llama_batch_free(mtp_batch); + } return true; } @@ -1303,7 +1397,28 @@ int main(int argc, char ** argv) { LOG_INF("%s\n", common_params_get_system_info(params).c_str()); } - if (!compute_imatrix(ctx, params, n_ctx)) { + // Optional second context for the MTP/NextN draft head. Shares the same model + // as `ctx`; uses LLAMA_CONTEXT_TYPE_MTP so the MTP graph is built/run instead + // of the trunk graph. The trunk feeds it pre-norm hidden states each batch. + llama_context * ctx_mtp = nullptr; + if (params.imat_mtp) { + if (llama_model_n_nextn(model) == 0) { + LOG_WRN("%s: --mtp requested but model has no MTP/NextN layers; ignoring\n", __func__); + } else { + auto cparams_mtp = common_context_params_to_llama(params); + cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP; + cparams_mtp.n_rs_seq = 0; + ctx_mtp = llama_init_from_model(model, cparams_mtp); + if (ctx_mtp == nullptr) { + LOG_ERR("%s : failed to create MTP context\n", __func__); + return 1; + } + LOG_INF("%s: created MTP draft-head context for imatrix collection\n", __func__); + } + } + + if (!compute_imatrix(ctx, ctx_mtp, params, n_ctx)) { + if (ctx_mtp) llama_free(ctx_mtp); return 1; } @@ -1312,6 +1427,10 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); + if (ctx_mtp) { + llama_free(ctx_mtp); + } + llama_backend_free(); return 0;