diff --git a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp index 00d840d7ccf7..345e929a45b0 100644 --- a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -87,8 +87,8 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) { return; auto srcBlocked = dyn_cast(srcType.getEncoding()); - auto dstDotOp = - dyn_cast(dstType.getEncoding()); + auto dstEncoding = dstType.getEncoding(); + auto dstDotOp = dyn_cast(dstEncoding); if (srcBlocked && dstDotOp) { // FIXME [Dot LL] // We support this one via LLs, as the LocalLoad path is buggy @@ -99,15 +99,21 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) { return; } } - + auto srcOrder = triton::gpu::getOrder(srcBlocked); + auto rank = srcOrder.size(); + SmallVector sharedOrder; + if (rank == 3) { + sharedOrder = gpu::getThreadOrder(dstEncoding); + } else { + sharedOrder = srcOrder; + } Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); auto tmpType = MemDescType::get( dstType.getShape(), dstType.getElementType(), triton::gpu::SharedEncodingAttr::get( - module.getContext(), dstDotOp, srcType.getShape(), - srcBlocked.getOrder(), srcBlocked.getCTALayout(), - srcType.getElementType()), + module.getContext(), dstDotOp, srcType.getShape(), sharedOrder, + srcBlocked.getCTALayout(), srcType.getElementType()), sharedMemorySpace); auto tmp = builder.create( cvtOp.getLoc(), tmpType, cvtOp.getSrc()); diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp index 3a406c3cc28e..8accdf9f777f 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -36,30 +36,29 @@ class TritonGPUReduceDataDuplicationPass auto srcType = cast(cvtOp.getSrc().getType()); auto dstType = cast(cvtOp.getType()); auto srcEncoding = srcType.getEncoding(); + auto dstEncoding = dstType.getEncoding(); if (isa(srcEncoding)) return; auto dstDotOp = - dyn_cast(dstType.getEncoding()); + dyn_cast(dstEncoding); if (!dstDotOp) return; if (!cvtNeedsSharedMemory(srcType, dstType)) return; // FIXME [Dot LL] // We support this one via LLs, as the LocalLoad path is buggy - bool largeKWidth = - dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64; - if (largeKWidth) { - return; + if (auto mma = dyn_cast(dstDotOp.getParent())) { + bool largeKWidth = + dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64; + if (mma.isAmpere() && largeKWidth) { + return; + } } auto srcOrder = triton::gpu::getOrder(srcEncoding); auto rank = srcOrder.size(); SmallVector sharedOrder; if (rank == 3) { - // add all elements except the element that is zero - for (unsigned i = 0; i < rank; ++i) - if (srcOrder[i] != 0) - sharedOrder.emplace_back(srcOrder[i]); - sharedOrder.emplace_back(0); + sharedOrder = gpu::getThreadOrder(dstEncoding); } else { sharedOrder = srcOrder; }