Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,13 @@ extern "C" {
LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported
};

enum llama_graph_type {
LLAMA_GRAPH_TYPE_DEFAULT,
LLAMA_GRAPH_TYPE_ENCODER,
LLAMA_GRAPH_TYPE_DECODER,
LLAMA_GRAPH_TYPE_DECODER_MTP,
};

// TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979)
typedef struct llama_token_data {
llama_token id; // token id
Expand Down Expand Up @@ -376,6 +383,8 @@ extern "C" {
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
struct llama_sampler_seq_config * samplers;
size_t n_samplers;

llama_graph_type graph_type; // type of the computation graph to be used
};

// model quantization parameters
Expand Down Expand Up @@ -1007,6 +1016,10 @@ extern "C" {
// otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);

// Copy the internal MTP state from ctx_llm to ctx_mtp, ready for MTP decoding.
// This must be done before calling llama_decode() on ctx_mtp
LLAMA_API int32_t llama_mtp_start(struct llama_context * ctx_llm, struct llama_context * ctx_mtp);

//
// backend sampling API [EXPERIMENTAL]
// note: use only if the llama_context was created with at least one llama_sampler_seq_config
Expand Down
72 changes: 71 additions & 1 deletion src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,23 @@ llama_context::llama_context(
}
}

switch (params.graph_type) {
case LLAMA_GRAPH_TYPE_DEFAULT:
gtype = LLM_GRAPH_TYPE_DEFAULT;
break;
case LLAMA_GRAPH_TYPE_ENCODER:
gtype = LLM_GRAPH_TYPE_ENCODER;
break;
case LLAMA_GRAPH_TYPE_DECODER:
gtype = LLM_GRAPH_TYPE_DECODER;
break;
case LLAMA_GRAPH_TYPE_DECODER_MTP:
gtype = LLM_GRAPH_TYPE_DECODER_MTP;
break;
default:
throw std::runtime_error("invalid graph type");
}

LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
Expand Down Expand Up @@ -811,6 +828,23 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
return it->second.data();
}

int32_t llama_context::cpy_mtp_state(llama_context & ctx_mtp) {
if (ctx_mtp.gtype != LLM_GRAPH_TYPE_DECODER_MTP) {
LLAMA_LOG_ERROR("%s: target context is not MTP\n", __func__);
return -1;
}

if (cross.n_token == 0 || cross.n_embd == 0) {
LLAMA_LOG_ERROR("%s: no state to copy\n", __func__);
return -1;
}

// TODO: maybe std::move is better?
ctx_mtp.cross = cross;

return 0;
}

llama_token llama_context::get_sampled_token_ith(int32_t idx) {
output_reorder();

Expand Down Expand Up @@ -1469,6 +1503,18 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s
int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT

if (gtype == LLM_GRAPH_TYPE_DECODER_MTP) {
if (model.hparams.nextn_predict_layers == 0) {
LLAMA_LOG_ERROR("%s: MTP decode called but model does not support MTP\n", __func__);
return -1;
}
if ((uint32_t)batch_inp.n_tokens > n_ubatch()) {
// TODO @ngxson : n_tokens > ubatch will mess up the llama_cross state, may need to fix it later
LLAMA_LOG_ERROR("%s: MTP decode requires n_ubatch >= n_tokens\n", __func__);
return -1;
}
}

if (!memory) {
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
return encode(batch_inp);
Expand Down Expand Up @@ -1597,6 +1643,15 @@ int llama_context::decode(const llama_batch & batch_inp) {
break;
}

const bool update_mtp_state = hparams.nextn_predict_layers > 0 && n_outputs > 0;

// set MTP state if needed
if (update_mtp_state) {
cross.n_embd = hparams.get_n_embd_mtp();
cross.n_token = n_outputs;
cross.mtp_embd.resize(cross.n_embd*cross.n_token);
}

// reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
Expand Down Expand Up @@ -1625,7 +1680,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
}

ggml_status status;
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
const auto * res = process_ubatch(ubatch, gtype, mctx.get(), status);

if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
Expand Down Expand Up @@ -1690,6 +1745,14 @@ int llama_context::decode(const llama_batch & batch_inp) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
GGML_ASSERT(backend_embd != nullptr);

// set MTP state if needed
if (update_mtp_state) {
const int64_t n_embd_mtp = cross.n_embd;
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_mtp <= (int64_t)cross.mtp_embd.size());
ggml_backend_tensor_get_async(backend_embd, t_embd, cross.mtp_embd.data(), 0, n_outputs*n_embd_mtp*sizeof(float));
}

switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
Expand Down Expand Up @@ -2982,6 +3045,7 @@ llama_context_params llama_context_default_params() {
/*.kv_unified =*/ false,
/*.sampler =*/ nullptr,
/*.n_sampler =*/ 0,
/*.graph_type =*/ LLAMA_GRAPH_TYPE_DEFAULT,
};

return result;
Expand Down Expand Up @@ -3170,6 +3234,12 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
return ctx->get_embeddings_seq(seq_id);
}

int32_t llama_mtp_start(llama_context * ctx_llm, llama_context * ctx_mtp) {
ctx_llm->synchronize();

return ctx_llm->cpy_mtp_state(*ctx_mtp);
}

Comment on lines +3237 to +3242
Copy link
Member

@ggerganov ggerganov Jan 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also not sure if this is the best way, but seems OK for now. Would look into generalizing somehow to not be too MTP specific. I.e. a more generic mechanism for sharing data between contexts.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking about another version of llama_cross that will encapsulate multiple llama_context inside it:

struct llama_cross {
    llama_context * ctx_enc; // text encoder models like T5
    llama_context * ctx_llm;
    llama_context * ctx_mtp;
    llama_context * ctx_mtmd;
};

Such that when a llama_process() is called on one context, it will propagate the state to another context.

For now I cannot think of a better way to avoid having purpose-specific naming like mtmd, mtp because the data between 2 contexts can vary depending on the task. But I think we can iterate from this idea.

For the current PR, I think I can proceed with the llama_mtp_start because it will easy to adapt to whatever API we may come up with in the future.

bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
return ctx->set_sampler(seq_id, smpl);
}
Expand Down
4 changes: 4 additions & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ struct llama_context {
float * get_embeddings();
float * get_embeddings_ith(int32_t i);
float * get_embeddings_seq(llama_seq_id seq_id);

int32_t cpy_mtp_state(llama_context & ctx_mtp);

llama_token * get_sampled_tokens() const;
llama_token get_sampled_token_ith(int32_t idx);
Expand Down Expand Up @@ -349,6 +351,8 @@ struct llama_context {
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_ptr buf_output;

llm_graph_type gtype;

bool has_evaluated_once = false;

// env: LLAMA_GRAPH_REUSE_DISABLE
Expand Down
25 changes: 25 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,17 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
}
}

void llm_graph_input_cross_mtp::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);

if (cross_mtp && !cross->mtp_embd.empty()) {
assert(cross_mtp->type == GGML_TYPE_F32);
assert(ggml_nelements(cross_mtp) == (int64_t)cross->mtp_embd.size());

ggml_backend_tensor_set(cross_mtp, cross->mtp_embd.data(), 0, ggml_nbytes(cross_mtp));
}
}

static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
const char * swa_type_str = "unknown";
Expand Down Expand Up @@ -1627,6 +1638,20 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
return cur;
}

ggml_tensor * llm_graph_context::build_inp_cross_mtp() const {
auto inp = std::make_unique<llm_graph_input_cross_mtp>(cross);

auto & cur = inp->cross_mtp;

GGML_ASSERT(cross != nullptr);
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, cross->n_embd, cross->n_token);
ggml_set_input(cur);

res->add_input(std::move(inp));

return cur;
}

ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
auto inp = std::make_unique<llm_graph_input_pos_bucket>(hparams);

Expand Down
24 changes: 22 additions & 2 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ enum llm_graph_type {
LLM_GRAPH_TYPE_DEFAULT,
LLM_GRAPH_TYPE_ENCODER,
LLM_GRAPH_TYPE_DECODER,
LLM_GRAPH_TYPE_DECODER_MTP,
};

enum llm_ffn_op_type {
Expand All @@ -56,18 +57,24 @@ enum llm_norm_type {
};

// TODO: tmp - need something better to pass the data from the encoder to the decoder
// currently also for passing embeddings for from main model to MTP layers
struct llama_cross {
// the output embeddings from the encoder as a ggml tensor
// TODO: this needs more work to be correct, for now copy the embeddings data to host memory
// ref: https://github.com/ggml-org/llama.cpp/pull/11213#discussion_r1969892524
//ggml_tensor * t_embd = nullptr;

int64_t n_embd = 0;
int64_t n_enc = 0;
int64_t n_embd = 0;
int64_t n_enc = 0;
int64_t n_token = 0; // used by mtp

// embeddings data copied to host memory (tmp)
std::vector<float> v_embd;

// embeddings data to be passed to MTP layers
// TODO: optimize by using ggml_tensor here
std::vector<float> mtp_embd;

// needed to construct the cross-attention mask in the decoder
std::vector<std::set<llama_seq_id>> seq_ids_enc;
};
Expand Down Expand Up @@ -258,6 +265,18 @@ class llm_graph_input_cross_embd : public llm_graph_input_i {
const llama_cross * cross;
};

class llm_graph_input_cross_mtp : public llm_graph_input_i {
public:
llm_graph_input_cross_mtp(
const llama_cross * cross) : cross(cross) {}
virtual ~llm_graph_input_cross_mtp() = default;
void set_input(const llama_ubatch * ubatch) override;

ggml_tensor * cross_mtp; // F32 [n_embd, n_token]

const llama_cross * cross;
};

class llm_graph_input_attn_no_cache : public llm_graph_input_i {
public:
llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
Expand Down Expand Up @@ -849,6 +868,7 @@ struct llm_graph_context {
ggml_tensor * build_inp_cls() const;

ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_cross_mtp() const;
ggml_tensor * build_inp_pos_bucket_enc() const;
ggml_tensor * build_inp_pos_bucket_dec() const;
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
Expand Down
4 changes: 4 additions & 0 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ uint32_t llama_hparams::n_embd_out() const {
return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd;
}

uint32_t llama_hparams::get_n_embd_mtp() const {
return n_embd;
}

uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
const uint32_t n_head_kv = this->n_head_kv(il);

Expand Down
3 changes: 3 additions & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ struct llama_hparams {
// dimension of output embeddings
uint32_t n_embd_out() const;

// dimension of cross embeddings between main LLM and MTP
uint32_t get_n_embd_mtp() const;

// dimension of key embeddings across all k-v heads
uint32_t n_embd_k_gqa(uint32_t il = 0) const;

Expand Down
6 changes: 5 additions & 1 deletion src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8162,7 +8162,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
} break;
case LLM_ARCH_GLM4_MOE:
{
llm = std::make_unique<llm_build_glm4_moe>(*this, params);
if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) {
llm = std::make_unique<llm_build_glm4_moe<true>>(*this, params);
} else {
llm = std::make_unique<llm_build_glm4_moe<false>>(*this, params);
}
} break;
case LLM_ARCH_BITNET:
{
Expand Down
Loading
Loading