From e627e19adb36ca0c5ce883dffa01675353787b5f Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Tue, 12 Sep 2023 13:46:44 +0200 Subject: [PATCH 1/2] [NFC] Create explicit conversion pattern for ExternElementwiseOp in TT->TTGPU pass. This is needed for forward-compatibility with MLIR that now has "inherent" and "discardable" attributes (https://mlir.llvm.org/OpenMeetings/2023-02-09-Properties.pdf) and the ExternElementwiseOp attrs do not propagate with the current `addNamedAttrs` implementation. --- .../TritonToTritonGPUPass.cpp | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index b03d76ac4bdb..ab636fab5e21 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -600,6 +600,24 @@ struct TritonScanReturnPattern } }; +struct TritonExternElementwisePattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + triton::ExternElementwiseOp op, + typename triton::ExternElementwiseOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = this->getTypeConverter()->convertType(op.getType()); + addNamedAttrs( + rewriter.replaceOpWithNewOp( + op, retType, adaptor.getOperands(), op.getLibnameAttr(), + op.getLibpathAttr(), op.getSymbolAttr(), op.getPureAttr()), + adaptor.getAttributes()); + return success(); + } +}; + struct TritonPrintPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -692,9 +710,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonReduceReturnPattern, TritonScanPattern, TritonScanReturnPattern, TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, TritonStorePattern, - TritonGenericPattern, TritonPrintPattern, - TritonAssertPattern, TritonAtomicRMWPattern, TritonFuncOpPattern, - TritonReturnOpPattern, TritonCallOpPattern>(typeConverter, context); + TritonExternElementwisePattern, TritonPrintPattern, TritonAssertPattern, + TritonAtomicRMWPattern, TritonFuncOpPattern, TritonReturnOpPattern, + TritonCallOpPattern>(typeConverter, context); } // From 8e09ea44028767ce5ccddf5114f17850d20a090b Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Tue, 12 Sep 2023 13:46:44 +0200 Subject: [PATCH 2/2] [NFC] Create explicit conversion pattern for ExternElementwiseOp in TT->TTGPU pass. This is needed for forward-compatibility with MLIR that now has "inherent" and "discardable" attributes (https://mlir.llvm.org/OpenMeetings/2023-02-09-Properties.pdf) and the ExternElementwiseOp attrs do not propagate with the current `addNamedAttrs` implementation. --- .../TritonToTritonGPUPass.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index ab636fab5e21..de5ad6947c53 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -604,16 +604,16 @@ struct TritonExternElementwisePattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - triton::ExternElementwiseOp op, - typename triton::ExternElementwiseOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult + matchAndRewrite(triton::ExternElementwiseOp op, + typename triton::ExternElementwiseOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); - addNamedAttrs( - rewriter.replaceOpWithNewOp( - op, retType, adaptor.getOperands(), op.getLibnameAttr(), - op.getLibpathAttr(), op.getSymbolAttr(), op.getPureAttr()), - adaptor.getAttributes()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retType, adaptor.getOperands(), op.getLibnameAttr(), + op.getLibpathAttr(), op.getSymbolAttr(), + op.getPureAttr()), + adaptor.getAttributes()); return success(); } };