@@ -131,21 +131,20 @@ using std::function;
131131 if (dbg && t0 == 0 ) {
132132 cb (logits_all, " idxkv_logits_all" , -1 );
133133 }
134- // Reshape and apply ReLU: [N_kv, H, Tc]
135- ggml_tensor * logits_resh = ggml_reshape_3d (ctx, logits_all, N_kv, H, Tc);
136- ggml_tensor * logits_act = ggml_relu (ctx, logits_resh);
137- // Weights slice [H, Tc] and broadcast-mul, then sum over H → [N_kv, Tc]
138- ggml_tensor * w_slice = ggml_view_2d (ctx, weights, H, Tc, weights->nb [1 ], t0*weights->nb [1 ]);
139-
140- // reshape to [1, H, Tc] so it can broadcast across N_kv
141- ggml_tensor * w3 = ggml_reshape_3d (ctx, w_slice, 1 , H, Tc);
142- ggml_tensor * w_bcast = ggml_repeat (ctx, w3, logits_act);
143- ggml_tensor * contrib = ggml_mul (ctx, logits_act, w_bcast); // [N_kv, H, Tc]
144- // Sum over head dimension (ne1): permute to [H, N_kv, Tc] and sum rows
145- ggml_tensor * contrib_perm = ggml_permute (ctx, contrib, 1 , 0 , 2 , 3 );
146- contrib_perm = ggml_cont (ctx, contrib_perm);
147- ggml_tensor * sum_h = ggml_sum_rows (ctx, contrib_perm); // [1, N_kv, Tc]
148- scores_tc = ggml_reshape_2d (ctx, sum_h, N_kv, Tc); // [N_kv, Tc]
134+ // Streaming per-head accumulation to avoid [N_kv, H, Tc] temporaries
135+ for (int64_t h = 0 ; h < H; ++h) {
136+ size_t off_h = (size_t ) h * (size_t ) Tc * logits_all->nb [1 ];
137+ ggml_tensor * logits_h = ggml_view_2d (ctx, logits_all, N_kv, Tc, logits_all->nb [1 ], off_h);
138+ ggml_tensor * logits_h_act = ggml_relu (ctx, logits_h);
139+ size_t w_off = (size_t ) h * weights->nb [0 ];
140+ ggml_tensor * w_row = ggml_view_2d (ctx, weights, 1 , Tc, weights->nb [1 ], t0*weights->nb [1 ] + w_off);
141+ if (w_row->type != logits_h_act->type ) {
142+ w_row = ggml_cast (ctx, w_row, logits_h_act->type );
143+ }
144+ ggml_tensor * w_bcast = ggml_repeat (ctx, w_row, logits_h_act); // [N_kv, Tc]
145+ ggml_tensor * contrib = ggml_mul (ctx, logits_h_act, w_bcast); // [N_kv, Tc]
146+ scores_tc = scores_tc ? ggml_add (ctx, scores_tc, contrib) : contrib;
147+ }
149148
150149
151150 // Safe K-scale proxy application after head reduction (always apply)
0 commit comments