Skip to content

Commit 15dff20

Browse files
Merge branch 'glm4-mtp-batch' of https://github.com/SamuelOliveirads/llama.cpp into glm4-mtp-graph-cache
2 parents 171346c + cae85fe commit 15dff20

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

src/llama-context.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,16 +1217,25 @@ int llama_context::decode(const llama_batch & batch_inp) {
12171217

12181218
// extract logits
12191219
if (t_logits && n_outputs > 0) {
1220-
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
1221-
GGML_ASSERT(backend_res != nullptr);
1222-
GGML_ASSERT(logits != nullptr);
1223-
1224-
float * logits_out = logits + n_outputs_prev*n_vocab;
1225-
1226-
if (n_outputs) {
1227-
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1228-
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
1229-
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
1220+
// MTP operations that are purely for updating the KV cache
1221+
// (MTP_OP_WARMUP and MTP_OP_UPDATE_ACCEPTED) also produce a logit tensor
1222+
// as a side effect of running the graph. If these logits are copied
1223+
// back to the main context buffer, they will overwrite the valid logits
1224+
// produced by the main model's pass, leading to incorrect sampling.
1225+
// This condition explicitly prevents that copy for cache-only operations.
1226+
if (batch_inp.mtp_params.op_type != MTP_OP_WARMUP &&
1227+
batch_inp.mtp_params.op_type != MTP_OP_UPDATE_ACCEPTED) {
1228+
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
1229+
GGML_ASSERT(backend_res != nullptr);
1230+
GGML_ASSERT(logits != nullptr);
1231+
1232+
float * logits_out = logits + n_outputs_prev*n_vocab;
1233+
1234+
if (n_outputs) {
1235+
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
1236+
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
1237+
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
1238+
}
12301239
}
12311240
}
12321241

0 commit comments

Comments
 (0)