Skip to content

Commit 3bfa5d3

Browse files
mtp-graph (feat): simplify graph logic
1 parent 5859cb9 commit 3bfa5d3

File tree

7 files changed

+195
-442
lines changed

7 files changed

+195
-442
lines changed

common/speculative.cpp

Lines changed: 2 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ llama_token mtp_speculative_gen_draft(
378378
const llama_seq_id draft_seq_id = 0;
379379
common_batch_add(mtp_batch, id_last, n_past, {0}, true);
380380

381-
mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN;
381+
mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_ONLY;
382382

383383
// Perform the MTP draft generation decode. This writes the MTP layer's
384384
// KV state for the draft token into the cache.
@@ -406,58 +406,4 @@ llama_token mtp_speculative_gen_draft(
406406
common_sampler_apply_chain(smpl, cur_p);
407407

408408
return cur_p->data[0].id;
409-
}
410-
411-
412-
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
413-
if (batch.n_tokens == 0) {
414-
return;
415-
}
416-
417-
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
418-
419-
llama_batch mtp_batch = batch;
420-
if (is_prompt_warmup) {
421-
mtp_batch.mtp_params.op_type = MTP_OP_WARMUP;
422-
} else {
423-
mtp_batch.mtp_params.op_type = MTP_OP_UPDATE_ACCEPTED;
424-
}
425-
426-
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
427-
mtp_batch.logits[i] = true;
428-
}
429-
const int64_t t_start_us = ggml_time_us();
430-
llama_decode(ctx, mtp_batch);
431-
const int64_t t_end_us = ggml_time_us();
432-
LOG_INF("[PERF-MTP] mtp_update_kv_cache internal decode (op=%d): %.2f ms\n", (int)mtp_batch.mtp_params.op_type, (t_end_us - t_start_us) / 1000.0);
433-
}
434-
435-
void mtp_accept_tokens(
436-
struct llama_context * ctx,
437-
const std::vector<llama_token> & ids,
438-
int32_t n_past_base,
439-
llama_seq_id seq_id
440-
) {
441-
if (ids.empty()) {
442-
return;
443-
}
444-
445-
// Prepare a resized copy of the validation sinfo to match the number of accepted tokens.
446-
// This sets up the context for a "forced sinfo" decode.
447-
if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) {
448-
return;
449-
}
450-
451-
// Build a new batch containing only the accepted tokens.
452-
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
453-
for (size_t i = 0; i < ids.size(); ++i) {
454-
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
455-
}
456-
457-
mtp_update_kv_cache(ctx, accepted_batch, false);
458-
459-
// Clean up the forced state to not affect subsequent, normal decode calls.
460-
llama_mtp_cancel_sinfo_update(ctx);
461-
462-
llama_batch_free(accepted_batch);
463-
}
409+
}

common/speculative.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,4 @@ llama_tokens common_speculative_gen_draft(
4747
struct common_speculative * spec,
4848
struct common_speculative_params params,
4949
const llama_tokens & prompt,
50-
llama_token id_last);
51-
52-
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);
53-
54-
void mtp_accept_tokens(
55-
struct llama_context * ctx,
56-
const std::vector<llama_token> & ids,
57-
int32_t n_past_base,
58-
llama_seq_id seq_id
59-
);
50+
llama_token id_last);

include/llama.h

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,8 @@ extern "C" {
223223
//
224224
typedef enum {
225225
MTP_OP_NONE,
226-
MTP_OP_WARMUP,
227-
MTP_OP_UPDATE_ACCEPTED,
228-
MTP_OP_DRAFT_GEN,
229-
MTP_OP_MAIN_VALIDATION,
226+
MTP_OP_DRAFT_ONLY,
227+
MTP_OP_UNIFIED,
230228
} llama_mtp_op_type;
231229

232230
typedef struct llama_mtp_params {
@@ -1473,26 +1471,6 @@ extern "C" {
14731471

14741472
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
14751473

1476-
/**
1477-
* @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo.
1478-
* This is used after speculative validation when only a subset of draft tokens are accepted.
1479-
* @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized.
1480-
* @return true on success.
1481-
*/
1482-
LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted);
1483-
1484-
/**
1485-
* @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode.
1486-
* This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned.
1487-
* @return true on success.
1488-
*/
1489-
LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx);
1490-
1491-
/**
1492-
* @brief Clears the forced sinfo state from the context. Must be called after a decode that used a prepared sinfo.
1493-
*/
1494-
LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx);
1495-
14961474
/**
14971475
* @brief Removes KV cache metadata for a specified sequence and token range.
14981476
* This makes the physical cells logically available again without deleting the tensor data.

0 commit comments

Comments
 (0)