@@ -131,20 +131,21 @@ using std::function;
131131 if (dbg && t0 == 0 ) {
132132 cb (logits_all, " idxkv_logits_all" , -1 );
133133 }
134- // Streaming per-head accumulation to avoid [N_kv, H, Tc] temporaries
135- for (int64_t h = 0 ; h < H; ++h) {
136- size_t off_h = (size_t ) h * (size_t ) Tc * logits_all->nb [1 ];
137- ggml_tensor * logits_h = ggml_view_2d (ctx, logits_all, N_kv, Tc, logits_all->nb [1 ], off_h);
138- ggml_tensor * logits_h_act = ggml_relu (ctx, logits_h);
139- size_t w_off = (size_t ) h * weights->nb [0 ];
140- ggml_tensor * w_row = ggml_view_2d (ctx, weights, 1 , Tc, weights->nb [1 ], t0*weights->nb [1 ] + w_off);
141- if (w_row->type != logits_h_act->type ) {
142- w_row = ggml_cast (ctx, w_row, logits_h_act->type );
143- }
144- ggml_tensor * w_bcast = ggml_repeat (ctx, w_row, logits_h_act); // [N_kv, Tc]
145- ggml_tensor * contrib = ggml_mul (ctx, logits_h_act, w_bcast); // [N_kv, Tc]
146- scores_tc = scores_tc ? ggml_add (ctx, scores_tc, contrib) : contrib;
147- }
134+ // Reshape and apply ReLU: [N_kv, H, Tc]
135+ ggml_tensor * logits_resh = ggml_reshape_3d (ctx, logits_all, N_kv, H, Tc);
136+ ggml_tensor * logits_act = ggml_relu (ctx, logits_resh);
137+ // Weights slice [H, Tc] and broadcast-mul, then sum over H → [N_kv, Tc]
138+ ggml_tensor * w_slice = ggml_view_2d (ctx, weights, H, Tc, weights->nb [1 ], t0*weights->nb [1 ]);
139+
140+ // reshape to [1, H, Tc] so it can broadcast across N_kv
141+ ggml_tensor * w3 = ggml_reshape_3d (ctx, w_slice, 1 , H, Tc);
142+ ggml_tensor * w_bcast = ggml_repeat (ctx, w3, logits_act);
143+ ggml_tensor * contrib = ggml_mul (ctx, logits_act, w_bcast); // [N_kv, H, Tc]
144+ // Sum over head dimension (ne1): permute to [H, N_kv, Tc] and sum rows
145+ ggml_tensor * contrib_perm = ggml_permute (ctx, contrib, 1 , 0 , 2 , 3 );
146+ contrib_perm = ggml_cont (ctx, contrib_perm);
147+ ggml_tensor * sum_h = ggml_sum_rows (ctx, contrib_perm); // [1, N_kv, Tc]
148+ scores_tc = ggml_reshape_2d (ctx, sum_h, N_kv, Tc); // [N_kv, Tc]
148149
149150
150151 // Safe K-scale proxy application after head reduction (always apply)
0 commit comments