Skip to content

Commit edc23f9

Browse files
committed
Streaming per-head accumulation to avoid [N_kv, H, Tc] temporaries
1 parent 6fb54c1 commit edc23f9

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

src/llama-sparse-topk.cpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -131,21 +131,20 @@ using std::function;
131131
if (dbg && t0 == 0) {
132132
cb(logits_all, "idxkv_logits_all", -1);
133133
}
134-
// Reshape and apply ReLU: [N_kv, H, Tc]
135-
ggml_tensor * logits_resh = ggml_reshape_3d(ctx, logits_all, N_kv, H, Tc);
136-
ggml_tensor * logits_act = ggml_relu(ctx, logits_resh);
137-
// Weights slice [H, Tc] and broadcast-mul, then sum over H → [N_kv, Tc]
138-
ggml_tensor * w_slice = ggml_view_2d(ctx, weights, H, Tc, weights->nb[1], t0*weights->nb[1]);
139-
140-
// reshape to [1, H, Tc] so it can broadcast across N_kv
141-
ggml_tensor * w3 = ggml_reshape_3d(ctx, w_slice, 1, H, Tc);
142-
ggml_tensor * w_bcast = ggml_repeat(ctx, w3, logits_act);
143-
ggml_tensor * contrib = ggml_mul(ctx, logits_act, w_bcast); // [N_kv, H, Tc]
144-
// Sum over head dimension (ne1): permute to [H, N_kv, Tc] and sum rows
145-
ggml_tensor * contrib_perm = ggml_permute(ctx, contrib, 1, 0, 2, 3);
146-
contrib_perm = ggml_cont(ctx, contrib_perm);
147-
ggml_tensor * sum_h = ggml_sum_rows(ctx, contrib_perm); // [1, N_kv, Tc]
148-
scores_tc = ggml_reshape_2d(ctx, sum_h, N_kv, Tc); // [N_kv, Tc]
134+
// Streaming per-head accumulation to avoid [N_kv, H, Tc] temporaries
135+
for (int64_t h = 0; h < H; ++h) {
136+
size_t off_h = (size_t) h * (size_t) Tc * logits_all->nb[1];
137+
ggml_tensor * logits_h = ggml_view_2d(ctx, logits_all, N_kv, Tc, logits_all->nb[1], off_h);
138+
ggml_tensor * logits_h_act = ggml_relu(ctx, logits_h);
139+
size_t w_off = (size_t) h * weights->nb[0];
140+
ggml_tensor * w_row = ggml_view_2d(ctx, weights, 1, Tc, weights->nb[1], t0*weights->nb[1] + w_off);
141+
if (w_row->type != logits_h_act->type) {
142+
w_row = ggml_cast(ctx, w_row, logits_h_act->type);
143+
}
144+
ggml_tensor * w_bcast = ggml_repeat(ctx, w_row, logits_h_act); // [N_kv, Tc]
145+
ggml_tensor * contrib = ggml_mul(ctx, logits_h_act, w_bcast); // [N_kv, Tc]
146+
scores_tc = scores_tc ? ggml_add(ctx, scores_tc, contrib) : contrib;
147+
}
149148

150149

151150
// Safe K-scale proxy application after head reduction (always apply)

0 commit comments

Comments
 (0)