Skip to content

Commit 002433b

Browse files
am17anqnixsynapse
authored andcommitted
CUDA: add conv_2d_transpose (ggml-org#14287)
* CUDA: add conv_2d_transpose * remove direct include of cuda_fp16 * Review: add brackets for readability, remove ggml_set_param and add asserts
1 parent 55184b4 commit 002433b

File tree

2 files changed

+219
-1104
lines changed

2 files changed

+219
-1104
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "ggml-cuda/concat.cuh"
1313
#include "ggml-cuda/conv-transpose-1d.cuh"
1414
#include "ggml-cuda/conv2d-dw.cuh"
15+
#include "ggml-cuda/conv2d-transpose.cuh"
1516
#include "ggml-cuda/convert.cuh"
1617
#include "ggml-cuda/count-equal.cuh"
1718
#include "ggml-cuda/cpy.cuh"
@@ -2341,6 +2342,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23412342
case GGML_OP_CONV_2D_DW:
23422343
ggml_cuda_op_conv2d_dw(ctx, dst);
23432344
break;
2345+
case GGML_OP_CONV_TRANSPOSE_2D:
2346+
ggml_cuda_conv_2d_transpose_p0(ctx, dst);
2347+
break;
23442348
case GGML_OP_CONV_TRANSPOSE_1D:
23452349
ggml_cuda_op_conv_transpose_1d(ctx,dst);
23462350
break;
@@ -3252,6 +3256,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32523256
}
32533257
case GGML_OP_IM2COL:
32543258
case GGML_OP_CONV_2D_DW:
3259+
case GGML_OP_CONV_TRANSPOSE_2D:
32553260
case GGML_OP_POOL_2D:
32563261
case GGML_OP_SUM:
32573262
case GGML_OP_SUM_ROWS:

0 commit comments

Comments
 (0)