Skip to content

Commit 5859cb9

Browse files
mtp-graph (wip): testing different ways to allow graph reuse
1 parent 15dff20 commit 5859cb9

File tree

6 files changed

+117
-41
lines changed

6 files changed

+117
-41
lines changed

common/speculative.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,10 @@ llama_token mtp_speculative_gen_draft(
382382

383383
// Perform the MTP draft generation decode. This writes the MTP layer's
384384
// KV state for the draft token into the cache.
385+
const int64_t t_start_us = ggml_time_us();
385386
llama_decode(ctx, mtp_batch);
387+
const int64_t t_end_us = ggml_time_us();
388+
LOG_INF("[PERF-MTP] mtp_speculative_gen_draft internal decode: %.2f ms\n", (t_end_us - t_start_us) / 1000.0);
386389
llama_batch_free(mtp_batch);
387390

388391
// CRITICAL: Purge the metadata for the draft token we just wrote.
@@ -423,7 +426,10 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b
423426
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
424427
mtp_batch.logits[i] = true;
425428
}
429+
const int64_t t_start_us = ggml_time_us();
426430
llama_decode(ctx, mtp_batch);
431+
const int64_t t_end_us = ggml_time_us();
432+
LOG_INF("[PERF-MTP] mtp_update_kv_cache internal decode (op=%d): %.2f ms\n", (int)mtp_batch.mtp_params.op_type, (t_end_us - t_start_us) / 1000.0);
427433
}
428434

429435
void mtp_accept_tokens(

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ extern "C" {
226226
MTP_OP_WARMUP,
227227
MTP_OP_UPDATE_ACCEPTED,
228228
MTP_OP_DRAFT_GEN,
229+
MTP_OP_MAIN_VALIDATION,
229230
} llama_mtp_op_type;
230231

231232
typedef struct llama_mtp_params {

src/llama-context.cpp

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct llama_context_kv_cache_data {
3535
llama_kv_cache_unified::slot_info_vec_t resized_sinfo_for_force;
3636
const llama_kv_cache_unified::slot_info_vec_t * forced_sinfos = nullptr;
3737
std::map<llama_graph_cache_key, llm_graph_result_ptr> graph_cache;
38+
llm_graph_result_ptr gf_res_prev_validation;
3839
};
3940

4041
llama_context::llama_context(
@@ -788,24 +789,35 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
788789
LLAMA_LOG_INFO("[GRAPH-CACHE] MISS, RECONSTRUCTING THE STRUCTURE of the graph for key (op=%d, tok=%d, out=%d)\n",
789790
(int)key.op_type, key.n_tokens, key.n_outputs);
790791

792+
const int64_t t_reset_start_us = ggml_time_us();
791793
ggml_backend_sched_reset(sched.get());
794+
const int64_t t_reset_end_us = ggml_time_us();
792795
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
793796

794797
res->reset();
795798
res->set_params(gparams);
799+
const int64_t t_build_start_us = ggml_time_us();
796800
res->gf = model.build_graph(gparams);
801+
const int64_t t_build_end_us = ggml_time_us();
802+
LLAMA_LOG_INFO("[PERF-GRAPH] Graph build (op=%d): %.2f ms\n", (int)mtp_params.op_type, (t_build_end_us - t_build_start_us) / 1000.0);
797803

798804
if (!res->gf) {
799805
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
800806
ret = GGML_STATUS_FAILED;
801807
return nullptr;
802808
}
803809

810+
const int64_t t_alloc_start_us = ggml_time_us();
804811
if (!ggml_backend_sched_alloc_graph(sched.get(), res->gf)) {
805812
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
806813
ret = GGML_STATUS_ALLOC_FAILED;
807814
return nullptr;
808815
}
816+
const int64_t t_alloc_end_us = ggml_time_us();
817+
LLAMA_LOG_INFO("[PERF-GRAPH] sched_reset: %.2f ms | sched_alloc: %.2f ms (op=%d)\n",
818+
(t_reset_end_us - t_reset_start_us) / 1000.0,
819+
(t_alloc_end_us - t_alloc_start_us) / 1000.0,
820+
(int)mtp_params.op_type);
809821
// }
810822

811823
} else {
@@ -818,14 +830,19 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
818830
} else {
819831
LLAMA_LOG_INFO("%s: RECONSTRUCTED graph...\n", __func__);
820832

833+
const int64_t t_reset_start_us = ggml_time_us();
821834
ggml_backend_sched_reset(sched.get());
835+
const int64_t t_reset_end_us = ggml_time_us();
822836
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
823837

824838
res->reset();
825839
res->set_params(gparams);
826840
//const auto t_start_us = ggml_time_us();
827841

842+
const int64_t t_build_start_us = ggml_time_us();
828843
res->gf = model.build_graph(gparams);
844+
const int64_t t_build_end_us = ggml_time_us();
845+
LLAMA_LOG_INFO("[PERF-GRAPH] Graph build (op=%d): %.2f ms\n", (int)mtp_params.op_type, (t_build_end_us - t_build_start_us) / 1000.0);
829846

830847
//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
831848

@@ -835,15 +852,21 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
835852
return nullptr;
836853
}
837854

855+
const int64_t t_alloc_start_us = ggml_time_us();
838856
if (!ggml_backend_sched_alloc_graph(sched.get(), res->gf)) {
839857
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
840858
ret = GGML_STATUS_ALLOC_FAILED;
841859
return nullptr;
842860
}
861+
const int64_t t_alloc_end_us = ggml_time_us();
862+
LLAMA_LOG_INFO("[PERF-GRAPH] sched_reset: %.2f ms | sched_alloc: %.2f ms (op=%d)\n",
863+
(t_reset_end_us - t_reset_start_us) / 1000.0,
864+
(t_alloc_end_us - t_alloc_start_us) / 1000.0,
865+
(int)mtp_params.op_type);
843866
}
844867
}
845868

846-
if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation
869+
if (mtp_params.op_type != MTP_OP_NONE && mtp_params.op_type != MTP_OP_MAIN_VALIDATION) {
847870
if (!prepare_mtp_graph_inputs(res, ubatch, mtp_params)) {
848871
ret = GGML_STATUS_FAILED;
849872
return nullptr;
@@ -1241,7 +1264,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
12411264

12421265
// extract embeddings
12431266
if (t_embd && n_outputs > 0) {
1244-
if (batch_inp.mtp_params.op_type == MTP_OP_NONE) {
1267+
if (batch_inp.mtp_params.op_type == MTP_OP_NONE || batch_inp.mtp_params.op_type == MTP_OP_MAIN_VALIDATION) {
12451268
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
12461269
GGML_ASSERT(backend_embd != nullptr);
12471270

@@ -3133,7 +3156,7 @@ std::unique_ptr<llama_memory_context_i> llama_context::initialize_decode_context
31333156
} else {
31343157
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
31353158

3136-
if (batch_inp.mtp_params.op_type == MTP_OP_NONE) {
3159+
if (batch_inp.mtp_params.op_type == MTP_OP_NONE || batch_inp.mtp_params.op_type == MTP_OP_MAIN_VALIDATION) {
31373160
if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) {
31383161
kvd->last_main_model_sinfos = static_cast<llama_kv_cache_unified_context *>(mctx.get())->get_sinfos();
31393162
} else {

src/llama-graph.cpp

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -442,34 +442,22 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
442442

443443
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
444444
if (!this->params.allow_reuse(params)) {
445-
if (debug > 1) {
446-
LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
447-
}
448-
445+
LLAMA_LOG_WARN("[GRAPH-REUSE-FAIL] Failure in 'allow_reuse'. Incompatible parameters.");
446+
LLAMA_LOG_WARN(" n_tokens: %d vs %d, op_type: %d vs %d",
447+
this->params.ubatch.n_tokens, params.ubatch.n_tokens,
448+
(int)this->params.mtp_params.op_type, (int)params.mtp_params.op_type);
449449
return false;
450450
}
451451

452-
if (debug > 1) {
453-
LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
454-
}
455-
456-
bool res = true;
457-
458-
for (auto & input : inputs) {
459-
const bool cur = input->can_reuse(params);
460-
461-
if (debug > 1) {
462-
LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur);
452+
for (size_t i = 0; i < inputs.size(); ++i) {
453+
if (!inputs[i]->can_reuse(params)) {
454+
LLAMA_LOG_WARN("[GRAPH-REUSE-FAIL] Failure in 'can_reuse' of the input node #%zu.", i);
455+
return false;
463456
}
464-
465-
res = res && cur;
466457
}
467458

468-
if (debug > 0) {
469-
LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
470-
}
471-
472-
return res;
459+
LLAMA_LOG_DEBUG("%s: can reuse graph = true\n", __func__);
460+
return true;
473461
}
474462

475463
llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {

src/llama-model.cpp

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13793,7 +13793,21 @@ struct llm_build_glm4_moe : public llm_graph_context {
1379313793

1379413794
ggml_tensor * cur;
1379513795

13796-
if (params.mtp_params.op_type != MTP_OP_NONE) {
13796+
// if (params.mtp_params.op_type != MTP_OP_NONE && params.mtp_params.op_type != MTP_OP_MAIN_VALIDATION) {
13797+
// ggml_tensor* hidden_states_from_main_model;
13798+
13799+
// if (params.mtp_params.op_type == MTP_OP_WARMUP || params.mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
13800+
// hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
13801+
// } else {
13802+
// hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd);
13803+
// }
13804+
// ggml_set_name(hidden_states_from_main_model, "result_embd_pooled");
13805+
// ggml_set_input(hidden_states_from_main_model);
13806+
13807+
// auto inp_mtp = std::make_unique<llm_graph_input_mtp_states>();
13808+
// inp_mtp->states = hidden_states_from_main_model;
13809+
// res->add_input(std::move(inp_mtp));
13810+
if (params.mtp_params.op_type != MTP_OP_NONE && params.mtp_params.op_type != MTP_OP_MAIN_VALIDATION) {
1379713811
ggml_tensor* hidden_states_from_main_model;
1379813812

1379913813
if (params.mtp_params.op_type == MTP_OP_WARMUP || params.mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
@@ -13971,8 +13985,9 @@ struct llm_build_glm4_moe : public llm_graph_context {
1397113985
ggml_tensor * embd_copy = ggml_dup(ctx0, prev_embeddings);
1397213986

1397313987
const int il = hparams.n_layer - 1;
13988+
// cb(embd_copy, "mtp_embd_copy", il);
1397413989
ggml_tensor * sum_node = ggml_sum(ctx0, embd_copy);
13975-
13990+
// cb(sum_node, "mtp_sum_node", il);
1397613991
ggml_set_name(sum_node, "mtp_input_sum");
1397713992

1397813993
ggml_tensor * inp_pos = build_inp_pos();
@@ -13983,30 +13998,48 @@ struct llm_build_glm4_moe : public llm_graph_context {
1398313998
ggml_tensor * hidden_state_norm = build_norm(embd_copy, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
1398413999

1398514000
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0);
14001+
// cb(combined, "mtp_combined", il);
14002+
1398614003
ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined);
1398714004

1398814005
// now proceed through last layer (skipped in main model)
1398914006
ggml_tensor * inpSA = cur;
1399014007
// Pre-attention norm for the MTP block
1399114008
cur = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il);
14009+
// cb(cur, "mtp_attn_norm", il);
1399214010

1399314011
// self-attention
1399414012
{
1399514013
ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur);
14014+
// if (mtp_layer.bq) {
14015+
// Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq);
14016+
// cb(Qcur, "mtp_q_bias", il); // ADICIONADO
14017+
// }
1399614018
if (mtp_layer.bq) Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq);
1399714019
cb(Qcur, "Qcur", il);
1399814020

1399914021
ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur);
14022+
// if (mtp_layer.bk) {
14023+
// Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk);
14024+
// cb(Kcur, "mtp_k_bias", il); // ADICIONADO
14025+
// }
1400014026
if (mtp_layer.bk) Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk);
1400114027
cb(Kcur, "Kcur", il);
1400214028

1400314029
ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur);
14030+
// if (mtp_layer.bv) {
14031+
// Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv);
14032+
// cb(Vcur, "mtp_v_bias", il); // ADICIONADO
14033+
// }
1400414034
if (mtp_layer.bv) Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv);
1400514035
cb(Vcur, "Vcur", il);
1400614036

1400714037
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
14038+
// cb(Qcur, "mtp_q_reshaped", il);
1400814039
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
14040+
// cb(Kcur, "mtp_k_reshaped", il);
1400914041
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
14042+
// cb(Vcur, "mtp_v_reshaped", il);
1401014043

1401114044
// Apply Q/K norm if available (GLM-4.5 355B variant)
1401214045
if (mtp_layer.attn_q_norm) {
@@ -14023,12 +14056,14 @@ struct llm_build_glm4_moe : public llm_graph_context {
1402314056
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1402414057
ext_factor, attn_factor, beta_fast, beta_slow
1402514058
);
14059+
// cb(Qcur, "mtp_q_rope", il);
1402614060

1402714061
Kcur = ggml_rope_ext(
1402814062
ctx0, Kcur, inp_pos, nullptr,
1402914063
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
1403014064
ext_factor, attn_factor, beta_fast, beta_slow
1403114065
);
14066+
// cb(Kcur, "mtp_k_rope", il);
1403214067

1403314068
cb(Qcur, "Qcur", il);
1403414069
cb(Kcur, "Kcur", il);
@@ -14040,8 +14075,10 @@ struct llm_build_glm4_moe : public llm_graph_context {
1404014075
}
1404114076

1404214077
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14078+
// cb(ffn_inp, "mtp_ffn_inp", il);
1404314079

1404414080
cur = build_norm(ffn_inp, mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, il);
14081+
// cb(cur, "post_attn_norm", il);
1404514082

1404614083
// moe ffn for nextn block
1404714084
{
@@ -14073,7 +14110,10 @@ struct llm_build_glm4_moe : public llm_graph_context {
1407314110
cb(cur, "ffn_out", il);
1407414111
}
1407514112
cur = ggml_add(ctx0, cur, ffn_inp);
14113+
// cb(cur, "mtp_ffn_residual", il);
14114+
1407614115
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il);
14116+
// cb(cur, "mtp_final_norm", il);
1407714117
cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur);
1407814118

1407914119
return cur;
@@ -18305,7 +18345,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1830518345
}
1830618346

1830718347
ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
18308-
18348+
const int64_t t_start_us = ggml_time_us();
1830918349
std::unique_ptr<llm_graph_context> llm;
1831018350
switch (arch) {
1831118351
case LLM_ARCH_LLAMA:
@@ -18664,10 +18704,16 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1866418704
GGML_ABORT("fatal error");
1866518705
}
1866618706

18667-
if (params.mtp_params.op_type == MTP_OP_NONE) {
18707+
if (params.mtp_params.op_type == MTP_OP_NONE || params.mtp_params.op_type == MTP_OP_MAIN_VALIDATION) {
1866818708
// add on pooling layer
1866918709
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
1867018710
}
18711+
const int64_t t_end_us = ggml_time_us();
18712+
LLAMA_LOG_INFO(
18713+
"[PERF] Graph build time: %.2f ms (MTP path: %s)\n",
18714+
(t_end_us - t_start_us) / 1000.0,
18715+
params.mtp_params.op_type != MTP_OP_NONE || params.mtp_params.op_type != MTP_OP_MAIN_VALIDATION ? "yes" : "no"
18716+
);
1867118717
return llm->res->get_gf();
1867218718
}
1867318719

0 commit comments

Comments
 (0)