@@ -99,10 +99,12 @@ class AttnBlock : public UnaryBlock {
9999 k = ggml_cont (ctx, ggml_permute (ctx, k, 1 , 2 , 0 , 3 )); // [N, h, w, in_channels]
100100 k = ggml_reshape_3d (ctx, k, c, h * w, n); // [N, h * w, in_channels]
101101
102- auto v = v_proj->forward (ctx, h_); // [N, in_channels, h, w]
103- v = ggml_reshape_3d (ctx, v, h * w, c, n); // [N, in_channels, h * w]
102+ auto v = v_proj->forward (ctx, h_); // [N, in_channels, h, w]
103+ v = ggml_cont (ctx, ggml_permute (ctx, v, 1 , 2 , 0 , 3 )); // [N, h, w, in_channels]
104+ v = ggml_reshape_3d (ctx, v, c, h * w, n); // [N, h * w, in_channels]
104105
105- h_ = ggml_nn_attention (ctx, q, k, v, false ); // [N, h * w, in_channels]
106+ // h_ = ggml_nn_attention(ctx, q, k, v, false); // [N, h * w, in_channels]
107+ h_ = ggml_nn_attention_ext (ctx, q, k, v, 1 , nullptr , false , true , true );
106108
107109 h_ = ggml_cont (ctx, ggml_permute (ctx, h_, 1 , 0 , 2 , 3 )); // [N, in_channels, h * w]
108110 h_ = ggml_reshape_4d (ctx, h_, w, h, c, n); // [N, in_channels, h, w]
@@ -612,4 +614,4 @@ struct AutoEncoderKL : public GGMLRunner {
612614 };
613615};
614616
615- #endif
617+ #endif
0 commit comments