Skip to content

Commit 29c3d91

Browse files
authored
[HEU][Paddle TensorRT No.69-72,74-85] Add UnaryOp converter (#70535)
* add converter * fix * add marker * fix * fix
1 parent e8c33cd commit 29c3d91

File tree

5 files changed

+397
-37
lines changed

5 files changed

+397
-37
lines changed

paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,24 @@ DEFINE_GENERAL_PATTERN(Flip, paddle::dialect::FlipOp)
9494
DEFINE_GENERAL_PATTERN(Mish, paddle::dialect::MishOp)
9595
DEFINE_GENERAL_PATTERN(AssignValue, paddle::dialect::AssignValueOp)
9696
DEFINE_GENERAL_PATTERN(AssignValue_, paddle::dialect::AssignValue_Op)
97+
DEFINE_GENERAL_PATTERN(Exp, paddle::dialect::ExpOp)
98+
DEFINE_GENERAL_PATTERN(Abs, paddle::dialect::AbsOp)
99+
DEFINE_GENERAL_PATTERN(Abs_, paddle::dialect::Abs_Op)
100+
DEFINE_GENERAL_PATTERN(Sin, paddle::dialect::SinOp)
101+
DEFINE_GENERAL_PATTERN(Cos, paddle::dialect::CosOp)
102+
DEFINE_GENERAL_PATTERN(Sinh, paddle::dialect::SinhOp)
103+
DEFINE_GENERAL_PATTERN(Cosh, paddle::dialect::CoshOp)
104+
DEFINE_GENERAL_PATTERN(Asinh, paddle::dialect::AsinhOp)
105+
DEFINE_GENERAL_PATTERN(Acosh, paddle::dialect::AcoshOp)
106+
DEFINE_GENERAL_PATTERN(Atanh, paddle::dialect::AtanhOp)
107+
DEFINE_GENERAL_PATTERN(Ceil, paddle::dialect::CeilOp)
108+
DEFINE_GENERAL_PATTERN(Rsqrt, paddle::dialect::RsqrtOp)
109+
DEFINE_GENERAL_PATTERN(Reciprocal, paddle::dialect::ReciprocalOp)
110+
DEFINE_GENERAL_PATTERN(Erf, paddle::dialect::ErfOp)
111+
DEFINE_GENERAL_PATTERN(Sign, paddle::dialect::SignOp)
112+
DEFINE_GENERAL_PATTERN(Round, paddle::dialect::RoundOp)
97113
DEFINE_GENERAL_PATTERN(Numel, paddle::dialect::NumelOp)
114+
98115
#undef DEFINE_GENERAL_PATTERN
99116

100117
// Add ReduceCommonOpPattern base class to simplify code
@@ -267,8 +284,30 @@ class ActOpPattern : public pir::OpRewritePattern<OpType> {
267284
using TanhOpPattern = ActOpPattern<paddle::dialect::TanhOp>;
268285
using CeluOpPattern = ActOpPattern<paddle::dialect::CeluOp>;
269286
using TanhShrinkOpPattern = ActOpPattern<paddle::dialect::TanhShrinkOp>;
270-
using LogicalNotOpPattern = ActOpPattern<paddle::dialect::LogicalNotOp>;
271-
using LogicalNot_OpPattern = ActOpPattern<paddle::dialect::LogicalNot_Op>;
287+
288+
template <typename OpType>
289+
class Logical_NotOpPattern : public pir::OpRewritePattern<OpType> {
290+
public:
291+
using pir::OpRewritePattern<OpType>::OpRewritePattern;
292+
bool MatchAndRewrite(OpType op,
293+
pir::PatternRewriter &rewriter) const override {
294+
if (op->HasAttribute(kCanRunTrtAttr) &&
295+
op->template attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
296+
return false;
297+
}
298+
pir::Value x = op.operand_source(0);
299+
auto x_dtype = pir::GetDataTypeFromValue(x);
300+
if (!x_dtype.isa<pir::BoolType>()) {
301+
VLOG(3) << " logical_not op only support bool input in tensorrt.";
302+
return false;
303+
}
304+
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
305+
return true;
306+
}
307+
};
308+
using LogicalNotOpPattern = Logical_NotOpPattern<paddle::dialect::LogicalNotOp>;
309+
using LogicalNot_OpPattern =
310+
Logical_NotOpPattern<paddle::dialect::LogicalNot_Op>;
272311

273312
class Pool2dOpPattern
274313
: public pir::OpRewritePattern<paddle::dialect::Pool2dOp> {
@@ -538,24 +577,6 @@ class ArangeOpPattern
538577
}
539578
};
540579

541-
class SignOpPattern : public pir::OpRewritePattern<paddle::dialect::SignOp> {
542-
public:
543-
using pir::OpRewritePattern<paddle::dialect::SignOp>::OpRewritePattern;
544-
bool MatchAndRewrite(paddle::dialect::SignOp op,
545-
pir::PatternRewriter &rewriter) const override {
546-
if (op->HasAttribute(kCanRunTrtAttr) &&
547-
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
548-
return false;
549-
}
550-
#if IS_TRT_VERSION_LT(8200)
551-
VLOG(3) << "sign op is only supported by tensorrt8.2 above ";
552-
return false;
553-
#endif
554-
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
555-
return true;
556-
}
557-
};
558-
559580
class GroupNormOpPattern
560581
: public pir::OpRewritePattern<paddle::dialect::GroupNormOp> {
561582
public:
@@ -2273,6 +2294,23 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
22732294
ADD_PATTERN(Mish)
22742295
ADD_PATTERN(AssignValue)
22752296
ADD_PATTERN(AssignValue_)
2297+
ADD_PATTERN(Exp)
2298+
ADD_PATTERN(Abs)
2299+
ADD_PATTERN(Abs_)
2300+
ADD_PATTERN(Cos)
2301+
ADD_PATTERN(Sin)
2302+
ADD_PATTERN(Cos)
2303+
ADD_PATTERN(Sinh)
2304+
ADD_PATTERN(Cosh)
2305+
ADD_PATTERN(Asinh)
2306+
ADD_PATTERN(Acosh)
2307+
ADD_PATTERN(Atanh)
2308+
ADD_PATTERN(Ceil)
2309+
ADD_PATTERN(Rsqrt)
2310+
ADD_PATTERN(Reciprocal)
2311+
ADD_PATTERN(Erf)
2312+
ADD_PATTERN(Sign)
2313+
ADD_PATTERN(Round)
22762314
ADD_PATTERN(Numel)
22772315
#if IS_TRT_VERSION_GE(8600)
22782316
ADD_PATTERN(Layer_norm)
@@ -2283,7 +2321,6 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
22832321
ps.Add(std::make_unique<DepthwiseConv2dTransposeOpPattern>(context));
22842322
ps.Add(std::make_unique<DeformableConvOpPattern>(context));
22852323
ps.Add(std::make_unique<ArangeOpPattern>(context));
2286-
ps.Add(std::make_unique<SignOpPattern>(context));
22872324
ps.Add(std::make_unique<LogicalNotOpPattern>(context));
22882325
ps.Add(std::make_unique<BitwiseAndOpPattern>(context));
22892326
ps.Add(std::make_unique<BitwiseOrOpPattern>(context));

python/paddle/tensorrt/converter_utils.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,29 @@ def squeeze_trt(network, input_tensor, axes):
686686
def unary_op_converter(network, paddle_op, inputs):
687687
from paddle.tensorrt import PrecisionMode
688688

689+
ops_type_map = {
690+
"pd_op.sqrt": [trt.UnaryOperation.SQRT],
691+
"pd_op.sqrt_": [trt.UnaryOperation.SQRT],
692+
"pd_op.floor": [trt.UnaryOperation.FLOOR],
693+
"pd_op.exp": [trt.UnaryOperation.EXP],
694+
"pd_op.abs": [trt.UnaryOperation.ABS],
695+
"pd_op.abs_": [trt.UnaryOperation.ABS],
696+
"pd_op.sin": [trt.UnaryOperation.SIN],
697+
"pd_op.cos": [trt.UnaryOperation.COS],
698+
"pd_op.sinh": [trt.UnaryOperation.SINH],
699+
"pd_op.cosh": [trt.UnaryOperation.COSH],
700+
"pd_op.asinh": [trt.UnaryOperation.ASINH],
701+
"pd_op.acosh": [trt.UnaryOperation.ACOSH],
702+
"pd_op.atanh": [trt.UnaryOperation.ATANH],
703+
"pd_op.ceil": [trt.UnaryOperation.CEIL],
704+
"pd_op.reciprocal": [trt.UnaryOperation.RECIP],
705+
"pd_op.erf": [trt.UnaryOperation.ERF],
706+
"pd_op.sign": [trt.UnaryOperation.SIGN],
707+
"pd_op.round": [trt.UnaryOperation.ROUND],
708+
"pd_op.logical_not": [trt.UnaryOperation.NOT],
709+
"pd_op.rsqrt": [trt.UnaryOperation.SQRT, trt.UnaryOperation.RECIP],
710+
}
711+
689712
input_tensor = inputs[0]
690713
layer = None
691714
org_type = input_tensor.dtype
@@ -707,9 +730,10 @@ def unary_op_converter(network, paddle_op, inputs):
707730
identity_layer.set_output_type(0, trt.float16)
708731
input_tensor = identity_layer.get_output(0)
709732

710-
if paddle_op.name() in ["pd_op.logical_not", "pd_op.logical_not_"]:
711-
layer = network.add_unary(input_tensor, trt.UnaryOperation.NOT)
712-
input_tensor = layer.get_output(0)
733+
if paddle_op.name() in ops_type_map:
734+
for trt_op in ops_type_map[paddle_op.name()]:
735+
layer = network.add_unary(input_tensor, trt_op)
736+
input_tensor = layer.get_output(0)
713737
else:
714738
raise NotImplementedError(
715739
f"Unsupported unary operation: {paddle_op.name()}"

python/paddle/tensorrt/impls/ops.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,32 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import tensorrt as trt
1514

15+
from paddle.tensorrt.converter_utils import unary_op_converter
1616
from paddle.tensorrt.register import converter_registry
1717

18-
ops_type_map = {
19-
"pd_op.sqrt": trt.UnaryOperation.SQRT,
20-
"pd_op.sqrt_": trt.UnaryOperation.SQRT,
21-
"pd_op.floor": trt.UnaryOperation.FLOOR,
22-
}
23-
2418

2519
@converter_registry.register("pd_op.sqrt", trt_version="trt_version_ge=8.0")
2620
@converter_registry.register("pd_op.sqrt_", trt_version="trt_version_ge=8.0")
27-
@converter_registry.register("pd_op.floor", trt_version="8.x")
28-
def sqrt_converter(network, paddle_op, inputs):
29-
input_tensor = inputs[0]
30-
layer = network.add_unary(input_tensor, ops_type_map[paddle_op.name()])
31-
return layer.get_output(0)
21+
@converter_registry.register("pd_op.floor", trt_version="trt_version_ge=8.0")
22+
@converter_registry.register("pd_op.exp", trt_version="trt_version_ge=8.0")
23+
@converter_registry.register("pd_op.abs", trt_version="trt_version_ge=8.0")
24+
@converter_registry.register("pd_op.abs_", trt_version="trt_version_ge=8.0")
25+
@converter_registry.register("pd_op.sin", trt_version="trt_version_ge=8.0")
26+
@converter_registry.register("pd_op.cos", trt_version="trt_version_ge=8.0")
27+
@converter_registry.register("pd_op.sinh", trt_version="trt_version_ge=8.0")
28+
@converter_registry.register("pd_op.cosh", trt_version="trt_version_ge=8.0")
29+
@converter_registry.register("pd_op.asinh", trt_version="trt_version_ge=8.0")
30+
@converter_registry.register("pd_op.acosh", trt_version="trt_version_ge=8.0")
31+
@converter_registry.register("pd_op.atanh", trt_version="trt_version_ge=8.0")
32+
@converter_registry.register("pd_op.ceil", trt_version="trt_version_ge=8.0")
33+
@converter_registry.register(
34+
"pd_op.reciprocal", trt_version="trt_version_ge=8.0"
35+
)
36+
@converter_registry.register("pd_op.erf", trt_version="trt_version_ge=8.0")
37+
@converter_registry.register("pd_op.rsqrt", trt_version="trt_version_ge=8.0")
38+
@converter_registry.register("pd_op.sign", trt_version="trt_version_ge=8.2")
39+
@converter_registry.register("pd_op.round", trt_version="trt_version_ge=8.2")
40+
def UnaryOpConverter(network, paddle_op, inputs):
41+
layer_output = unary_op_converter(network, paddle_op, inputs)
42+
return layer_output

test/tensorrt/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ if(NOT WIN32 AND TENSORRT_FOUND)
1414
set_tests_properties(test_converter_conv PROPERTIES TIMEOUT "300")
1515
set_tests_properties(test_export PROPERTIES TIMEOUT "500")
1616
set_tests_properties(test_converter_norm PROPERTIES TIMEOUT "300")
17-
set_tests_properties(test_converter_ops PROPERTIES TIMEOUT "300")
17+
set_tests_properties(test_converter_ops PROPERTIES TIMEOUT "500")
1818
set_tests_properties(test_converter_stat PROPERTIES TIMEOUT "300")
1919
set_tests_properties(test_converter_math PROPERTIES TIMEOUT "300")
2020
set_tests_properties(test_converter_activation PROPERTIES TIMEOUT "300")

0 commit comments

Comments
 (0)