diff --git a/tests/filecheck/dialects/csl/csl-canonicalize.mlir b/tests/filecheck/dialects/csl/csl-canonicalize.mlir new file mode 100644 index 0000000000..5075afe5c9 --- /dev/null +++ b/tests/filecheck/dialects/csl/csl-canonicalize.mlir @@ -0,0 +1,54 @@ +// RUN: xdsl-opt %s -p canonicalize --split-input-file | filecheck %s + + +builtin.module { +// CHECK-NEXT: builtin.module { + +%0 = arith.constant 512 : i16 +%1 = "csl.zeros"() : () -> memref<512xf32> +%2 = "csl.get_mem_dsd"(%1, %0) : (memref<512xf32>, i16) -> !csl + +%3 = arith.constant 1 : si16 +%4 = "csl.increment_dsd_offset"(%2, %3) <{"elem_type" = f32}> : (!csl, si16) -> !csl + +%5 = arith.constant 510 : ui16 +%6 = "csl.set_dsd_length"(%4, %5) : (!csl, ui16) -> !csl + +%int8 = arith.constant 1 : si8 +%7 = "csl.set_dsd_stride"(%6, %int8) : (!csl, si8) -> !csl + +"test.op"(%7) : (!csl) -> () + +// CHECK-NEXT: %0 = "csl.zeros"() : () -> memref<512xf32> +// CHECK-NEXT: %1 = arith.constant 510 : ui16 +// CHECK-NEXT: %2 = "csl.get_mem_dsd"(%0, %1) <{"offsets" = [1 : si16], "strides" = [1 : si8]}> : (memref<512xf32>, ui16) -> !csl +// CHECK-NEXT: "test.op"(%2) : (!csl) -> () + + + +%8 = "test.op"() : () -> (!csl) +%9 = arith.constant 2 : si16 +%10 = "csl.increment_dsd_offset"(%8, %9) <{"elem_type" = f32}> : (!csl, si16) -> !csl +%11 = "csl.increment_dsd_offset"(%10, %9) <{"elem_type" = f32}> : (!csl, si16) -> !csl +%12 = arith.constant 509 : ui16 +%13 = arith.constant 511 : ui16 +%14 = "csl.set_dsd_length"(%11, %12) : (!csl, ui16) -> !csl +%15 = "csl.set_dsd_length"(%14, %13) : (!csl, ui16) -> !csl +%16 = arith.constant 2 : si8 +%17 = arith.constant 3 : si8 +%18 = "csl.set_dsd_stride"(%15, %16) : (!csl, si8) -> !csl +%19 = "csl.set_dsd_stride"(%18, %17) : (!csl, si8) -> !csl +"test.op"(%19) : (!csl) -> () + +// CHECK-NEXT: %3 = "test.op"() : () -> !csl +// CHECK-NEXT: %4 = arith.constant 2 : si16 +// CHECK-NEXT: %5 = arith.addi %4, %4 : si16 +// CHECK-NEXT: %6 = "csl.increment_dsd_offset"(%3, %5) <{"elem_type" = f32}> : (!csl, si16) -> !csl +// CHECK-NEXT: %7 = arith.constant 511 : ui16 +// CHECK-NEXT: %8 = "csl.set_dsd_length"(%6, %7) : (!csl, ui16) -> !csl +// CHECK-NEXT: %9 = arith.constant 3 : si8 +// CHECK-NEXT: %10 = "csl.set_dsd_stride"(%8, %9) : (!csl, si8) -> !csl +// CHECK-NEXT: "test.op"(%10) : (!csl) -> () + +} +// CHECK-NEXT: } diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index a9897f246b..ced6af4ed9 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -67,15 +67,18 @@ var_operand_def, ) from xdsl.parser import Parser +from xdsl.pattern_rewriter import RewritePattern from xdsl.printer import Printer from xdsl.traits import ( HasAncestor, + HasCanonicalizationPatternsTrait, HasParent, IsolatedFromAbove, IsTerminator, NoMemoryEffect, NoTerminator, OpTrait, + Pure, SymbolOpInterface, ) from xdsl.utils.exceptions import VerifyException @@ -1006,6 +1009,54 @@ def __init__( super().__init__(operands=[x_coord, y_coord, params], properties={"file": name}) +class DsdOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait): + @classmethod + def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: + from xdsl.transforms.canonicalization_patterns.csl import ( + GetDsdAndLengthFolding, + GetDsdAndOffsetFolding, + GetDsdAndStrideFolding, + ) + + return ( + GetDsdAndOffsetFolding(), + GetDsdAndLengthFolding(), + GetDsdAndStrideFolding(), + ) + + +class IncrementDsdOffsetOpHasCanonicalizationPatternsTrait( + HasCanonicalizationPatternsTrait +): + @classmethod + def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: + from xdsl.transforms.canonicalization_patterns.csl import ( + ChainedDsdOffsetFolding, + ) + + return (ChainedDsdOffsetFolding(),) + + +class SetDsdLengthOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait): + @classmethod + def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: + from xdsl.transforms.canonicalization_patterns.csl import ( + ChainedDsdLengthFolding, + ) + + return (ChainedDsdLengthFolding(),) + + +class SetDsdStrideOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait): + @classmethod + def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: + from xdsl.transforms.canonicalization_patterns.csl import ( + ChainedDsdStrideFolding, + ) + + return (ChainedDsdStrideFolding(),) + + class _GetDsdOp(IRDLOperation, ABC): """ Abstract base class for CSL @get_dsd() @@ -1030,6 +1081,13 @@ class GetMemDsdOp(_GetDsdOp): offsets = opt_prop_def(ArrayAttr[AnyIntegerAttr]) strides = opt_prop_def(ArrayAttr[AnyIntegerAttr]) + traits = frozenset( + [ + Pure(), + DsdOpHasCanonicalizationPatternsTrait(), + ] + ) + def verify_(self) -> None: if not isinstance(self.result.type, DsdType): raise VerifyException("DSD type is not DsdType") @@ -1107,6 +1165,8 @@ class SetDsdBaseAddrOp(IRDLOperation): ) result = result_def(DsdType) + traits = frozenset([Pure()]) + def verify_(self) -> None: if ( not isinstance(self.result.type, DsdType) @@ -1145,6 +1205,8 @@ class IncrementDsdOffsetOp(IRDLOperation): elem_type = prop_def(DsdElementTypeConstr) result = result_def(DsdType) + traits = frozenset([Pure(), IncrementDsdOffsetOpHasCanonicalizationPatternsTrait()]) + def verify_(self) -> None: if ( not isinstance(self.result.type, DsdType) @@ -1170,6 +1232,8 @@ class SetDsdLengthOp(IRDLOperation): length = operand_def(u16_value) result = result_def(DsdType) + traits = frozenset([Pure(), SetDsdLengthOpHasCanonicalizationPatternsTrait()]) + def verify_(self) -> None: if ( not isinstance(self.result.type, DsdType) @@ -1196,6 +1260,8 @@ class SetDsdStrideOp(IRDLOperation): stride = operand_def(IntegerType(8, Signedness.SIGNED)) result = result_def(DsdType) + traits = frozenset([Pure(), SetDsdStrideOpHasCanonicalizationPatternsTrait()]) + def verify_(self) -> None: if ( not isinstance(self.result.type, DsdType) diff --git a/xdsl/transforms/canonicalization_patterns/csl.py b/xdsl/transforms/canonicalization_patterns/csl.py new file mode 100644 index 0000000000..90a92409f5 --- /dev/null +++ b/xdsl/transforms/canonicalization_patterns/csl.py @@ -0,0 +1,184 @@ +from xdsl.dialects import arith +from xdsl.dialects.builtin import ArrayAttr, IntegerAttr +from xdsl.dialects.csl import csl +from xdsl.ir import OpResult +from xdsl.pattern_rewriter import ( + PatternRewriter, + RewritePattern, + op_type_rewrite_pattern, +) +from xdsl.utils.hints import isa + + +class GetDsdAndOffsetFolding(RewritePattern): + """ + Folds a `csl.get_mem_dsd` immediately followed by a `csl.increment_dsd_offset` + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter) -> None: + # single use that is `@increment_dsd_offset` + if len(op.result.uses) != 1 or not isinstance( + offset_op := next(iter(op.result.uses)).operation, csl.IncrementDsdOffsetOp + ): + return + # only works on 1d + if op.offsets and len(op.offsets) > 1: + return + + # check if we can promote arith.const to property + if ( + isinstance(offset_op.offset, OpResult) + and isinstance(cnst := offset_op.offset.op, arith.Constant) + and isa(cnst.value, IntegerAttr) + ): + rewriter.replace_matched_op( + new_op := csl.GetMemDsdOp.build( + operands=[op.base_addr, op.sizes], + result_types=op.result_types, + properties={**op.properties, "offsets": ArrayAttr([cnst.value])}, + ) + ) + rewriter.replace_op(offset_op, [], new_results=[new_op.result]) + + +class GetDsdAndLengthFolding(RewritePattern): + """ + Folds a `csl.get_mem_dsd` immediately followed by a `csl.set_dsd_length` + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter) -> None: + # single use that is `@set_dsd_length` + if len(op.result.uses) != 1 or not isinstance( + size_op := next(iter(op.result.uses)).operation, csl.SetDsdLengthOp + ): + return + # only works on 1d + if len(op.sizes) > 1: + return + + rewriter.replace_op( + size_op, + csl.GetMemDsdOp.build( + operands=[op.base_addr, [size_op.length]], + result_types=op.result_types, + properties=op.properties.copy(), + ), + ) + rewriter.erase_matched_op() + + +class GetDsdAndStrideFolding(RewritePattern): + """ + Folds a `csl.get_mem_dsd` immediately followed by a `csl.set_dsd_stride` + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter) -> None: + # single use that is `@set_dsd_stride` + if len(op.result.uses) != 1 or not isinstance( + stride_op := next(iter(op.result.uses)).operation, csl.SetDsdStrideOp + ): + return + # only works on 1d + if op.offsets and len(op.offsets) > 1: + return + + # check if we can promote arith.const to property + if ( + isinstance(stride_op.stride, OpResult) + and isinstance(cnst := stride_op.stride.op, arith.Constant) + and isa(cnst.value, IntegerAttr) + ): + rewriter.replace_matched_op( + new_op := csl.GetMemDsdOp.build( + operands=[op.base_addr, op.sizes], + result_types=op.result_types, + properties={**op.properties, "strides": ArrayAttr([cnst.value])}, + ) + ) + rewriter.replace_op(stride_op, [], new_results=[new_op.result]) + + +class ChainedDsdOffsetFolding(RewritePattern): + """ + Folds a chain of `csl.increment_dsd_offset` + """ + + @op_type_rewrite_pattern + def match_and_rewrite( + self, op: csl.IncrementDsdOffsetOp, rewriter: PatternRewriter + ) -> None: + # single use that is `@increment_dsd_offset` + if len(op.result.uses) != 1 or not isinstance( + next_op := next(iter(op.result.uses)).operation, csl.IncrementDsdOffsetOp + ): + return + + # check if we can promote arith.const to property + if op.elem_type == next_op.elem_type: + rewriter.replace_op( + next_op, + [ + new_offset := arith.Addi(op.offset, next_op.offset), + csl.IncrementDsdOffsetOp( + operands=[op.op, new_offset.result], + properties=op.properties.copy(), + result_types=op.result_types, + ), + ], + ) + rewriter.erase_matched_op() + + +class ChainedDsdLengthFolding(RewritePattern): + """ + Folds a chain of `csl.set_dsd_length` + """ + + @op_type_rewrite_pattern + def match_and_rewrite( + self, op: csl.SetDsdLengthOp, rewriter: PatternRewriter + ) -> None: + # single use that is `@set_dsd_length` + if len(op.result.uses) != 1 or not isinstance( + next_op := next(iter(op.result.uses)).operation, csl.SetDsdLengthOp + ): + return + + # check if we can promote arith.const to property + rewriter.replace_matched_op( + rebuilt := csl.SetDsdLengthOp( + operands=[op.op, next_op.length], + properties=op.properties.copy(), + result_types=op.result_types, + ), + ) + rewriter.replace_op(next_op, [], new_results=[rebuilt.result]) + + +class ChainedDsdStrideFolding(RewritePattern): + """ + Folds a chain of `csl.set_dsd_stride` + """ + + @op_type_rewrite_pattern + def match_and_rewrite( + self, op: csl.SetDsdStrideOp, rewriter: PatternRewriter + ) -> None: + # single use that is `@set_dsd_stride` + if len(op.result.uses) != 1 or not isinstance( + next_op := next(iter(op.result.uses)).operation, csl.SetDsdStrideOp + ): + return + + # check if we can promote arith.const to property + rewriter.replace_matched_op( + rebuilt := csl.SetDsdStrideOp( + operands=[op.op, next_op.stride], + properties=op.properties.copy(), + result_types=op.result_types, + ) + ) + rewriter.replace_op(next_op, [], new_results=[rebuilt.result])