Skip to content

Commit 6624650

Browse files
committed
Enable only for Vulkan, reduced duplicated code
1 parent ff452b8 commit 6624650

File tree

3 files changed

+65
-85
lines changed

3 files changed

+65
-85
lines changed

common.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class DownSampleBlock : public GGMLBlock {
1717
out_channels(out_channels),
1818
vae_downsample(vae_downsample) {
1919
if (vae_downsample) {
20-
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(channels, out_channels, {3, 3}, {2, 2}, {0, 0}));
20+
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {0, 0}, {1, 1}, true, true));
2121
} else {
2222
blocks["op"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {1, 1}));
2323
}
@@ -26,7 +26,7 @@ class DownSampleBlock : public GGMLBlock {
2626
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
2727
// x: [N, channels, h, w]
2828
if (vae_downsample) {
29-
auto conv = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv"]);
29+
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
3030

3131
x = ggml_pad(ctx, x, 1, 1, 0, 0);
3232
x = conv->forward(ctx, x);
@@ -49,12 +49,12 @@ class UpSampleBlock : public GGMLBlock {
4949
int out_channels)
5050
: channels(channels),
5151
out_channels(out_channels) {
52-
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
52+
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, true));
5353
}
5454

5555
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
5656
// x: [N, channels, h, w]
57-
auto conv = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv"]);
57+
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
5858

5959
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); // [N, channels, h*2, w*2]
6060
x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2]

ggml_extend.hpp

Lines changed: 13 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,6 +1475,7 @@ class Conv2d : public UnaryBlock {
14751475
std::pair<int, int> padding;
14761476
std::pair<int, int> dilation;
14771477
bool bias;
1478+
bool direct;
14781479

14791480
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
14801481
enum ggml_type wtype = GGML_TYPE_F16; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F16;
@@ -1492,67 +1493,32 @@ class Conv2d : public UnaryBlock {
14921493
std::pair<int, int> stride = {1, 1},
14931494
std::pair<int, int> padding = {0, 0},
14941495
std::pair<int, int> dilation = {1, 1},
1495-
bool bias = true)
1496+
bool bias = true,
1497+
bool direct = false)
14961498
: in_channels(in_channels),
14971499
out_channels(out_channels),
14981500
kernel_size(kernel_size),
14991501
stride(stride),
15001502
padding(padding),
15011503
dilation(dilation),
1502-
bias(bias) {}
1504+
bias(bias),
1505+
direct(direct) {}
15031506

15041507
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
15051508
struct ggml_tensor* w = params["weight"];
15061509
struct ggml_tensor* b = NULL;
15071510
if (bias) {
15081511
b = params["bias"];
15091512
}
1510-
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1511-
}
1512-
};
1513-
1514-
class Conv2dDirect : public UnaryBlock {
1515-
protected:
1516-
int64_t in_channels;
1517-
int64_t out_channels;
1518-
std::pair<int, int> kernel_size;
1519-
std::pair<int, int> stride;
1520-
std::pair<int, int> padding;
1521-
std::pair<int, int> dilation;
1522-
bool bias;
1523-
1524-
void init_params(struct ggml_context* ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = "") {
1525-
enum ggml_type wtype = GGML_TYPE_F16; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F16;
1526-
params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels, out_channels);
1527-
if (bias) {
1528-
enum ggml_type wtype = GGML_TYPE_F32; // (tensor_types.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32;
1529-
params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels);
1530-
}
1531-
}
1532-
1533-
public:
1534-
Conv2dDirect(int64_t in_channels,
1535-
int64_t out_channels,
1536-
std::pair<int, int> kernel_size,
1537-
std::pair<int, int> stride = {1, 1},
1538-
std::pair<int, int> padding = {0, 0},
1539-
std::pair<int, int> dilation = {1, 1},
1540-
bool bias = true)
1541-
: in_channels(in_channels),
1542-
out_channels(out_channels),
1543-
kernel_size(kernel_size),
1544-
stride(stride),
1545-
padding(padding),
1546-
dilation(dilation),
1547-
bias(bias) {}
1548-
1549-
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
1550-
struct ggml_tensor* w = params["weight"];
1551-
struct ggml_tensor* b = NULL;
1552-
if (bias) {
1553-
b = params["bias"];
1513+
if (direct) {
1514+
#if defined(SD_USE_CUDA) || defined(SD_USE_SYCL) || defined(SD_USE_METAL) || defined(SD_USE_OPENCL)
1515+
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1516+
#else
1517+
return ggml_nn_conv_2d_direct(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1518+
#endif
1519+
} else {
1520+
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
15541521
}
1555-
return ggml_nn_conv_2d_direct(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
15561522
}
15571523
};
15581524

vae.hpp

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,23 @@ class ResnetBlock : public UnaryBlock {
2020
out_channels(out_channels) {
2121
// temb_channels is always 0
2222
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
23-
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
23+
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, true));
2424

2525
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(out_channels));
26-
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
26+
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(out_channels, out_channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, true));
2727

2828
if (out_channels != in_channels) {
29-
blocks["nin_shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, out_channels, {1, 1}));
29+
blocks["nin_shortcut"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, true, true));
3030
}
3131
}
3232

3333
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
3434
// x: [N, in_channels, h, w]
3535
// t_emb is always None
3636
auto norm1 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm1"]);
37-
auto conv1 = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv1"]);
37+
auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv1"]);
3838
auto norm2 = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm2"]);
39-
auto conv2 = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv2"]);
39+
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv2"]);
4040

4141
auto h = x;
4242
h = norm1->forward(ctx, h);
@@ -51,7 +51,7 @@ class ResnetBlock : public UnaryBlock {
5151

5252
// skip connection
5353
if (out_channels != in_channels) {
54-
auto nin_shortcut = std::dynamic_pointer_cast<Conv2dDirect>(blocks["nin_shortcut"]);
54+
auto nin_shortcut = std::dynamic_pointer_cast<Conv2d>(blocks["nin_shortcut"]);
5555

5656
x = nin_shortcut->forward(ctx, x); // [N, out_channels, h, w]
5757
}
@@ -69,20 +69,20 @@ class AttnBlock : public UnaryBlock {
6969
AttnBlock(int64_t in_channels)
7070
: in_channels(in_channels) {
7171
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
72-
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, in_channels, {1, 1}));
73-
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, in_channels, {1, 1}));
74-
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, in_channels, {1, 1}));
72+
blocks["q"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, true, true));
73+
blocks["k"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, true, true));
74+
blocks["v"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, true, true));
7575

76-
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, in_channels, {1, 1}));
76+
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, in_channels, {1, 1}, {1, 1}, {0, 0}, {1, 1}, true, true));
7777
}
7878

7979
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
8080
// x: [N, in_channels, h, w]
8181
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
82-
auto q_proj = std::dynamic_pointer_cast<Conv2dDirect>(blocks["q"]);
83-
auto k_proj = std::dynamic_pointer_cast<Conv2dDirect>(blocks["k"]);
84-
auto v_proj = std::dynamic_pointer_cast<Conv2dDirect>(blocks["v"]);
85-
auto proj_out = std::dynamic_pointer_cast<Conv2dDirect>(blocks["proj_out"]);
82+
auto q_proj = std::dynamic_pointer_cast<Conv2d>(blocks["q"]);
83+
auto k_proj = std::dynamic_pointer_cast<Conv2d>(blocks["k"]);
84+
auto v_proj = std::dynamic_pointer_cast<Conv2d>(blocks["v"]);
85+
auto proj_out = std::dynamic_pointer_cast<Conv2d>(blocks["proj_out"]);
8686

8787
auto h_ = norm->forward(ctx, x);
8888

@@ -114,7 +114,7 @@ class AttnBlock : public UnaryBlock {
114114
}
115115
};
116116

117-
class AE3DConv : public Conv2dDirect {
117+
class AE3DConv : public Conv2d {
118118
public:
119119
AE3DConv(int64_t in_channels,
120120
int64_t out_channels,
@@ -123,8 +123,9 @@ class AE3DConv : public Conv2dDirect {
123123
std::pair<int, int> stride = {1, 1},
124124
std::pair<int, int> padding = {0, 0},
125125
std::pair<int, int> dilation = {1, 1},
126-
bool bias = true)
127-
: Conv2dDirect(in_channels, out_channels, kernel_size, stride, padding, dilation, bias) {
126+
bool bias = true,
127+
bool direct = false)
128+
: Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias, direct) {
128129
int64_t kernel_padding = video_kernel_size / 2;
129130
blocks["time_mix_conv"] = std::shared_ptr<GGMLBlock>(new Conv3dnx1x1(out_channels,
130131
out_channels,
@@ -141,7 +142,7 @@ class AE3DConv : public Conv2dDirect {
141142
// result: [N, OC, OH, OW]
142143
auto time_mix_conv = std::dynamic_pointer_cast<Conv3dnx1x1>(blocks["time_mix_conv"]);
143144

144-
x = Conv2dDirect::forward(ctx, x);
145+
x = Conv2d::forward(ctx, x);
145146
// timesteps = x.shape[0]
146147
// x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
147148
// x = conv3d(x)
@@ -240,7 +241,7 @@ class Encoder : public GGMLBlock {
240241
in_channels(in_channels),
241242
z_channels(z_channels),
242243
double_z(double_z) {
243-
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, ch, {3, 3}, {1, 1}, {1, 1}));
244+
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, ch, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, true));
244245

245246
size_t num_resolutions = ch_mult.size();
246247

@@ -268,18 +269,18 @@ class Encoder : public GGMLBlock {
268269
blocks["mid.block_2"] = std::shared_ptr<GGMLBlock>(new ResnetBlock(block_in, block_in));
269270

270271
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
271-
blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}));
272+
blocks["conv_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, true));
272273
}
273274

274275
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
275276
// x: [N, in_channels, h, w]
276277

277-
auto conv_in = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv_in"]);
278+
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
278279
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
279280
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
280281
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
281282
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
282-
auto conv_out = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv_out"]);
283+
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
283284

284285
auto h = conv_in->forward(ctx, x); // [N, ch, h, w]
285286

@@ -328,11 +329,14 @@ class Decoder : public GGMLBlock {
328329
int64_t out_channels,
329330
std::pair<int, int> kernel_size,
330331
std::pair<int, int> stride = {1, 1},
331-
std::pair<int, int> padding = {0, 0}) {
332+
std::pair<int, int> padding = {0, 0},
333+
std::pair<int, int> dilation = {1, 1},
334+
bool bias = true,
335+
bool direct = false){
332336
if (video_decoder) {
333337
return std::shared_ptr<GGMLBlock>(new AE3DConv(in_channels, out_channels, kernel_size, video_kernel_size, stride, padding));
334338
} else {
335-
return std::shared_ptr<GGMLBlock>(new Conv2dDirect(in_channels, out_channels, kernel_size, stride, padding));
339+
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias, direct));
336340
}
337341
}
338342

@@ -363,7 +367,7 @@ class Decoder : public GGMLBlock {
363367
size_t num_resolutions = ch_mult.size();
364368
int block_in = ch * ch_mult[num_resolutions - 1];
365369

366-
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}));
370+
blocks["conv_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, block_in, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, true));
367371

368372
blocks["mid.block_1"] = get_resnet_block(block_in, block_in);
369373
blocks["mid.attn_1"] = std::shared_ptr<GGMLBlock>(new AttnBlock(block_in));
@@ -385,7 +389,7 @@ class Decoder : public GGMLBlock {
385389
}
386390

387391
blocks["norm_out"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(block_in));
388-
blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1});
392+
blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1}, {1, 1}, true, true);
389393
}
390394

391395
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) {
@@ -394,12 +398,12 @@ class Decoder : public GGMLBlock {
394398
// merge_strategy is always learned
395399
// time_mode is always conv-only, so we need to replace conv_out_op/resnet_op to AE3DConv/VideoResBlock
396400
// AttnVideoBlock will not be used
397-
auto conv_in = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv_in"]);
401+
auto conv_in = std::dynamic_pointer_cast<Conv2d>(blocks["conv_in"]);
398402
auto mid_block_1 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_1"]);
399403
auto mid_attn_1 = std::dynamic_pointer_cast<AttnBlock>(blocks["mid.attn_1"]);
400404
auto mid_block_2 = std::dynamic_pointer_cast<ResnetBlock>(blocks["mid.block_2"]);
401405
auto norm_out = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm_out"]);
402-
auto conv_out = std::dynamic_pointer_cast<Conv2dDirect>(blocks["conv_out"]);
406+
auto conv_out = std::dynamic_pointer_cast<Conv2d>(blocks["conv_out"]);
403407

404408
// conv_in
405409
auto h = conv_in->forward(ctx, z); // [N, block_in, h, w]
@@ -472,9 +476,14 @@ class AutoencodingEngine : public GGMLBlock {
472476
dd_config.z_channels,
473477
use_video_decoder));
474478
if (use_quant) {
475-
blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(dd_config.z_channels,
479+
blocks["post_quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(dd_config.z_channels,
476480
embed_dim,
477-
{1, 1}));
481+
{1, 1},
482+
{1, 1},
483+
{0, 0},
484+
{1, 1},
485+
true,
486+
true));
478487
}
479488
if (!decode_only) {
480489
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new Encoder(dd_config.ch,
@@ -486,17 +495,22 @@ class AutoencodingEngine : public GGMLBlock {
486495
if (use_quant) {
487496
int factor = dd_config.double_z ? 2 : 1;
488497

489-
blocks["quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2dDirect(embed_dim * factor,
498+
blocks["quant_conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(embed_dim * factor,
490499
dd_config.z_channels * factor,
491-
{1, 1}));
500+
{1, 1},
501+
{1, 1},
502+
{0, 0},
503+
{1, 1},
504+
true,
505+
true));
492506
}
493507
}
494508
}
495509

496510
struct ggml_tensor* decode(struct ggml_context* ctx, struct ggml_tensor* z) {
497511
// z: [N, z_channels, h, w]
498512
if (use_quant) {
499-
auto post_quant_conv = std::dynamic_pointer_cast<Conv2dDirect>(blocks["post_quant_conv"]);
513+
auto post_quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["post_quant_conv"]);
500514
z = post_quant_conv->forward(ctx, z); // [N, z_channels, h, w]
501515
}
502516
auto decoder = std::dynamic_pointer_cast<Decoder>(blocks["decoder"]);
@@ -513,7 +527,7 @@ class AutoencodingEngine : public GGMLBlock {
513527

514528
auto h = encoder->forward(ctx, x); // [N, 2*z_channels, h/8, w/8]
515529
if (use_quant) {
516-
auto quant_conv = std::dynamic_pointer_cast<Conv2dDirect>(blocks["quant_conv"]);
530+
auto quant_conv = std::dynamic_pointer_cast<Conv2d>(blocks["quant_conv"]);
517531
h = quant_conv->forward(ctx, h); // [N, 2*embed_dim, h/8, w/8]
518532
}
519533
return h;

0 commit comments

Comments
 (0)