Skip to content

Commit 9e9a84a

Browse files
committed
Revert last change as it was objectively worse.
1 parent edc23f9 commit 9e9a84a

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

src/llama-sparse-topk.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -131,20 +131,21 @@ using std::function;
131131
if (dbg && t0 == 0) {
132132
cb(logits_all, "idxkv_logits_all", -1);
133133
}
134-
// Streaming per-head accumulation to avoid [N_kv, H, Tc] temporaries
135-
for (int64_t h = 0; h < H; ++h) {
136-
size_t off_h = (size_t) h * (size_t) Tc * logits_all->nb[1];
137-
ggml_tensor * logits_h = ggml_view_2d(ctx, logits_all, N_kv, Tc, logits_all->nb[1], off_h);
138-
ggml_tensor * logits_h_act = ggml_relu(ctx, logits_h);
139-
size_t w_off = (size_t) h * weights->nb[0];
140-
ggml_tensor * w_row = ggml_view_2d(ctx, weights, 1, Tc, weights->nb[1], t0*weights->nb[1] + w_off);
141-
if (w_row->type != logits_h_act->type) {
142-
w_row = ggml_cast(ctx, w_row, logits_h_act->type);
143-
}
144-
ggml_tensor * w_bcast = ggml_repeat(ctx, w_row, logits_h_act); // [N_kv, Tc]
145-
ggml_tensor * contrib = ggml_mul(ctx, logits_h_act, w_bcast); // [N_kv, Tc]
146-
scores_tc = scores_tc ? ggml_add(ctx, scores_tc, contrib) : contrib;
147-
}
134+
// Reshape and apply ReLU: [N_kv, H, Tc]
135+
ggml_tensor * logits_resh = ggml_reshape_3d(ctx, logits_all, N_kv, H, Tc);
136+
ggml_tensor * logits_act = ggml_relu(ctx, logits_resh);
137+
// Weights slice [H, Tc] and broadcast-mul, then sum over H → [N_kv, Tc]
138+
ggml_tensor * w_slice = ggml_view_2d(ctx, weights, H, Tc, weights->nb[1], t0*weights->nb[1]);
139+
140+
// reshape to [1, H, Tc] so it can broadcast across N_kv
141+
ggml_tensor * w3 = ggml_reshape_3d(ctx, w_slice, 1, H, Tc);
142+
ggml_tensor * w_bcast = ggml_repeat(ctx, w3, logits_act);
143+
ggml_tensor * contrib = ggml_mul(ctx, logits_act, w_bcast); // [N_kv, H, Tc]
144+
// Sum over head dimension (ne1): permute to [H, N_kv, Tc] and sum rows
145+
ggml_tensor * contrib_perm = ggml_permute(ctx, contrib, 1, 0, 2, 3);
146+
contrib_perm = ggml_cont(ctx, contrib_perm);
147+
ggml_tensor * sum_h = ggml_sum_rows(ctx, contrib_perm); // [1, N_kv, Tc]
148+
scores_tc = ggml_reshape_2d(ctx, sum_h, N_kv, Tc); // [N_kv, Tc]
148149

149150

150151
// Safe K-scale proxy application after head reduction (always apply)

0 commit comments

Comments
 (0)