diff --git a/include/triton/Dialect/Triton/IR/Traits.h b/include/triton/Dialect/Triton/IR/Traits.h index 804b1648e943..dbbf876cb513 100644 --- a/include/triton/Dialect/Triton/IR/Traits.h +++ b/include/triton/Dialect/Triton/IR/Traits.h @@ -3,6 +3,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Support/LogicalResult.h" #include "triton/Dialect/Triton/IR/Types.h" @@ -27,7 +28,7 @@ LogicalResult verifyTensorLayouts(Operation *op); LogicalResult verifySameOperandsEncoding(Operation *op, bool allowTensorPointerType = false); - +LogicalResult verifyEquivalentType(Type typeA, Type typeB); LogicalResult verifySameOperandsAndResultEncoding(Operation *op, bool allowTensorPointerType = false); diff --git a/include/triton/Dialect/Triton/IR/TritonInterfaces.td b/include/triton/Dialect/Triton/IR/TritonInterfaces.td index f51cca0bc254..a9188cbf638d 100644 --- a/include/triton/Dialect/Triton/IR/TritonInterfaces.td +++ b/include/triton/Dialect/Triton/IR/TritonInterfaces.td @@ -2,6 +2,7 @@ #define TRITON_INTERFACES include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">; def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">; @@ -13,4 +14,17 @@ def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAn def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">; def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">; +// A trait equivalent to InferTypeOpAdaptor, but that checks for structural +// equivalence of the layouts of the result rather than just layout equality. +def InferTypeOpWithLayoutEquivalence : InferTypeOpAdaptorBase<[{ + static bool isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) { + if (lhs.size() != rhs.size()) + return false; + return llvm::all_of(llvm::zip(lhs, rhs), [](auto tup) { + auto [lhs, rhs] = tup; + return succeeded(OpTrait::impl::verifyEquivalentType(lhs, rhs)); + }); + } +}]>; + #endif // TRITON_INTERFACES diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index f18bf9fe3893..0262d8507b54 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -539,7 +539,7 @@ def TT_SplitOp : TT_Op<"split", [ def TT_TransOp : TT_Op<"trans", [Pure, TransposeOpInterface, - InferTypeOpAdaptorWithIsCompatible, + InferTypeOpWithLayoutEquivalence, SameOperandsAndResultElementType]> { let summary = "rearrange the dimensions of a tensor"; diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 7c17ed2decd1..c4812e9f2630 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -235,26 +235,6 @@ LogicalResult TransOp::inferReturnTypes( return success(); } -bool TransOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) { - assert(lhs.size() == rhs.size()); - assert(lhs.size() == 1); - auto lhsType = cast(lhs[0]); - auto rhsType = cast(rhs[0]); - - if (lhsType.getShape() != rhsType.getShape()) - return false; - - auto lhsEnc = lhsType.getEncoding(); - auto rhsEnc = rhsType.getEncoding(); - // If there's no encoding or the encodings are the same - if (lhsEnc == rhsEnc) - return true; - - return cast(&lhsEnc.getDialect()) - ->verifyLayoutsAreEqual(lhsType.getShape(), lhsEnc, rhsEnc, {}) - .succeeded(); -} - //-- DotOp -- LogicalResult DotOp::inferReturnTypes(MLIRContext *context, std::optional location, diff --git a/lib/Dialect/Triton/IR/Traits.cpp b/lib/Dialect/Triton/IR/Traits.cpp index 690826f4efaf..a38e37bb0734 100644 --- a/lib/Dialect/Triton/IR/Traits.cpp +++ b/lib/Dialect/Triton/IR/Traits.cpp @@ -3,12 +3,33 @@ #include #include "mlir/IR/TypeUtilities.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "llvm/Support/ErrorHandling.h" using namespace mlir; +LogicalResult OpTrait::impl::verifyEquivalentType(Type typeA, Type typeB) { + auto tensorTypeA = dyn_cast(typeA); + auto tensorTypeB = dyn_cast(typeB); + if (!(bool(tensorTypeA) && bool(tensorTypeB))) + return typeA == typeB ? success() : failure(); + auto encodingA = tensorTypeA.getEncoding(); + auto encodingB = tensorTypeB.getEncoding(); + auto shapeA = tensorTypeA.getShape(); + auto shapeB = tensorTypeB.getShape(); + if (shapeA != shapeB) + return failure(); + + // If there's no encoding or the encodings are the same + if (encodingA == encodingB) + return success(); + + return cast(&encodingA.getDialect()) + ->verifyLayoutsAreEqual(shapeA, encodingA, encodingB, {}); +} + static LogicalResult verifySameEncoding(Type typeA, Type typeB, bool allowTensorPointerType) { // TODO(Keren): the allowTensorPointerType argument is a hack to allow.