diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index b03d76ac4bdb..de5ad6947c53 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -600,6 +600,24 @@ struct TritonScanReturnPattern } }; +struct TritonExternElementwisePattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ExternElementwiseOp op, + typename triton::ExternElementwiseOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = this->getTypeConverter()->convertType(op.getType()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, adaptor.getOperands(), op.getLibnameAttr(), + op.getLibpathAttr(), op.getSymbolAttr(), + op.getPureAttr()), + adaptor.getAttributes()); + return success(); + } +}; + struct TritonPrintPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -692,9 +710,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonReduceReturnPattern, TritonScanPattern, TritonScanReturnPattern, TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, TritonStorePattern, - TritonGenericPattern, TritonPrintPattern, - TritonAssertPattern, TritonAtomicRMWPattern, TritonFuncOpPattern, - TritonReturnOpPattern, TritonCallOpPattern>(typeConverter, context); + TritonExternElementwisePattern, TritonPrintPattern, TritonAssertPattern, + TritonAtomicRMWPattern, TritonFuncOpPattern, TritonReturnOpPattern, + TritonCallOpPattern>(typeConverter, context); } //