diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 33c4516b47f5..fa909d4df94c 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -7,7 +7,6 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/Transforms/Passes.h" @@ -18,35 +17,7 @@ namespace mlir::triton { namespace { bool isZero(Value val) { - if (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat())) - return true; - // broadcast(constant_0) - if (auto bc = val.getDefiningOp()) { - if (matchPattern(bc.getSrc(), m_Zero()) || - matchPattern(bc.getSrc(), m_AnyZeroFloat())) - return true; - } - return false; -} - -bool isBroadcastConstantCombinable(Attribute value) { - if (auto denseValue = dyn_cast(value)) { - return denseValue.isSplat(); - } - return isa(value); -} - -DenseElementsAttr getConstantValue(Builder &builder, Attribute value, - Value bcast_res) { - auto resType = cast(bcast_res.getType()); - DenseElementsAttr res; - if (auto denseValue = dyn_cast(value)) { - res = - DenseElementsAttr::get(resType, denseValue.getSplatValue()); - } else { - res = DenseElementsAttr::get(resType, value); - } - return res; + return (matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat())); } bool isAddPtrOffsetCombinable(Value first, Value second) { @@ -231,7 +202,6 @@ class CombineOpsPass : public TritonCombineOpsBase { // %} patterns.add(context); patterns.add(context); - patterns.add(context); patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index 5a2fcecfa949..e3588f587757 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -44,11 +44,4 @@ def CombineAddPtrPattern : Pat< (TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)), [(Constraint> $idx0, $idx1)]>; -// broadcast(cst) => cst -def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">; -def CombineBroadcastConstantPattern : Pat< - (TT_BroadcastOp:$bcast_res (Arith_ConstantOp $value)), - (Arith_ConstantOp (getConstantValue $value, $bcast_res), (location $bcast_res)), - [(Constraint> $value)]>; - #endif diff --git a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp index 43479a3d9f9a..486fc1c7b9da 100644 --- a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp +++ b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp @@ -206,18 +206,6 @@ struct MoveBroadcastAfterElementwisePattern } }; -template -class CanonicalizePattern : public OpRewritePattern { -public: - explicit CanonicalizePattern(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(OpType op, - PatternRewriter &rewriter) const override { - return OpType::canonicalize(op, rewriter); - } -}; - class ReorderBroadcastPass : public ::impl::TritonReorderBroadcastBase { public: @@ -226,8 +214,8 @@ class ReorderBroadcastPass RewritePatternSet patterns(context); ModuleOp m = getOperation(); - patterns.add>(context); - patterns.add>(context); + BroadcastOp::getCanonicalizationPatterns(patterns, context); + ExpandDimsOp::getCanonicalizationPatterns(patterns, context); // elementwise(broadcast(a)) => broadcast(elementwise(a)) patterns.add(context); // elementwise(splat(a), splat(b), ...) => splat(elementwise(a, b, ...)) diff --git a/test/Triton/canonicalize.mlir b/test/Triton/canonicalize.mlir index fd31e5c782fe..c425343d1954 100644 --- a/test/Triton/canonicalize.mlir +++ b/test/Triton/canonicalize.mlir @@ -134,3 +134,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked> } } // end module + +// ----- + +// CHECK-LABEL: @fold_broadcast_constant_pattern +tt.func @fold_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { + // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32> + %const = arith.constant dense<1.0> : tensor<8x1xf32> + %bst_out = tt.broadcast %const : tensor<8x1xf32> -> tensor<8x2xf32> + + // CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32> + tt.return %bst_out : tensor<8x2xf32> +} diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index 41a3ba15a8ee..ecaa60e53c7d 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -208,16 +208,6 @@ tt.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr, tensor<8xf32>, tensor<8xf32> } -// CHECK-LABEL: @test_combine_broadcast_constant_pattern -tt.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { - // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32> - %const = arith.constant dense<1.0> : tensor<8x1xf32> - %bst_out = tt.broadcast %const : tensor<8x1xf32> -> tensor<8x2xf32> - - // CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32> - tt.return %bst_out : tensor<8x2xf32> -} - // CHECK-LABEL: @test_canonicalize_masked_load_pattern tt.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { %true_mask = arith.constant dense : tensor<8xi1> diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 67dc3f2b640b..e6231a389acb 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -189,8 +189,8 @@ def make_ttir(mod, metadata, options): pm.enable_debug() passes.common.add_inliner(pm) passes.ttir.add_rewrite_tensor_pointer(pm) - passes.ttir.add_combine(pm) passes.common.add_canonicalizer(pm) + passes.ttir.add_combine(pm) passes.ttir.add_reorder_broadcast(pm) passes.common.add_cse(pm) passes.common.add_licm(pm) diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 8188e5ea09ea..6d6d70fc87e3 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -189,8 +189,8 @@ def make_ttir(mod, metadata, opt): pm.enable_debug() passes.common.add_inliner(pm) passes.ttir.add_rewrite_tensor_pointer(pm) - passes.ttir.add_combine(pm) passes.common.add_canonicalizer(pm) + passes.ttir.add_combine(pm) passes.ttir.add_reorder_broadcast(pm) passes.common.add_cse(pm) passes.common.add_licm(pm)