@@ -378,7 +378,7 @@ llama_token mtp_speculative_gen_draft(
378378 const llama_seq_id draft_seq_id = 0 ;
379379 common_batch_add (mtp_batch, id_last, n_past, {0 }, true );
380380
381- mtp_batch.mtp_params .op_type = MTP_OP_DRAFT_GEN ;
381+ mtp_batch.mtp_params .op_type = MTP_OP_DRAFT_ONLY ;
382382
383383 // Perform the MTP draft generation decode. This writes the MTP layer's
384384 // KV state for the draft token into the cache.
@@ -406,58 +406,4 @@ llama_token mtp_speculative_gen_draft(
406406 common_sampler_apply_chain (smpl, cur_p);
407407
408408 return cur_p->data [0 ].id ;
409- }
410-
411-
412- void mtp_update_kv_cache (struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
413- if (batch.n_tokens == 0 ) {
414- return ;
415- }
416-
417- LOG_DBG (" [MTP-UPDATE|%s] Updating %d tokens...\n " , is_prompt_warmup ? " PROMPT_WARMUP" : " GEN_ACCEPTED" , batch.n_tokens );
418-
419- llama_batch mtp_batch = batch;
420- if (is_prompt_warmup) {
421- mtp_batch.mtp_params .op_type = MTP_OP_WARMUP;
422- } else {
423- mtp_batch.mtp_params .op_type = MTP_OP_UPDATE_ACCEPTED;
424- }
425-
426- for (int i = 0 ; i < mtp_batch.n_tokens ; ++i) {
427- mtp_batch.logits [i] = true ;
428- }
429- const int64_t t_start_us = ggml_time_us ();
430- llama_decode (ctx, mtp_batch);
431- const int64_t t_end_us = ggml_time_us ();
432- LOG_INF (" [PERF-MTP] mtp_update_kv_cache internal decode (op=%d): %.2f ms\n " , (int )mtp_batch.mtp_params .op_type , (t_end_us - t_start_us) / 1000.0 );
433- }
434-
435- void mtp_accept_tokens (
436- struct llama_context * ctx,
437- const std::vector<llama_token> & ids,
438- int32_t n_past_base,
439- llama_seq_id seq_id
440- ) {
441- if (ids.empty ()) {
442- return ;
443- }
444-
445- // Prepare a resized copy of the validation sinfo to match the number of accepted tokens.
446- // This sets up the context for a "forced sinfo" decode.
447- if (!llama_mtp_prepare_sinfo_for_update (ctx, ids.size ())) {
448- return ;
449- }
450-
451- // Build a new batch containing only the accepted tokens.
452- llama_batch accepted_batch = llama_batch_init (ids.size (), 0 , 1 );
453- for (size_t i = 0 ; i < ids.size (); ++i) {
454- common_batch_add (accepted_batch, ids[i], n_past_base + i, { seq_id }, true );
455- }
456-
457- mtp_update_kv_cache (ctx, accepted_batch, false );
458-
459- // Clean up the forced state to not affect subsequent, normal decode calls.
460- llama_mtp_cancel_sinfo_update (ctx);
461-
462- llama_batch_free (accepted_batch);
463- }
409+ }
0 commit comments