Skip to content

Commit 3f2dac0

Browse files
[PaddlePaddle Hackathon 4][Frontend][Paddle]add conv3d for paddle frontend (#14290)
* add conv3d for paddle frontend * codestyle * codestyle * fix bugs
1 parent 58dce66 commit 3f2dac0

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

python/tvm/relay/frontend/paddlepaddle.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,60 @@ def convert_conv2d_transpose(g, op, block):
400400
g.add_node(op.output("Output")[0], out)
401401

402402

403+
def convert_conv3d(g, op, block):
404+
"""Operator converter for conv3d."""
405+
406+
dilations = op.attr("dilations")
407+
groups = op.attr("groups")
408+
paddings = op.attr("paddings")
409+
padding_algorithm = op.attr("padding_algorithm")
410+
strides = op.attr("strides")
411+
412+
kernel = g.get_node(op.input("Filter")[0])
413+
input_x = g.get_node(op.input("Input")[0])
414+
out_channels, _, k_d, k_h, k_w = infer_shape(kernel)
415+
if padding_algorithm == "VALID":
416+
paddings = [0, 0, 0]
417+
elif padding_algorithm == "SAME":
418+
dilations = [1, 1, 1]
419+
input_x = autopad(input_x, strides, [k_d, k_h, k_w], dilations)
420+
paddings = [0, 0, 0]
421+
elif padding_algorithm == "EXPLICIT":
422+
if len(paddings) == 3:
423+
paddings = [
424+
paddings[0],
425+
paddings[1],
426+
paddings[2],
427+
paddings[0],
428+
paddings[1],
429+
paddings[2],
430+
]
431+
elif len(paddings) == 6:
432+
paddings = [
433+
paddings[0],
434+
paddings[3],
435+
paddings[1],
436+
paddings[4],
437+
paddings[2],
438+
paddings[5],
439+
]
440+
else:
441+
msg = 'Value {} in attribute "padding" of operator Conv is not "valid."'
442+
raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm))
443+
444+
out = _op.nn.conv3d(
445+
input_x,
446+
kernel,
447+
strides=strides,
448+
padding=paddings,
449+
dilation=dilations,
450+
groups=groups,
451+
channels=out_channels,
452+
kernel_size=[k_d, k_h, k_w],
453+
)
454+
g.add_node(op.output("Output")[0], out)
455+
456+
403457
def convert_dist(g, op, block):
404458
"""Operator converter for dist."""
405459

@@ -2416,6 +2470,7 @@ def convert_where_index(g, op, block):
24162470
"concat": convert_concat,
24172471
"conv2d": convert_conv2d,
24182472
"conv2d_transpose": convert_conv2d_transpose,
2473+
"conv3d": convert_conv3d,
24192474
"cos": convert_unary_op,
24202475
"cosh": convert_unary_op,
24212476
"cumsum": convert_cumsum,

tests/python/frontend/paddlepaddle/test_forward.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,49 @@ def forward(self, inputs):
554554
verify_model(Conv2DTranspose(stride=3, padding="SAME", groups=1), input_data=input_data)
555555

556556

557+
@tvm.testing.uses_gpu
558+
def test_forward_conv3d():
559+
class Conv3D(nn.Layer):
560+
def __init__(self, stride=1, padding=0, dilation=1, groups=1, padding_mode="zeros"):
561+
super(Conv3D, self).__init__()
562+
self.conv = nn.Conv3D(
563+
3,
564+
6,
565+
3,
566+
stride=stride,
567+
padding=padding,
568+
dilation=dilation,
569+
groups=groups,
570+
padding_mode=padding_mode,
571+
)
572+
self.softmax = nn.Softmax()
573+
574+
@paddle.jit.to_static
575+
def forward(self, inputs):
576+
return self.softmax(self.conv(inputs))
577+
578+
input_shapes = [[1, 3, 10, 10, 10], [1, 3, 12, 12, 12]]
579+
580+
for input_shape in input_shapes:
581+
input_data = paddle.rand(input_shape, dtype="float32")
582+
verify_model(Conv3D(), input_data=input_data)
583+
verify_model(Conv3D(stride=2, padding="VALID", dilation=3), input_data=input_data)
584+
verify_model(Conv3D(stride=2, padding="SAME", dilation=3), input_data=input_data)
585+
verify_model(
586+
Conv3D(stride=2, padding=(3, 3, 4, 4, 2, 2), dilation=3),
587+
input_data=input_data,
588+
)
589+
verify_model(
590+
Conv3D(stride=2, padding=3, dilation=3, padding_mode="reflect"),
591+
input_data=input_data,
592+
)
593+
verify_model(
594+
Conv3D(stride=2, padding=3, dilation=3, padding_mode="replicate"),
595+
input_data=input_data,
596+
)
597+
verify_model(Conv3D(stride=2, padding="SAME", dilation=2, groups=3), input_data=input_data)
598+
599+
557600
@tvm.testing.uses_gpu
558601
def test_forward_dot():
559602
class Dot(nn.Layer):

0 commit comments

Comments
 (0)