@@ -1217,16 +1217,25 @@ int llama_context::decode(const llama_batch & batch_inp) {
12171217
12181218 // extract logits
12191219 if (t_logits && n_outputs > 0 ) {
1220- ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend (sched.get (), t_logits);
1221- GGML_ASSERT (backend_res != nullptr );
1222- GGML_ASSERT (logits != nullptr );
1223-
1224- float * logits_out = logits + n_outputs_prev*n_vocab;
1225-
1226- if (n_outputs) {
1227- GGML_ASSERT ( n_outputs_prev + n_outputs <= n_outputs_all);
1228- GGML_ASSERT ((n_outputs_prev + n_outputs)*n_vocab <= (int64_t ) logits_size);
1229- ggml_backend_tensor_get_async (backend_res, t_logits, logits_out, 0 , n_outputs*n_vocab*sizeof (float ));
1220+ // MTP operations that are purely for updating the KV cache
1221+ // (MTP_OP_WARMUP and MTP_OP_UPDATE_ACCEPTED) also produce a logit tensor
1222+ // as a side effect of running the graph. If these logits are copied
1223+ // back to the main context buffer, they will overwrite the valid logits
1224+ // produced by the main model's pass, leading to incorrect sampling.
1225+ // This condition explicitly prevents that copy for cache-only operations.
1226+ if (batch_inp.mtp_params .op_type != MTP_OP_WARMUP &&
1227+ batch_inp.mtp_params .op_type != MTP_OP_UPDATE_ACCEPTED) {
1228+ ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend (sched.get (), t_logits);
1229+ GGML_ASSERT (backend_res != nullptr );
1230+ GGML_ASSERT (logits != nullptr );
1231+
1232+ float * logits_out = logits + n_outputs_prev*n_vocab;
1233+
1234+ if (n_outputs) {
1235+ GGML_ASSERT ( n_outputs_prev + n_outputs <= n_outputs_all);
1236+ GGML_ASSERT ((n_outputs_prev + n_outputs)*n_vocab <= (int64_t ) logits_size);
1237+ ggml_backend_tensor_get_async (backend_res, t_logits, logits_out, 0 , n_outputs*n_vocab*sizeof (float ));
1238+ }
12301239 }
12311240 }
12321241
0 commit comments