Skip to content

Commit ff452b8

Browse files
committed
Conv2DDirect for VAE stage
1 parent 7eb30d0 commit ff452b8

File tree

3 files changed

+97
-33
lines changed

3 files changed

+97
-33
lines changed

common.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class DownSampleBlock : public GGMLBlock {
1717
out_channels(out_channels),
1818
vae_downsample(vae_downsample) {
1919
if (vae_downsample) {
20-
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {0, 0}));
20+
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(channels, out_channels, {3, 3}, {2, 2}, {0, 0}));
2121
} else {
2222
blocks["op"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {1, 1}));
2323
}
@@ -26,7 +26,7 @@ class DownSampleBlock : public GGMLBlock {
2626
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
2727
// x: [N, channels, h, w]
2828
if (vae_downsample) {
29-
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
29+
auto conv = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv"]);
3030

3131
x = ggml_pad(ctx, x, 1, 1, 0, 0);
3232
x = conv->forward(ctx, x);
@@ -49,12 +49,12 @@ class UpSampleBlock : public GGMLBlock {
4949
int out_channels)
5050
: channels(channels),
5151
out_channels(out_channels) {
52-
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
52+
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
5353
}
5454

5555
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
5656
// x: [N, channels, h, w]
57-
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
57+
auto conv = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv"]);
5858

5959
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); // [N, channels, h*2, w*2]
6060
x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2]

ggml_extend.hpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,25 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
706706
return x;
707707
}
708708

709+
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d_direct(struct ggml_context* ctx,
710+
struct ggml_tensor* x,
711+
struct ggml_tensor* w,
712+
struct ggml_tensor* b,
713+
int s0 = 1,
714+
int s1 = 1,
715+
int p0 = 0,
716+
int p1 = 0,
717+
int d0 = 1,
718+
int d1 = 1) {
719+
x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
720+
if (b != NULL) {
721+
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
722+
// b = ggml_repeat(ctx, b, x);
723+
x = ggml_add(ctx, x, b);
724+
}
725+
return x;
726+
}
727+
709728
// w: [OC,IC, KD, 1 * 1]
710729
// x: [N, IC, IH, IW]
711730
// b: [OC,]
@@ -1492,6 +1511,51 @@ class Conv2d : public UnaryBlock {
14921511
}
14931512
};
14941513

1514+
class Conv2dDirect : public UnaryBlock {
1515+
protected:
1516+
int64_t in_channels;
1517+
int64_t out_channels;
1518+
std::pair<int, int> kernel_size;
1519+
std::pair<int, int> stride;
1520+
std::pair<int, int> padding;
1521+
std::pair<int, int> dilation;
1522+
bool bias;
1523+
1524+
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
1525+
enum ggml_type wtype = GGML_TYPE_F16; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F16;
1526+
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels, out_channels);
1527+
if (bias) {
1528+
enum ggml_type wtype = GGML_TYPE_F32; // (tensor_types.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32;
1529+
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels);
1530+
}
1531+
}
1532+
1533+
public:
1534+
Conv2dDirect(int64_t in_channels,
1535+
int64_t out_channels,
1536+
std::pair<int, int> kernel_size,
1537+
std::pair<int, int> stride = {1, 1},
1538+
std::pair<int, int> padding = {0, 0},
1539+
std::pair<int, int> dilation = {1, 1},
1540+
bool bias = true)
1541+
: in_channels(in_channels),
1542+
out_channels(out_channels),
1543+
kernel_size(kernel_size),
1544+
stride(stride),
1545+
padding(padding),
1546+
dilation(dilation),
1547+
bias(bias) {}
1548+
1549+
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
1550+
struct ggml_tensor* w = params["weight"];
1551+
struct ggml_tensor* b = NULL;
1552+
if (bias) {
1553+
b = params["bias"];
1554+
}
1555+
return ggml_nn_conv_2d_direct(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1556+
}
1557+
};
1558+
14951559
class Conv3dnx1x1 : public UnaryBlock {
14961560
protected:
14971561
int64_t in_channels;

vae.hpp

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,23 @@ class ResnetBlock : public UnaryBlock {
2020
out_channels(out_channels) {
2121
// temb_channels is always 0
2222
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
23-
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
23+
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
2424

2525
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(out_channels));
26-
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
26+
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
2727

2828
if (out_channels != in_channels) {
29-
blocks["nin_shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {1, 1}));
29+
blocks["nin_shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, out_channels, {1, 1}));
3030
}
3131
}
3232

3333
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
3434
// x: [N, in_channels, h, w]
3535
// t_emb is always None
3636
auto norm1 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm1"]);
37-
auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv1"]);
37+
auto conv1 = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv1"]);
3838
auto norm2 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm2"]);
39-
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv2"]);
39+
auto conv2 = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv2"]);
4040

4141
auto h = x;
4242
h = norm1->forward(ctx, h);
@@ -51,7 +51,7 @@ class ResnetBlock : public UnaryBlock {
5151

5252
// skip connection
5353
if (out_channels != in_channels) {
54-
auto nin_shortcut = std::dynamic_pointer_cast<Conv2d>(blocks["nin_shortcut"]);
54+
auto nin_shortcut = std::dynamic_pointer_cast<Conv2dDirect>(blocks["nin_shortcut"]);
5555

5656
x = nin_shortcut->forward(ctx, x); // [N, out_channels, h, w]
5757
}
@@ -69,20 +69,20 @@ class AttnBlock : public UnaryBlock {
6969
AttnBlock(int64_t in_channels)
7070
: in_channels(in_channels) {
7171
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
72-
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
73-
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
74-
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
72+
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, in_channels, {1, 1}));
73+
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, in_channels, {1, 1}));
74+
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, in_channels, {1, 1}));
7575

76-
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}));
76+
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, in_channels, {1, 1}));
7777
}
7878

7979
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
8080
// x: [N, in_channels, h, w]
8181
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
82-
auto q_proj = std::dynamic_pointer_cast<Conv2d>(blocks["q"]);
83-
auto k_proj = std::dynamic_pointer_cast<Conv2d>(blocks["k"]);
84-
auto v_proj = std::dynamic_pointer_cast<Conv2d>(blocks["v"]);
85-
auto proj_out = std::dynamic_pointer_cast<Conv2d>(blocks["proj_out"]);
82+
auto q_proj = std::dynamic_pointer_cast<Conv2dDirect>(blocks["q"]);
83+
auto k_proj = std::dynamic_pointer_cast<Conv2dDirect>(blocks["k"]);
84+
auto v_proj = std::dynamic_pointer_cast<Conv2dDirect>(blocks["v"]);
85+
auto proj_out = std::dynamic_pointer_cast<Conv2dDirect>(blocks["proj_out"]);
8686

8787
auto h_ = norm->forward(ctx, x);
8888

@@ -114,7 +114,7 @@ class AttnBlock : public UnaryBlock {
114114
}
115115
};
116116

117-
class AE3DConv : public Conv2d {
117+
class AE3DConv : public Conv2dDirect {
118118
public:
119119
AE3DConv(int64_t in_channels,
120120
int64_t out_channels,
@@ -124,7 +124,7 @@ class AE3DConv : public Conv2d {
124124
std::pair<int, int> padding = {0, 0},
125125
std::pair<int, int> dilation = {1, 1},
126126
bool bias = true)
127-
: Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) {
127+
: Conv2dDirect(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) {
128128
int64_t kernel_padding = video_kernel_size / 2;
129129
blocks["time_mix_conv"] = std::shared_ptr<GGMLBlock>(new Conv3dnx1x1(out_channels,
130130
out_channels,
@@ -141,7 +141,7 @@ class AE3DConv : public Conv2d {
141141
// result: [N, OC, OH, OW]
142142
auto time_mix_conv = std::dynamic_pointer_cast<Conv3dnx1x1>(blocks["time_mix_conv"]);
143143

144-
x = Conv2d::forward(ctx, x);
144+
x = Conv2dDirect::forward(ctx, x);
145145
// timesteps = x.shape[0]
146146
// x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
147147
// x = conv3d(x)
@@ -240,7 +240,7 @@ class Encoder : public GGMLBlock {
240240
in_channels(in_channels),
241241
z_channels(z_channels),
242242
double_z(double_z) {
243-
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, ch, {3, 3}, {1, 1}, {1, 1}));
243+
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, ch, {3, 3}, {1, 1}, {1, 1}));
244244

245245
size_t num_resolutions = ch_mult.size();
246246

@@ -268,18 +268,18 @@ class Encoder : public GGMLBlock {
268268
blocks["mid.block_2"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
269269

270270
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
271-
blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}));
271+
blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}));
272272
}
273273

274274
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
275275
// x: [N, in_channels, h, w]
276276

277-
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
277+
auto conv_in = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv_in"]);
278278
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
279279
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
280280
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
281281
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
282-
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
282+
auto conv_out = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv_out"]);
283283

284284
auto h = conv_in->forward(ctx, x); // [N, ch, h, w]
285285

@@ -332,7 +332,7 @@ class Decoder : public GGMLBlock {
332332
if (video_decoder) {
333333
return std::shared_ptr<GGMLBlock>(new AE3DConv(in_channels, out_channels, kernel_size, video_kernel_size, stride, padding));
334334
} else {
335-
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, stride, padding));
335+
return std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, out_channels, kernel_size, stride, padding));
336336
}
337337
}
338338

@@ -363,7 +363,7 @@ class Decoder : public GGMLBlock {
363363
size_t num_resolutions = ch_mult.size();
364364
int block_in = ch * ch_mult[num_resolutions - 1];
365365

366-
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}));
366+
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}));
367367

368368
blocks["mid.block_1"] = get_resnet_block(block_in, block_in);
369369
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in));
@@ -394,12 +394,12 @@ class Decoder : public GGMLBlock {
394394
// merge_strategy is always learned
395395
// time_mode is always conv-only, so we need to replace conv_out_op/resnet_op to AE3DConv/VideoResBlock
396396
// AttnVideoBlock will not be used
397-
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
397+
auto conv_in = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv_in"]);
398398
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
399399
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
400400
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
401401
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
402-
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
402+
auto conv_out = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv_out"]);
403403

404404
// conv_in
405405
auto h = conv_in->forward(ctx, z); // [N, block_in, h, w]
@@ -472,7 +472,7 @@ class AutoencodingEngine : public GGMLBlock {
472472
dd_config.z_channels,
473473
use_video_decoder));
474474
if (use_quant) {
475-
blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dd_config.z_channels,
475+
blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(dd_config.z_channels,
476476
embed_dim,
477477
{1, 1}));
478478
}
@@ -486,7 +486,7 @@ class AutoencodingEngine : public GGMLBlock {
486486
if (use_quant) {
487487
int factor = dd_config.double_z ? 2 : 1;
488488

489-
blocks["quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(embed_dim * factor,
489+
blocks["quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(embed_dim * factor,
490490
dd_config.z_channels * factor,
491491
{1, 1}));
492492
}
@@ -496,7 +496,7 @@ class AutoencodingEngine : public GGMLBlock {
496496
struct ggml_tensor* decode(struct ggml_context* ctx, struct ggml_tensor* z) {
497497
// z: [N, z_channels, h, w]
498498
if (use_quant) {
499-
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
499+
auto post_quant_conv = std::dynamic_pointer_cast<Conv2dDirect>(blocks["post_quant_conv"]);
500500
z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w]
501501
}
502502
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
@@ -513,7 +513,7 @@ class AutoencodingEngine : public GGMLBlock {
513513

514514
auto h = encoder->forward(ctx, x); // [N, 2*z_channels, h/8, w/8]
515515
if (use_quant) {
516-
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
516+
auto quant_conv = std::dynamic_pointer_cast<Conv2dDirect>(blocks["quant_conv"]);
517517
h = quant_conv->forward(ctx, h); // [N, 2*embed_dim, h/8, w/8]
518518
}
519519
return h;

0 commit comments

Comments
 (0)