@@ -20,23 +20,23 @@ class ResnetBlock : public UnaryBlock {
2020 out_channels (out_channels) {
2121 // temb_channels is always 0
2222 blocks[" norm1" ] = std::shared_ptr<GGMLBlock>(new GroupNorm32 (in_channels));
23- blocks[" conv1" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, out_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
23+ blocks[" conv1" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, out_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
2424
2525 blocks[" norm2" ] = std::shared_ptr<GGMLBlock>(new GroupNorm32 (out_channels));
26- blocks[" conv2" ] = std::shared_ptr<GGMLBlock>(new Conv2d (out_channels, out_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
26+ blocks[" conv2" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (out_channels, out_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
2727
2828 if (out_channels != in_channels) {
29- blocks[" nin_shortcut" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, out_channels, {1 , 1 }));
29+ blocks[" nin_shortcut" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, out_channels, {1 , 1 }));
3030 }
3131 }
3232
3333 struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * x) {
3434 // x: [N, in_channels, h, w]
3535 // t_emb is always None
3636 auto norm1 = std::dynamic_pointer_cast<GroupNorm32>(blocks[" norm1" ]);
37- auto conv1 = std::dynamic_pointer_cast<Conv2d >(blocks[" conv1" ]);
37+ auto conv1 = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" conv1" ]);
3838 auto norm2 = std::dynamic_pointer_cast<GroupNorm32>(blocks[" norm2" ]);
39- auto conv2 = std::dynamic_pointer_cast<Conv2d >(blocks[" conv2" ]);
39+ auto conv2 = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" conv2" ]);
4040
4141 auto h = x;
4242 h = norm1->forward (ctx, h);
@@ -51,7 +51,7 @@ class ResnetBlock : public UnaryBlock {
5151
5252 // skip connection
5353 if (out_channels != in_channels) {
54- auto nin_shortcut = std::dynamic_pointer_cast<Conv2d >(blocks[" nin_shortcut" ]);
54+ auto nin_shortcut = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" nin_shortcut" ]);
5555
5656 x = nin_shortcut->forward (ctx, x); // [N, out_channels, h, w]
5757 }
@@ -69,20 +69,20 @@ class AttnBlock : public UnaryBlock {
6969 AttnBlock (int64_t in_channels)
7070 : in_channels(in_channels) {
7171 blocks[" norm" ] = std::shared_ptr<GGMLBlock>(new GroupNorm32 (in_channels));
72- blocks[" q" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, in_channels, {1 , 1 }));
73- blocks[" k" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, in_channels, {1 , 1 }));
74- blocks[" v" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, in_channels, {1 , 1 }));
72+ blocks[" q" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, in_channels, {1 , 1 }));
73+ blocks[" k" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, in_channels, {1 , 1 }));
74+ blocks[" v" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, in_channels, {1 , 1 }));
7575
76- blocks[" proj_out" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, in_channels, {1 , 1 }));
76+ blocks[" proj_out" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, in_channels, {1 , 1 }));
7777 }
7878
7979 struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * x) {
8080 // x: [N, in_channels, h, w]
8181 auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks[" norm" ]);
82- auto q_proj = std::dynamic_pointer_cast<Conv2d >(blocks[" q" ]);
83- auto k_proj = std::dynamic_pointer_cast<Conv2d >(blocks[" k" ]);
84- auto v_proj = std::dynamic_pointer_cast<Conv2d >(blocks[" v" ]);
85- auto proj_out = std::dynamic_pointer_cast<Conv2d >(blocks[" proj_out" ]);
82+ auto q_proj = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" q" ]);
83+ auto k_proj = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" k" ]);
84+ auto v_proj = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" v" ]);
85+ auto proj_out = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" proj_out" ]);
8686
8787 auto h_ = norm->forward (ctx, x);
8888
@@ -114,7 +114,7 @@ class AttnBlock : public UnaryBlock {
114114 }
115115};
116116
117- class AE3DConv : public Conv2d {
117+ class AE3DConv : public Conv2dDirect {
118118public:
119119 AE3DConv (int64_t in_channels,
120120 int64_t out_channels,
@@ -124,7 +124,7 @@ class AE3DConv : public Conv2d {
124124 std::pair<int , int > padding = {0 , 0 },
125125 std::pair<int , int > dilation = {1 , 1 },
126126 bool bias = true )
127- : Conv2d (in_channels, out_channels, kernel_size, stride, padding, dilation, bias) {
127+ : Conv2dDirect (in_channels, out_channels, kernel_size, stride, padding, dilation, bias) {
128128 int64_t kernel_padding = video_kernel_size / 2 ;
129129 blocks[" time_mix_conv" ] = std::shared_ptr<GGMLBlock>(new Conv3dnx1x1 (out_channels,
130130 out_channels,
@@ -141,7 +141,7 @@ class AE3DConv : public Conv2d {
141141 // result: [N, OC, OH, OW]
142142 auto time_mix_conv = std::dynamic_pointer_cast<Conv3dnx1x1>(blocks[" time_mix_conv" ]);
143143
144- x = Conv2d ::forward (ctx, x);
144+ x = Conv2dDirect ::forward (ctx, x);
145145 // timesteps = x.shape[0]
146146 // x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
147147 // x = conv3d(x)
@@ -240,7 +240,7 @@ class Encoder : public GGMLBlock {
240240 in_channels(in_channels),
241241 z_channels(z_channels),
242242 double_z(double_z) {
243- blocks[" conv_in" ] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, ch, {3 , 3 }, {1 , 1 }, {1 , 1 }));
243+ blocks[" conv_in" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, ch, {3 , 3 }, {1 , 1 }, {1 , 1 }));
244244
245245 size_t num_resolutions = ch_mult.size ();
246246
@@ -268,18 +268,18 @@ class Encoder : public GGMLBlock {
268268 blocks[" mid.block_2" ] = std::shared_ptr<GGMLBlock>(new ResnetBlock (block_in, block_in));
269269
270270 blocks[" norm_out" ] = std::shared_ptr<GGMLBlock>(new GroupNorm32 (block_in));
271- blocks[" conv_out" ] = std::shared_ptr<GGMLBlock>(new Conv2d (block_in, double_z ? z_channels * 2 : z_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
271+ blocks[" conv_out" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (block_in, double_z ? z_channels * 2 : z_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
272272 }
273273
274274 virtual struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * x) {
275275 // x: [N, in_channels, h, w]
276276
277- auto conv_in = std::dynamic_pointer_cast<Conv2d >(blocks[" conv_in" ]);
277+ auto conv_in = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" conv_in" ]);
278278 auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks[" mid.block_1" ]);
279279 auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks[" mid.attn_1" ]);
280280 auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks[" mid.block_2" ]);
281281 auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks[" norm_out" ]);
282- auto conv_out = std::dynamic_pointer_cast<Conv2d >(blocks[" conv_out" ]);
282+ auto conv_out = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" conv_out" ]);
283283
284284 auto h = conv_in->forward (ctx, x); // [N, ch, h, w]
285285
@@ -332,7 +332,7 @@ class Decoder : public GGMLBlock {
332332 if (video_decoder) {
333333 return std::shared_ptr<GGMLBlock>(new AE3DConv (in_channels, out_channels, kernel_size, video_kernel_size, stride, padding));
334334 } else {
335- return std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, out_channels, kernel_size, stride, padding));
335+ return std::shared_ptr<GGMLBlock>(new Conv2dDirect (in_channels, out_channels, kernel_size, stride, padding));
336336 }
337337 }
338338
@@ -363,7 +363,7 @@ class Decoder : public GGMLBlock {
363363 size_t num_resolutions = ch_mult.size ();
364364 int block_in = ch * ch_mult[num_resolutions - 1 ];
365365
366- blocks[" conv_in" ] = std::shared_ptr<GGMLBlock>(new Conv2d (z_channels, block_in, {3 , 3 }, {1 , 1 }, {1 , 1 }));
366+ blocks[" conv_in" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (z_channels, block_in, {3 , 3 }, {1 , 1 }, {1 , 1 }));
367367
368368 blocks[" mid.block_1" ] = get_resnet_block (block_in, block_in);
369369 blocks[" mid.attn_1" ] = std::shared_ptr<GGMLBlock>(new AttnBlock (block_in));
@@ -394,12 +394,12 @@ class Decoder : public GGMLBlock {
394394 // merge_strategy is always learned
395395 // time_mode is always conv-only, so we need to replace conv_out_op/resnet_op to AE3DConv/VideoResBlock
396396 // AttnVideoBlock will not be used
397- auto conv_in = std::dynamic_pointer_cast<Conv2d >(blocks[" conv_in" ]);
397+ auto conv_in = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" conv_in" ]);
398398 auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks[" mid.block_1" ]);
399399 auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks[" mid.attn_1" ]);
400400 auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks[" mid.block_2" ]);
401401 auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks[" norm_out" ]);
402- auto conv_out = std::dynamic_pointer_cast<Conv2d >(blocks[" conv_out" ]);
402+ auto conv_out = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" conv_out" ]);
403403
404404 // conv_in
405405 auto h = conv_in->forward (ctx, z); // [N, block_in, h, w]
@@ -472,7 +472,7 @@ class AutoencodingEngine : public GGMLBlock {
472472 dd_config.z_channels ,
473473 use_video_decoder));
474474 if (use_quant) {
475- blocks[" post_quant_conv" ] = std::shared_ptr<GGMLBlock>(new Conv2d (dd_config.z_channels ,
475+ blocks[" post_quant_conv" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (dd_config.z_channels ,
476476 embed_dim,
477477 {1 , 1 }));
478478 }
@@ -486,7 +486,7 @@ class AutoencodingEngine : public GGMLBlock {
486486 if (use_quant) {
487487 int factor = dd_config.double_z ? 2 : 1 ;
488488
489- blocks[" quant_conv" ] = std::shared_ptr<GGMLBlock>(new Conv2d (embed_dim * factor,
489+ blocks[" quant_conv" ] = std::shared_ptr<GGMLBlock>(new Conv2dDirect (embed_dim * factor,
490490 dd_config.z_channels * factor,
491491 {1 , 1 }));
492492 }
@@ -496,7 +496,7 @@ class AutoencodingEngine : public GGMLBlock {
496496 struct ggml_tensor * decode (struct ggml_context * ctx, struct ggml_tensor * z) {
497497 // z: [N, z_channels, h, w]
498498 if (use_quant) {
499- auto post_quant_conv = std::dynamic_pointer_cast<Conv2d >(blocks[" post_quant_conv" ]);
499+ auto post_quant_conv = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" post_quant_conv" ]);
500500 z = post_quant_conv->forward (ctx, z); // [N, z_channels, h, w]
501501 }
502502 auto decoder = std::dynamic_pointer_cast<Decoder>(blocks[" decoder" ]);
@@ -513,7 +513,7 @@ class AutoencodingEngine : public GGMLBlock {
513513
514514 auto h = encoder->forward (ctx, x); // [N, 2*z_channels, h/8, w/8]
515515 if (use_quant) {
516- auto quant_conv = std::dynamic_pointer_cast<Conv2d >(blocks[" quant_conv" ]);
516+ auto quant_conv = std::dynamic_pointer_cast<Conv2dDirect >(blocks[" quant_conv" ]);
517517 h = quant_conv->forward (ctx, h); // [N, 2*embed_dim, h/8, w/8]
518518 }
519519 return h;
0 commit comments