From 9374811a5aeb37310d5fc0130e76110a51836942 Mon Sep 17 00:00:00 2001 From: Tori Baker Date: Tue, 21 Apr 2026 01:47:52 -0700 Subject: [PATCH] Cherry-pick https://github.com/triton-lang/triton/pull/10032 which reverts a change that causes a crash in getTMABlockShape Added a minimal reproducer for our triton emitter as well in case this comes up again. PiperOrigin-RevId: 903075268 --- .../triton/common/cherry-pick-f43dff6f.patch | 475 ++++++++++++++++++ third_party/triton/common/series.bzl | 1 + .../triton/tests/dot/dot_transpose_tma.hlo | 27 + 3 files changed, 503 insertions(+) create mode 100644 third_party/triton/common/cherry-pick-f43dff6f.patch create mode 100644 xla/backends/gpu/codegen/triton/tests/dot/dot_transpose_tma.hlo diff --git a/third_party/triton/common/cherry-pick-f43dff6f.patch b/third_party/triton/common/cherry-pick-f43dff6f.patch new file mode 100644 index 0000000000000..6702e2ec18c87 --- /dev/null +++ b/third_party/triton/common/cherry-pick-f43dff6f.patch @@ -0,0 +1,475 @@ +// Cherry-pick go/triton-commit/f43dff6f + +--- a/include/triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h ++++ b/include/triton/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.h +@@ -41,15 +41,6 @@ + CGAEncodingAttr cgaLayout, + ArrayRef usageShape, + unsigned numCTAs); +- +-protected: +- virtual Attribute getCompatibleSharedEncoding(Attribute enc, +- ArrayRef shape, +- Type elementType) { +- return isCompatibleSharedEncoding(enc) ? enc : Attribute(); +- } +- +-private: + // Override with backend specific implementation + virtual Attribute buildFallbackSharedEncoding(mlir::MLIRContext *, + ArrayRef, + +--- a/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp +@@ -250,36 +250,23 @@ + + Attribute + AssignDescriptorMemoryLayouts::findLoadEncodingFromUsers(Operation *op) { +- auto getCompatibleEncodingForType = [&](Type type) -> Attribute { +- if (auto memDescTy = dyn_cast(type)) { +- return getCompatibleSharedEncoding(memDescTy.getEncoding(), +- memDescTy.getShape(), +- memDescTy.getElementType()); +- } +- if (auto tensorTy = dyn_cast(type)) { +- return getCompatibleSharedEncoding(tensorTy.getEncoding(), +- tensorTy.getShape(), +- tensorTy.getElementType()); +- } +- return {}; +- }; +- + // Check if there are any desired encodings available on the op + if (auto attr = op->getDiscardableAttr("tt.desired_encoding")) { +- if (auto resultTy = dyn_cast(op->getResult(0).getType())) +- if (auto compatible = getCompatibleSharedEncoding( +- attr, resultTy.getShape(), resultTy.getElementType())) +- return compatible; ++ if (auto enc = dyn_cast(attr)) { ++ if (isCompatibleSharedEncoding(enc)) ++ return enc; ++ } + } + // Ignore multiple users and just pick the first compatible layout + for (auto use : op->getUsers()) { + if (auto alloc = dyn_cast(use)) { +- if (auto compatible = getCompatibleEncodingForType(alloc.getType())) +- return compatible; ++ auto enc = alloc.getType().getEncoding(); ++ if (isCompatibleSharedEncoding(enc)) ++ return enc; + } else if (auto store = dyn_cast(use)) { +- if (auto compatible = +- getCompatibleEncodingForType(store.getDst().getType())) +- return compatible; ++ auto enc = store.getDst().getType().getEncoding(); ++ if (isCompatibleSharedEncoding(enc)) ++ return enc; + } + } + return {}; +@@ -455,9 +442,7 @@ + auto ctx = func.getContext(); + auto numCTAs = triton::gpu::lookupNumCTAs(func); + for (auto &[desc, einfo] : valueToEncodingInfo) { +- auto descTy = desc.getType(); +- auto existingTy = +- RankedTensorType::get(descTy.getShape(), descTy.getElementType()); ++ auto existingTy = desc.getType().getBlockType(); + Attribute newEncoding; + if (einfo->desiredEncoding) { + newEncoding = einfo->desiredEncoding; +@@ -475,11 +460,10 @@ + SmallVector resultTys(func.getResultTypes()); + for (auto [i, resultTy] : llvm::enumerate(resultTys)) { + if (auto descTy = dyn_cast(resultTy)) { +- auto existingTy = +- RankedTensorType::get(descTy.getShape(), descTy.getElementType()); +- auto encoding = getFallbackSharedEncoding(existingTy, {}, {}, numCTAs); +- resultTys[i] = +- getTensorDescTypeWithEncoding(nullptr, existingTy, encoding); ++ auto encoding = ++ getFallbackSharedEncoding(descTy.getBlockType(), {}, {}, numCTAs); ++ resultTys[i] = getTensorDescTypeWithEncoding( ++ nullptr, descTy.getBlockType(), encoding); + } + } + func.setFunctionType(FunctionType::get(ctx, argTys, resultTys)); + +--- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +@@ -3,6 +3,8 @@ + #include "mlir/Support/LogicalResult.h" + #include "mlir/Transforms/GreedyPatternRewriteDriver.h" + #include "mlir/Transforms/Passes.h" ++#include "triton/Analysis/Utility.h" ++#include "triton/Dialect/Triton/IR/Utility.h" + #include "triton/Dialect/TritonGPU/IR/Attributes.h" + #include "triton/Dialect/TritonGPU/IR/Dialect.h" + #include "triton/Dialect/TritonGPU/Transforms/Passes.h" +@@ -10,8 +12,7 @@ + #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + #include "triton/Tools/LayoutUtils.h" + #include "triton/Tools/LinearLayout.h" +-#include +-#include ++#include + + namespace mlir::triton::gpu { + +@@ -107,10 +108,7 @@ + return failure(); + + MemDescType allocType = allocOp.getType(); +- auto allocEncoding = +- dyn_cast(allocType.getEncoding()); +- if (!allocEncoding) +- return failure(); ++ auto allocEncoding = cast(allocType.getEncoding()); + RankedTensorType srcTy = trans.getSrc().getType(); + + auto ctx = getContext(); +@@ -178,113 +176,6 @@ + reshapeOp.getSrc()); + rewriter.replaceOpWithNewOp(allocOp, allocOp.getType(), + newAlloc); +- return success(); +- } +-}; +- +-// Rewrite +-// tt.reshape / tt.trans -> local_alloc -> [memdesc views] -> mma +-// into +-// local_alloc -> memdesc reshape / trans -> [memdesc views] -> mma +-// +-// The MMA operand layout is determined by the sink memdesc already feeding the +-// dot-like op. This pattern back-propagates that layout through the tensor +-// reshape/transpose chain, hoists local_alloc to the base tensor feeding that +-// view chain, and replays those tensor views as memdesc reshape/transpose +-// ops so the original local_alloc type is preserved. +-class RewriteMmaOperandViewsToMemDescForDotOp +- : public OpInterfaceRewritePattern { +-public: +- using OpInterfaceRewritePattern< +- triton::DotOpInterface>::OpInterfaceRewritePattern; +- +- LogicalResult matchAndRewrite(triton::DotOpInterface dotOp, +- PatternRewriter &rewriter) const override { +- if (!isa(dotOp)) +- return failure(); +- +- bool changed = false; +- +- if (rewriteOperand(dotOp.getA(), rewriter).succeeded()) +- changed = true; +- +- if (rewriteOperand(dotOp.getB(), rewriter).succeeded()) +- changed = true; +- +- return success(changed); +- } +- +-private: +- LogicalResult rewriteOperand(Value operand, PatternRewriter &rewriter) const { +- if (!isa(operand.getType())) +- return failure(); +- +- Value beforeTrailing = operand; +- while (auto view = beforeTrailing.getDefiningOp()) { +- if (auto reshape = dyn_cast(view)) { +- beforeTrailing = reshape.getSrc(); +- continue; +- } +- if (auto trans = dyn_cast(view)) { +- beforeTrailing = trans.getSrc(); +- continue; +- } +- break; +- } +- +- auto localAlloc = beforeTrailing.getDefiningOp(); +- if (!localAlloc || !localAlloc.getSrc()) +- return failure(); +- +- Value baseTensor = localAlloc.getSrc(); +- SmallVector tensorReplaySteps; +- MemDescType baseMemTy = localAlloc.getType(); +- while (auto view = baseTensor.getDefiningOp()) { +- if (auto reshape = dyn_cast(view)) { +- MemDescType srcTy; +- auto inferred = MemDescReshapeOp::inferReturnTypes( +- getContext(), reshape.getLoc(), baseMemTy, +- reshape.getSrc().getType().getShape(), srcTy); +- assert(succeeded(inferred) && "backward memdesc reshape inference " +- "must succeed"); +- (void)inferred; +- baseMemTy = srcTy; +- } else if (auto trans = dyn_cast(view)) { +- Attribute srcEnc = inferSrcEncoding(view, baseMemTy.getEncoding()); +- if (!srcEnc) +- return failure(); +- baseMemTy = MemDescType::get( +- trans.getSrc().getType().getShape(), baseMemTy.getElementType(), +- srcEnc, baseMemTy.getMemorySpace(), baseMemTy.getMutableMemory()); +- } else { +- break; +- } +- tensorReplaySteps.push_back(view); +- baseTensor = view->getOperand(0); +- } +- if (tensorReplaySteps.empty()) +- return failure(); +- +- std::reverse(tensorReplaySteps.begin(), tensorReplaySteps.end()); +- +- PatternRewriter::InsertionGuard guard(rewriter); +- rewriter.setInsertionPoint(localAlloc); +- +- Value rewritten = LocalAllocOp::create(rewriter, localAlloc.getLoc(), +- baseMemTy, baseTensor); +- for (Operation *op : tensorReplaySteps) { +- if (auto reshape = dyn_cast(op)) { +- rewritten = MemDescReshapeOp::create(rewriter, op->getLoc(), rewritten, +- reshape.getType().getShape()); +- } else { +- auto trans = cast(op); +- rewritten = MemDescTransOp::create(rewriter, op->getLoc(), rewritten, +- trans.getOrder()); +- } +- } +- rewriter.replaceOp(localAlloc, rewritten); + return success(); + } + }; +@@ -450,7 +341,6 @@ + mlir::RewritePatternSet patterns(context); + patterns.add(context); + patterns.add(context); +- patterns.add(context); + patterns.add(context); + ConvertLayoutOp::getCanonicalizationPatterns(patterns, context); + if (failed(applyPatternsGreedily(m, std::move(patterns)))) + +--- a/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp ++++ b/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp +@@ -20,58 +20,15 @@ + ArrayRef order, + ttg::CGAEncodingAttr cgaLayout, + Type elementType) override; +- Attribute getCompatibleSharedEncoding(Attribute enc, ArrayRef shape, +- Type elementType) override; + bool isCompatibleSharedEncoding(Attribute enc) override; + }; + + bool NvidiaGPUAssignDescriptorMemoryLayouts::isCompatibleSharedEncoding( + Attribute enc) { +- if (auto nvmma = dyn_cast(enc)) ++ if (auto nvmma = dyn_cast(enc)) { + return !nvmma.getTransposed(); ++ } + return false; +-} +- +-Attribute NvidiaGPUAssignDescriptorMemoryLayouts::getCompatibleSharedEncoding( +- Attribute enc, ArrayRef shape, Type elementType) { +- if (isCompatibleSharedEncoding(enc)) +- return enc; +- +- auto sharedLinear = dyn_cast(enc); +- if (!sharedLinear) +- return {}; +- +- auto *ctx = enc.getContext(); +- auto cgaLayout = ttg::getCGALayout(sharedLinear); +- auto order = ttg::getOrder(sharedLinear, shape); +- +- SmallVector preferredCandidates; +- // TMA descriptors only support non-transposed layouts. Preserve Triton's +- // default shape/order-based choice when it already matches this +- // shared_linear layout. The full candidate scan below is only a fallback for +- // equivalent non-transposed layouts not selected by the heuristic builder. +- for (bool fp4Padded : {false, true}) { +- auto preferred = ttg::NVMMASharedEncodingAttr::get( +- ctx, shape, order, cgaLayout, elementType, fp4Padded); +- preferredCandidates.push_back(preferred); +- if (ttg::areLayoutsEquivalent(shape, sharedLinear, preferred)) +- return preferred; +- } +- +- unsigned elementBitWidth = std::max(8u, elementType.getIntOrFloatBitWidth()); +- for (bool fp4Padded : {false, true}) { +- for (unsigned swizzle : {0u, 32u, 64u, 128u}) { +- auto candidate = ttg::NVMMASharedEncodingAttr::get( +- ctx, swizzle, /*transposed=*/false, elementBitWidth, fp4Padded, +- cgaLayout); +- if (llvm::is_contained(preferredCandidates, candidate)) +- continue; +- if (ttg::areLayoutsEquivalent(shape, sharedLinear, candidate)) +- return candidate; +- } +- } +- +- return {}; + } + + // Build fallback encoding given shape, order, cga layout and element type + +--- a/test/TritonGPU/dot-operands.mlir ++++ b/test/TritonGPU/dot-operands.mlir +@@ -293,96 +293,3 @@ + tt.return %a: !ttg.memdesc<64x64xf32, #shared, #smem> + } + } +- +-// ----- +- +-#blocked0 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +-#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +-#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> +-#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> +-#linear0 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> +-#linear2 = #ttg.linear<{register = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 128], [16, 0], [32, 0], [64, 0]], lane = [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16]], warp = [[0, 32], [0, 64]], block = []}> +-#shared0 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +-#shared5 = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> +-#smem = #ttg.shared_memory +-#tmem0 = #ttng.tensor_memory_encoding +-// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +-// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +-module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +- // CHECK-LABEL: @swizzle0_operand_views +- // CHECK-DAG: %[[A_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<128x128xf8E4M3FN> -> tensor<128x128xf8E4M3FN, #{{.*}}> +- // CHECK-DAG: %[[A_BASE:.*]] = ttg.local_alloc %[[A_DESC]] : (tensor<128x128xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<128x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> +- // CHECK-DAG: %[[B_DESC:.*]] = tt.descriptor_load {{.*}} : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> +- // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC]] : (tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> +- // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> +- // CHECK-DAG: %[[B_TR0:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> +- // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR0]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> +- // CHECK-DAG: %[[B_TR1:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #{{.*}}, #smem{{.*}}> +- // CHECK-NOT: tt.reshape +- // CHECK-NOT: tt.trans +- // CHECK-NOT: ttg.local_load +- // CHECK: ttng.tc_gen5_mma %[[A_BASE]], %[[B_TR1]], %arg2, %true, %true +- tt.func @swizzle0_operand_views( +- %a_desc: !tt.tensordesc<128x128xf8E4M3FN>, +- %b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>, +- %acc: !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory>) { +- %true = arith.constant true +- %c0_i32 = arith.constant 0 : i32 +- %a = tt.descriptor_load %a_desc[%c0_i32, %c0_i32] : !tt.tensordesc<128x128xf8E4M3FN> -> tensor<128x128xf8E4M3FN, #blocked0> +- %b = tt.descriptor_load %b_desc[%c0_i32, %c0_i32, %c0_i32, %c0_i32, %c0_i32] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked1> +- %b1 = tt.reshape %b : tensor<1x1x256x8x16xf8E4M3FN, #blocked1> -> tensor<8x32x8x16xf8E4M3FN, #blocked2> +- %b2 = tt.trans %b1 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blocked2> -> tensor<32x8x8x16xf8E4M3FN, #blocked3> +- %b3 = tt.reshape %b2 : tensor<32x8x8x16xf8E4M3FN, #blocked3> -> tensor<256x128xf8E4M3FN, #linear0> +- %b4 = tt.trans %b3 {order = array} : tensor<256x128xf8E4M3FN, #linear0> -> tensor<128x256xf8E4M3FN, #linear2> +- %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blocked0>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem> +- %b_s = ttg.local_alloc %b4 : (tensor<128x256xf8E4M3FN, #linear2>) -> !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem> +- ttng.tc_gen5_mma %a_s, %b_s, %acc, %true, %true : !ttg.memdesc<128x128xf8E4M3FN, #shared0, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #shared5, #smem>, !ttg.memdesc<128x256xf32, #tmem0, #ttng.tensor_memory> +- tt.return +- } +-} +- +-// ----- +- +-#blockedA_desc = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +-#blockedB_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +-#blockedB0_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [1, 4, 8, 1], warpsPerCTA = [1, 4, 1, 1], order = [3, 2, 1, 0]}> +-#blockedB1_desc = #ttg.blocked<{sizePerThread = [1, 1, 1, 16], threadsPerWarp = [4, 8, 1, 1], warpsPerCTA = [4, 1, 1, 1], order = [3, 1, 0, 2]}> +-#linear_desc = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [128, 0], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[32, 0], [64, 0]], block = []}> +-#mma_desc = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 32]}> +-#sharedA_desc = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> +-#sharedB3_desc = #ttg.shared_linear<{offset = [[0, 1], [0, 2], [0, 4], [0, 8], [1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [32, 0], [64, 0], [128, 0], [0, 16], [0, 32], [0, 64]]}, alignment = 128> +-#sharedB4_desc = #ttg.shared_linear<{offset = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [16, 0], [32, 0], [64, 0]]}, alignment = 128> +-#smem = #ttg.shared_memory +-// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +-// CHECK-DAG: #{{.*}} = #ttg.shared_linear<{{.*}}alignment = 128> +-module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { +- // CHECK-LABEL: @swizzle0_operand_views_warp_group_dot +- // CHECK-DAG: %[[B_DESC_LOAD:.*]] = tt.descriptor_load %arg1[%{{.*}}] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #{{.*}}> +- // CHECK-DAG: %[[B_BASE:.*]] = ttg.local_alloc %[[B_DESC_LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, {{.*}}>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> +- // CHECK-DAG: %[[B_RS0:.*]] = ttg.memdesc_reshape %[[B_BASE]] : !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #{{.*}}, #smem{{.*}}> -> !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> +- // CHECK-DAG: %[[B_TR:.*]] = ttg.memdesc_trans %[[B_RS0]] {order = array} : !ttg.memdesc<8x32x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> +- // CHECK-DAG: %[[B_RS1:.*]] = ttg.memdesc_reshape %[[B_TR]] : !ttg.memdesc<32x8x8x16xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> +- // CHECK-DAG: %[[B_T:.*]] = ttg.memdesc_trans %[[B_RS1]] {order = array} : !ttg.memdesc<256x128xf8E4M3FN, {{.*}}, #smem{{.*}}> -> !ttg.memdesc<128x256xf8E4M3FN, #{{.*}}, #smem{{.*}}> +- // CHECK-NOT: tt.reshape +- // CHECK-NOT: tt.trans +- // CHECK: %[[DOT:.*]] = ttng.warp_group_dot %{{.*}}, %[[B_T]], %arg2 +- // CHECK: %[[WAIT:.*]]:3 = ttng.warp_group_dot_wait %[[DOT]], {{.*}}, %[[B_T]] {pendings = 0 : i32} +- tt.func @swizzle0_operand_views_warp_group_dot( +- %a_desc: !tt.tensordesc<128x128xf8E4M3FN, #sharedA_desc>, +- %b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>, +- %acc: tensor<128x256xf32, #mma_desc>) -> tensor<128x256xf32, #mma_desc> { +- %c0 = arith.constant 0 : i32 +- %a = tt.descriptor_load %a_desc[%c0, %c0] : !tt.tensordesc<128x128xf8E4M3FN, #sharedA_desc> -> tensor<128x128xf8E4M3FN, #blockedA_desc> +- %a_s = ttg.local_alloc %a : (tensor<128x128xf8E4M3FN, #blockedA_desc>) -> !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> +- %b_cm = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> +- %b0 = tt.reshape %b_cm : tensor<1x1x256x8x16xf8E4M3FN, #blockedB_desc> -> tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> +- %b1 = tt.trans %b0 {order = array} : tensor<8x32x8x16xf8E4M3FN, #blockedB0_desc> -> tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> +- %b2 = tt.reshape %b1 : tensor<32x8x8x16xf8E4M3FN, #blockedB1_desc> -> tensor<256x128xf8E4M3FN, #linear_desc> +- %b_s = ttg.local_alloc %b2 : (tensor<256x128xf8E4M3FN, #linear_desc>) -> !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> +- %b_t = ttg.memdesc_trans %b_s {order = array} : !ttg.memdesc<256x128xf8E4M3FN, #sharedB3_desc, #smem> -> !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> +- %r = ttng.warp_group_dot %a_s, %b_t, %acc {inputPrecision = 0 : i32, isAsync = true, maxNumImpreciseAcc = 1073741824 : i32} +- : !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem> * !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> -> tensor<128x256xf32, #mma_desc> +- %w:3 = ttng.warp_group_dot_wait %r, %a_s, %b_t {pendings = 0 : i32} : tensor<128x256xf32, #mma_desc>, !ttg.memdesc<128x128xf8E4M3FN, #sharedA_desc, #smem>, !ttg.memdesc<128x256xf8E4M3FN, #sharedB4_desc, #smem> +- tt.return %w#0 : tensor<128x256xf32, #mma_desc> +- } +-} + +--- a/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir ++++ b/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir +@@ -86,27 +86,6 @@ + + // ----- + +-#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +-#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16}> +-#smem = #ttg.shared_memory +- +-module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { +-// CHECK-DAG: #[[BLOCKED_16x128:.*]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +-// CHECK-DAG: #[[NVMMA_128:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +-// CHECK-DAG: #[[NVMMA_TRANSPOSED_32:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16}> +-tt.func public @descriptor_ignores_transposed_local_alloc(%arg0: !tt.tensordesc<16x128xf16>) { +- // CHECK: %arg0: !tt.tensordesc<16x128xf16, #[[NVMMA_128]]> +- // CHECK: %[[LOAD:.*]] = tt.descriptor_load %arg0{{.*}} : !tt.tensordesc<16x128xf16, #[[NVMMA_128]]> -> tensor<16x128xf16, #[[BLOCKED_16x128]]> +- // CHECK: ttg.local_alloc %[[LOAD]] : (tensor<16x128xf16, #[[BLOCKED_16x128]]>) -> !ttg.memdesc<16x128xf16, #[[NVMMA_TRANSPOSED_32]], #smem> +- %c0 = arith.constant 0 : i32 +- %0 = tt.descriptor_load %arg0[%c0, %c0] : !tt.tensordesc<16x128xf16> -> tensor<16x128xf16, #blocked> +- %1 = ttg.local_alloc %0 : (tensor<16x128xf16, #blocked>) -> !ttg.memdesc<16x128xf16, #shared, #smem> +- tt.return +-} +-} +- +-// ----- +- + + #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +@@ -144,25 +123,3 @@ + tt.return + } + } +- +-// ----- +- +-#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +-#shared_linear_base = #ttg.shared_linear<{offset = [[0, 0, 0, 0, 1], [0, 0, 0, 0, 2], [0, 0, 0, 0, 4], [0, 0, 0, 0, 8], [0, 0, 0, 1, 0], [0, 0, 0, 2, 0], [0, 0, 0, 4, 0], [0, 0, 1, 0, 0], [0, 0, 2, 0, 0], [0, 0, 4, 0, 0], [0, 0, 8, 0, 0], [0, 0, 16, 0, 0], [0, 0, 32, 0, 0], [0, 0, 64, 0, 0], [0, 0, 128, 0, 0]]}, alignment = 128> +-#smem = #ttg.shared_memory +- +-module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { +-// CHECK-DAG: #[[NVMMA_0:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, rank = 5}> +-// CHECK-DAG: #[[BLOCKED5D:.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 16], threadsPerWarp = [1, 1, 4, 8, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +-// CHECK-DAG: #[[SL_BASE:.*]] = #ttg.shared_linear<{{.*}}alignment = 128> +-tt.func public @descriptor_arg_from_shared_linear_use(%b_desc: !tt.tensordesc<1x1x256x8x16xf8E4M3FN>) { +- // CHECK: %arg0: !tt.tensordesc<1x1x256x8x16xf8E4M3FN, #[[NVMMA_0]]> +- // CHECK: %[[LOAD:.*]] = tt.descriptor_load %arg0 +- // CHECK-SAME: : !tt.tensordesc<1x1x256x8x16xf8E4M3FN, #[[NVMMA_0]]> -> tensor<1x1x256x8x16xf8E4M3FN, #[[BLOCKED5D]]> +- // CHECK: ttg.local_alloc %[[LOAD]] : (tensor<1x1x256x8x16xf8E4M3FN, #[[BLOCKED5D]]>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #[[SL_BASE]], #smem> +- %c0 = arith.constant 0 : i32 +- %b = tt.descriptor_load %b_desc[%c0, %c0, %c0, %c0, %c0] : !tt.tensordesc<1x1x256x8x16xf8E4M3FN> -> tensor<1x1x256x8x16xf8E4M3FN, #blocked2> +- %b_s = ttg.local_alloc %b : (tensor<1x1x256x8x16xf8E4M3FN, #blocked2>) -> !ttg.memdesc<1x1x256x8x16xf8E4M3FN, #shared_linear_base, #smem> +- tt.return +-} +-} + diff --git a/third_party/triton/common/series.bzl b/third_party/triton/common/series.bzl index afe7c69555202..254fb6fa950d5 100644 --- a/third_party/triton/common/series.bzl +++ b/third_party/triton/common/series.bzl @@ -34,5 +34,6 @@ common_patch_list = [ "//third_party/triton:common/llvm_cl895542516.patch", "//third_party/triton:common/assert_fail.patch", "//third_party/triton:common/llvm_cl900404532.patch", + "//third_party/triton:common/cherry-pick-f43dff6f.patch", # Add new patches just above this line ] diff --git a/xla/backends/gpu/codegen/triton/tests/dot/dot_transpose_tma.hlo b/xla/backends/gpu/codegen/triton/tests/dot/dot_transpose_tma.hlo new file mode 100644 index 0000000000000..e7637692f0522 --- /dev/null +++ b/xla/backends/gpu/codegen/triton/tests/dot/dot_transpose_tma.hlo @@ -0,0 +1,27 @@ +// Regression test for b/503878530 +// RUN: triton_test_correctness --abs_error_bound=1e-3 --rel_error_bound=1e-3 %s + +dot { + p0 = f16[64,128,32]{2,1,0} parameter(0) + t = f16[64,32,128]{2,1,0} transpose(p0), dimensions={0,2,1} + b = f16[2048,128]{1,0} bitcast(t) + p1 = f16[128,256]{1,0} parameter(1) + ROOT dot = f16[2048,256]{1,0} dot(b, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, backend_config={"sizes":["128"]} +} + +ENTRY entry_computation { + p0 = f16[64,128,32]{2,1,0} parameter(0) + p1 = f16[128,256]{1,0} parameter(1) + ROOT fusion = f16[2048,256]{1,0} fusion(p0, p1), kind=kCustom, + calls=dot, + backend_config={"fusion_backend_config":{"kind":"__triton_nested_gemm_fusion", + "block_level_fusion_config":{ + "num_warps":"8", + "output_tiles":[{"sizes":["128","128"]}], + "num_ctas":1, + "num_stages":1, + "is_tma_allowed":true + } + } + } +} \ No newline at end of file