diff --git a/include/triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h b/include/triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h index b1ebc749d1af..068a271c9f23 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h +++ b/include/triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h @@ -41,6 +41,15 @@ class AssignDescriptorMemoryLayouts { CGAEncodingAttr cgaLayout, ArrayRef usageShape, unsigned numCTAs); + +protected: + virtual Attribute getCompatibleSharedEncoding(Attribute enc, + ArrayRef shape, + Type elementType) { + return isCompatibleSharedEncoding(enc) ? enc : Attribute(); + } + +private: // Override with backend specific implementation virtual Attribute buildFallbackSharedEncoding(mlir::MLIRContext *, ArrayRef, diff --git a/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp b/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp index 3818477aa611..6e513e954214 100644 --- a/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp +++ b/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp @@ -250,23 +250,36 @@ EncodingInfo AssignDescriptorMemoryLayouts::combineEncodings( Attribute AssignDescriptorMemoryLayouts::findLoadEncodingFromUsers(Operation *op) { + auto getCompatibleEncodingForType = [&](Type type) -> Attribute { + if (auto memDescTy = dyn_cast(type)) { + return getCompatibleSharedEncoding(memDescTy.getEncoding(), + memDescTy.getShape(), + memDescTy.getElementType()); + } + if (auto tensorTy = dyn_cast(type)) { + return getCompatibleSharedEncoding(tensorTy.getEncoding(), + tensorTy.getShape(), + tensorTy.getElementType()); + } + return {}; + }; + // Check if there are any desired encodings available on the op if (auto attr = op->getDiscardableAttr("tt.desired_encoding")) { - if (auto enc = dyn_cast(attr)) { - if (isCompatibleSharedEncoding(enc)) - return enc; - } + if (auto resultTy = dyn_cast(op->getResult(0).getType())) + if (auto compatible = getCompatibleSharedEncoding( + attr, resultTy.getShape(), resultTy.getElementType())) + return compatible; } // Ignore multiple users and just pick the first compatible layout for (auto use : op->getUsers()) { if (auto alloc = dyn_cast(use)) { - auto enc = alloc.getType().getEncoding(); - if (isCompatibleSharedEncoding(enc)) - return enc; + if (auto compatible = getCompatibleEncodingForType(alloc.getType())) + return compatible; } else if (auto store = dyn_cast(use)) { - auto enc = store.getDst().getType().getEncoding(); - if (isCompatibleSharedEncoding(enc)) - return enc; + if (auto compatible = + getCompatibleEncodingForType(store.getDst().getType())) + return compatible; } } return {}; @@ -442,7 +455,9 @@ void AssignDescriptorMemoryLayouts::runOnFunction(FuncOp &func) { auto ctx = func.getContext(); auto numCTAs = triton::gpu::lookupNumCTAs(func); for (auto &[desc, einfo] : valueToEncodingInfo) { - auto existingTy = desc.getType().getBlockType(); + auto descTy = desc.getType(); + auto existingTy = + RankedTensorType::get(descTy.getShape(), descTy.getElementType()); Attribute newEncoding; if (einfo->desiredEncoding) { newEncoding = einfo->desiredEncoding; @@ -460,10 +475,11 @@ void AssignDescriptorMemoryLayouts::runOnFunction(FuncOp &func) { SmallVector resultTys(func.getResultTypes()); for (auto [i, resultTy] : llvm::enumerate(resultTys)) { if (auto descTy = dyn_cast(resultTy)) { - auto encoding = - getFallbackSharedEncoding(descTy.getBlockType(), {}, {}, numCTAs); - resultTys[i] = getTensorDescTypeWithEncoding( - nullptr, descTy.getBlockType(), encoding); + auto existingTy = + RankedTensorType::get(descTy.getShape(), descTy.getElementType()); + auto encoding = getFallbackSharedEncoding(existingTy, {}, {}, numCTAs); + resultTys[i] = + getTensorDescTypeWithEncoding(nullptr, existingTy, encoding); } } func.setFunctionType(FunctionType::get(ctx, argTys, resultTys)); diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 8fa3c6db297d..134b8ea8154b 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -3,8 +3,6 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" -#include "triton/Analysis/Utility.h" -#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -12,7 +10,8 @@ #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Tools/LayoutUtils.h" #include "triton/Tools/LinearLayout.h" -#include +#include +#include namespace mlir::triton::gpu { @@ -108,7 +107,10 @@ class FuseTransMMAV3Plus : public OpRewritePattern { return failure(); MemDescType allocType = allocOp.getType(); - auto allocEncoding = cast(allocType.getEncoding()); + auto allocEncoding = + dyn_cast(allocType.getEncoding()); + if (!allocEncoding) + return failure(); RankedTensorType srcTy = trans.getSrc().getType(); auto ctx = getContext(); @@ -180,6 +182,113 @@ class ReshapeMemDesc : public OpRewritePattern { } }; +// Rewrite +// tt.reshape / tt.trans -> local_alloc -> [memdesc views] -> mma +// into +// local_alloc -> memdesc reshape / trans -> [memdesc views] -> mma +// +// The MMA operand layout is determined by the sink memdesc already feeding the +// dot-like op. This pattern back-propagates that layout through the tensor +// reshape/transpose chain, hoists local_alloc to the base tensor feeding that +// view chain, and replays those tensor views as memdesc reshape/transpose +// ops so the original local_alloc type is preserved. +class RewriteMmaOperandViewsToMemDescForDotOp + : public OpInterfaceRewritePattern { +public: + using OpInterfaceRewritePattern< + triton::DotOpInterface>::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(triton::DotOpInterface dotOp, + PatternRewriter &rewriter) const override { + if (!isa(dotOp)) + return failure(); + + bool changed = false; + + if (rewriteOperand(dotOp.getA(), rewriter).succeeded()) + changed = true; + + if (rewriteOperand(dotOp.getB(), rewriter).succeeded()) + changed = true; + + return success(changed); + } + +private: + LogicalResult rewriteOperand(Value operand, PatternRewriter &rewriter) const { + if (!isa(operand.getType())) + return failure(); + + Value beforeTrailing = operand; + while (auto view = beforeTrailing.getDefiningOp()) { + if (auto reshape = dyn_cast(view)) { + beforeTrailing = reshape.getSrc(); + continue; + } + if (auto trans = dyn_cast(view)) { + beforeTrailing = trans.getSrc(); + continue; + } + break; + } + + auto localAlloc = beforeTrailing.getDefiningOp(); + if (!localAlloc || !localAlloc.getSrc()) + return failure(); + + Value baseTensor = localAlloc.getSrc(); + SmallVector tensorReplaySteps; + MemDescType baseMemTy = localAlloc.getType(); + while (auto view = baseTensor.getDefiningOp()) { + if (auto reshape = dyn_cast(view)) { + MemDescType srcTy; + auto inferred = MemDescReshapeOp::inferReturnTypes( + getContext(), reshape.getLoc(), baseMemTy, + reshape.getSrc().getType().getShape(), srcTy); + assert(succeeded(inferred) && "backward memdesc reshape inference " + "must succeed"); + (void)inferred; + baseMemTy = srcTy; + } else if (auto trans = dyn_cast(view)) { + Attribute srcEnc = inferSrcEncoding(view, baseMemTy.getEncoding()); + if (!srcEnc) + return failure(); + baseMemTy = MemDescType::get( + trans.getSrc().getType().getShape(), baseMemTy.getElementType(), + srcEnc, baseMemTy.getMemorySpace(), baseMemTy.getMutableMemory()); + } else { + break; + } + tensorReplaySteps.push_back(view); + baseTensor = view->getOperand(0); + } + if (tensorReplaySteps.empty()) + return failure(); + + std::reverse(tensorReplaySteps.begin(), tensorReplaySteps.end()); + + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(localAlloc); + + Value rewritten = LocalAllocOp::create(rewriter, localAlloc.getLoc(), + baseMemTy, baseTensor); + for (Operation *op : tensorReplaySteps) { + if (auto reshape = dyn_cast(op)) { + rewritten = MemDescReshapeOp::create(rewriter, op->getLoc(), rewritten, + reshape.getType().getShape()); + } else { + auto trans = cast(op); + rewritten = MemDescTransOp::create(rewriter, op->getLoc(), rewritten, + trans.getOrder()); + } + } + rewriter.replaceOp(localAlloc, rewritten); + return success(); + } +}; + // Inject TMEM copy instructions into IR to efficiently load blocked scales for // scaled dot class UseShmemForScales @@ -341,6 +450,7 @@ class TritonGPUOptimizeDotOperandsPass mlir::RewritePatternSet patterns(context); patterns.add(context); patterns.add(context); + patterns.add(context); patterns.add(context); ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); if (failed(applyPatternsGreedily(m, std::move(patterns)))) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp index 2dbaab2d77ce..e8d88a9bdc8a 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp @@ -20,17 +20,60 @@ class NvidiaGPUAssignDescriptorMemoryLayouts ArrayRef order, ttg::CGAEncodingAttr cgaLayout, Type elementType) override; + Attribute getCompatibleSharedEncoding(Attribute enc, ArrayRef shape, + Type elementType) override; bool isCompatibleSharedEncoding(Attribute enc) override; }; bool NvidiaGPUAssignDescriptorMemoryLayouts::isCompatibleSharedEncoding( Attribute enc) { - if (auto nvmma = dyn_cast(enc)) { + if (auto nvmma = dyn_cast(enc)) return !nvmma.getTransposed(); - } return false; } +Attribute NvidiaGPUAssignDescriptorMemoryLayouts::getCompatibleSharedEncoding( + Attribute enc, ArrayRef shape, Type elementType) { + if (isCompatibleSharedEncoding(enc)) + return enc; + + auto sharedLinear = dyn_cast(enc); + if (!sharedLinear) + return {}; + + auto *ctx = enc.getContext(); + auto cgaLayout = ttg::getCGALayout(sharedLinear); + auto order = ttg::getOrder(sharedLinear, shape); + + SmallVector preferredCandidates; + // TMA descriptors only support non-transposed layouts. Preserve Triton's + // default shape/order-based choice when it already matches this + // shared_linear layout. The full candidate scan below is only a fallback for + // equivalent non-transposed layouts not selected by the heuristic builder. + for (bool fp4Padded : {false, true}) { + auto preferred = ttg::NVMMASharedEncodingAttr::get( + ctx, shape, order, cgaLayout, elementType, fp4Padded); + preferredCandidates.push_back(preferred); + if (ttg::areLayoutsEquivalent(shape, sharedLinear, preferred)) + return preferred; + } + + unsigned elementBitWidth = std::max(8u, elementType.getIntOrFloatBitWidth()); + for (bool fp4Padded : {false, true}) { + for (unsigned swizzle : {0u, 32u, 64u, 128u}) { + auto candidate = ttg::NVMMASharedEncodingAttr::get( + ctx, swizzle, /*transposed=*/false, elementBitWidth, fp4Padded, + cgaLayout); + if (llvm::is_contained(preferredCandidates, candidate)) + continue; + if (ttg::areLayoutsEquivalent(shape, sharedLinear, candidate)) + return candidate; + } + } + + return {}; +} + // Build fallback encoding given shape, order, cga layout and element type Attribute NvidiaGPUAssignDescriptorMemoryLayouts::buildFallbackSharedEncoding( mlir::MLIRContext *ctx, ArrayRef shape, ArrayRef order, diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index f292a2fcb211..4434061722a8 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -293,3 +293,96 @@ module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num tt.return %a: !ttg.memdesc<64x64xf32, #shared, #smem> } } + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> +#linear0 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> +#linear2 = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 128], [16, 0], [32, 0], [64, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[0, 32], [0, 64]], block = []}> +#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#shared5 = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> +#smem = #ttg.shared_memory +#tmem0 = #ttng.tensor_memory_encoding +// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @swizzle0_operand_views + // CHECK-DAG: %[[A_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<128x128xf8E4M3FN> -> tensor<128x128xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_DESC]] : (tensor<128x128xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<128x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC]] : (tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_TR0:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR0]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_TR1:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-NOT: tt.reshape + // CHECK-NOT: tt.trans + // CHECK-NOT: ttg.local_load + // CHECK: ttng.tc_gen5_mma %[[A_BASE]], %[[B_TR1]], %arg2, %true, %true + tt.func @swizzle0_operand_views( + %a_desc: !tt.tensordesc<128x128xf8E4M3FN>, + %b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>, + %acc: !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory>) { + %true = arith.constant true + %c0_i32 = arith.constant 0 : i32 + %a = tt.descriptor_load %a_desc[%c0_i32, %c0_i32] : !tt.tensordesc<128x128xf8E4M3FN> -> tensor<128x128xf8E4M3FN, #blocked0> + %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked1> + %b1 = tt.reshape %b : tensor<1x1x256x8x16xf8E4M3FN, #blocked1> -> tensor<8x32x8x16xf8E4M3FN, #blocked2> + %b2 = tt.trans %b1 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blocked2> -> tensor<32x8x8x16xf8E4M3FN, #blocked3> + %b3 = tt.reshape %b2 : tensor<32x8x8x16xf8E4M3FN, #blocked3> -> tensor<256x128xf8E4M3FN, #linear0> + %b4 = tt.trans %b3 {order = array} : tensor<256x128xf8E4M3FN, #linear0> -> tensor<128x256xf8E4M3FN, #linear2> + %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blocked0>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem> + %b_s = ttg.local_alloc %b4 : (tensor<128x256xf8E4M3FN, #linear2>) -> !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem> + ttng.tc_gen5_mma %a_s, %b_s, %acc, %true, %true : !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem>, !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory> + tt.return + } +} + +// ----- + +#blockedA_desc = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blockedB_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#blockedB0_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> +#blockedB1_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> +#linear_desc = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> +#mma_desc = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 32]}> +#sharedA_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +#sharedB3_desc = #ttg.shared_linear<{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [128, 0], [0, 16], [0, 32], [0, 64]]}, alignment = 128> +#sharedB4_desc = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> +#smem = #ttg.shared_memory +// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @swizzle0_operand_views_warp_group_dot + // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[%{{.*}}] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> + // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_TR:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> + // CHECK-DAG: %[[B_T:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #{{.*}}, #smem{{.*}}> + // CHECK-NOT: tt.reshape + // CHECK-NOT: tt.trans + // CHECK: %[[DOT:.*]] = ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 + // CHECK: %[[WAIT:.*]]:3 = ttng.warp_group_dot_wait %[[DOT]], {{.*}}, %[[B_T]] {pendings = 0 : i32} + tt.func @swizzle0_operand_views_warp_group_dot( + %a_desc: !tt.tensordesc<128x128xf8E4M3FN, #sharedA_desc>, + %b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>, + %acc: tensor<128x256xf32, #mma_desc>) -> tensor<128x256xf32, #mma_desc> { + %c0 = arith.constant 0 : i32 + %a = tt.descriptor_load %a_desc[%c0, %c0] : !tt.tensordesc<128x128xf8E4M3FN, #sharedA_desc> -> tensor<128x128xf8E4M3FN, #blockedA_desc> + %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blockedA_desc>) -> !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> + %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> + %b0 = tt.reshape %b_cm : tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> -> tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> + %b1 = tt.trans %b0 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> -> tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> + %b2 = tt.reshape %b1 : tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> -> tensor<256x128xf8E4M3FN, #linear_desc> + %b_s = ttg.local_alloc %b2 : (tensor<256x128xf8E4M3FN, #linear_desc>) -> !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> + %b_t = ttg.memdesc_trans %b_s {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> -> !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> + %r = ttng.warp_group_dot %a_s, %b_t, %acc {inputPrecision = 0 : i32, isAsync = true, maxNumImpreciseAcc = 1073741824 : i32} + : !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> * !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> -> tensor<128x256xf32, #mma_desc> + %w:3 = ttng.warp_group_dot_wait %r, %a_s, %b_t {pendings = 0 : i32} : tensor<128x256xf32, #mma_desc>, !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> + tt.return %w#0 : tensor<128x256xf32, #mma_desc> + } +} diff --git a/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir b/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir index aff7ca9c66af..1376692737b3 100644 --- a/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir +++ b/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir @@ -86,6 +86,27 @@ tt.func public @descriptor_kernel_arg(%arg0: !tt.tensordesc<64x64xf16>, %arg1: i // ----- +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16}> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { +// CHECK-DAG: #[[BLOCKED_16x128:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK-DAG: #[[NVMMA_128:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +// CHECK-DAG: #[[NVMMA_TRANSPOSED_32:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16}> +tt.func public @descriptor_ignores_transposed_local_alloc(%arg0: !tt.tensordesc<16x128xf16>) { + // CHECK: %arg0: !tt.tensordesc<16x128xf16, #[[NVMMA_128]]> + // CHECK: %[[LOAD:.*]] = tt.descriptor_load %arg0{{.*}} : !tt.tensordesc<16x128xf16, #[[NVMMA_128]]> -> tensor<16x128xf16, #[[BLOCKED_16x128]]> + // CHECK: ttg.local_alloc %[[LOAD]] : (tensor<16x128xf16, #[[BLOCKED_16x128]]>) -> !ttg.memdesc<16x128xf16, #[[NVMMA_TRANSPOSED_32]], #smem> + %c0 = arith.constant 0 : i32 + %0 = tt.descriptor_load %arg0[%c0, %c0] : !tt.tensordesc<16x128xf16> -> tensor<16x128xf16, #blocked> + %1 = ttg.local_alloc %0 : (tensor<16x128xf16, #blocked>) -> !ttg.memdesc<16x128xf16, #shared, #smem> + tt.return +} +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> @@ -123,3 +144,25 @@ tt.func public @tma_load_while(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, tt.return } } + +// ----- + +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#shared_linear_base = #ttg.shared_linear<{offset = [[0, 0, 0, 0, 1], [0, 0, 0, 0, 2], [0, 0, 0, 0, 4], [0, 0, 0, 0, 8], [0, 0, 0, 1, 0], [0, 0, 0, 2, 0], [0, 0, 0, 4, 0], [0, 0, 1, 0, 0], [0, 0, 2, 0, 0], [0, 0, 4, 0, 0], [0, 0, 8, 0, 0], [0, 0, 16, 0, 0], [0, 0, 32, 0, 0], [0, 0, 64, 0, 0], [0, 0, 128, 0, 0]]}, alignment = 128> +#smem = #ttg.shared_memory + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { +// CHECK-DAG: #[[NVMMA_0:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> +// CHECK-DAG: #[[BLOCKED5D:.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +// CHECK-DAG: #[[SL_BASE:.*]] = #ttg.shared_linear<{{.*}}alignment = 128> +tt.func public @descriptor_arg_from_shared_linear_use(%b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>) { + // CHECK: %arg0: !tt.tensordesc<1x1x256x8x16xf8E4M3FN, #[[NVMMA_0]]> + // CHECK: %[[LOAD:.*]] = tt.descriptor_load %arg0 + // CHECK-SAME: : !tt.tensordesc<1x1x256x8x16xf8E4M3FN, #[[NVMMA_0]]> -> tensor<1x1x256x8x16xf8E4M3FN, #[[BLOCKED5D]]> + // CHECK: ttg.local_alloc %[[LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, #[[BLOCKED5D]]>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #[[SL_BASE]], #smem> + %c0 = arith.constant 0 : i32 + %b = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked2> + %b_s = ttg.local_alloc %b : (tensor<1x1x256x8x16xf8E4M3FN, #blocked2>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared_linear_base, #smem> + tt.return +} +}