diff --git a/include/triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h b/include/triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h index 068a271c9f23..b1ebc749d1af 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h +++ b/include/triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h @@ -41,15 +41,6 @@ class AssignDescriptorMemoryLayouts { CGAEncodingAttr cgaLayout, ArrayRef usageShape, unsigned numCTAs); - -protected: - virtual Attribute getCompatibleSharedEncoding(Attribute enc, - ArrayRef shape, - Type elementType) { - return isCompatibleSharedEncoding(enc) ? enc : Attribute(); - } - -private: // Override with backend specific implementation virtual Attribute buildFallbackSharedEncoding(mlir::MLIRContext *, ArrayRef, diff --git a/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp b/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp index 6e513e954214..3818477aa611 100644 --- a/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp +++ b/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp @@ -250,36 +250,23 @@ EncodingInfo AssignDescriptorMemoryLayouts::combineEncodings( Attribute AssignDescriptorMemoryLayouts::findLoadEncodingFromUsers(Operation *op) { - auto getCompatibleEncodingForType = [&](Type type) -> Attribute { - if (auto memDescTy = dyn_cast(type)) { - return getCompatibleSharedEncoding(memDescTy.getEncoding(), - memDescTy.getShape(), - memDescTy.getElementType()); - } - if (auto tensorTy = dyn_cast(type)) { - return getCompatibleSharedEncoding(tensorTy.getEncoding(), - tensorTy.getShape(), - tensorTy.getElementType()); - } - return {}; - }; - // Check if there are any desired encodings available on the op if (auto attr = op->getDiscardableAttr("tt.desired_encoding")) { - if (auto resultTy = dyn_cast(op->getResult(0).getType())) - if (auto compatible = getCompatibleSharedEncoding( - attr, resultTy.getShape(), resultTy.getElementType())) - return compatible; + if (auto enc = dyn_cast(attr)) { + if (isCompatibleSharedEncoding(enc)) + return enc; + } } // Ignore multiple users and just pick the first compatible layout for (auto use : op->getUsers()) { if (auto alloc = dyn_cast(use)) { - if (auto compatible = getCompatibleEncodingForType(alloc.getType())) - return compatible; + auto enc = alloc.getType().getEncoding(); + if (isCompatibleSharedEncoding(enc)) + return enc; } else if (auto store = dyn_cast(use)) { - if (auto compatible = - getCompatibleEncodingForType(store.getDst().getType())) - return compatible; + auto enc = store.getDst().getType().getEncoding(); + if (isCompatibleSharedEncoding(enc)) + return enc; } } return {}; @@ -455,9 +442,7 @@ void AssignDescriptorMemoryLayouts::runOnFunction(FuncOp &func) { auto ctx = func.getContext(); auto numCTAs = triton::gpu::lookupNumCTAs(func); for (auto &[desc, einfo] : valueToEncodingInfo) { - auto descTy = desc.getType(); - auto existingTy = - RankedTensorType::get(descTy.getShape(), descTy.getElementType()); + auto existingTy = desc.getType().getBlockType(); Attribute newEncoding; if (einfo->desiredEncoding) { newEncoding = einfo->desiredEncoding; @@ -475,11 +460,10 @@ void AssignDescriptorMemoryLayouts::runOnFunction(FuncOp &func) { SmallVector resultTys(func.getResultTypes()); for (auto [i, resultTy] : llvm::enumerate(resultTys)) { if (auto descTy = dyn_cast(resultTy)) { - auto existingTy = - RankedTensorType::get(descTy.getShape(), descTy.getElementType()); - auto encoding = getFallbackSharedEncoding(existingTy, {}, {}, numCTAs); - resultTys[i] = - getTensorDescTypeWithEncoding(nullptr, existingTy, encoding); + auto encoding = + getFallbackSharedEncoding(descTy.getBlockType(), {}, {}, numCTAs); + resultTys[i] = getTensorDescTypeWithEncoding( + nullptr, descTy.getBlockType(), encoding); } } func.setFunctionType(FunctionType::get(ctx, argTys, resultTys)); diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 651b8145a17e..3e981ebc5f2a 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -3,6 +3,8 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -10,8 +12,7 @@ #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Tools/LayoutUtils.h" #include "triton/Tools/LinearLayout.h" -#include -#include +#include namespace mlir::triton::gpu { @@ -106,10 +107,7 @@ class FuseTransMMAV3Plus : public OpRewritePattern { return failure(); MemDescType allocType = allocOp.getType(); - auto allocEncoding = - dyn_cast(allocType.getEncoding()); - if (!allocEncoding) - return failure(); + auto allocEncoding = cast(allocType.getEncoding()); RankedTensorType srcTy = trans.getSrc().getType(); Dialect &dialect = allocEncoding.getDialect(); @@ -178,113 +176,6 @@ class ReshapeMemDesc : public OpRewritePattern { } }; -// Rewrite -// tt.reshape / tt.trans -> local_alloc -> [memdesc views] -> mma -// into -// local_alloc -> memdesc reshape / trans -> [memdesc views] -> mma -// -// The MMA operand layout is determined by the sink memdesc already feeding the -// dot-like op. This pattern back-propagates that layout through the tensor -// reshape/transpose chain, hoists local_alloc to the base tensor feeding that -// view chain, and replays those tensor views as memdesc reshape/transpose -// ops so the original local_alloc type is preserved. -class RewriteMmaOperandViewsToMemDescForDotOp - : public OpInterfaceRewritePattern { -public: - using OpInterfaceRewritePattern< - triton::DotOpInterface>::OpInterfaceRewritePattern; - - LogicalResult matchAndRewrite(triton::DotOpInterface dotOp, - PatternRewriter &rewriter) const override { - if (!isa(dotOp)) - return failure(); - - bool changed = false; - - if (rewriteOperand(dotOp.getA(), rewriter).succeeded()) - changed = true; - - if (rewriteOperand(dotOp.getB(), rewriter).succeeded()) - changed = true; - - return success(changed); - } - -private: - LogicalResult rewriteOperand(Value operand, PatternRewriter &rewriter) const { - if (!isa(operand.getType())) - return failure(); - - Value beforeTrailing = operand; - while (auto view = beforeTrailing.getDefiningOp()) { - if (auto reshape = dyn_cast(view)) { - beforeTrailing = reshape.getSrc(); - continue; - } - if (auto trans = dyn_cast(view)) { - beforeTrailing = trans.getSrc(); - continue; - } - break; - } - - auto localAlloc = beforeTrailing.getDefiningOp(); - if (!localAlloc || !localAlloc.getSrc()) - return failure(); - - Value baseTensor = localAlloc.getSrc(); - SmallVector tensorReplaySteps; - MemDescType baseMemTy = localAlloc.getType(); - while (auto view = baseTensor.getDefiningOp()) { - if (auto reshape = dyn_cast(view)) { - MemDescType srcTy; - auto inferred = MemDescReshapeOp::inferReturnTypes( - getContext(), reshape.getLoc(), baseMemTy, - reshape.getSrc().getType().getShape(), srcTy); - assert(succeeded(inferred) && "backward memdesc reshape inference " - "must succeed"); - (void)inferred; - baseMemTy = srcTy; - } else if (auto trans = dyn_cast(view)) { - Attribute srcEnc = inferSrcEncoding(view, baseMemTy.getEncoding()); - if (!srcEnc) - return failure(); - baseMemTy = MemDescType::get( - trans.getSrc().getType().getShape(), baseMemTy.getElementType(), - srcEnc, baseMemTy.getMemorySpace(), baseMemTy.getMutableMemory()); - } else { - break; - } - tensorReplaySteps.push_back(view); - baseTensor = view->getOperand(0); - } - if (tensorReplaySteps.empty()) - return failure(); - - std::reverse(tensorReplaySteps.begin(), tensorReplaySteps.end()); - - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(localAlloc); - - Value rewritten = LocalAllocOp::create(rewriter, localAlloc.getLoc(), - baseMemTy, baseTensor); - for (Operation *op : tensorReplaySteps) { - if (auto reshape = dyn_cast(op)) { - rewritten = MemDescReshapeOp::create(rewriter, op->getLoc(), rewritten, - reshape.getType().getShape()); - } else { - auto trans = cast(op); - rewritten = MemDescTransOp::create(rewriter, op->getLoc(), rewritten, - trans.getOrder()); - } - } - rewriter.replaceOp(localAlloc, rewritten); - return success(); - } -}; - // Inject TMEM copy instructions into IR to efficiently load blocked scales for // scaled dot class UseShmemForScales @@ -446,7 +337,6 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); - patterns.add(context); patterns.add(context); ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); if (failed(applyPatternsGreedily(m, std::move(patterns)))) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp index e8d88a9bdc8a..2dbaab2d77ce 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp @@ -20,58 +20,15 @@ class NvidiaGPUAssignDescriptorMemoryLayouts ArrayRef order, ttg::CGAEncodingAttr cgaLayout, Type elementType) override; - Attribute getCompatibleSharedEncoding(Attribute enc, ArrayRef shape, - Type elementType) override; bool isCompatibleSharedEncoding(Attribute enc) override; }; bool NvidiaGPUAssignDescriptorMemoryLayouts::isCompatibleSharedEncoding( Attribute enc) { - if (auto nvmma = dyn_cast(enc)) + if (auto nvmma = dyn_cast(enc)) { return !nvmma.getTransposed(); - return false; -} - -Attribute NvidiaGPUAssignDescriptorMemoryLayouts::getCompatibleSharedEncoding( - Attribute enc, ArrayRef shape, Type elementType) { - if (isCompatibleSharedEncoding(enc)) - return enc; - - auto sharedLinear = dyn_cast(enc); - if (!sharedLinear) - return {}; - - auto *ctx = enc.getContext(); - auto cgaLayout = ttg::getCGALayout(sharedLinear); - auto order = ttg::getOrder(sharedLinear, shape); - - SmallVector preferredCandidates; - // TMA descriptors only support non-transposed layouts. Preserve Triton's - // default shape/order-based choice when it already matches this - // shared_linear layout. The full candidate scan below is only a fallback for - // equivalent non-transposed layouts not selected by the heuristic builder. - for (bool fp4Padded : {false, true}) { - auto preferred = ttg::NVMMASharedEncodingAttr::get( - ctx, shape, order, cgaLayout, elementType, fp4Padded); - preferredCandidates.push_back(preferred); - if (ttg::areLayoutsEquivalent(shape, sharedLinear, preferred)) - return preferred; - } - - unsigned elementBitWidth = std::max(8u, elementType.getIntOrFloatBitWidth()); - for (bool fp4Padded : {false, true}) { - for (unsigned swizzle : {0u, 32u, 64u, 128u}) { - auto candidate = ttg::NVMMASharedEncodingAttr::get( - ctx, swizzle, /*transposed=*/false, elementBitWidth, fp4Padded, - cgaLayout); - if (llvm::is_contained(preferredCandidates, candidate)) - continue; - if (ttg::areLayoutsEquivalent(shape, sharedLinear, candidate)) - return candidate; - } } - - return {}; + return false; } // Build fallback encoding given shape, order, cga layout and element type diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 4434061722a8..f292a2fcb211 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -293,96 +293,3 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num tt.return %a: !ttg.memdesc<64x64xf32, #shared, #smem> } } - -// ----- - -#blocked0 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> -#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> -#linear0 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> -#linear2 = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 128], [16, 0], [32, 0], [64, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[0, 32], [0, 64]], block = []}> -#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> -#shared5 = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> -#smem = #ttg.shared_memory -#tmem0 = #ttng.tensor_memory_encoding -// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> -// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> -module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: @swizzle0_operand_views - // CHECK-DAG: %[[A_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<128x128xf8E4M3FN> -> tensor<128x128xf8E4M3FN, #{{.*}}> - // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_DESC]] : (tensor<128x128xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<128x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> - // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC]] : (tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_TR0:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR0]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_TR1:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-NOT: tt.reshape - // CHECK-NOT: tt.trans - // CHECK-NOT: ttg.local_load - // CHECK: ttng.tc_gen5_mma %[[A_BASE]], %[[B_TR1]], %arg2, %true, %true - tt.func @swizzle0_operand_views( - %a_desc: !tt.tensordesc<128x128xf8E4M3FN>, - %b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>, - %acc: !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory>) { - %true = arith.constant true - %c0_i32 = arith.constant 0 : i32 - %a = tt.descriptor_load %a_desc[%c0_i32, %c0_i32] : !tt.tensordesc<128x128xf8E4M3FN> -> tensor<128x128xf8E4M3FN, #blocked0> - %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked1> - %b1 = tt.reshape %b : tensor<1x1x256x8x16xf8E4M3FN, #blocked1> -> tensor<8x32x8x16xf8E4M3FN, #blocked2> - %b2 = tt.trans %b1 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blocked2> -> tensor<32x8x8x16xf8E4M3FN, #blocked3> - %b3 = tt.reshape %b2 : tensor<32x8x8x16xf8E4M3FN, #blocked3> -> tensor<256x128xf8E4M3FN, #linear0> - %b4 = tt.trans %b3 {order = array} : tensor<256x128xf8E4M3FN, #linear0> -> tensor<128x256xf8E4M3FN, #linear2> - %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blocked0>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem> - %b_s = ttg.local_alloc %b4 : (tensor<128x256xf8E4M3FN, #linear2>) -> !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem> - ttng.tc_gen5_mma %a_s, %b_s, %acc, %true, %true : !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem>, !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory> - tt.return - } -} - -// ----- - -#blockedA_desc = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> -#blockedB_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> -#blockedB0_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> -#blockedB1_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> -#linear_desc = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> -#mma_desc = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 32]}> -#sharedA_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> -#sharedB3_desc = #ttg.shared_linear<{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [128, 0], [0, 16], [0, 32], [0, 64]]}, alignment = 128> -#sharedB4_desc = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> -#smem = #ttg.shared_memory -// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> -// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> -module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: @swizzle0_operand_views_warp_group_dot - // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[%{{.*}}] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> - // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_TR:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> - // CHECK-DAG: %[[B_T:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #{{.*}}, #smem{{.*}}> - // CHECK-NOT: tt.reshape - // CHECK-NOT: tt.trans - // CHECK: %[[DOT:.*]] = ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 - // CHECK: %[[WAIT:.*]]:3 = ttng.warp_group_dot_wait %[[DOT]], {{.*}}, %[[B_T]] {pendings = 0 : i32} - tt.func @swizzle0_operand_views_warp_group_dot( - %a_desc: !tt.tensordesc<128x128xf8E4M3FN, #sharedA_desc>, - %b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>, - %acc: tensor<128x256xf32, #mma_desc>) -> tensor<128x256xf32, #mma_desc> { - %c0 = arith.constant 0 : i32 - %a = tt.descriptor_load %a_desc[%c0, %c0] : !tt.tensordesc<128x128xf8E4M3FN, #sharedA_desc> -> tensor<128x128xf8E4M3FN, #blockedA_desc> - %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blockedA_desc>) -> !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> - %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> - %b0 = tt.reshape %b_cm : tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> -> tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> - %b1 = tt.trans %b0 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> -> tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> - %b2 = tt.reshape %b1 : tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> -> tensor<256x128xf8E4M3FN, #linear_desc> - %b_s = ttg.local_alloc %b2 : (tensor<256x128xf8E4M3FN, #linear_desc>) -> !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> - %b_t = ttg.memdesc_trans %b_s {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> -> !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> - %r = ttng.warp_group_dot %a_s, %b_t, %acc {inputPrecision = 0 : i32, isAsync = true, maxNumImpreciseAcc = 1073741824 : i32} - : !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> * !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> -> tensor<128x256xf32, #mma_desc> - %w:3 = ttng.warp_group_dot_wait %r, %a_s, %b_t {pendings = 0 : i32} : tensor<128x256xf32, #mma_desc>, !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> - tt.return %w#0 : tensor<128x256xf32, #mma_desc> - } -} diff --git a/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir b/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir index 1376692737b3..aff7ca9c66af 100644 --- a/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir +++ b/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir @@ -86,27 +86,6 @@ tt.func public @descriptor_kernel_arg(%arg0: !tt.tensordesc<64x64xf16>, %arg1: i // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16}> -#smem = #ttg.shared_memory - -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { -// CHECK-DAG: #[[BLOCKED_16x128:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK-DAG: #[[NVMMA_128:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> -// CHECK-DAG: #[[NVMMA_TRANSPOSED_32:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16}> -tt.func public @descriptor_ignores_transposed_local_alloc(%arg0: !tt.tensordesc<16x128xf16>) { - // CHECK: %arg0: !tt.tensordesc<16x128xf16, #[[NVMMA_128]]> - // CHECK: %[[LOAD:.*]] = tt.descriptor_load %arg0{{.*}} : !tt.tensordesc<16x128xf16, #[[NVMMA_128]]> -> tensor<16x128xf16, #[[BLOCKED_16x128]]> - // CHECK: ttg.local_alloc %[[LOAD]] : (tensor<16x128xf16, #[[BLOCKED_16x128]]>) -> !ttg.memdesc<16x128xf16, #[[NVMMA_TRANSPOSED_32]], #smem> - %c0 = arith.constant 0 : i32 - %0 = tt.descriptor_load %arg0[%c0, %c0] : !tt.tensordesc<16x128xf16> -> tensor<16x128xf16, #blocked> - %1 = ttg.local_alloc %0 : (tensor<16x128xf16, #blocked>) -> !ttg.memdesc<16x128xf16, #shared, #smem> - tt.return -} -} - -// ----- - #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> @@ -144,25 +123,3 @@ tt.func public @tma_load_while(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return } } - -// ----- - -#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> -#shared_linear_base = #ttg.shared_linear<{offset = [[0, 0, 0, 0, 1], [0, 0, 0, 0, 2], [0, 0, 0, 0, 4], [0, 0, 0, 0, 8], [0, 0, 0, 1, 0], [0, 0, 0, 2, 0], [0, 0, 0, 4, 0], [0, 0, 1, 0, 0], [0, 0, 2, 0, 0], [0, 0, 4, 0, 0], [0, 0, 8, 0, 0], [0, 0, 16, 0, 0], [0, 0, 32, 0, 0], [0, 0, 64, 0, 0], [0, 0, 128, 0, 0]]}, alignment = 128> -#smem = #ttg.shared_memory - -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { -// CHECK-DAG: #[[NVMMA_0:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> -// CHECK-DAG: #[[BLOCKED5D:.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> -// CHECK-DAG: #[[SL_BASE:.*]] = #ttg.shared_linear<{{.*}}alignment = 128> -tt.func public @descriptor_arg_from_shared_linear_use(%b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>) { - // CHECK: %arg0: !tt.tensordesc<1x1x256x8x16xf8E4M3FN, #[[NVMMA_0]]> - // CHECK: %[[LOAD:.*]] = tt.descriptor_load %arg0 - // CHECK-SAME: : !tt.tensordesc<1x1x256x8x16xf8E4M3FN, #[[NVMMA_0]]> -> tensor<1x1x256x8x16xf8E4M3FN, #[[BLOCKED5D]]> - // CHECK: ttg.local_alloc %[[LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, #[[BLOCKED5D]]>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #[[SL_BASE]], #smem> - %c0 = arith.constant 0 : i32 - %b = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked2> - %b_s = ttg.local_alloc %b : (tensor<1x1x256x8x16xf8E4M3FN, #blocked2>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared_linear_base, #smem> - tt.return -} -}