@@ -94,7 +94,24 @@ DEFINE_GENERAL_PATTERN(Flip, paddle::dialect::FlipOp)
9494DEFINE_GENERAL_PATTERN (Mish, paddle::dialect::MishOp)
9595DEFINE_GENERAL_PATTERN (AssignValue, paddle::dialect::AssignValueOp)
9696DEFINE_GENERAL_PATTERN (AssignValue_, paddle::dialect::AssignValue_Op)
97+ DEFINE_GENERAL_PATTERN (Exp, paddle::dialect::ExpOp)
98+ DEFINE_GENERAL_PATTERN (Abs, paddle::dialect::AbsOp)
99+ DEFINE_GENERAL_PATTERN (Abs_, paddle::dialect::Abs_Op)
100+ DEFINE_GENERAL_PATTERN (Sin, paddle::dialect::SinOp)
101+ DEFINE_GENERAL_PATTERN (Cos, paddle::dialect::CosOp)
102+ DEFINE_GENERAL_PATTERN (Sinh, paddle::dialect::SinhOp)
103+ DEFINE_GENERAL_PATTERN (Cosh, paddle::dialect::CoshOp)
104+ DEFINE_GENERAL_PATTERN (Asinh, paddle::dialect::AsinhOp)
105+ DEFINE_GENERAL_PATTERN (Acosh, paddle::dialect::AcoshOp)
106+ DEFINE_GENERAL_PATTERN (Atanh, paddle::dialect::AtanhOp)
107+ DEFINE_GENERAL_PATTERN (Ceil, paddle::dialect::CeilOp)
108+ DEFINE_GENERAL_PATTERN (Rsqrt, paddle::dialect::RsqrtOp)
109+ DEFINE_GENERAL_PATTERN (Reciprocal, paddle::dialect::ReciprocalOp)
110+ DEFINE_GENERAL_PATTERN (Erf, paddle::dialect::ErfOp)
111+ DEFINE_GENERAL_PATTERN (Sign, paddle::dialect::SignOp)
112+ DEFINE_GENERAL_PATTERN (Round, paddle::dialect::RoundOp)
97113DEFINE_GENERAL_PATTERN (Numel, paddle::dialect::NumelOp)
114+
98115#undef DEFINE_GENERAL_PATTERN
99116
100117// Add ReduceCommonOpPattern base class to simplify code
@@ -267,8 +284,30 @@ class ActOpPattern : public pir::OpRewritePattern<OpType> {
267284using TanhOpPattern = ActOpPattern<paddle::dialect::TanhOp>;
268285using CeluOpPattern = ActOpPattern<paddle::dialect::CeluOp>;
269286using TanhShrinkOpPattern = ActOpPattern<paddle::dialect::TanhShrinkOp>;
270- using LogicalNotOpPattern = ActOpPattern<paddle::dialect::LogicalNotOp>;
271- using LogicalNot_OpPattern = ActOpPattern<paddle::dialect::LogicalNot_Op>;
287+
288+ template <typename OpType>
289+ class Logical_NotOpPattern : public pir ::OpRewritePattern<OpType> {
290+ public:
291+ using pir::OpRewritePattern<OpType>::OpRewritePattern;
292+ bool MatchAndRewrite (OpType op,
293+ pir::PatternRewriter &rewriter) const override {
294+ if (op->HasAttribute (kCanRunTrtAttr ) &&
295+ op->template attribute <pir::BoolAttribute>(kCanRunTrtAttr ).data ()) {
296+ return false ;
297+ }
298+ pir::Value x = op.operand_source (0 );
299+ auto x_dtype = pir::GetDataTypeFromValue (x);
300+ if (!x_dtype.isa <pir::BoolType>()) {
301+ VLOG (3 ) << " logical_not op only support bool input in tensorrt." ;
302+ return false ;
303+ }
304+ op->set_attribute (kCanRunTrtAttr , rewriter.bool_attr (true ));
305+ return true ;
306+ }
307+ };
308+ using LogicalNotOpPattern = Logical_NotOpPattern<paddle::dialect::LogicalNotOp>;
309+ using LogicalNot_OpPattern =
310+ Logical_NotOpPattern<paddle::dialect::LogicalNot_Op>;
272311
273312class Pool2dOpPattern
274313 : public pir::OpRewritePattern<paddle::dialect::Pool2dOp> {
@@ -538,24 +577,6 @@ class ArangeOpPattern
538577 }
539578};
540579
541- class SignOpPattern : public pir ::OpRewritePattern<paddle::dialect::SignOp> {
542- public:
543- using pir::OpRewritePattern<paddle::dialect::SignOp>::OpRewritePattern;
544- bool MatchAndRewrite (paddle::dialect::SignOp op,
545- pir::PatternRewriter &rewriter) const override {
546- if (op->HasAttribute (kCanRunTrtAttr ) &&
547- op->attribute <pir::BoolAttribute>(kCanRunTrtAttr ).data ()) {
548- return false ;
549- }
550- #if IS_TRT_VERSION_LT(8200)
551- VLOG (3 ) << " sign op is only supported by tensorrt8.2 above " ;
552- return false ;
553- #endif
554- op->set_attribute (kCanRunTrtAttr , rewriter.bool_attr (true ));
555- return true ;
556- }
557- };
558-
559580class GroupNormOpPattern
560581 : public pir::OpRewritePattern<paddle::dialect::GroupNormOp> {
561582 public:
@@ -2273,6 +2294,23 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
22732294 ADD_PATTERN (Mish)
22742295 ADD_PATTERN (AssignValue)
22752296 ADD_PATTERN (AssignValue_)
2297+ ADD_PATTERN (Exp)
2298+ ADD_PATTERN (Abs)
2299+ ADD_PATTERN (Abs_)
2300+ ADD_PATTERN (Cos)
2301+ ADD_PATTERN (Sin)
2302+ ADD_PATTERN (Cos)
2303+ ADD_PATTERN (Sinh)
2304+ ADD_PATTERN (Cosh)
2305+ ADD_PATTERN (Asinh)
2306+ ADD_PATTERN (Acosh)
2307+ ADD_PATTERN (Atanh)
2308+ ADD_PATTERN (Ceil)
2309+ ADD_PATTERN (Rsqrt)
2310+ ADD_PATTERN (Reciprocal)
2311+ ADD_PATTERN (Erf)
2312+ ADD_PATTERN (Sign)
2313+ ADD_PATTERN (Round)
22762314 ADD_PATTERN (Numel)
22772315#if IS_TRT_VERSION_GE(8600)
22782316 ADD_PATTERN (Layer_norm)
@@ -2283,7 +2321,6 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
22832321 ps.Add (std::make_unique<DepthwiseConv2dTransposeOpPattern>(context));
22842322 ps.Add (std::make_unique<DeformableConvOpPattern>(context));
22852323 ps.Add (std::make_unique<ArangeOpPattern>(context));
2286- ps.Add (std::make_unique<SignOpPattern>(context));
22872324 ps.Add (std::make_unique<LogicalNotOpPattern>(context));
22882325 ps.Add (std::make_unique<BitwiseAndOpPattern>(context));
22892326 ps.Add (std::make_unique<BitwiseOrOpPattern>(context));
0 commit comments