diff --git a/test/Conversion/amd/tritongpu_tdm_stride_order.mlir b/test/Conversion/amd/tritongpu_tdm_stride_order.mlir index 7e352951bd15..a5efead5c604 100644 --- a/test/Conversion/amd/tritongpu_tdm_stride_order.mlir +++ b/test/Conversion/amd/tritongpu_tdm_stride_order.mlir @@ -4,9 +4,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @tdm_load_runtime_strides(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %stride0: i64, %stride1: i64) { %c_shape = arith.constant 128 : i32 - %c_stride0 = arith.constant 128 : i64 - %c_stride1 = arith.constant 1 : i64 - // expected-error @+2 {{requires at least one dimension to have stride 1}} + // expected-error @+2 {{last dimension must have stride 1}} // expected-error @+1 {{failed to legalize operation}} %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%stride0, %stride1] : , <64x64xf16, #shared> tt.return @@ -17,12 +15,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [64, 64]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - tt.func public @tdm_1x1_tensor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %stride0: i64, %stride1: i64) { + // CHECK-LABEL: tdm_1x1_tensor + tt.func public @tdm_1x1_tensor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %stride1: i64) { %c_shape = arith.constant 1 : i32 %c_stride1 = arith.constant 1 : i64 - // expected-error @+2 {{requires at least one dimension to have stride 1}} - // expected-error @+1 {{failed to legalize operation}} - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%stride1, %stride1] : , <64x64xf16, #shared> + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride1, %c_stride1] : , <64x64xf16, #shared> tt.return } } @@ -34,7 +31,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr tt.func public @tdm_wrong_stride_order(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %runtime_stride: i64) { %c_shape = arith.constant 128 : i32 %c_stride1 = arith.constant 1 : i64 - // expected-error @+2 {{requires shared order [rank-2, rank-1, rank-3, rank-4, ..., 0] because dim[rank-2] has stride 1}} + // expected-error @+2 {{last dimension must have stride 1}} // expected-error @+1 {{failed to legalize operation}} %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride1, %runtime_stride] : , <64x64xf16, #shared> tt.return @@ -103,13 +100,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #shared = #ttg.padded_shared<[32:+4] {order = [0, 1, 2], shape = [1, 1, 1]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - tt.func public @tdm_1x1xi(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %runtime_stride: i64) { + tt.func public @tdm_1xix1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %runtime_stride: i64) { %c_stride1 = arith.constant 1 : i64 %c_shape = arith.constant 1 : i32 %c_shape2 = arith.constant 128 : i32 // expected-error @+2 {{requires all stride 1 dimensions to be consecutive starting from the last dimension}} // expected-error @+1 {{failed to legalize operation}} - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape, %c_shape2], [%c_stride1, %c_stride1, %runtime_stride] : , <1x1x1xf16, #shared> + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape, %c_shape2], [%c_stride1, %runtime_stride, %c_stride1] : , <1x1x1xf16, #shared> tt.return } } diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index c2011f238f52..df0ce78f99e6 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -703,7 +703,7 @@ LogicalResult AsyncTDMCopyLocalToGlobalOp::verify() { return emitOpError("TDM store only supports single interval paddings."); auto shapePerCTA = triton::gpu::getShapePerCTA(paddedEnc, blockShape); - if (intervals[0] != shapePerCTA[paddedEnc.getOrder().front()]) + if (intervals[0] != shapePerCTA.back()) return emitOpError("TDM store padding is only supported when padding " "interval equals the innermost block dimension (got " "padInterval=") @@ -769,13 +769,6 @@ LogicalResult AsyncTDMScatterOp::verify() { if (!paddedEnc && !swizzledEnc) return emitOpError("Invalid shared memory layout for TDM"); - auto shapePerCTA = triton::gpu::getShapePerCTA(smemTy); - auto sharedOrder = triton::gpu::getOrder( - cast(smemTy.getEncoding()), - shapePerCTA); - if (sharedOrder[0] != (sharedOrder.size() - 1)) - return emitOpError("TDM scatter only supports row-major shared order"); - return success(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 9dc4a0ac8a75..89f65619fa52 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1267,15 +1267,12 @@ struct AsyncTDMCopyGlobalToLocalOpConversion auto shapePerCTA = triton::gpu::getShapePerCTA(encoding, tensorDescTy.getShape()); - auto sharedOrder = triton::gpu::getOrder( - cast(encoding), shapePerCTA); - bool isRowMajor = sharedOrder[0] == (sharedOrder.size() - 1); mlir::LLVM::AMD::emitTDMLoadStore( rewriter, loc, getTypeConverter(), desc, shapePerCTA, numWarps, padInterval, padAmount, offset, dstPtrs, op.getPred(), multicastMask, - elementType, barrierPtr, /*isLoad=*/true, sharedLayout, encoding, ctaId, - isRowMajor); + elementType, barrierPtr, /*isLoad=*/true, sharedLayout, encoding, + ctaId); rewriter.eraseOp(op); return success(); @@ -1347,17 +1344,13 @@ struct AsyncTDMCopyLocalToGlobalOpConversion auto ctaId = targetInfo.getClusterCTAId(rewriter, loc); auto shapePerCTA = triton::gpu::getShapePerCTA(smemTy); - auto sharedOrder = triton::gpu::getOrder( - cast(smemTy.getEncoding()), - shapePerCTA); - bool isRowMajor = sharedOrder[0] == (sharedOrder.size() - 1); Value pred = arith::ConstantIntOp::create(rewriter, loc, 1, 32); mlir::LLVM::AMD::emitTDMLoadStore( rewriter, loc, getTypeConverter(), desc, shapePerCTA, numWarps, padInterval, padAmount, offset, srcPtrs, pred, /*multicastMask=*/{}, elementType, barrierPtr, - /*isLoad=*/false, sharedLayout, encoding, ctaId, isRowMajor); + /*isLoad=*/false, sharedLayout, encoding, ctaId); rewriter.eraseOp(op); return success(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp index a1f417f17693..f656767826b6 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp @@ -127,12 +127,6 @@ SmallVector TDMDescriptor::getAllGroups() const { return result; } -// Swap the trailing two dimensions of a vector for TDM operations. -template void swapTrailingDims(SmallVector &vec) { - assert(vec.size() >= 2 && "need at least 2 dims to swap"); - std::swap(vec[vec.size() - 2], vec[vec.size() - 1]); -} - // Decode a full TDM descriptor from all 4 group vectors for 3D-5D tensors // Returns (base, tensorShape[], tensorStride[], blockShape[]) std::tuple, SmallVector, SmallVector> @@ -233,7 +227,7 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc, unsigned padInterval, unsigned padAmount, SmallVector tensorShape, SmallVector tensorStride, Value srcPtr, - bool isRowMajor, Attribute sharedEncoding) { + Attribute sharedEncoding) { size_t numDims = tensorShape.size(); assert(numDims >= 1 && numDims <= 5 && tensorStride.size() == numDims && "TDM only supported for 1D-5D tensors."); @@ -242,12 +236,6 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc, "Block/tensor/stride dim count must all be equal."); auto b = TritonLLVMOpBuilder(loc, rewriter); - if (!isRowMajor) { - swapTrailingDims(blockShape); - swapTrailingDims(tensorStride); - swapTrailingDims(tensorShape); - } - // Define common values for better readability Value v16 = b.i32_val(16); Value v32 = b.i64_val(32); @@ -508,29 +496,6 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc, return TDMDescriptor{group0, group1, group2, group3}; } -// Returns a copy of `layout` where the semantics of dimA and dimB are -// exchanged: new.apply(x)[dimA] == old.apply(x)[dimB] and vice versa. The -// output dimension order is preserved. -static triton::LinearLayout -swapOutDimSemantics(const triton::LinearLayout &layout, StringAttr dimA, - StringAttr dimB) { - assert(layout.hasOutDim(dimA)); - assert(layout.hasOutDim(dimB)); - SmallVector> renamedOutDims; - for (auto [name, size] : layout.getOutDims()) { - if (name == dimA) - renamedOutDims.push_back({dimB, size}); - else if (name == dimB) - renamedOutDims.push_back({dimA, size}); - else - renamedOutDims.push_back({name, size}); - } - // Transpose to restore the original output dimension order. - return triton::LinearLayout(layout.getBases(), renamedOutDims, - /*requireSurjective=*/false) - .transposeOuts(llvm::to_vector(layout.getOutDimNames())); -} - // Fill TDM descriptor for regular load/store operations (1D-5D tensors) void fillTDMDescriptor( RewriterBase &rewriter, Location loc, @@ -542,7 +507,7 @@ void fillTDMDescriptor( SmallVector offset, ArrayRef dstPtrs, Value pred, Value multicastMask, Value barrierPtr, const triton::LinearLayout &sharedLayout, Value ctaId, bool isStore, - bool isRowMajor, ArrayRef warpsPerCTA) { + ArrayRef warpsPerCTA) { size_t numDims = offset.size(); assert(numDims >= 1 && numDims <= 5 && "TDM supports 1D to 5D tensors."); assert(!dstPtrs.empty() && "dstPtrs cannot be empty"); @@ -553,22 +518,6 @@ void fillTDMDescriptor( Type globalPtrTy = ptr_ty(ctx, 1); Type sharedPtrTy = ptr_ty(ctx, 3); - // For col-major tensors the TDM descriptor was created with the trailing two - // dimensions swapped. Swap shapePerCTA and offset to match that hardware - // view, and rename the same two dims in the shared layout to align out dims. - std::optional adjustedSharedLayout; - if (!isRowMajor) { - swapTrailingDims(shapePerCTA); - swapTrailingDims(offset); - if (numDims >= 2) { - auto dimN_2 = StringAttr::get(ctx, "dim" + std::to_string(numDims - 2)); - auto dimN_1 = StringAttr::get(ctx, "dim" + std::to_string(numDims - 1)); - adjustedSharedLayout = swapOutDimSemantics(sharedLayout, dimN_2, dimN_1); - } - } - const auto &tdmViewSharedLayout = - adjustedSharedLayout ? *adjustedSharedLayout : sharedLayout; - // Decode the full TDM descriptor to get all values auto [srcPtr, tensorShape, tensorStride, decodedBlockShape] = decodeTDMDescriptorFull( @@ -588,7 +537,7 @@ void fillTDMDescriptor( auto kPartition = str_attr("partition"); auto cgaLayout = triton::gpu::SharedLinearEncodingAttr::get( - ctx, tdmViewSharedLayout, /*layoutAlignment=*/16) + ctx, sharedLayout, /*layoutAlignment=*/16) .getCGALayout() .getLinearLayout(); @@ -615,7 +564,7 @@ void fillTDMDescriptor( } srcPtr = b.gep(globalPtrTy, elementType, srcPtr, baseOffset); - auto tdmToShared = tdmLayout.invertAndCompose(tdmViewSharedLayout); + auto tdmToShared = tdmLayout.invertAndCompose(sharedLayout); auto sharedOffsets = applyLinearLayout( loc, rewriter, tdmToShared, {{kMessage, b.i32_val(0)}, {kWarp, warpId}, {kBlock, ctaId}}); @@ -956,7 +905,7 @@ emitTDMIntrinsic(RewriterBase &rewriter, Location loc, SmallVector globalOffset, ArrayRef instrDstPtrs, Value pred, Value multicastMask, Value barrier, const triton::LinearLayout &instrSharedLayout, Value ctaId, - bool isLoad, bool isRowMajor, ArrayRef warpsPerCTA) { + bool isLoad, ArrayRef warpsPerCTA) { auto b = TritonLLVMOpBuilder(loc, rewriter); auto v8i32Ty = VectorType::get(8, rewriter.getI32Type()); Value group4Zero = LLVM::ZeroOp::create(rewriter, loc, v8i32Ty); @@ -972,7 +921,7 @@ emitTDMIntrinsic(RewriterBase &rewriter, Location loc, group0Vec, group1Vec, std::ref(group2Vec), std::ref(group3Vec), globalOffset, instrDstPtrs, pred, multicastMask, barrier, instrSharedLayout, ctaId, !isLoad, - isRowMajor, warpsPerCTA); + warpsPerCTA); auto group0 = packLLVector(loc, group0Vec, rewriter); auto group1 = packLLVector(loc, group1Vec, rewriter); @@ -988,11 +937,11 @@ emitTDMIntrinsic(RewriterBase &rewriter, Location loc, auto group0Vec = SmallVector(desc.begin(), desc.begin() + 4); auto group1Vec = SmallVector(desc.begin() + 4, desc.end()); - fillTDMDescriptor( - rewriter, loc, typeConverter, elementType, effectiveBlockShape, - numWarps, padInterval, padAmount, group0Vec, group1Vec, std::nullopt, - std::nullopt, globalOffset, instrDstPtrs, pred, multicastMask, barrier, - instrSharedLayout, ctaId, !isLoad, isRowMajor, warpsPerCTA); + fillTDMDescriptor(rewriter, loc, typeConverter, elementType, + effectiveBlockShape, numWarps, padInterval, padAmount, + group0Vec, group1Vec, std::nullopt, std::nullopt, + globalOffset, instrDstPtrs, pred, multicastMask, barrier, + instrSharedLayout, ctaId, !isLoad, warpsPerCTA); auto group0 = packLLVector(loc, group0Vec, rewriter); auto group1 = packLLVector(loc, group1Vec, rewriter); @@ -1024,7 +973,7 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc, Value pred, Value multicastMask, Type elementType, Value barrierPtr, bool isLoad, const triton::LinearLayout &sharedLayout, - Attribute encoding, Value ctaId, bool isRowMajor) { + Attribute encoding, Value ctaId) { auto b = TritonLLVMOpBuilder(loc, rewriter); size_t numDims = blockShape.size(); assert(numDims <= 5); @@ -1040,8 +989,7 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc, emitTDMIntrinsic(rewriter, loc, typeConverter, desc, numDims, elementType, to_vector(blockShape), numWarps, padInterval, padAmount, to_vector(offset), dstPtrs, pred, multicastMask, - barrierPtr, sharedLayout, ctaId, isLoad, isRowMajor, - warpsPerCTA); + barrierPtr, sharedLayout, ctaId, isLoad, warpsPerCTA); return; } @@ -1105,7 +1053,7 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc, emitTDMIntrinsic(rewriter, loc, typeConverter, desc, numDims, elementType, effectiveBlockShape, numWarps, padInterval, padAmount, globalOffset, instrDstPtrs, pred, multicastMask, barrier, - sliceLayout, ctaId, isLoad, isRowMajor, warpsPerCTA); + sliceLayout, ctaId, isLoad, warpsPerCTA); } } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h index 1ec26ea6f1f1..f757404602fe 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h @@ -37,7 +37,7 @@ TDMDescriptor createTDMDescriptor(RewriterBase &rewriter, Location loc, unsigned padInterval, unsigned padAmount, SmallVector tensorShape, SmallVector tensorStride, Value srcPtr, - bool isRowMajor, Attribute sharedEncoding); + Attribute sharedEncoding); // Update the global memory address with offset, and fill the shared memory // address and pred in a given TDM descriptor for regular load/store (1D-5D). @@ -53,7 +53,7 @@ void fillTDMDescriptor( SmallVector offset, ArrayRef dstPtrs, Value pred, Value multicastMask, Value barrierPtr, const triton::LinearLayout &sharedLayout, Value ctaId, bool isStore, - bool isRowMajor, ArrayRef warpsPerCTA); + ArrayRef warpsPerCTA); // Fill TDM descriptor for gather/scatter operations (2D only). // Gather reads from non-contiguous rows in global memory to LDS. @@ -94,7 +94,7 @@ void emitTDMLoadStore(RewriterBase &rewriter, Location loc, Value pred, Value multicastMask, Type elementType, Value barrierPtr, bool isLoad, const triton::LinearLayout &sharedLayout, - Attribute encoding, Value ctaId, bool isRowMajor); + Attribute encoding, Value ctaId); // Returns (warpsPerCTA, numTDMInstructions) for a given shared encoding. // For PartitionedSharedEncodingAttr, computes a partition-aligned warp diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp index 350a1c05a77f..4c806d64dfa1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp @@ -16,8 +16,6 @@ namespace { // compatible with TDM. Requirements: // - The shared order must be [rank-1, rank-2, ..., 0]. // - All stride-1 dimensions must be consecutive trailing dims. -// Additionally, a single stride-1 dimension may appear at the rank-2 -// position (col-major) if the shared order has rank-2 and rank-1 swapped. LogicalResult validateStridesAndSharedOrder(triton::MakeTensorDescOp op, Attribute sharedEnc, ArrayRef shape, @@ -25,41 +23,22 @@ LogicalResult validateStridesAndSharedOrder(triton::MakeTensorDescOp op, int rank = shape.size(); auto sharedOrder = triton::gpu::getOrder( cast(sharedEnc), shape); - - SmallVector strideOneDims; - for (auto [dim, strideVal] : llvm::enumerate(strides)) { - if (getConstantIntValue(getAsOpFoldResult(strideVal)).value_or(0) == 1) - strideOneDims.push_back(dim); - } - - if (strideOneDims.empty()) - return op.emitError() << "requires at least one dimension to have stride 1"; - - // If the only stride-1 dim is the second-to-last dimension (col-major) we can - // safely reorder the dimensions during lowering. - bool isColMajor = - strideOneDims.size() == 1 && strideOneDims.front() == rank - 2; - SmallVector expectedOrder(llvm::reverse(llvm::seq(rank))); - if (isColMajor) - std::swap(expectedOrder[0], expectedOrder[1]); - if (sharedOrder != ArrayRef(expectedOrder)) { - if (isColMajor) - return op.emitError() - << "requires shared order [rank-2, rank-1, rank-3, " - "rank-4, ..., 0] because dim[rank-2] has stride 1"; return op.emitError() << "requires shared order [rank-1, rank-2, ..., 0]"; } - if (strideOneDims.size() > 1) { - unsigned numStride1Dims = strideOneDims.size(); - for (unsigned i = 0; i < numStride1Dims; ++i) { - if (strideOneDims[i] != rank - numStride1Dims + i) - return op.emitError() << "requires all stride 1 dimensions to be " - "consecutive starting from the last dimension"; - } - } + auto isStride1 = [](Value v) { + return getConstantIntValue(getAsOpFoldResult(v)).value_or(0) == 1; + }; + auto reversedStrides = llvm::reverse(strides); + auto firstNonStride1 = llvm::find_if_not(reversedStrides, isStride1); + if (firstNonStride1 == reversedStrides.begin()) + return op.emitError() << "last dimension must have stride 1"; + if (llvm::any_of(llvm::make_range(firstNonStride1, reversedStrides.end()), + isStride1)) + return op.emitError() << "requires all stride 1 dimensions to be " + "consecutive starting from the last dimension"; return success(); } @@ -102,8 +81,8 @@ struct MakeTensorDescOpConversion ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto basePtr = adaptor.getBase(); - auto tensorShape = llvm::to_vector(adaptor.getShape()); - auto tensorStride = llvm::to_vector(adaptor.getStrides()); + auto tensorShape = adaptor.getShape(); + auto tensorStride = adaptor.getStrides(); auto result = op.getResult(); auto tensorDescTy = result.getType(); @@ -138,8 +117,7 @@ struct MakeTensorDescOpConversion // Create TDM descriptor for 2D-5D tensors auto tdmDesc = LLVM::AMD::createTDMDescriptor( rewriter, loc, getTypeConverter(), elementType, shapePerCTA, numWarps, - padInterval, padAmount, tensorShape, tensorStride, basePtr, isRowMajor, - sharedEnc); + padInterval, padAmount, tensorShape, tensorStride, basePtr, sharedEnc); SmallVector groups = tdmDesc.getAllGroups(); diff --git a/third_party/amd/python/test/test_gluon_gfx1250.py b/third_party/amd/python/test/test_gluon_gfx1250.py index e008f14f5b62..7ff89e563c76 100644 --- a/third_party/amd/python/test/test_gluon_gfx1250.py +++ b/third_party/amd/python/test/test_gluon_gfx1250.py @@ -327,7 +327,7 @@ def gemm_async_pipelined_kernel(a_ptr, b_ptr, c_ptr, # stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, BLOCK_K: ttgl.constexpr, # - NUM_BUFFERS: ttgl.constexpr, USE_TDM: ttgl.constexpr, IS_B_K_CONTIG: ttgl.constexpr): + NUM_BUFFERS: ttgl.constexpr, USE_TDM: ttgl.constexpr): a_dtype: ttgl.constexpr = a_ptr.type.element_ty b_dtype: ttgl.constexpr = b_ptr.type.element_ty ttgl.static_assert(a_dtype.is_fp16() or a_dtype.is_bf16(), "Only fp16/bf16 supported for A") @@ -335,19 +335,11 @@ def gemm_async_pipelined_kernel(a_ptr, b_ptr, c_ptr, # ttgl.static_assert(NUM_BUFFERS >= 2, "NUM_BUFFERS must be at least 2") BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0]) - if IS_B_K_CONTIG: - BLOCKED_LAYOUT_B: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0]) - else: - BLOCKED_LAYOUT_B: ttgl.constexpr = ttgl.BlockedLayout([8, 1], [8, 4], [1, 4], [0, 1]) WMMA_LAYOUT: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, True, [[0, 1], [1, 0]], [], [16, 16, 32]) SHARED_LAYOUT_A: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[BLOCK_K, 8]], [BLOCK_M, BLOCK_K], [1, 0]) - if IS_B_K_CONTIG: - SHARED_LAYOUT_B: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[BLOCK_N, 8]], [BLOCK_K, BLOCK_N], - [1, 0]) - else: - SHARED_LAYOUT_B: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[BLOCK_K, 8]], [BLOCK_K, BLOCK_N], - [0, 1]) + SHARED_LAYOUT_B: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[BLOCK_N, 8]], [BLOCK_K, BLOCK_N], + [1, 0]) OPERAND_LAYOUT_A: ttgl.constexpr = ttgl.DotOperandLayout(0, WMMA_LAYOUT, 8) OPERAND_LAYOUT_B: ttgl.constexpr = ttgl.DotOperandLayout(1, WMMA_LAYOUT, 8) @@ -375,8 +367,8 @@ def gemm_async_pipelined_kernel(a_ptr, b_ptr, c_ptr, # offs_am = (pid_m * BLOCK_M + ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT))) % M a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak - offs_bk = ttgl.arange(0, BLOCK_K, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT_B)) - offs_bn = (pid_n * BLOCK_N + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, BLOCKED_LAYOUT_B))) % N + offs_bk = ttgl.arange(0, BLOCK_K, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT)) + offs_bn = (pid_n * BLOCK_N + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, BLOCKED_LAYOUT))) % N b_ptrs = b_ptr + offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn a_buffer = ttgl.allocate_shared_memory(a_desc.dtype, shape=[NUM_BUFFERS] + a_desc.block_shape, layout=a_desc.layout) @@ -457,26 +449,23 @@ def gemm_async_pipelined_kernel(a_ptr, b_ptr, c_ptr, # @pytest.mark.parametrize("BLOCK_M,BLOCK_N,BLOCK_K", [(m, n, k) for (m, n) in [(32, 32), (64, 64)] \ for k in [32, 64]]) @pytest.mark.parametrize("NUM_BUFFERS", [2, 4]) -@pytest.mark.parametrize("IS_B_K_CONTIG", [True, False]) @pytest.mark.parametrize("ASYNC_LOAD_TYPE", ["ASYNC_COPY", "TDM"]) -def test_compile_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, IS_B_K_CONTIG, ASYNC_LOAD_TYPE): +def test_compile_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, ASYNC_LOAD_TYPE): # Inner strides need to be constexpr (1) to get contiguity. Note the compiler frontend does the same for normal dispatches signature = { "a_ptr": "*fp16", "b_ptr": "*fp16", "c_ptr": "*fp32", # "M": "i32", "N": "i32", "K": "i32", # "stride_am": "i32", "stride_ak": "constexpr", # - "stride_bk": "constexpr", "stride_bn": "constexpr", # + "stride_bk": "i32", "stride_bn": "constexpr", # "stride_cm": "i32", "stride_cn": "constexpr", # "BLOCK_M": "constexpr", "BLOCK_N": "constexpr", "BLOCK_K": "constexpr", # - "NUM_BUFFERS": "constexpr", "USE_TDM": "constexpr", "IS_B_K_CONTIG": "constexpr" + "NUM_BUFFERS": "constexpr", "USE_TDM": "constexpr" } constexprs = { - "stride_ak": 1, "stride_cn": 1, "BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K, "NUM_BUFFERS": - NUM_BUFFERS, "USE_TDM": ASYNC_LOAD_TYPE == "TDM", "IS_B_K_CONTIG": IS_B_K_CONTIG + "stride_ak": 1, "stride_bn": 1, "stride_cn": 1, "BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "BLOCK_K": BLOCK_K, + "NUM_BUFFERS": NUM_BUFFERS, "USE_TDM": ASYNC_LOAD_TYPE == "TDM" } - constexprs["stride_bn"] = BLOCK_N if not IS_B_K_CONTIG else 1 - constexprs["stride_bk"] = BLOCK_K if IS_B_K_CONTIG else 1 fn = gemm_async_pipelined_kernel @@ -497,7 +486,7 @@ def test_compile_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, IS assert len(re.findall("tensor_load_to_lds", amdgcn)) == NUM_BUFFERS * 2 else: copy_instr_for_A = BLOCK_M // 4 // 4 - copy_isntr_for_B = (BLOCK_K if IS_B_K_CONTIG else BLOCK_N) // 4 // 4 + copy_isntr_for_B = BLOCK_K // 4 // 4 copy_instr_per_iter = copy_instr_for_A + copy_isntr_for_B for cnt in range(NUM_BUFFERS - 1, -1, -1): assert re.search(f"s_wait_asynccnt 0x{(cnt * copy_instr_per_iter):x}", amdgcn) @@ -509,9 +498,8 @@ def test_compile_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, IS for k in [32, 64]]) @pytest.mark.parametrize("NUM_BUFFERS", [2, 4]) @pytest.mark.parametrize("M,N,K", [(256, 256, 512), (240, 240, 496), (250, 250, 510)]) -@pytest.mark.parametrize("IS_B_K_CONTIG", [True, False]) @pytest.mark.parametrize("ASYNC_LOAD_TYPE", ["ASYNC_COPY", "TDM"]) -def test_runtime_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, M, N, K, IS_B_K_CONTIG, ASYNC_LOAD_TYPE): +def test_runtime_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, M, N, K, ASYNC_LOAD_TYPE): if triton.cdiv(K, BLOCK_K) < NUM_BUFFERS: pytest.skip("Skip tests where K/BLOCK_K < NUM_BUFFERS") @@ -525,7 +513,7 @@ def test_runtime_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, M, c = torch.zeros((M, N), dtype=torch.float32) a_device = a.cuda() - b_device = b.cuda() if IS_B_K_CONTIG else b.data.T.contiguous().T.cuda() + b_device = b.cuda() c_device = c.cuda() stride_am, stride_ak = a_device.stride(0), a_device.stride(1) @@ -540,7 +528,7 @@ def test_runtime_gemm_async_pipelined(BLOCK_M, BLOCK_N, BLOCK_K, NUM_BUFFERS, M, stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, # - NUM_BUFFERS=NUM_BUFFERS, USE_TDM=ASYNC_LOAD_TYPE == "TDM", IS_B_K_CONTIG=IS_B_K_CONTIG) + NUM_BUFFERS=NUM_BUFFERS, USE_TDM=ASYNC_LOAD_TYPE == "TDM") c_triton = c_device.cpu() c_torch = a.to(torch.float32) @ b.to(torch.float32) @@ -1693,9 +1681,7 @@ def tensor_descriptor_load_store_nd_kernel_host_tdm(out_desc, inp_desc): ttgl.amd.gfx1250.tdm.async_wait(0) -def _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, TDM_TYPE, SHARED_LAYOUT, ROW_MAJOR): - if not ROW_MAJOR and TDM_TYPE == "HOST_TDM": - pytest.skip("NYI: Host TDM does not support non-row major layouts") +def _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, TDM_TYPE, SHARED_LAYOUT): """Utility function to run TDM load/store tests with a given shared layout.""" alloc_shape = [1, 1, 3, 7, INNER_BLOCK][-ndim:] @@ -1704,10 +1690,6 @@ def _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, TDM_TYP inp.data = inp.data[..., :INNER_BLOCK - 3] out = inp.new_empty(BLOCK_SHAPE) - if ndim > 1 and not ROW_MAJOR: - out = out.data.transpose(-2, -1).contiguous().transpose(-2, -1) - inp = inp.data.transpose(-2, -1).contiguous().transpose(-2, -1) - inp = inp.cuda() out = out.cuda() @@ -1742,23 +1724,19 @@ def _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, TDM_TYP @pytest.mark.parametrize("INNER_BLOCK", [4, 8, 16, 32, 64, 128]) @pytest.mark.parametrize("dtype_str", sorted(set(float_dtypes + int_dtypes))) @pytest.mark.parametrize("TDM_TYPE", ["DEVICE_TDM", "HOST_TDM"]) -@pytest.mark.parametrize("ROW_MAJOR", [False, True]) -def test_tensor_descriptor_load_store_nd(dtype_str, ndim, INNER_BLOCK, TDM_TYPE, ROW_MAJOR): +def test_tensor_descriptor_load_store_nd(dtype_str, ndim, INNER_BLOCK, TDM_TYPE): """Test TDM load/store with swizzled shared layout.""" order = [ndim - 1 - i for i in range(ndim)] - if ndim > 1 and not ROW_MAJOR: - order[0], order[1] = order[1], order[0] SHARED_LAYOUT: ttgl.constexpr = ttgl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=order) - _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, TDM_TYPE, SHARED_LAYOUT, ROW_MAJOR) + _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, TDM_TYPE, SHARED_LAYOUT) @pytest.mark.parametrize("ndim", [1, 2, 3, 4, 5]) @pytest.mark.parametrize("INNER_BLOCK", [16, 64]) @pytest.mark.parametrize("dtype_str", ["float16", "int32"]) -@pytest.mark.parametrize("ROW_MAJOR", [False, True]) -def test_tensor_descriptor_load_store_nd_with_padding(dtype_str, ndim, INNER_BLOCK, ROW_MAJOR): +def test_tensor_descriptor_load_store_nd_with_padding(dtype_str, ndim, INNER_BLOCK): """Test TDM load/store with padded shared memory layout. TDM store only supports padding when: 1. There is a single padding interval @@ -1768,12 +1746,9 @@ def test_tensor_descriptor_load_store_nd_with_padding(dtype_str, ndim, INNER_BLO BLOCK_SHAPE = (2, 2, 4, 8, INNER_BLOCK)[-ndim:] order = [ndim - 1 - i for i in range(ndim)] padding = [INNER_BLOCK, 8] - if ndim > 1 and not ROW_MAJOR: - order[0], order[1] = order[1], order[0] - padding = [8, INNER_BLOCK] PADDED_LAYOUT: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([padding], BLOCK_SHAPE, order) - _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, "DEVICE_TDM", PADDED_LAYOUT, ROW_MAJOR) + _run_tensor_descriptor_load_store_test(dtype_str, ndim, INNER_BLOCK, "DEVICE_TDM", PADDED_LAYOUT) def test_tensor_descriptor_load_store_invalid_blocksize():