@@ -20,23 +20,23 @@ class ResnetBlock : public UnaryBlock {
2020 out_channels (out_channels) {
2121 // temb_channels is always 0
2222 blocks[" norm1" ] = std::shared_ptr<GGMLBlock>(new GroupNorm32 (in_channels));
23- blocks[" conv1" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, out_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
23+ blocks[" conv1" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, out_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , true ));
2424
2525 blocks[" norm2" ] = std::shared_ptr<GGMLBlock>(new GroupNorm32 (out_channels));
26- blocks[" conv2" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (out_channels, out_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
26+ blocks[" conv2" ] = std::shared_ptr<GGMLBlock>(new Conv2d (out_channels, out_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , true ));
2727
2828 if (out_channels != in_channels) {
29- blocks[" nin_shortcut" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, out_channels, {1 , 1 }));
29+ blocks[" nin_shortcut" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, out_channels, {1 , 1 }, { 1 , 1 }, { 0 , 0 }, { 1 , 1 }, true , true ));
3030 }
3131 }
3232
3333 struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * x) {
3434 // x: [N, in_channels, h, w]
3535 // t_emb is always None
3636 auto norm1 = std::dynamic_pointer_cast<GroupNorm32>(blocks[" norm1" ]);
37- auto conv1 = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" conv1" ]);
37+ auto conv1 = std::dynamic_pointer_cast<Conv2d >(blocks[" conv1" ]);
3838 auto norm2 = std::dynamic_pointer_cast<GroupNorm32>(blocks[" norm2" ]);
39- auto conv2 = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" conv2" ]);
39+ auto conv2 = std::dynamic_pointer_cast<Conv2d >(blocks[" conv2" ]);
4040
4141 auto h = x;
4242 h = norm1->forward (ctx, h);
@@ -51,7 +51,7 @@ class ResnetBlock : public UnaryBlock {
5151
5252 // skip connection
5353 if (out_channels != in_channels) {
54- auto nin_shortcut = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" nin_shortcut" ]);
54+ auto nin_shortcut = std::dynamic_pointer_cast<Conv2d >(blocks[" nin_shortcut" ]);
5555
5656 x = nin_shortcut->forward (ctx, x); // [N, out_channels, h, w]
5757 }
@@ -69,20 +69,20 @@ class AttnBlock : public UnaryBlock {
6969 AttnBlock (int64_t in_channels)
7070 : in_channels(in_channels) {
7171 blocks[" norm" ] = std::shared_ptr<GGMLBlock>(new GroupNorm32 (in_channels));
72- blocks[" q" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, in_channels, {1 , 1 }));
73- blocks[" k" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, in_channels, {1 , 1 }));
74- blocks[" v" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, in_channels, {1 , 1 }));
72+ blocks[" q" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, in_channels, {1 , 1 }, { 1 , 1 }, { 0 , 0 }, { 1 , 1 }, true , true ));
73+ blocks[" k" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, in_channels, {1 , 1 }, { 1 , 1 }, { 0 , 0 }, { 1 , 1 }, true , true ));
74+ blocks[" v" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, in_channels, {1 , 1 }, { 1 , 1 }, { 0 , 0 }, { 1 , 1 }, true , true ));
7575
76- blocks[" proj_out" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, in_channels, {1 , 1 }));
76+ blocks[" proj_out" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, in_channels, {1 , 1 }, { 1 , 1 }, { 0 , 0 }, { 1 , 1 }, true , true ));
7777 }
7878
7979 struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * x) {
8080 // x: [N, in_channels, h, w]
8181 auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks[" norm" ]);
82- auto q_proj = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" q" ]);
83- auto k_proj = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" k" ]);
84- auto v_proj = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" v" ]);
85- auto proj_out = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" proj_out" ]);
82+ auto q_proj = std::dynamic_pointer_cast<Conv2d >(blocks[" q" ]);
83+ auto k_proj = std::dynamic_pointer_cast<Conv2d >(blocks[" k" ]);
84+ auto v_proj = std::dynamic_pointer_cast<Conv2d >(blocks[" v" ]);
85+ auto proj_out = std::dynamic_pointer_cast<Conv2d >(blocks[" proj_out" ]);
8686
8787 auto h_ = norm->forward (ctx, x);
8888
@@ -114,7 +114,7 @@ class AttnBlock : public UnaryBlock {
114114 }
115115};
116116
117- class AE3DConv : public Conv2dDirect {
117+ class AE3DConv : public Conv2d {
118118public:
119119 AE3DConv (int64_t in_channels,
120120 int64_t out_channels,
@@ -123,8 +123,9 @@ class AE3DConv : public Conv2dDirect {
123123 std::pair<int , int > stride = {1 , 1 },
124124 std::pair<int , int > padding = {0 , 0 },
125125 std::pair<int , int > dilation = {1 , 1 },
126- bool bias = true )
127- : Conv2dDirect(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) {
126+ bool bias = true ,
127+ bool direct = false )
128+ : Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias, direct) {
128129 int64_t kernel_padding = video_kernel_size / 2 ;
129130 blocks[" time_mix_conv" ] = std::shared_ptr<GGMLBlock>(new Conv3dnx1x1 (out_channels,
130131 out_channels,
@@ -141,7 +142,7 @@ class AE3DConv : public Conv2dDirect {
141142 // result: [N, OC, OH, OW]
142143 auto time_mix_conv = std::dynamic_pointer_cast<Conv3dnx1x1>(blocks[" time_mix_conv" ]);
143144
144- x = Conv2dDirect ::forward (ctx, x);
145+ x = Conv2d ::forward (ctx, x);
145146 // timesteps = x.shape[0]
146147 // x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
147148 // x = conv3d(x)
@@ -240,7 +241,7 @@ class Encoder : public GGMLBlock {
240241 in_channels(in_channels),
241242 z_channels(z_channels),
242243 double_z(double_z) {
243- blocks[" conv_in" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, ch, {3 , 3 }, {1 , 1 }, {1 , 1 }));
244+ blocks[" conv_in" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, ch, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , true ));
244245
245246 size_t num_resolutions = ch_mult.size ();
246247
@@ -268,18 +269,18 @@ class Encoder : public GGMLBlock {
268269 blocks[" mid.block_2" ] = std::shared_ptr<GGMLBlock>(new ResnetBlock (block_in, block_in));
269270
270271 blocks[" norm_out" ] = std::shared_ptr<GGMLBlock>(new GroupNorm32 (block_in));
271- blocks[" conv_out" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (block_in, double_z ? z_channels * 2 : z_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
272+ blocks[" conv_out" ] = std::shared_ptr<GGMLBlock>(new Conv2d (block_in, double_z ? z_channels * 2 : z_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , true ));
272273 }
273274
274275 virtual struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * x) {
275276 // x: [N, in_channels, h, w]
276277
277- auto conv_in = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" conv_in" ]);
278+ auto conv_in = std::dynamic_pointer_cast<Conv2d >(blocks[" conv_in" ]);
278279 auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks[" mid.block_1" ]);
279280 auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks[" mid.attn_1" ]);
280281 auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks[" mid.block_2" ]);
281282 auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks[" norm_out" ]);
282- auto conv_out = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" conv_out" ]);
283+ auto conv_out = std::dynamic_pointer_cast<Conv2d >(blocks[" conv_out" ]);
283284
284285 auto h = conv_in->forward (ctx, x); // [N, ch, h, w]
285286
@@ -328,11 +329,14 @@ class Decoder : public GGMLBlock {
328329 int64_t out_channels,
329330 std::pair<int , int > kernel_size,
330331 std::pair<int , int > stride = {1 , 1 },
331- std::pair<int , int > padding = {0 , 0 }) {
332+ std::pair<int , int > padding = {0 , 0 },
333+ std::pair<int , int > dilation = {1 , 1 },
334+ bool bias = true ,
335+ bool direct = false ){
332336 if (video_decoder) {
333337 return std::shared_ptr<GGMLBlock>(new AE3DConv (in_channels, out_channels, kernel_size, video_kernel_size, stride, padding));
334338 } else {
335- return std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, out_channels, kernel_size, stride, padding));
339+ return std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, out_channels, kernel_size, stride, padding, dilation, bias, direct ));
336340 }
337341 }
338342
@@ -363,7 +367,7 @@ class Decoder : public GGMLBlock {
363367 size_t num_resolutions = ch_mult.size ();
364368 int block_in = ch * ch_mult[num_resolutions - 1 ];
365369
366- blocks[" conv_in" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (z_channels, block_in, {3 , 3 }, {1 , 1 }, {1 , 1 }));
370+ blocks[" conv_in" ] = std::shared_ptr<GGMLBlock>(new Conv2d (z_channels, block_in, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , true ));
367371
368372 blocks[" mid.block_1" ] = get_resnet_block (block_in, block_in);
369373 blocks[" mid.attn_1" ] = std::shared_ptr<GGMLBlock>(new AttnBlock (block_in));
@@ -385,7 +389,7 @@ class Decoder : public GGMLBlock {
385389 }
386390
387391 blocks[" norm_out" ] = std::shared_ptr<GGMLBlock>(new GroupNorm32 (block_in));
388- blocks[" conv_out" ] = get_conv_out (block_in, out_ch, {3 , 3 }, {1 , 1 }, {1 , 1 });
392+ blocks[" conv_out" ] = get_conv_out (block_in, out_ch, {3 , 3 }, {1 , 1 }, {1 , 1 }, { 1 , 1 }, true , true );
389393 }
390394
391395 virtual struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * z) {
@@ -394,12 +398,12 @@ class Decoder : public GGMLBlock {
394398 // merge_strategy is always learned
395399 // time_mode is always conv-only, so we need to replace conv_out_op/resnet_op to AE3DConv/VideoResBlock
396400 // AttnVideoBlock will not be used
397- auto conv_in = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" conv_in" ]);
401+ auto conv_in = std::dynamic_pointer_cast<Conv2d >(blocks[" conv_in" ]);
398402 auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks[" mid.block_1" ]);
399403 auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks[" mid.attn_1" ]);
400404 auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks[" mid.block_2" ]);
401405 auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks[" norm_out" ]);
402- auto conv_out = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" conv_out" ]);
406+ auto conv_out = std::dynamic_pointer_cast<Conv2d >(blocks[" conv_out" ]);
403407
404408 // conv_in
405409 auto h = conv_in->forward (ctx, z); // [N, block_in, h, w]
@@ -472,9 +476,14 @@ class AutoencodingEngine : public GGMLBlock {
472476 dd_config.z_channels ,
473477 use_video_decoder));
474478 if (use_quant) {
475- blocks[" post_quant_conv" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (dd_config.z_channels ,
479+ blocks[" post_quant_conv" ] = std::shared_ptr<GGMLBlock>(new Conv2d (dd_config.z_channels ,
476480 embed_dim,
477- {1 , 1 }));
481+ {1 , 1 },
482+ {1 , 1 },
483+ {0 , 0 },
484+ {1 , 1 },
485+ true ,
486+ true ));
478487 }
479488 if (!decode_only) {
480489 blocks[" encoder" ] = std::shared_ptr<GGMLBlock>(new Encoder (dd_config.ch ,
@@ -486,17 +495,22 @@ class AutoencodingEngine : public GGMLBlock {
486495 if (use_quant) {
487496 int factor = dd_config.double_z ? 2 : 1 ;
488497
489- blocks[" quant_conv" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (embed_dim * factor,
498+ blocks[" quant_conv" ] = std::shared_ptr<GGMLBlock>(new Conv2d (embed_dim * factor,
490499 dd_config.z_channels * factor,
491- {1 , 1 }));
500+ {1 , 1 },
501+ {1 , 1 },
502+ {0 , 0 },
503+ {1 , 1 },
504+ true ,
505+ true ));
492506 }
493507 }
494508 }
495509
496510 struct ggml_tensor * decode (struct ggml_context * ctx, struct ggml_tensor * z) {
497511 // z: [N, z_channels, h, w]
498512 if (use_quant) {
499- auto post_quant_conv = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" post_quant_conv" ]);
513+ auto post_quant_conv = std::dynamic_pointer_cast<Conv2d >(blocks[" post_quant_conv" ]);
500514 z = post_quant_conv->forward (ctx, z); // [N, z_channels, h, w]
501515 }
502516 auto decoder = std::dynamic_pointer_cast<Decoder>(blocks[" decoder" ]);
@@ -513,7 +527,7 @@ class AutoencodingEngine : public GGMLBlock {
513527
514528 auto h = encoder->forward (ctx, x); // [N, 2*z_channels, h/8, w/8]
515529 if (use_quant) {
516- auto quant_conv = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" quant_conv" ]);
530+ auto quant_conv = std::dynamic_pointer_cast<Conv2d >(blocks[" quant_conv" ]);
517531 h = quant_conv->forward (ctx, h); // [N, 2*embed_dim, h/8, w/8]
518532 }
519533 return h;
0 commit comments