diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 3d97b4e985bf..a7b4926d8aa0 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -230,6 +230,54 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, state.addTypes({resultType}); } +// load(ptr, splat(1), ...) -> load(ptr, ...) +// load(ptr, splat(0), other, ...) -> other +struct CanonicalizeMaskedLoadPattern + : public mlir::OpRewritePattern { + CanonicalizeMaskedLoadPattern(mlir::MLIRContext *context) + : OpRewritePattern(context, 1) {} + + mlir::LogicalResult + matchAndRewrite(triton::LoadOp loadOp, + mlir::PatternRewriter &rewriter) const override { + auto mask = loadOp.getMask(); + if (!mask) + return mlir::failure(); + + auto constantMask = + llvm::dyn_cast_or_null(mask.getDefiningOp()); + if (!constantMask) + return mlir::failure(); + + auto splatMask = constantMask.getValue().dyn_cast(); + if (!splatMask) + return mlir::failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(), + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + } else { + // mask = splat(0) + + // If there's no "other", the value is "undef". Perhaps we want to + // optimize it in the future.x + auto otherVal = loadOp.getOther(); + if (!otherVal) + return mlir::failure(); + rewriter.replaceOp(loadOp, otherVal); + } + return mlir::success(); + } +}; + +void triton::LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //-- StoreOp -- void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value, @@ -257,6 +305,47 @@ void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, evict); } +// store(ptr, value, splat(1), ...) -> store(ptr, value, ...) +// store(ptr, value, splat(0), ...) -> [none] +struct CanonicalizeMaskedStorePattern + : public mlir::OpRewritePattern { + CanonicalizeMaskedStorePattern(mlir::MLIRContext *context) + : OpRewritePattern(context, 1) {} + + mlir::LogicalResult + matchAndRewrite(triton::StoreOp storeOp, + mlir::PatternRewriter &rewriter) const override { + auto mask = storeOp.getMask(); + if (!mask) + return mlir::failure(); + + auto constantMask = + llvm::dyn_cast_or_null(mask.getDefiningOp()); + if (!constantMask) + return mlir::failure(); + + auto splatMask = constantMask.getValue().dyn_cast(); + if (!splatMask) + return mlir::failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(), + storeOp.getEvict()); + } else { + // mask = splat(0) + rewriter.eraseOp(storeOp); + } + return mlir::success(); + } +}; + +void triton::StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //-- TransOp -- mlir::LogicalResult mlir::triton::TransOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 850182366b79..d9c4356a94a7 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -101,95 +101,6 @@ class CombineSelectMaskedLoadPattern : public mlir::RewritePattern { } }; -// load(ptr, splat(1), ...) -> load(ptr, ...) -// load(ptr, splat(0), other, ...) -> other -struct CanonicalizeMaskedLoadPattern - : public mlir::OpRewritePattern { - CanonicalizeMaskedLoadPattern(mlir::MLIRContext *context) - : OpRewritePattern(context, 1) {} - - mlir::LogicalResult - matchAndRewrite(triton::LoadOp loadOp, - mlir::PatternRewriter &rewriter) const override { - auto mask = loadOp.getMask(); - if (!mask) - return mlir::failure(); - - auto constantMask = - llvm::dyn_cast_or_null(mask.getDefiningOp()); - if (!constantMask) - return mlir::failure(); - - auto splatMask = constantMask.getValue().dyn_cast(); - if (!splatMask) - return mlir::failure(); - - if (splatMask.getSplatValue().getValue() == true) { - // mask = splat(1) - rewriter.replaceOpWithNewOp( - loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(), - loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); - } else { - // mask = splat(0) - - // If there's no "other", the value is "undef". Perhaps we want to - // optimize it in the future.x - auto otherVal = loadOp.getOther(); - if (!otherVal) - return mlir::failure(); - rewriter.replaceOp(loadOp, otherVal); - } - return mlir::success(); - } -}; - -void triton::LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -// store(ptr, value, splat(1), ...) -> store(ptr, value, ...) -// store(ptr, value, splat(0), ...) -> [none] -struct CanonicalizeMaskedStorePattern - : public mlir::OpRewritePattern { - CanonicalizeMaskedStorePattern(mlir::MLIRContext *context) - : OpRewritePattern(context, 1) {} - - mlir::LogicalResult - matchAndRewrite(triton::StoreOp storeOp, - mlir::PatternRewriter &rewriter) const override { - auto mask = storeOp.getMask(); - if (!mask) - return mlir::failure(); - - auto constantMask = - llvm::dyn_cast_or_null(mask.getDefiningOp()); - if (!constantMask) - return mlir::failure(); - - auto splatMask = constantMask.getValue().dyn_cast(); - if (!splatMask) - return mlir::failure(); - - if (splatMask.getSplatValue().getValue() == true) { - // mask = splat(1) - rewriter.replaceOpWithNewOp( - storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(), - storeOp.getEvict()); - } else { - // mask = splat(0) - rewriter.eraseOp(storeOp); - } - return mlir::success(); - } -}; - -void triton::StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - #define GEN_PASS_CLASSES #include "triton/Dialect/Triton/Transforms/Passes.h.inc"