Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/triton/Dialect/Triton/IR/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Support/LogicalResult.h"
#include "triton/Dialect/Triton/IR/Types.h"

Expand All @@ -27,7 +28,7 @@ LogicalResult verifyTensorLayouts(Operation *op);

LogicalResult verifySameOperandsEncoding(Operation *op,
bool allowTensorPointerType = false);

LogicalResult verifyEquivalentType(Type typeA, Type typeB);
LogicalResult
verifySameOperandsAndResultEncoding(Operation *op,
bool allowTensorPointerType = false);
Expand Down
14 changes: 14 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define TRITON_INTERFACES

include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"

def TensorSizeTrait : NativeOpTrait<"TensorSizeTrait">;
def VerifyTensorLayoutsTrait : NativeOpTrait<"VerifyTensorLayoutsTrait">;
Expand All @@ -13,4 +14,17 @@ def SameLoadStoreOperandsAndResultShape : NativeOpTrait<"SameLoadStoreOperandsAn
def SameLoadStoreOperandsEncoding : NativeOpTrait<"SameLoadStoreOperandsEncoding">;
def SameLoadStoreOperandsAndResultEncoding : NativeOpTrait<"SameLoadStoreOperandsAndResultEncoding">;

// A trait equivalent to InferTypeOpAdaptor, but that checks for structural
// equivalence of the layouts of the result rather than just layout equality.
def InferTypeOpWithLayoutEquivalence : InferTypeOpAdaptorBase<[{
static bool isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
if (lhs.size() != rhs.size())
return false;
return llvm::all_of(llvm::zip(lhs, rhs), [](auto tup) {
auto [lhs, rhs] = tup;
return succeeded(OpTrait::impl::verifyEquivalentType(lhs, rhs));
});
}
}]>;

#endif // TRITON_INTERFACES
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def TT_SplitOp : TT_Op<"split", [

def TT_TransOp : TT_Op<"trans", [Pure,
TransposeOpInterface,
InferTypeOpAdaptorWithIsCompatible,
InferTypeOpWithLayoutEquivalence,
SameOperandsAndResultElementType]> {

let summary = "rearrange the dimensions of a tensor";
Expand Down
20 changes: 0 additions & 20 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,26 +235,6 @@ LogicalResult TransOp::inferReturnTypes(
return success();
}

bool TransOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) {
assert(lhs.size() == rhs.size());
assert(lhs.size() == 1);
auto lhsType = cast<RankedTensorType>(lhs[0]);
auto rhsType = cast<RankedTensorType>(rhs[0]);

if (lhsType.getShape() != rhsType.getShape())
return false;

auto lhsEnc = lhsType.getEncoding();
auto rhsEnc = rhsType.getEncoding();
// If there's no encoding or the encodings are the same
if (lhsEnc == rhsEnc)
return true;

return cast<DialectInferLayoutInterface>(&lhsEnc.getDialect())
->verifyLayoutsAreEqual(lhsType.getShape(), lhsEnc, rhsEnc, {})
.succeeded();
}

//-- DotOp --
LogicalResult
DotOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
Expand Down
21 changes: 21 additions & 0 deletions lib/Dialect/Triton/IR/Traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,33 @@
#include <numeric>

#include "mlir/IR/TypeUtilities.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "llvm/Support/ErrorHandling.h"

using namespace mlir;

LogicalResult OpTrait::impl::verifyEquivalentType(Type typeA, Type typeB) {
auto tensorTypeA = dyn_cast<RankedTensorType>(typeA);
auto tensorTypeB = dyn_cast<RankedTensorType>(typeB);
if (!(bool(tensorTypeA) && bool(tensorTypeB)))
return typeA == typeB ? success() : failure();
auto encodingA = tensorTypeA.getEncoding();
auto encodingB = tensorTypeB.getEncoding();
auto shapeA = tensorTypeA.getShape();
auto shapeB = tensorTypeB.getShape();
if (shapeA != shapeB)
return failure();

// If there's no encoding or the encodings are the same
if (encodingA == encodingB)
return success();

return cast<triton::DialectInferLayoutInterface>(&encodingA.getDialect())
->verifyLayoutsAreEqual(shapeA, encodingA, encodingB, {});
}

static LogicalResult verifySameEncoding(Type typeA, Type typeB,
bool allowTensorPointerType) {
// TODO(Keren): the allowTensorPointerType argument is a hack to allow.
Expand Down