Skip to content

Commit 06fabe4

Browse files
[PaddlePaddle Hackathon 4][Frontend][Paddle]add grid-sample/gaussian_random/flip/fill_zeros_like/unique for paddle frontend (#14277)
Add grid-sample/gaussian_random/flip/fill_zeros_like/unique for paddle frontend.
1 parent 6fa88e3 commit 06fabe4

File tree

2 files changed

+208
-0
lines changed

2 files changed

+208
-0
lines changed

python/tvm/relay/frontend/paddlepaddle.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,17 @@ def convert_fill_constant_batch_size_like(g, op, block):
680680
g.add_node(op.output("Out")[0], out)
681681

682682

683+
def convert_fill_zeros_like(g, op, block):
684+
"""Operator converter for fill_zeros_like."""
685+
686+
x = g.get_node(op.input("X")[0])
687+
dtype = op.attr("dtype")
688+
dtype = _convert_dtype_value(dtype)
689+
value = _expr.const(0, dtype=dtype)
690+
out = _op.transform.full_like(x, value).astype(dtype)
691+
g.add_node(op.output("Out")[0], out)
692+
693+
683694
def convert_flatten(g, op, block):
684695
"""Operator converter for flatten."""
685696

@@ -707,6 +718,21 @@ def convert_flatten(g, op, block):
707718
g.add_node(op.output("Out")[0], out)
708719

709720

721+
def convert_flip(g, op, block):
722+
"""Operator converter for flip."""
723+
724+
x = g.get_node(op.input("X")[0])
725+
axis = op.attr("axis")
726+
727+
for i, ax in enumerate(axis):
728+
if i == 0:
729+
out = _op.reverse(x, ax)
730+
else:
731+
out = _op.reverse(out, ax)
732+
733+
g.add_node(op.output("Out")[0], out)
734+
735+
710736
def convert_gather(g, op, block):
711737
"""Operator converter for gather."""
712738

@@ -730,6 +756,17 @@ def convert_gather_nd(g, op, block):
730756
g.add_node(op.output("Out")[0], out)
731757

732758

759+
def convert_gaussian_random(g, op, block):
760+
"""Operator converter for convert_gaussian_random."""
761+
762+
mean = op.attr("mean")
763+
std = op.attr("std")
764+
shape = op.attr("shape")
765+
seed = op.attr("seed")
766+
out = _op.random.normal(key=seed, shape=shape, mean=mean, scale=std)
767+
g.add_node(op.output("Out")[0], out)
768+
769+
733770
def convert_gelu(g, op, block):
734771
"""Operator converter for gelu."""
735772

@@ -741,6 +778,32 @@ def convert_gelu(g, op, block):
741778
g.add_node(op.output("Out")[0], out)
742779

743780

781+
def convert_grid_sampler(g, op, block):
782+
"""Operator converter for grid_sampler."""
783+
784+
x = g.get_node(op.input("X")[0])
785+
data_shape = infer_shape(x)
786+
grid = g.get_node(op.input("Grid")[0])
787+
mode = op.attr("mode")
788+
padding_mode = op.attr("padding_mode")
789+
align_corners = op.attr("align_corners")
790+
791+
if len(data_shape) == 4:
792+
layout = "NCHW"
793+
axes = [0, 3, 1, 2]
794+
grid = _op.transform.transpose(grid, axes)
795+
elif len(data_shape) == 5:
796+
layout = "NCDHW"
797+
axes = [0, 4, 1, 2, 3]
798+
grid = _op.transform.transpose(grid, axes)
799+
else:
800+
msg = f"only 4D and 5D are supported."
801+
raise ValueError(msg)
802+
803+
out = _op.image.grid_sample(x, grid, mode, layout, padding_mode, align_corners)
804+
g.add_node(op.output("Output")[0], out)
805+
806+
744807
def convert_group_norm(g, op, block):
745808
"""Operator converter for group_norm."""
746809

@@ -2255,6 +2318,40 @@ def convert_transpose(g, op, block):
22552318
g.add_node(op.output("Out")[0], out)
22562319

22572320

2321+
def convert_unique(g, op, block):
2322+
"""Operator converter for unique."""
2323+
2324+
x = g.get_node(op.input("X")[0])
2325+
return_index = op.attr("return_index")
2326+
return_inverse = op.attr("return_inverse")
2327+
return_counts = op.attr("return_counts")
2328+
axis = op.attr("axis")
2329+
dtype = op.attr("dtype")
2330+
dtype = _convert_dtype_value(dtype)
2331+
2332+
if len(axis) == 0:
2333+
x = _op.reshape(x, [-1])
2334+
2335+
if return_counts:
2336+
unique, indices, inverse_indices, _, counts = _op.unique(
2337+
x, is_sorted=True, return_counts=True
2338+
)
2339+
else:
2340+
unique, indices, inverse_indices, _ = _op.unique(x, is_sorted=True, return_counts=False)
2341+
2342+
out = unique
2343+
if dtype != infer_type(out).checked_type.dtype:
2344+
out = _op.cast(out, dtype)
2345+
g.add_node(op.output("Out")[0], unique)
2346+
2347+
if return_index:
2348+
g.add_node(op.output("Indices")[0], indices)
2349+
if return_inverse:
2350+
g.add_node(op.output("Index")[0], inverse_indices)
2351+
if return_counts:
2352+
g.add_node(op.output("Counts")[0], counts)
2353+
2354+
22582355
def convert_unsqueeze(g, op, block):
22592356
"""Operator converter for unsqueeze."""
22602357

@@ -2346,14 +2443,18 @@ def convert_where_index(g, op, block):
23462443
"fill_any_like": convert_fill_any_like,
23472444
"fill_constant": convert_fill_constant,
23482445
"fill_constant_batch_size_like": convert_fill_constant_batch_size_like,
2446+
"fill_zeros_like": convert_fill_zeros_like,
23492447
"flatten_contiguous_range": convert_flatten,
23502448
"floor": convert_unary_op,
23512449
"floor_mod": convert_elementwise_op,
2450+
"flip": convert_flip,
23522451
"gather": convert_gather,
23532452
"gather_nd": convert_gather_nd,
2453+
"gaussian_random": convert_gaussian_random,
23542454
"gelu": convert_gelu,
23552455
"greater_equal": convert_elementwise_op,
23562456
"greater_than": convert_elementwise_op,
2457+
"grid_sampler": convert_grid_sampler,
23572458
"group_norm": convert_group_norm,
23582459
"hard_shrink": convert_hard_shrink,
23592460
"hard_sigmoid": convert_hard_sigmoid,
@@ -2443,6 +2544,7 @@ def convert_where_index(g, op, block):
24432544
"tile": convert_tile,
24442545
"top_k_v2": convert_topk,
24452546
"transpose2": convert_transpose,
2547+
"unique": convert_unique,
24462548
"unsqueeze2": convert_unsqueeze,
24472549
"unstack": convert_unstack,
24482550
"where": convert_where,

tests/python/frontend/paddlepaddle/test_forward.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,23 @@ def forward(self, x, y):
679679
verify_model(ExpandAs(), [x_data, y_data])
680680

681681

682+
@tvm.testing.uses_gpu
683+
def test_forward_fill_zeros_like():
684+
class FilZeroLike(nn.Layer):
685+
def __init__(self, dtype=None):
686+
super(FilZeroLike, self).__init__()
687+
self.dtype = dtype
688+
689+
@paddle.jit.to_static
690+
def forward(self, x):
691+
return paddle.zeros_like(x, dtype=self.dtype)
692+
693+
input_shape = [2, 3, 5]
694+
input_data = paddle.rand(input_shape, dtype="float32")
695+
verify_model(FilZeroLike("float32"), input_data=input_data)
696+
verify_model(FilZeroLike("int32"), input_data=input_data)
697+
698+
682699
@tvm.testing.uses_gpu
683700
def test_forward_flatten():
684701
class Flatten(nn.Layer):
@@ -697,6 +714,23 @@ def forward(self, x):
697714
verify_model(Flatten(2, -2), input_data=input_data)
698715

699716

717+
@tvm.testing.uses_gpu
718+
def test_forward_flip():
719+
class Flip(nn.Layer):
720+
def __init__(self, axis):
721+
super(Flip, self).__init__()
722+
self.axis = axis
723+
724+
@paddle.jit.to_static
725+
def forward(self, x):
726+
return paddle.flip(x, axis=self.axis)
727+
728+
input_data = paddle.rand([2, 3, 4], dtype="float32")
729+
verify_model(Flip(0), input_data)
730+
verify_model(Flip(-1), input_data)
731+
verify_model(Flip([0, 1]), input_data)
732+
733+
700734
@tvm.testing.uses_gpu
701735
def test_forward_gather():
702736
class Gather(nn.Layer):
@@ -750,6 +784,39 @@ def forward(self, inputs):
750784
verify_model(GroupNorm(num_channels, 2), input_data, rtol=1e-4, atol=1e-4)
751785

752786

787+
@tvm.testing.uses_gpu
788+
def test_forward_grid_sampler():
789+
class GridSampler(nn.Layer):
790+
def __init__(self, mode="bilinear", padding_mode="zeros", align_corners=True):
791+
super(GridSampler, self).__init__()
792+
self.mode = mode
793+
self.padding_mode = padding_mode
794+
self.align_corners = align_corners
795+
796+
def forward(self, x, grid):
797+
return paddle.nn.functional.grid_sample(
798+
x,
799+
grid,
800+
mode=self.mode,
801+
padding_mode=self.padding_mode,
802+
align_corners=self.align_corners,
803+
)
804+
805+
x_2D = paddle.rand(shape=[4, 4, 8, 8], dtype="float32")
806+
grid_2D = paddle.rand(shape=[4, 8, 8, 2], dtype="float32")
807+
verify_model(GridSampler(mode="nearest"), input_data=[x_2D, grid_2D])
808+
verify_model(GridSampler(padding_mode="reflection"), input_data=[x_2D, grid_2D])
809+
verify_model(GridSampler(padding_mode="border"), input_data=[x_2D, grid_2D])
810+
verify_model(GridSampler(align_corners=False), input_data=[x_2D, grid_2D])
811+
812+
x_3D = paddle.rand(shape=[4, 4, 4, 4, 4], dtype="float32")
813+
grid_3D = paddle.rand(shape=[4, 8, 8, 8, 3], dtype="float32")
814+
verify_model(GridSampler(mode="nearest"), input_data=[x_3D, grid_3D])
815+
verify_model(GridSampler(padding_mode="reflection"), input_data=[x_3D, grid_3D])
816+
verify_model(GridSampler(padding_mode="border"), input_data=[x_3D, grid_3D])
817+
verify_model(GridSampler(align_corners=False), input_data=[x_3D, grid_3D])
818+
819+
753820
@tvm.testing.uses_gpu
754821
def test_forward_scatter():
755822
class Scatter(nn.Layer):
@@ -1394,6 +1461,45 @@ def slice5(inputs):
13941461
# verify_model(slice5, input_data=paddle.randn((4,)))
13951462

13961463

1464+
@tvm.testing.uses_gpu
1465+
def test_forward_unique():
1466+
class Unique(nn.Layer):
1467+
def __init__(
1468+
self,
1469+
return_index=False,
1470+
return_inverse=False,
1471+
return_counts=False,
1472+
axis=None,
1473+
dtype="int64",
1474+
):
1475+
super(Unique, self).__init__()
1476+
self.return_index = return_index
1477+
self.return_inverse = return_inverse
1478+
self.return_counts = return_counts
1479+
self.axis = None
1480+
self.dtype = dtype
1481+
1482+
@paddle.jit.to_static
1483+
def forward(self, inputs):
1484+
result = paddle.unique(
1485+
inputs,
1486+
return_inverse=self.return_inverse,
1487+
return_counts=self.return_counts,
1488+
axis=self.axis,
1489+
dtype=self.dtype,
1490+
)
1491+
return result
1492+
1493+
input_shape = [2, 3, 5]
1494+
input_data = paddle.rand(input_shape)
1495+
verify_model(Unique(), input_data=input_data)
1496+
verify_model(Unique(return_index=True), input_data=input_data)
1497+
verify_model(Unique(return_index=True, return_inverse=True), input_data=input_data)
1498+
verify_model(
1499+
Unique(return_index=True, return_inverse=True, return_counts=True), input_data=input_data
1500+
)
1501+
1502+
13971503
@tvm.testing.uses_gpu
13981504
def run_math_api(func):
13991505
api_name = func.__name__.split("_")[-1]

0 commit comments

Comments
 (0)