Skip to content

Commit 6fa88e3

Browse files
authored
[PaddlePaddle Hackathon 4][Frontend][Paddle]add thresholded_relu/index_select/eye/linspace/take_alone_axis/dist for paddle frontend (#14172)
Add thresholded_relu/index_select/eye/linspace/take_alone_axis/dist for paddle frontend. But in paddle 2.1.3, eye/linspace/take_alone_axis are not supported. The test case has passed completely in version 2.4.2.
1 parent caf6b03 commit 6fa88e3

File tree

2 files changed

+247
-0
lines changed

2 files changed

+247
-0
lines changed

python/tvm/relay/frontend/paddlepaddle.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,30 @@ def convert_conv2d_transpose(g, op, block):
400400
g.add_node(op.output("Output")[0], out)
401401

402402

403+
def convert_dist(g, op, block):
404+
"""Operator converter for dist."""
405+
406+
x = g.get_node(op.input("X")[0])
407+
y = g.get_node(op.input("Y")[0])
408+
z = _op.abs(_op.subtract(x, y))
409+
dtype = infer_type(x).checked_type.dtype
410+
p = op.attr("p")
411+
if p == np.inf:
412+
out = _op.reduce.max(_op.abs(z))
413+
elif p == np.NINF:
414+
out = _op.reduce.min(_op.abs(z))
415+
elif p == 0.0:
416+
out = _op.reduce.sum(_op.sign(_op.abs(z)))
417+
else:
418+
inv_p = _expr.const(1.0 / p, dtype=dtype)
419+
p = _expr.const(p, dtype=dtype)
420+
power_z = _op.power(z, p)
421+
sum_pow = _op.reduce.sum(power_z)
422+
out = _op.power(sum_pow, inv_p)
423+
out = _op.full(out, shape=(1))
424+
g.add_node(op.output("Out")[0], out)
425+
426+
403427
def convert_cumsum(g, op, block):
404428
"""Operator converter for cumsum."""
405429

@@ -475,6 +499,39 @@ def convert_elementwise_op(g, op, block):
475499
g.add_node(op.output("Out")[0], out)
476500

477501

502+
def convert_linspace(g, op, block):
503+
"""Operator converter for linspace."""
504+
505+
start = g.get_node(op.input("Start")[0])
506+
stop = g.get_node(op.input("Stop")[0])
507+
num = g.get_node(op.input("Num")[0])
508+
dtype = _convert_dtype_value(op.attr("dtype"))
509+
510+
start = _op.cast(start, dtype)
511+
stop = _op.cast(stop, dtype)
512+
num = _op.cast(num, dtype)
513+
514+
if dtype in ["int32", "float32"]:
515+
tmp_dtype = "float32"
516+
else:
517+
tmp_dtype = "float64"
518+
start = _op.cast(start, tmp_dtype)
519+
stop = _op.cast(stop, tmp_dtype)
520+
num = _op.cast(num, tmp_dtype)
521+
const_one = _expr.const(1, tmp_dtype)
522+
const_zero = _expr.const(0, tmp_dtype)
523+
seg_num = _op.where(num > const_one, num - const_one, num - const_zero)
524+
seg_len = _op.subtract(stop, start)
525+
step_len = _op.divide(seg_len, seg_num)
526+
step_cnt = _op.argwhere(_op.ones(num, dtype=tmp_dtype))
527+
step_cnt = _op.cast(step_cnt, dtype=tmp_dtype)
528+
out = _op.multiply(step_len, step_cnt)
529+
out = _op.add(start, out)
530+
out = _op.squeeze(out, axis=[1])
531+
out = _op.cast(out, dtype)
532+
g.add_node(op.output("Out")[0], out)
533+
534+
478535
def convert_elu(g, op, block):
479536
"""Operator converter for elu."""
480537

@@ -514,6 +571,27 @@ def convert_expand_as(g, op, block):
514571
g.add_node(op.output("Out")[0], out)
515572

516573

574+
def convert_eye(g, op, block):
575+
"""Operator converter for eye."""
576+
577+
num_rows = op.attr("num_rows")
578+
num_columns = op.attr("num_columns")
579+
one_nums = min(num_rows, num_columns)
580+
dtype = op.attr("dtype")
581+
dtype = _convert_dtype_value(dtype)
582+
583+
zeros = _op.zeros((num_rows, num_columns), dtype)
584+
if one_nums == 0:
585+
out = zeros
586+
else:
587+
ones = _op.ones(one_nums, dtype)
588+
indices = _op.arange(
589+
_expr.const(0, dtype="int32"), _expr.const(one_nums, dtype="int32"), dtype="int32"
590+
)
591+
out = _op.scatter_nd(zeros, _op.stack([indices, indices], axis=0), ones, "update")
592+
g.add_node(op.output("Out")[0], out)
593+
594+
517595
def convert_feed(g, op, block):
518596
"""Converter for model input node."""
519597

@@ -830,6 +908,16 @@ def get_interpolate_mode(op):
830908
g.add_node(op.output("Out")[0], out)
831909

832910

911+
def convert_index_select(g, op, block):
912+
"""Operator converter for index_select."""
913+
914+
x = g.get_node(op.input("X")[0])
915+
index = g.get_node(op.input("Index")[0])
916+
axis = op.attr("dim")
917+
out = _op.transform.take(x, index, axis, mode="wrap")
918+
g.add_node(op.output("Out")[0], out)
919+
920+
833921
def convert_instance_norm(g, op, block):
834922
"""Operator converter for instance_norm."""
835923

@@ -2072,13 +2160,27 @@ def convert_swish(g, op, block):
20722160

20732161

20742162
def convert_take_along_axis(g, op, block):
2163+
"""Operator converter for take_along_axis."""
2164+
20752165
x = g.get_node(op.input("Input")[0])
20762166
idx = g.get_node(op.input("Index")[0])
20772167
axis = op.attr("Axis")
20782168
out = _op.gather(x, axis, idx)
20792169
g.add_node(op.output("Result")[0], out)
20802170

20812171

2172+
def convert_thresholded_relu(g, op, block):
2173+
"""Operator converter for thresholded_relu."""
2174+
2175+
x = g.get_node(op.input("X")[0])
2176+
dtype = infer_type(x).checked_type.dtype
2177+
threshold = op.attr("threshold")
2178+
threshold = _expr.const(threshold, dtype)
2179+
zero = _expr.const(0, dtype=dtype)
2180+
out = tvm.relay.where(x > threshold, x, zero)
2181+
g.add_node(op.output("Out")[0], out)
2182+
2183+
20822184
def convert_tile(g, op, block):
20832185
"""Operator converter for tile."""
20842186

@@ -2220,6 +2322,7 @@ def convert_where_index(g, op, block):
22202322
"cumsum": convert_cumsum,
22212323
"depthwise_conv2d": convert_conv2d,
22222324
"depthwise_conv2d_transpose": convert_conv2d_transpose,
2325+
"dist": convert_dist,
22232326
"dot": convert_dot,
22242327
"dropout": convert_dropout,
22252328
"elementwise_add": convert_elementwise_op,
@@ -2238,6 +2341,7 @@ def convert_where_index(g, op, block):
22382341
"exp": convert_unary_op,
22392342
"expand_v2": convert_expand,
22402343
"expand_as_v2": convert_expand_as,
2344+
"eye": convert_eye,
22412345
"feed": convert_feed,
22422346
"fill_any_like": convert_fill_any_like,
22432347
"fill_constant": convert_fill_constant,
@@ -2254,6 +2358,7 @@ def convert_where_index(g, op, block):
22542358
"hard_shrink": convert_hard_shrink,
22552359
"hard_sigmoid": convert_hard_sigmoid,
22562360
"hard_swish": convert_hard_swish,
2361+
"index_select": convert_index_select,
22572362
"instance_norm": convert_instance_norm,
22582363
"isfinite_v2": convert_unary_op,
22592364
"isinf_v2": convert_unary_op,
@@ -2262,6 +2367,7 @@ def convert_where_index(g, op, block):
22622367
"leaky_relu": convert_leaky_relu,
22632368
"less_equal": convert_elementwise_op,
22642369
"less_than": convert_elementwise_op,
2370+
"linspace": convert_linspace,
22652371
"log": convert_unary_op,
22662372
"log2": convert_unary_op,
22672373
"log10": convert_unary_op,
@@ -2333,6 +2439,7 @@ def convert_where_index(g, op, block):
23332439
"tan": convert_unary_op,
23342440
"tanh": convert_unary_op,
23352441
"top_k": convert_topk,
2442+
"thresholded_relu": convert_thresholded_relu,
23362443
"tile": convert_tile,
23372444
"top_k_v2": convert_topk,
23382445
"transpose2": convert_transpose,

tests/python/frontend/paddlepaddle/test_forward.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,5 +1992,145 @@ def forward(self, inputs):
19921992
verify_model(Mish(), input_data=input_data)
19931993

19941994

1995+
@tvm.testing.uses_gpu
1996+
def test_forward_thresholded_relu():
1997+
class ThresholdedRelu1(nn.Layer):
1998+
@paddle.jit.to_static
1999+
def forward(self, inputs):
2000+
return nn.functional.thresholded_relu(inputs)
2001+
2002+
class ThresholdedRelu2(nn.Layer):
2003+
@paddle.jit.to_static
2004+
def forward(self, inputs):
2005+
return nn.functional.thresholded_relu(inputs, threshold=0.5)
2006+
2007+
input_shapes = [[10], [2, 3], [5, 10, 11], [3, 4, 5, 6]]
2008+
for input_shape in input_shapes:
2009+
input_data = paddle.randn(shape=input_shape, dtype="float32")
2010+
verify_model(ThresholdedRelu1(), input_data=input_data)
2011+
verify_model(ThresholdedRelu2(), input_data=input_data)
2012+
2013+
2014+
@tvm.testing.uses_gpu
2015+
def test_forward_index_select():
2016+
class IndexSelect1(nn.Layer):
2017+
@paddle.jit.to_static
2018+
def forward(self, x, index):
2019+
return paddle.index_select(x, index, axis=0)
2020+
2021+
class IndexSelect2(nn.Layer):
2022+
@paddle.jit.to_static
2023+
def forward(self, x, index):
2024+
return paddle.index_select(x, index, axis=-1)
2025+
2026+
input_shapes = [[10], [2, 3], [5, 10, 11], [3, 4, 5, 6]]
2027+
for input_shape in input_shapes:
2028+
input_data = paddle.randn(shape=input_shape, dtype="float32")
2029+
index = paddle.to_tensor([0, 1, 1], dtype="int32")
2030+
verify_model(IndexSelect1(), input_data=[input_data, index])
2031+
verify_model(IndexSelect2(), input_data=[input_data, index])
2032+
2033+
2034+
@tvm.testing.uses_gpu
2035+
def test_forward_eye():
2036+
class Eye1(nn.Layer):
2037+
@paddle.jit.to_static
2038+
def forward(self, inputs):
2039+
return paddle.eye(3, 5, dtype="int32"), paddle.eye(3, 5, dtype="float32"), inputs
2040+
2041+
class Eye2(nn.Layer):
2042+
@paddle.jit.to_static
2043+
def forward(self, inputs):
2044+
return paddle.eye(5, 3, dtype="int64"), paddle.eye(5, 3, dtype="float64"), inputs
2045+
2046+
class Eye3(nn.Layer):
2047+
@paddle.jit.to_static
2048+
def forward(self, inputs):
2049+
return paddle.eye(0, 3, dtype="int64"), paddle.eye(0, 0, dtype="float64"), inputs
2050+
2051+
class Eye4(nn.Layer):
2052+
@paddle.jit.to_static
2053+
def forward(self, inputs):
2054+
return paddle.eye(4, None, dtype="int64"), paddle.eye(4, None, dtype="float64"), inputs
2055+
2056+
x = paddle.to_tensor([1], dtype="float32")
2057+
verify_model(Eye1(), input_data=[x])
2058+
verify_model(Eye2(), input_data=[x])
2059+
verify_model(Eye3(), input_data=[x])
2060+
verify_model(Eye4(), input_data=[x])
2061+
2062+
2063+
@tvm.testing.uses_gpu
2064+
def test_forward_linspace():
2065+
class Linspace1(nn.Layer):
2066+
@paddle.jit.to_static
2067+
def forward(self, inputs):
2068+
out1 = paddle.linspace(0.5, 7, 1, "int32")
2069+
out2 = paddle.linspace(1.3, 7.1, 5, "float32")
2070+
out3 = paddle.linspace(1, 1000000000, 10, "int64")
2071+
out4 = paddle.linspace(1, 7.1, 5, "float64")
2072+
return out1, out2, out3, out4, inputs
2073+
2074+
class Linspace2(nn.Layer):
2075+
@paddle.jit.to_static
2076+
def forward(self, inputs):
2077+
start = paddle.to_tensor([-2.5])
2078+
stop = paddle.to_tensor([31.6])
2079+
num = paddle.to_tensor([13])
2080+
start = paddle.cast(start, "float32")
2081+
stop = paddle.cast(stop, "float32")
2082+
num = paddle.cast(num, "int32")
2083+
out1 = paddle.linspace(start, stop, num, "int32")
2084+
out2 = paddle.linspace(start, stop, num, "float32")
2085+
out3 = paddle.linspace(start, stop, num, "int64")
2086+
out4 = paddle.linspace(start, stop, num, "float64")
2087+
return out1, out2, out3, out4, inputs
2088+
2089+
class Linspace3(nn.Layer):
2090+
@paddle.jit.to_static
2091+
def forward(self, start, stop, num):
2092+
out1 = paddle.linspace(start, stop, num, "int32")
2093+
out2 = paddle.linspace(start, stop, num, "float32")
2094+
out3 = paddle.linspace(start, stop, num, "int64")
2095+
out4 = paddle.linspace(start, stop, num, "float32")
2096+
return out1
2097+
2098+
start = paddle.to_tensor([1.3])
2099+
stop = paddle.to_tensor([5.1])
2100+
num = paddle.to_tensor([3])
2101+
start = paddle.cast(start, "float32")
2102+
stop = paddle.cast(stop, "float32")
2103+
num = paddle.cast(num, "int32")
2104+
x = paddle.to_tensor([1], dtype="float32")
2105+
verify_model(Linspace1(), input_data=[x])
2106+
verify_model(Linspace2(), input_data=[x])
2107+
verify_model(Linspace3(), input_data=[start, stop, num], use_vm=True)
2108+
num = paddle.to_tensor([1])
2109+
num = paddle.cast(num, "int32")
2110+
verify_model(Linspace3(), input_data=[start, stop, num], use_vm=True)
2111+
2112+
2113+
@tvm.testing.uses_gpu
2114+
def test_forward_dist():
2115+
class Dist(nn.Layer):
2116+
@paddle.jit.to_static
2117+
def forward(self, x, y):
2118+
l0_norm = paddle.dist(x, y, 0)
2119+
l2_norm = paddle.dist(x, y, 2)
2120+
float_norm = paddle.dist(x, y, 1.3)
2121+
inf_norm = paddle.dist(x, y, float("inf"))
2122+
ninf_norm = paddle.dist(x, y, float("-inf"))
2123+
return l0_norm, l2_norm, float_norm, inf_norm, ninf_norm
2124+
2125+
x = paddle.to_tensor([[3, 3], [3, 3]], dtype="float32")
2126+
y = paddle.to_tensor([[1, 2], [3, 4]], dtype="float32")
2127+
w = paddle.to_tensor([[1, 2]], dtype="float32")
2128+
v = paddle.to_tensor([[2.1]], dtype="float32")
2129+
verify_model(Dist(), input_data=[x, y])
2130+
verify_model(Dist(), input_data=[x, w])
2131+
verify_model(Dist(), input_data=[w, v])
2132+
verify_model(Dist(), input_data=[y, v])
2133+
2134+
19952135
if __name__ == "__main__":
19962136
tvm.testing.main()

0 commit comments

Comments
 (0)