diff --git a/include/triton/Dialect/Triton/IR/CMakeLists.txt b/include/triton/Dialect/Triton/IR/CMakeLists.txt index f682f54a1c44..8139ebf1ae6b 100644 --- a/include/triton/Dialect/Triton/IR/CMakeLists.txt +++ b/include/triton/Dialect/Triton/IR/CMakeLists.txt @@ -21,7 +21,11 @@ mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) set(LLVM_TARGET_DEFINITIONS TritonTypeInterfaces.td) -mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls) -mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs) +mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs) + +set(LLVM_TARGET_DEFINITIONS TritonOpInterfaces.td) +mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs) add_public_tablegen_target(TritonTableGen) diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index b1f1597c5aa7..53485ce1de28 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -13,6 +13,7 @@ #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "triton/Dialect/Triton/IR/Dialect.h.inc" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" #include "triton/Dialect/Triton/IR/OpsEnums.h.inc" #include "triton/Dialect/Triton/IR/Traits.h" #include "triton/Dialect/Triton/IR/Types.h" diff --git a/include/triton/Dialect/Triton/IR/OpInterfaces.h b/include/triton/Dialect/Triton/IR/OpInterfaces.h new file mode 100644 index 000000000000..1489422d3e25 --- /dev/null +++ b/include/triton/Dialect/Triton/IR/OpInterfaces.h @@ -0,0 +1,21 @@ +#ifndef TRITON_IR_OP_INTERFACES_H_ +#define TRITON_IR_OP_INTERFACES_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { + +namespace triton { + +namespace impl { + +LogicalResult verifyTransposeOpInterface(Operation *op); + +} // namespace impl + +} // namespace triton +} // namespace mlir + +#include "triton/Dialect/Triton/IR/OpInterfaces.h.inc" + +#endif // TRITON_IR_OP_INTERFACES_H_ diff --git a/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td b/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td new file mode 100644 index 000000000000..4208f966b357 --- /dev/null +++ b/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td @@ -0,0 +1,36 @@ +#ifndef TRITON_OP_INTERFACES +#define TRITON_OP_INTERFACES + +include "mlir/IR/OpBase.td" + + +def TransposeOpInterface : OpInterface<"TransposeOpInterface"> { + let description = [{ + This interface is implemented by operations that perform a transpose. + It provides methods to access common properties such as the order attribute and the source operand. + }]; + + let cppNamespace = "::mlir::triton"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Get the source operand of the transposition. + }], + /*retType=*/"::mlir::Value", + /*methodName=*/"getSrc", + /*args=*/(ins)>, + InterfaceMethod< + /*desc=*/[{ + Get the order of the transposition. + }], + /*retType=*/"::mlir::ArrayRef", + /*methodName=*/"getOrder", + /*args=*/(ins)> + ]; + + let verify = [{ return ::mlir::triton::impl::verifyTransposeOpInterface($_op); }]; +} + + +#endif // TRITON_OP_INTERFACES diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 7f0a6aed9617..cdf9b0ea772b 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -8,16 +8,13 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "mlir/IR/OpBase.td" include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface -include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface -include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType -include "mlir/Interfaces/SideEffectInterfaces.td" // Pure -include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" // @@ -44,8 +41,7 @@ class TT_Op traits = []> : def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise, SameOperandsAndResultShape, SameOperandsAndResultEncoding, - Pure, - /*DeclareOpInterfaceMethods*/]> { + Pure]> { let summary = "Cast int64 to pointer"; let arguments = (ins TT_I64Like:$src); @@ -58,8 +54,7 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise, def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise, SameOperandsAndResultShape, SameOperandsAndResultEncoding, - Pure, - /*DeclareOpInterfaceMethods*/]> { + Pure]> { let summary = "Cast pointer to int64"; let arguments = (ins TT_PtrLike:$src); @@ -73,8 +68,7 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise, def TT_BitcastOp : TT_Op<"bitcast", [Elementwise, SameOperandsAndResultShape, SameOperandsAndResultEncoding, - Pure, - /*DeclareOpInterfaceMethods*/]> { + Pure]> { let summary = "Cast between types of the same bitwidth"; let arguments = (ins TT_Type:$src); @@ -89,8 +83,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [Elementwise, def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise, SameOperandsAndResultShape, SameOperandsAndResultEncoding, - Pure, - /*DeclareOpInterfaceMethods*/]> { + Pure]> { let summary = "Floating point casting for custom types"; let description = [{ @@ -118,8 +111,8 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise, // def TT_ClampFOp : TT_Op<"clampf", [Elementwise, - SameOperandsAndResultType, - Pure]> { + SameOperandsAndResultType, + Pure]> { let summary = "Clamp operation for floating point types"; let description = [{ @@ -149,8 +142,8 @@ def TT_ClampFOp : TT_Op<"clampf", [Elementwise, // def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise, - SameOperandsAndResultType, - Pure]> { + SameOperandsAndResultType, + Pure]> { let summary = "Precise sqrt for floating point types"; let description = [{ @@ -165,8 +158,8 @@ def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise, } def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise, - SameOperandsAndResultType, - Pure]> { + SameOperandsAndResultType, + Pure]> { let summary = "Precise div for floating point types"; let description = [{ @@ -181,8 +174,8 @@ def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise, } def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise, - SameOperandsAndResultType, - Pure]> { + SameOperandsAndResultType, + Pure]> { let summary = "Most significant N bits of the 2N-bit product of two integers"; let description = [{ @@ -200,12 +193,12 @@ def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise, // Pointer Arith Ops // def TT_AddPtrOp : TT_Op<"addptr", - [Pure, - Elementwise, - SameOperandsAndResultShape, - SameOperandsAndResultEncoding, - TypesMatchWith<"result type matches ptr type", - "result", "ptr", "$_self">]> { + [Pure, + Elementwise, + SameOperandsAndResultShape, + SameOperandsAndResultEncoding, + TypesMatchWith<"result type matches ptr type", + "result", "ptr", "$_self">]> { let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset); let results = (outs TT_PtrLike:$result); @@ -546,6 +539,7 @@ def TT_SplitOp : TT_Op<"split", [ } def TT_TransOp : TT_Op<"trans", [Pure, + TransposeOpInterface, DeclareOpInterfaceMethods, SameOperandsAndResultElementType]> { @@ -579,16 +573,15 @@ def TT_TransOp : TT_Op<"trans", [Pure, }]; let arguments = ( - ins TT_TensorOrMemDesc:$src, + ins TT_Tensor:$src, DenseI32ArrayAttr:$order ); - let results = (outs TT_TensorOrMemDesc:$result); + let results = (outs TT_Tensor:$result); let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; let hasFolder = 1; - let hasVerifier = 1; } // @@ -677,10 +670,10 @@ def TT_DotOp : TT_Op<"dot", [Pure, // DotScaled Op // def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure, - AttrSizedOperandSegments, - DotLike, - TypesMatchWith<"result's type matches accumulator's type", - "d", "c", "$_self">]> { + AttrSizedOperandSegments, + DotLike, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { let summary = "dot_scaled"; let description = [{ @@ -783,10 +776,10 @@ def TT_ScanReturnOp: TT_Op<"scan.return", // External Elementwise op // def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise, - SameOperandsAndResultEncoding, - SameVariadicOperandSize, - DeclareOpInterfaceMethods, - ConditionallySpeculatable]> { + SameOperandsAndResultEncoding, + SameVariadicOperandSize, + DeclareOpInterfaceMethods, + ConditionallySpeculatable]> { let description = [{ call an external function $symbol implemented in $libpath/$libname with $args diff --git a/include/triton/Dialect/Triton/IR/Types.h b/include/triton/Dialect/Triton/IR/Types.h index 74fa4ba961ac..17d2dbc8ccd8 100644 --- a/include/triton/Dialect/Triton/IR/Types.h +++ b/include/triton/Dialect/Triton/IR/Types.h @@ -8,7 +8,7 @@ #define GET_TYPEDEF_CLASSES #include "triton/Dialect/Triton/IR/Types.h.inc" -#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.h.inc" +#include "triton/Dialect/Triton/IR/TypeInterfaces.h.inc" namespace mlir { diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 8983fae24da1..77cb2c8bf092 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -7,6 +7,7 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" include "mlir/Dialect/Arith/IR/ArithBase.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" // Pure include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType @@ -221,6 +222,31 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> { let hasVerifier = 1; } +def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure, + TransposeOpInterface, + DeclareOpInterfaceMethods, + SameOperandsAndResultElementType]> { + let summary = "transpose the descriptor"; + + let description = [{ + This operation returns a new descriptor + representing a transposed view of the buffer. + }]; + + let arguments = (ins TT_MemDescType:$src, Variadic:$order); + + let arguments = ( + ins TT_MemDescType:$src, + DenseI32ArrayAttr:$order + ); + + let results = (outs TT_MemDescType:$result); + + let assemblyFormat = "$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))"; + + let hasFolder = 1; +} + def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods]> { let summary = "Load a buffer from local memory into a distributed tensor"; diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp index 3840bf4199e5..3a141a73faa0 100644 --- a/lib/Analysis/Alias.cpp +++ b/lib/Analysis/Alias.cpp @@ -38,9 +38,8 @@ LogicalResult SharedMemoryAliasAnalysis::visitOperation( if (isa(op)) { aliasInfo.insert(result); pessimistic = false; - } else if (isa(op)) { - // extract_slice %src - // trans %src + } else if (isa( + op)) { aliasInfo = AliasInfo(operands[0]->getValue()); pessimistic = false; } else { diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 8ba0fd3356f6..aa8840433efa 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -269,27 +269,38 @@ struct ExpandDimsOpConversion : public ConvertOpToLLVMPattern { return success(); } }; +struct MemDescTransOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(MemDescTransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getType()); + auto enc = cast(resultTy.getEncoding()); + auto llvmElemTy = + getTypeConverter()->convertType(resultTy.getElementType()); + auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto dstSmemObj = SharedMemoryObject( + srcSmemObj.base, srcSmemObj.baseElemType, + /*strides=*/applyPermutation(srcSmemObj.strides, op.getOrder()), + /*offsets=*/applyPermutation(srcSmemObj.offsets, op.getOrder())); + auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); + rewriter.replaceOp(op, retVal); + return success(); + } +}; + struct TransOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(TransOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto resultTy = cast(op.getType()); - if (auto enc = dyn_cast(resultTy.getEncoding())) { - auto llvmElemTy = - getTypeConverter()->convertType(resultTy.getElementType()); - auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), - llvmElemTy, rewriter); - auto dstSmemObj = SharedMemoryObject( - srcSmemObj.base, srcSmemObj.baseElemType, - /*strides=*/applyPermutation(srcSmemObj.strides, op.getOrder()), - /*offsets=*/applyPermutation(srcSmemObj.offsets, op.getOrder())); - auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); - rewriter.replaceOp(op, retVal); - return success(); - } else if (auto enc = mlir::dyn_cast( - resultTy.getEncoding())) { + auto resultTy = cast(op.getType()); + if (auto enc = + mlir::dyn_cast(resultTy.getEncoding())) { // If the dst encoding is blocked, then TransOp::inferReturnTypes // ensures that: // - the src encoding is also blocked, and @@ -302,9 +313,10 @@ struct TransOpConversion : public ConvertOpToLLVMPattern { rewriter.replaceOp(op, ret); return success(); } - return emitOptionalError(loc, "unsupported encoding for TransOp"); + return emitOptionalError(loc, "unsupported encoding for MemDescTransOp"); } }; + struct BroadcastOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -407,6 +419,7 @@ void mlir::triton::populateViewOpToLLVMPatterns( patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/lib/Dialect/Triton/IR/CMakeLists.txt b/lib/Dialect/Triton/IR/CMakeLists.txt index 752daa7ff055..f9d1586441d4 100644 --- a/lib/Dialect/Triton/IR/CMakeLists.txt +++ b/lib/Dialect/Triton/IR/CMakeLists.txt @@ -3,6 +3,7 @@ add_triton_library(TritonIR Ops.cpp Traits.cpp Types.cpp + OpInterfaces.cpp DEPENDS TritonTableGen diff --git a/lib/Dialect/Triton/IR/Dialect.cpp b/lib/Dialect/Triton/IR/Dialect.cpp index dc24177125a6..f9789585122d 100644 --- a/lib/Dialect/Triton/IR/Dialect.cpp +++ b/lib/Dialect/Triton/IR/Dialect.cpp @@ -3,7 +3,6 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/UB/IR/UBOps.h" -#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/raw_ostream.h" @@ -12,8 +11,10 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/Transforms/InliningUtils.h" +#include "triton/Dialect/Triton/IR/AttrInterfaces.cpp.inc" #include "triton/Dialect/Triton/IR/Dialect.cpp.inc" -#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.cpp.inc" +#include "triton/Dialect/Triton/IR/OpInterfaces.cpp.inc" +#include "triton/Dialect/Triton/IR/TypeInterfaces.cpp.inc" using namespace mlir; using namespace mlir::triton; diff --git a/lib/Dialect/Triton/IR/OpInterfaces.cpp b/lib/Dialect/Triton/IR/OpInterfaces.cpp new file mode 100644 index 000000000000..7f3a966bffdb --- /dev/null +++ b/lib/Dialect/Triton/IR/OpInterfaces.cpp @@ -0,0 +1,34 @@ +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LogicalResult.h" + +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/Triton/IR/Types.h" + +namespace mlir { +namespace triton { +namespace impl { + +LogicalResult verifyTransposeOpInterface(Operation *op) { + TransposeOpInterface transposeOp = cast(op); + auto rank = cast(transposeOp.getSrc().getType()).getRank(); + auto order = transposeOp.getOrder(); + if (rank != order.size()) { + return op->emitError( + "order must have the same size as the rank of the operand and result"); + } + + SmallVector sortedOrder(order); + llvm::sort(sortedOrder); + for (int32_t i = 0; i < sortedOrder.size(); i++) { + if (sortedOrder[i] != i) { + return op->emitError("order must be a permutation of [0, ..., rank - 1]"); + } + } + + return success(); +} + +} // namespace impl +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 60f9bdf95771..269c32553eff 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -207,7 +207,7 @@ LogicalResult TransOp::inferReturnTypes( DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { // type is the same as the input - auto argTy = cast(operands[0].getType()); + auto argTy = cast(operands[0].getType()); auto order = properties.as()->order.asArrayRef(); SmallVector retShape = applyPermutation(argTy.getShape(), order); @@ -223,35 +223,8 @@ LogicalResult TransOp::inferReturnTypes( return failure(); } } - if (auto memDescTy = dyn_cast(argTy)) { - inferredReturnTypes.push_back(MemDescType::get( - retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(), - memDescTy.getMutableMemory())); - } else { - inferredReturnTypes.push_back( - RankedTensorType::get(retShape, retEltTy, retEncoding)); - } - return success(); -} - -LogicalResult TransOp::verify() { - // Check that the op's `order` attribute is a permutation of the right length. - auto srcTy = getSrc().getType(); - - ArrayRef order = getOrder(); - if (order.size() != srcTy.getRank()) { - return emitError("order must have the same size as the rank of the " - "operand and result"); - } - - SmallVector sortedOrder(order); - llvm::sort(sortedOrder); - for (int32_t i = 0; i < sortedOrder.size(); i++) { - if (sortedOrder[i] != i) { - return emitError("order must be a permutation of [0, ..., rank - 1]"); - } - } - + inferredReturnTypes.push_back( + RankedTensorType::get(retShape, retEltTy, retEncoding)); return success(); } @@ -266,8 +239,8 @@ DotOp::inferReturnTypes(MLIRContext *context, std::optional location, inferredReturnTypes.push_back(accTy); // verify encodings - auto aEnc = cast(operands[0].getType()).getEncoding(); - auto bEnc = cast(operands[1].getType()).getEncoding(); + auto aEnc = cast(operands[0].getType()).getEncoding(); + auto bEnc = cast(operands[1].getType()).getEncoding(); auto retEnc = accTy.getEncoding(); if (aEnc) { assert(bEnc && retEnc); diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index fc2e979c2c0e..233883964f0f 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -1,5 +1,6 @@ #include "mlir/IR/BuiltinTypes.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -131,4 +132,48 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( return success(); } +OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) { + // transpose(x, order=[0, 1, ...]) -> x + if (isIota(getOrder())) { + return getSrc(); + } + + // transpose(transpose(x)) -> transpose(x) + if (auto innerTrans = getSrc().getDefiningOp()) { + setOrder(applyPermutation(innerTrans.getOrder(), getOrder())); + setOperand(innerTrans.getSrc()); + return getResult(); + } + + return {}; +} + +LogicalResult MemDescTransOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the input + auto argTy = cast(operands[0].getType()); + auto order = properties.as()->order.asArrayRef(); + SmallVector retShape = applyPermutation(argTy.getShape(), order); + + auto retEltTy = argTy.getElementType(); + Attribute argEncoding = argTy.getEncoding(); + Attribute retEncoding; + if (argEncoding) { + Dialect &dialect = argEncoding.getDialect(); + auto inferLayoutInterface = dyn_cast(&dialect); + if (inferLayoutInterface + ->inferTransOpEncoding(argEncoding, order, retEncoding) + .failed()) { + return failure(); + } + } + auto memDescTy = cast(argTy); + inferredReturnTypes.push_back(MemDescType::get( + retShape, retEltTy, retEncoding, memDescTy.getMemorySpace(), + memDescTy.getMutableMemory())); + return success(); +} + } // namespace mlir::triton::gpu diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index fdb9bafa0b6b..b6b376101a0b 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -148,7 +148,7 @@ class SwizzleShmemConvert : public OpRewritePattern { LogicalResult matchAndRewrite(ConvertLayoutOp cvtOp, PatternRewriter &rewriter) const override { // Match outerCvt(trans(innerCvt(x))). - auto trans = cvtOp.getSrc().getDefiningOp(); + auto trans = cvtOp.getSrc().getDefiningOp(); if (!trans || trans.getOrder() != ArrayRef{1, 0}) return failure(); @@ -170,9 +170,9 @@ class SwizzleShmemConvert : public OpRewritePattern { // Set needTrans to true here. newInnerCvtEnc is computed based on // argEncoding which is before the transpose. Without needTrans we will // compute vec and maxPhase based on incorrect m, n and k size of mma. The - // type inference of TransOp simply swap the order but doesn't fix the vec - // and maxPhase for the YType, hence it would causing incorrect swizzling - // code. + // type inference of MemDescTransOp simply swap the order but doesn't fix + // the vec and maxPhase for the YType, hence it would causing incorrect + // swizzling code. auto newInnerCvtEnc = SharedEncodingAttr::get(getContext(), cvtEncoding, srcTy.getShape(), /*order=*/getOrder(srcTy.getEncoding()), @@ -187,8 +187,8 @@ class SwizzleShmemConvert : public OpRewritePattern { MemDescType::get(srcTy.getShape(), srcTy.getElementType(), newInnerCvtEnc, sharedMemorySpace), trans.getSrc()); - auto newTrans = rewriter.create(trans.getLoc(), alloc, - ArrayRef({1, 0})); + auto newTrans = rewriter.create(trans.getLoc(), alloc, + ArrayRef({1, 0})); rewriter.replaceOpWithNewOp(trans, sharedLoadTy, newTrans); return success(); } @@ -326,13 +326,13 @@ class FuseTransHopper : public OpRewritePattern { return failure(); // Match outerCvt(trans(innerCvt(x))). - auto trans = allocOp.getSrc().getDefiningOp(); + auto trans = allocOp.getSrc().getDefiningOp(); if (!trans || trans.getOrder() != ArrayRef({1, 0})) return failure(); MemDescType allocType = allocOp.getType(); auto allocEncoding = cast(allocType.getEncoding()); - TensorOrMemDesc srcTy = trans.getSrc().getType(); + MemDescType srcTy = trans.getSrc().getType(); // MMAv3 with transpose only supports f16 and bf16. Fall back to MMAv3 // without transpose for other data types.) @@ -361,8 +361,8 @@ class FuseTransHopper : public OpRewritePattern { allocType.getMemorySpace()); auto newAlloc = rewriter.create(allocOp.getLoc(), innerTy, trans.getSrc()); - rewriter.replaceOpWithNewOp(allocOp, newAlloc, - ArrayRef({1, 0})); + rewriter.replaceOpWithNewOp(allocOp, newAlloc, + ArrayRef({1, 0})); return success(); } }; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 6b12287f4d7d..0d7bd5bdc24e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -1194,7 +1194,7 @@ static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, // come from an MemDescSubview op. Only ConvertLayout and Trans ops are // allowed in between. Value transitiveOperand = operand; - while (isa_and_nonnull( + while (isa_and_nonnull( transitiveOperand.getDefiningOp()) || isa(transitiveOperand)) { auto blockArg = dyn_cast(transitiveOperand); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index 4fc61754bff3..f90c6b7475d3 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -136,7 +136,8 @@ void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder, // TODO: can we use an early_inc iterator? for (OpOperand &use : oldUse->getUses()) { // Non-subview/trans ops will be replaced by `val`. - if (!isa(use.getOwner())) { + if (!isa( + use.getOwner())) { operandsToReplace.push_back(&use); continue; } @@ -155,9 +156,9 @@ void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder, newVal = builder.create( subview.getLoc(), newDstType, val, subview.getOffsets()); newVal.getDefiningOp()->setAttrs(user->getAttrs()); - } else if (auto trans = dyn_cast(user)) { - newVal = builder.create(trans.getLoc(), val, - trans.getOrderAttr()); + } else if (auto trans = dyn_cast(user)) { + newVal = builder.create(trans.getLoc(), val, + trans.getOrder()); newVal.getDefiningOp()->setAttrs(user->getAttrs()); } assert(newVal); diff --git a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp index bff277c59314..540cf081c53a 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp @@ -100,7 +100,7 @@ class TritonGPUReorderInstructionsPass }); // Move transpositions just after their definition opToMove.clear(); - m.walk([&](triton::TransOp op) { + m.walk([&](triton::TransposeOpInterface op) { Operation *argOp = op.getSrc().getDefiningOp(); if (!argOp) return; diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index bb60c1821ad7..fa8ec2b926eb 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -379,13 +379,13 @@ inferTransOpDstEncoding(Attribute srcEnc, ArrayRef order) { return std::nullopt; } -static std::optional inferDstEncoding(triton::TransOp op, - Attribute encoding) { +static std::optional +inferDstEncoding(triton::TransposeOpInterface op, Attribute encoding) { return inferTransOpDstEncoding(encoding, op.getOrder()); } -static std::optional inferSrcEncoding(triton::TransOp op, - Attribute encoding) { +static std::optional +inferSrcEncoding(triton::TransposeOpInterface op, Attribute encoding) { // We want to solve for srcEnc in // transpose(srcEnc, order) -> dstEnc. // Given the identity @@ -468,7 +468,7 @@ std::optional inferSrcEncoding(Operation *op, Attribute encoding) { return inferSrcEncoding(join, encoding); if (auto split = dyn_cast(op)) return inferSrcEncoding(split, encoding); - if (auto trans = dyn_cast(op)) + if (auto trans = dyn_cast(op)) return inferSrcEncoding(trans, encoding); if (auto reshape = dyn_cast(op)) return inferSrcEncoding(reshape, encoding); @@ -495,7 +495,7 @@ std::optional inferDstEncoding(Operation *op, Attribute encoding) { return inferDstEncoding(join, encoding); if (auto split = dyn_cast(op)) return inferDstEncoding(split, encoding); - if (auto trans = dyn_cast(op)) + if (auto trans = dyn_cast(op)) return inferDstEncoding(trans, encoding); if (auto reshape = dyn_cast(op)) return inferDstEncoding(reshape, encoding); diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 4f3af58e85f2..109395ae0466 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -58,7 +58,7 @@ tt.func @trans(%A : !tt.ptr) { // CHECK: %0 -> %0 %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> // CHECK: %1 -> %0 - %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory, mutable> + %b = triton_gpu.memdesc_trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory, mutable> tt.return } diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index a12e3a026071..db2e2947b8e8 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -304,7 +304,7 @@ tt.func @scratch() { tt.func @trans(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 %tensor = triton_gpu.local_alloc : () -> !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> - %b = tt.trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory, mutable> + %b = triton_gpu.memdesc_trans %tensor {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory, mutable> tt.return } @@ -450,7 +450,7 @@ tt.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr< // CHECK-NEXT: offset = 16384, size = 8192 %c_shared_init = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> %a_shared, %b_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init) -> (!tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>) { - %c0 = tt.trans %c_shared_init {order=array} : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #A_SHARED_T, #triton_gpu.shared_memory, mutable> + %c0 = triton_gpu.memdesc_trans %c_shared_init {order=array} : !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<32x128xf16, #A_SHARED_T, #triton_gpu.shared_memory, mutable> // CHECK-NEXT: offset = 24576, size = 8192 %c1 = triton_gpu.local_alloc : () -> !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> scf.yield %b_shared, %a_shared: !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable>, !tt.memdesc<128x32xf16, #A_SHARED, #triton_gpu.shared_memory, mutable> diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index df4f5ab01feb..65d802d9950b 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -132,7 +132,7 @@ tt.func @subview() { // CHECK-LABEL: trans tt.func @trans(%a: !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory>) { // CHECK-NOT: gpu.barrier - %b = tt.trans %a {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory> + %b = triton_gpu.memdesc_trans %a {order=array} : !tt.memdesc<16x32xf16, #A_SHARED, #triton_gpu.shared_memory> -> !tt.memdesc<32x16xf16, #A_SHARED_T, #triton_gpu.shared_memory> tt.return } diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index c7df02322476..d121d285d3f9 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -278,8 +278,8 @@ tt.func public @fn(%arg0: tensor<2x4x8x16xf32, #blocked>, %arg1: tensor<16x32x64 #shared3 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [4, 2], CTAOrder = [1, 0], hasLeadingOffset = true}> module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 8 : i32, "triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: !tt.memdesc<2x4x8x16xf32, #shared>, %arg1: !tt.memdesc<16x32xf32, #shared2>) { - %a = tt.trans %arg0 {order = array} : !tt.memdesc<2x4x8x16xf32, #shared> -> !tt.memdesc<4x16x8x2xf32, #shared1> - %b = tt.trans %arg1 {order = array} : !tt.memdesc<16x32xf32, #shared2> -> !tt.memdesc<32x16xf32, #shared3> + %a = triton_gpu.memdesc_trans %arg0 {order = array} : !tt.memdesc<2x4x8x16xf32, #shared> -> !tt.memdesc<4x16x8x2xf32, #shared1> + %b = triton_gpu.memdesc_trans %arg1 {order = array} : !tt.memdesc<16x32xf32, #shared2> -> !tt.memdesc<32x16xf32, #shared3> tt.return } } // end module diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 3b727b4e95e2..2ec11a24f197 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1562,7 +1562,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %21 = tt.load %20 : tensor<32x32x!tt.ptr, #blocked4> %22 = triton_gpu.convert_layout %21 : tensor<32x32xf16, #blocked4> -> tensor<32x32xf16, #blocked> %23 = triton_gpu.local_alloc %22 : (tensor<32x32xf16, #blocked>) -> !tt.memdesc<32x32xf16, #shared> - %24 = tt.trans %23 {order=array} : !tt.memdesc<32x32xf16, #shared> -> !tt.memdesc<32x32xf16, #shared1> + %24 = triton_gpu.memdesc_trans %23 {order=array} : !tt.memdesc<32x32xf16, #shared> -> !tt.memdesc<32x32xf16, #shared1> %25 = triton_gpu.local_load %24 : !tt.memdesc<32x32xf16, #shared1> -> tensor<32x32xf16, #blocked> %26 = triton_gpu.convert_layout %19 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> %27 = triton_gpu.convert_layout %25 : tensor<32x32xf16, #blocked> -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> @@ -1994,7 +1994,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %68 = tt.addptr %17, %65 : tensor<256x64x!tt.ptr, #blocked>, tensor<256x64xi32, #blocked> %69 = tt.load %68 : tensor<256x64x!tt.ptr, #blocked> %70 = triton_gpu.local_alloc %69 : (tensor<256x64xf16, #blocked>) -> !tt.memdesc<256x64xf16, #shared> - %71 = tt.trans %70 {order=array} : !tt.memdesc<256x64xf16, #shared> -> !tt.memdesc<64x256xf16, #shared1> + %71 = triton_gpu.memdesc_trans %70 {order=array} : !tt.memdesc<256x64xf16, #shared> -> !tt.memdesc<64x256xf16, #shared1> %72 = triton_gpu.convert_layout %67 : tensor<32x64xf16, #blocked> -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked3}>> %73 = triton_gpu.local_load %71 : !tt.memdesc<64x256xf16, #shared1> -> tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked3}>> %74 = triton_gpu.convert_layout %arg8 : tensor<32x256xf32, #blocked3> -> tensor<32x256xf32, #mma> diff --git a/test/TritonGPU/loop-pipeline-cuda.mlir b/test/TritonGPU/loop-pipeline-cuda.mlir index 7b8bed9a18f4..3cb8511b0b5c 100644 --- a/test/TritonGPU/loop-pipeline-cuda.mlir +++ b/test/TritonGPU/loop-pipeline-cuda.mlir @@ -50,7 +50,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> - %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> + %25 = triton_gpu.memdesc_trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> @@ -141,7 +141,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> %71 = triton_gpu.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> - %73 = tt.trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> + %73 = triton_gpu.memdesc_trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> %74 = triton_gpu.local_load %73 : !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index d6653b2b004c..4ad94615c84b 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -48,7 +48,7 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> - %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory, mutable> + %25 = triton_gpu.memdesc_trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory, mutable> %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> @@ -140,7 +140,7 @@ module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 %70 = tt.load %59 : tensor<32x64x!tt.ptr, #blocked1> %71 = triton_gpu.convert_layout %62 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %72 = triton_gpu.local_alloc %70 : (tensor<32x64xf32, #blocked1>) -> !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> - %73 = tt.trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory, mutable> + %73 = triton_gpu.memdesc_trans %72 {order=array} : !tt.memdesc<32x64xf32, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory, mutable> %74 = triton_gpu.local_load %73 : !tt.memdesc<64x32xf32, #shared1, #triton_gpu.shared_memory, mutable> -> tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> %75 = tt.dot %71, %74, %cst : tensor<64x64xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<64x32xf32, #mma> %76 = tt.load %61 : tensor<32x32x!tt.ptr, #blocked1> @@ -239,7 +239,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // Check that the stream pipeliner updates the resulting memory layout of transpose ops to mutable if immutable local buffers are replaced // CHECK-LABEL: loop_with_dot_and_transpose // CHECK: triton_gpu.local_alloc {{.*}}, mutable> -// CHECK: tt.trans {{.*}}, mutable> -> {{.*}}, mutable> +// CHECK: triton_gpu.memdesc_trans {{.*}}, mutable> -> {{.*}}, mutable> #blocked = #triton_gpu.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> @@ -253,7 +253,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %0 = scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 iter_args(%arg3 = %cst) -> (tensor<32x32xf32, #blocked>) : i32 { %2 = tt.load %arg4 : tensor<32x32x!tt.ptr, #blocked1> %3 = triton_gpu.local_alloc %2 : (tensor<32x32xf32, #blocked1>) -> !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory> - %4 = tt.trans %3 {order = array} : !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<32x32xf32, #shared1, #triton_gpu.shared_memory> + %4 = triton_gpu.memdesc_trans %3 {order = array} : !tt.memdesc<32x32xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<32x32xf32, #shared1, #triton_gpu.shared_memory> %5 = triton_gpu.local_load %4 : !tt.memdesc<32x32xf32, #shared1, #triton_gpu.shared_memory> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> %6 = triton_gpu.convert_layout %2 : tensor<32x32xf32, #blocked1> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> %7 = tt.dot %6, %5, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x32xf32, #blocked> diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index f541cacd0e8b..f3784fbe8c2b 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -397,7 +397,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %20 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> %21 = triton_nvidia_gpu.warp_group_dot %19, %20, %cst_2 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1> - %23 = tt.trans %20 {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> + %23 = triton_gpu.memdesc_trans %20 {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> %24 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> %25 = triton_nvidia_gpu.warp_group_dot %24, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> @@ -509,7 +509,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : // CHECK: %[[R:.+]]:{{.+}} = scf.for // CHECK: %[[DOT1:.+]] = triton_nvidia_gpu.warp_group_dot{{.*}} // CHECK: triton_gpu.async_wait {{.*}} {num = 1 : i32} - // CHECK: %[[TRANS:.+]] = tt.trans{{.*}} : !tt.memdesc + // CHECK: %[[TRANS:.+]] = triton_gpu.memdesc_trans{{.*}} : !tt.memdesc // CHECK: %[[DOT2:.+]] = triton_nvidia_gpu.warp_group_dot{{.*}} %[[TRANS]] // CHECK: triton_nvidia_gpu.warp_group_dot_wait %[[DOT1]], %[[DOT2]], %[[ALLOC1]], %[[ALLOC2]], %[[TRANS]] {pendings = 2 : i32} // CHECK: scf.yield @@ -518,7 +518,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %21 = triton_nvidia_gpu.warp_group_dot %19, %20, %arg6 : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1> %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> + %23 = triton_gpu.memdesc_trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> %25 = triton_nvidia_gpu.warp_group_dot %cst_4, %23, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 2}>> * !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x64xf32, #mma> %26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> scf.yield %25, %26, %21 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr, #blocked>, tensor<128x16xf32, #mma1> @@ -681,7 +681,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : %dot1.1 = arith.addf %dot1, %dot1 : tensor<128x16xf32, #mma1> %l = tt.load %arg5 : tensor<64x16x!tt.ptr, #blocked> %c = triton_gpu.local_alloc %l : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> - %23 = tt.trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> + %23 = triton_gpu.memdesc_trans %c {order=array} : !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared, #triton_gpu.shared_memory> // This dot can be async even though %prev_dot2 is not used directly by an // async dot, because that use follows the synchronous dot above. %prev_dot2.1 = arith.addf %prev_dot2, %prev_dot2 : tensor<128x64xf32, #mma> diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 82f726899ea4..973b35defb0e 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -931,7 +931,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %22 = arith.truncf %21 : tensor<128x16xf32, #mma> to tensor<128x16xf16, #mma> %23 = triton_gpu.convert_layout %22 : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %24 = triton_gpu.local_alloc %18 : (tensor<64x16xf16, #blocked>) -> !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> - %25 = tt.trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> + %25 = triton_gpu.memdesc_trans %24 {order=array} : !tt.memdesc<64x16xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> %26 = triton_gpu.local_load %25 : !tt.memdesc<16x64xf16, #shared1, #triton_gpu.shared_memory> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %27 = tt.dot %23, %26, %arg4 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x64xf32, #mma> scf.yield %21, %27 : tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma> @@ -1290,7 +1290,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: scf.for // CHECK: %[[LOAD_1:.*]] = tt.load %[[NEXT_BUFFER_1]] // CHECK: %[[BUFFER_2:.*]] = triton_gpu.local_alloc %[[LOAD_1]] -// CHECK: %[[TRANS:.*]] = tt.trans %[[BUFFER_2]] +// CHECK: %[[TRANS:.*]] = triton_gpu.memdesc_trans %[[BUFFER_2]] // CHECK: %[[LOCAL_LOAD_1:.*]] = triton_gpu.local_load %[[TRANS]] // CHECK: triton_gpu.async_wait // CHECK: triton_gpu.memdesc_subview %[[BUFFER_1]] @@ -1362,7 +1362,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : scf.for %arg1 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { %10 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> %11 = triton_gpu.local_alloc %10 : (tensor<16x16xf32, #blocked>) -> !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> - %12 = tt.trans %11 {order = array} : !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> + %12 = triton_gpu.memdesc_trans %11 {order = array} : !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> %13 = triton_gpu.local_load %12 : !tt.memdesc<16x16xf32, #shared1, #triton_gpu.shared_memory> -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 : i32 { %14 = tt.load %9 : tensor<16x16x!tt.ptr, #blocked> diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp index abea5db773a2..f55ab7855440 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp @@ -205,8 +205,8 @@ static void moveDownCoversion(triton::FuncOp funcOp) { // Move transpositions just after their definition. static void moveUpTranspose(triton::FuncOp funcOp) { - SmallVector transOps; - funcOp.walk([&](triton::TransOp op) { transOps.push_back(op); }); + SmallVector transOps; + funcOp.walk([&](triton::TransposeOpInterface op) { transOps.push_back(op); }); for (auto op : transOps) if (Operation *argOp = op.getSrc().getDefiningOp())