Skip to content

Commit b06ddf9

Browse files
committed
Cmake option to use conv2d direct
1 parent 6624650 commit b06ddf9

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ option(SD_VULKAN "sd: vulkan backend" OFF)
3131
option(SD_OPENCL "sd: opencl backend" OFF)
3232
option(SD_SYCL "sd: sycl backend" OFF)
3333
option(SD_MUSA "sd: musa backend" OFF)
34+
option(SD_CONV2D_DIRECT "sd: enable conv2d direct support" OFF)
3435
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
3536
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
3637
#option(SD_BUILD_SERVER "sd: build server example" ON)
@@ -77,6 +78,11 @@ if(SD_MUSA)
7778
endif()
7879
endif()
7980

81+
if(SD_CONV2D_DIRECT)
82+
message("-- Use CONV2D Direct for VAE")
83+
add_definitions(-DSD_USE_CONV2D_DIRECT)
84+
endif ()
85+
8086
set(SD_LIB stable-diffusion)
8187

8288
file(GLOB SD_LIB_SOURCES

ggml_extend.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,10 +1511,14 @@ class Conv2d : public UnaryBlock {
15111511
b = params["bias"];
15121512
}
15131513
if (direct) {
1514-
#if defined(SD_USE_CUDA) || defined(SD_USE_SYCL) || defined(SD_USE_METAL) || defined(SD_USE_OPENCL)
1515-
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1514+
#if defined(SD_USE_CONV2D_DIRECT)
1515+
#if defined(SD_USE_CUDA) || defined(SD_USE_SYCL) || defined(SD_USE_METAL) || defined(SD_USE_OPENCL)
1516+
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1517+
#else
1518+
return ggml_nn_conv_2d_direct(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1519+
#endif
15161520
#else
1517-
return ggml_nn_conv_2d_direct(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
1521+
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
15181522
#endif
15191523
} else {
15201524
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);

0 commit comments

Comments
 (0)