diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 84833f526320..f10e39770d78 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -69,8 +69,6 @@ getThreadsPerWarpWithUniqueData(Attribute layout, SmallVector getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); -SmallVector getThreadsPerCTA(Attribute layout); - SmallVector getOrder(Attribute layout); CTALayoutAttr getCTALayout(Attribute layout); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index ad2061d78ffe..bf4c737f364a 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -209,10 +209,8 @@ bool ReduceOpHelper::isSupportedLayout() { if (srcLayout.isa()) { return true; } - if (auto mmaLayout = srcLayout.dyn_cast()) { - if (mmaLayout.isAmpere() || mmaLayout.isHopper()) { - return true; - } + if (auto mmaLayout = srcLayout.dyn_cast()) { + return mmaLayout.supportReduction(); } if (auto sliceLayout = srcLayout.dyn_cast()) { return true; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 7970a7fbb878..1a5a47a2693e 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -39,16 +39,8 @@ namespace gpu { unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape, Type eltTy) { - if (auto blockedLayout = layout.dyn_cast()) { - return blockedLayout.getTotalElemsPerThread(shape, eltTy); - } else if (auto sliceLayout = layout.dyn_cast()) { - return sliceLayout.getTotalElemsPerThread(shape, eltTy); - } else if (auto mmaLayout = layout.dyn_cast()) { - return mmaLayout.getTotalElemsPerThread(shape, eltTy); - } else if (auto sharedLayout = layout.dyn_cast()) { - return sharedLayout.getTotalElemsPerThread(shape, eltTy); - } else if (auto dotLayout = layout.dyn_cast()) { - return dotLayout.getTotalElemsPerThread(shape, eltTy); + if (auto tritonGPUAttr = layout.dyn_cast()) { + return tritonGPUAttr.getTotalElemsPerThread(shape, eltTy); } else { llvm::report_fatal_error("getElemsPerThread not implemented"); return 0; @@ -57,12 +49,8 @@ unsigned getTotalElemsPerThread(Attribute layout, ArrayRef shape, SmallVector getElemsPerThread(Attribute layout, ArrayRef shape, Type eltTy) { - if (auto blockedLayout = layout.dyn_cast()) { - return blockedLayout.getElemsPerThread(shape, eltTy); - } else if (auto sliceLayout = layout.dyn_cast()) { - return sliceLayout.getElemsPerThread(shape, eltTy); - } else if (auto mmaLayout = layout.dyn_cast()) { - return mmaLayout.getElemsPerThread(shape, eltTy); + if (auto tritonGPUAttr = layout.dyn_cast()) { + return tritonGPUAttr.getElemsPerThread(shape, eltTy); } else { llvm::report_fatal_error("getElemsPerThread not implemented"); return SmallVector(); @@ -86,30 +74,12 @@ unsigned getTotalElemsPerThread(Type type) { } SmallVector getThreadsPerWarp(Attribute layout) { - if (auto blockedLayout = layout.dyn_cast()) { - return blockedLayout.getThreadsPerWarp(); - } - if (auto mmaLayout = layout.dyn_cast()) { - if (mmaLayout.isVolta()) - return {4, 8}; - if (mmaLayout.isAmpere()) - return {8, 4}; - if (mmaLayout.isHopper()) - return {8, 4}; - } - if (auto sliceLayout = layout.dyn_cast()) { - auto parent = sliceLayout.getParent(); - auto parentThreadsPerWarp = getThreadsPerWarp(parent); - assert(parentThreadsPerWarp.size() == 2 && - "getThreadsPerWarp only implemented for 2D slice layout"); - SmallVector threadsPerWarp = parentThreadsPerWarp; - threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); - for (unsigned i = 0; i < threadsPerWarp.size(); i++) - threadsPerWarp[i] *= parentThreadsPerWarp[sliceLayout.getDim()]; - return threadsPerWarp; + if (auto distributedLayout = layout.dyn_cast()) { + return distributedLayout.getThreadsPerWarp(); + } else { + llvm::report_fatal_error("getThreadsPerWarp not implemented"); + return SmallVector(); } - llvm::report_fatal_error("getThreadsPerWarp not implemented"); - return {}; } SmallVector @@ -135,27 +105,12 @@ getThreadsPerWarpWithUniqueData(Attribute layout, } SmallVector getWarpsPerCTA(Attribute layout) { - if (auto blockedLayout = layout.dyn_cast()) { - return blockedLayout.getWarpsPerCTA(); - } - if (auto mmaLayout = layout.dyn_cast()) { - return mmaLayout.getWarpsPerCTA(); - } - if (auto sliceLayout = layout.dyn_cast()) { - auto parent = sliceLayout.getParent(); - auto parentWarpsPerCTA = getWarpsPerCTA(parent); - assert(parentWarpsPerCTA.size() == 2 || - parentWarpsPerCTA[sliceLayout.getDim()] == 1 && - "getWarpsPerCTA only implemented for 2D slice layout or the " - "slice dim must have 1 warp in the parent layout"); - SmallVector warpsPerCTA = parentWarpsPerCTA; - warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim()); - for (unsigned i = 0; i < warpsPerCTA.size(); i++) - warpsPerCTA[i] *= parentWarpsPerCTA[sliceLayout.getDim()]; - return warpsPerCTA; + if (auto distributedLayout = layout.dyn_cast()) { + return distributedLayout.getWarpsPerCTA(); } + llvm::report_fatal_error("getWarpsPerCTA not implemented"); - return {}; + return SmallVector(); } SmallVector @@ -183,45 +138,8 @@ getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape) { } SmallVector getSizePerThread(Attribute layout) { - if (auto blockedLayout = layout.dyn_cast()) { - return blockedLayout.getSizePerThread(); - } else if (auto sliceLayout = layout.dyn_cast()) { - auto sizePerThread = getSizePerThread(sliceLayout.getParent()); - sizePerThread.erase(sizePerThread.begin() + sliceLayout.getDim()); - return sizePerThread; - } else if (auto mmaLayout = layout.dyn_cast()) { - if (mmaLayout.isAmpere()) { - return {2, 2}; - } else if (mmaLayout.isVolta()) { - return {1, 2}; - } else if (mmaLayout.isHopper()) { - auto instrShape = mmaLayout.getInstrShape(); - // TODO(thomas): what are those magic numbers? - return SmallVector{instrShape[0] * 4 / 32, instrShape[1] / 4}; - } else { - llvm_unreachable("Unexpected mma version"); - } - } else if (auto dotLayout = layout.dyn_cast()) { - auto parentLayout = dotLayout.getParent(); - assert(parentLayout && "DotOperandEncodingAttr must have a parent"); - if (auto parentMmaLayout = parentLayout.dyn_cast()) { - assert(parentMmaLayout.isAmpere() && - "mmaLayout version = 1 is not implemented yet"); - auto opIdx = dotLayout.getOpIdx(); - if (opIdx == 0) { - return {2, 4}; - } else if (opIdx == 1) { - return {4, 1}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - return {}; - } - } else { - llvm::report_fatal_error( - "DotOperandEncodingAttr non-MmaEncodingAttr parent not " - "supported yet"); - return {}; - } + if (auto distributedLayout = layout.dyn_cast()) { + return distributedLayout.getSizePerThread(); } else { llvm::report_fatal_error("getSizePerThread not implemented"); return {}; @@ -264,79 +182,14 @@ SmallVector getUniqueContigPerThread(Attribute layout, return ret; } -SmallVector getThreadsPerCTA(Attribute layout) { - SmallVector threads; - if (auto blockedLayout = layout.dyn_cast()) { - for (int d = 0, n = blockedLayout.getOrder().size(); d < n; ++d) - threads.push_back(blockedLayout.getThreadsPerWarp()[d] * - blockedLayout.getWarpsPerCTA()[d]); - } else if (auto mmaLayout = layout.dyn_cast()) { - if (mmaLayout.isAmpere()) { - threads = {8 * mmaLayout.getWarpsPerCTA()[0], - 4 * mmaLayout.getWarpsPerCTA()[1]}; - } else - llvm::report_fatal_error("Unimplemented usage of MmaEncodingAttr"); - } else { - llvm::report_fatal_error("Unimplemented usage of getThreadsPerCTA"); - } - - return threads; -} - SmallVector getShapePerCTATile(Attribute layout, ArrayRef tensorShape) { - SmallVector shape; - if (auto blockedLayout = layout.dyn_cast()) { - for (unsigned d = 0, n = blockedLayout.getOrder().size(); d < n; ++d) - shape.push_back(blockedLayout.getSizePerThread()[d] * - blockedLayout.getThreadsPerWarp()[d] * - blockedLayout.getWarpsPerCTA()[d]); - } else if (auto sliceLayout = layout.dyn_cast()) { - shape = getShapePerCTATile(sliceLayout.getParent(), tensorShape); - shape.erase(shape.begin() + sliceLayout.getDim()); - } else if (auto mmaLayout = layout.dyn_cast()) { - if (mmaLayout.isAmpere()) - return {16 * mmaLayout.getWarpsPerCTA()[0], - 8 * mmaLayout.getWarpsPerCTA()[1]}; - if (mmaLayout.isVolta()) { - assert(!tensorShape.empty() && "Volta needs the tensorShape"); - if (tensorShape.size() == 1) // must be SliceEncoding - return {static_cast(tensorShape[0]), - static_cast(tensorShape[0])}; - return {static_cast(tensorShape[0]), - static_cast(tensorShape[1])}; - } - if (mmaLayout.isHopper()) { - auto instrShape = mmaLayout.getInstrShape(); - return {16 * mmaLayout.getWarpsPerCTA()[0], - instrShape[1] * mmaLayout.getWarpsPerCTA()[1]}; - } - llvm::report_fatal_error("Unexpected MMA layout version found"); - } else if (auto dotLayout = layout.dyn_cast()) { - auto parentLayout = dotLayout.getParent(); - assert(parentLayout && "DotOperandEncodingAttr must have a parent"); - if (auto parentMmaLayout = parentLayout.dyn_cast()) { - assert(parentMmaLayout.isAmpere() && - "mmaLayout version = 1 is not implemented yet"); - auto parentShapePerCTATile = - getShapePerCTATile(parentLayout, tensorShape); - auto opIdx = dotLayout.getOpIdx(); - if (opIdx == 0) { - return {parentShapePerCTATile[0], 16}; - } else if (opIdx == 1) { - return {16, parentShapePerCTATile[1]}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - } - } else { - llvm::report_fatal_error( - "DotOperandEncodingAttr non-MmaEncodingAttr parent not " - "supported yet"); - } + if (auto distributedLayout = layout.dyn_cast()) { + return distributedLayout.getShapePerCTATile(tensorShape); } else { - llvm::report_fatal_error("Unimplemented usage of getShapePerCTATile"); + llvm::report_fatal_error("getThreadsPerWarp not implemented"); + return SmallVector(); } - return shape; } bool isExpensiveView(Type srcType, Type dstType) { @@ -368,7 +221,7 @@ SmallVector getOrder(Attribute layout) { if (auto blockedLayout = layout.dyn_cast()) { return SmallVector(blockedLayout.getOrder().begin(), blockedLayout.getOrder().end()); - } else if (auto mmaLayout = layout.dyn_cast()) { + } else if (auto mmaLayout = layout.dyn_cast()) { return {1, 0}; } else if (auto dotLayout = layout.dyn_cast()) { return {1, 0}; @@ -395,19 +248,11 @@ SmallVector getOrder(Attribute layout) { }; CTALayoutAttr getCTALayout(Attribute layout) { - if (auto blockedLayout = layout.dyn_cast()) - return blockedLayout.getCTALayout(); - else if (auto sliceLayout = layout.dyn_cast()) - return CTALayoutAttr::get(layout.getContext(), getCTAsPerCGA(sliceLayout), - getCTASplitNum(sliceLayout), - getCTAOrder(sliceLayout)); - else if (auto mmaLayout = layout.dyn_cast()) - return mmaLayout.getCTALayout(); - else if (auto dotLayout = layout.dyn_cast()) - return CTALayoutAttr::get(layout.getContext(), getCTAsPerCGA(dotLayout), - getCTASplitNum(dotLayout), - getCTAOrder(dotLayout)); - else if (auto sharedLayout = layout.dyn_cast()) + if (auto distributedLayout = layout.dyn_cast()) { + return CTALayoutAttr::get( + layout.getContext(), getCTAsPerCGA(distributedLayout), + getCTASplitNum(distributedLayout), getCTAOrder(distributedLayout)); + } else if (auto sharedLayout = layout.dyn_cast()) return sharedLayout.getCTALayout(); else llvm::report_fatal_error("Unimplemented usage of getCTALayout"); @@ -416,34 +261,8 @@ CTALayoutAttr getCTALayout(Attribute layout) { SmallVector getCTAsPerCGA(Attribute layout) { ArrayRef ref; - if (auto blockedLayout = layout.dyn_cast()) - ref = blockedLayout.getCTALayout().getCTAsPerCGA(); - else if (auto sliceLayout = layout.dyn_cast()) { - auto parentCTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent()); - if (parentCTAsPerCGA[sliceLayout.getDim()] == 1) { - parentCTAsPerCGA.erase(parentCTAsPerCGA.begin() + sliceLayout.getDim()); - return parentCTAsPerCGA; - } - /* For getCTAsPerCGA of a slice layout, we have two choices: - * (1) Return CTAsPerCGA of its parent. This is not a perfect solution - * because the rank of the returned CTAsPerCGA does not match the rank of - * tensorShape. - * (2) Get CTAsPerCGA of its parent and erase the sliced dim. This is not a - * perfect solution because the product of the returned CTAsPerCGA might not - * match numCTAs. - * To avoid introducing inconsistencies to the shape and - * layout system, the usage of directly getting CTAsPerCGA of a slice layout - * in which the sliced dim is not 1 is banned. You should always consider - * slice layout as a special case and use getCTAsPerCGA(layout.getParent()) - * in the branch where layout is an instance of SliceEncodingAttr. This is - * inconvenient but safe. - */ - llvm::report_fatal_error( - "getCTAsPerCGA for SliceEncodingAttr is not well-defined"); - } else if (auto mmaLayout = layout.dyn_cast()) - ref = mmaLayout.getCTALayout().getCTAsPerCGA(); - else if (auto dotLayout = layout.dyn_cast()) - return getCTAsPerCGA(dotLayout.getParent()); + if (auto distributedLayout = layout.dyn_cast()) + return distributedLayout.getCTAsPerCGA(); else if (auto sharedLayout = layout.dyn_cast()) ref = sharedLayout.getCTALayout().getCTAsPerCGA(); else @@ -453,43 +272,21 @@ SmallVector getCTAsPerCGA(Attribute layout) { SmallVector getCTASplitNum(Attribute layout) { SmallVector res; - - if (auto blockedLayout = layout.dyn_cast()) { - res.assign(blockedLayout.getCTALayout().getCTASplitNum().begin(), - blockedLayout.getCTALayout().getCTASplitNum().end()); - } else if (auto sliceLayout = layout.dyn_cast()) { - res = getCTASplitNum(sliceLayout.getParent()); - res.erase(res.begin() + sliceLayout.getDim()); - } else if (auto mmaLayout = layout.dyn_cast()) { - res.assign(mmaLayout.getCTALayout().getCTASplitNum().begin(), - mmaLayout.getCTALayout().getCTASplitNum().end()); - } else if (auto dotLayout = layout.dyn_cast()) { - res = getCTASplitNum(dotLayout.getParent()); - assert(res.size() == 2 && "Invalid dotLayout"); - - // Do not split CTA in K dimension - dotLayout.getOpIdx() == 0 ? res[1] = 1 : res[0] = 1; + if (auto distributedLayout = layout.dyn_cast()) { + return distributedLayout.getCTASplitNum(); } else if (auto sharedLayout = layout.dyn_cast()) { res.assign(sharedLayout.getCTALayout().getCTASplitNum().begin(), sharedLayout.getCTALayout().getCTASplitNum().end()); } else { assert(false && "Unimplemented usage of getCTASplitNum"); } - return res; } SmallVector getCTAOrder(Attribute layout) { ArrayRef ref; - if (auto blockedLayout = layout.dyn_cast()) { - ref = blockedLayout.getCTALayout().getCTAOrder(); - } else if (auto sliceLayout = layout.dyn_cast()) { - auto parentCTAOrder = getCTAOrder(sliceLayout.getParent()); - return eraseOrder(parentCTAOrder, sliceLayout.getDim()); - } else if (auto mmaLayout = layout.dyn_cast()) { - ref = mmaLayout.getCTALayout().getCTAOrder(); - } else if (auto dotLayout = layout.dyn_cast()) { - return getCTAOrder(dotLayout.getParent()); + if (auto distributedLayout = layout.dyn_cast()) { + ref = distributedLayout.getCTAOrder(); } else if (auto sharedLayout = layout.dyn_cast()) { ref = sharedLayout.getCTALayout().getCTAOrder(); } else { @@ -532,14 +329,16 @@ SmallVector getShapePerCTA(Type type) { } unsigned getNumWarpsPerCTA(Attribute layout) { - ArrayRef warpsPerCTA; + SmallVector warpsPerCTA; if (auto blockedLayout = layout.dyn_cast()) warpsPerCTA = blockedLayout.getWarpsPerCTA(); else if (auto sliceLayout = layout.dyn_cast()) return getNumWarpsPerCTA(sliceLayout.getParent()); - else if (auto mmaLayout = layout.dyn_cast()) - warpsPerCTA = mmaLayout.getWarpsPerCTA(); - else if (auto dotLayout = layout.dyn_cast()) + else if (auto mmaLayout = layout.dyn_cast()) { + // Use the distributed layout interface to get the number of warps per CTA. + auto distributedLayout = layout.cast(); + warpsPerCTA = distributedLayout.getWarpsPerCTA(); + } else if (auto dotLayout = layout.dyn_cast()) return getNumWarpsPerCTA(dotLayout.getParent()); else if (auto sharedLayout = layout.dyn_cast()) llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr"); @@ -549,24 +348,11 @@ unsigned getNumWarpsPerCTA(Attribute layout) { } unsigned getNumCTAs(Attribute layout) { - ArrayRef CTAsPerCGA; - if (auto blockedLayout = layout.dyn_cast()) - CTAsPerCGA = blockedLayout.getCTALayout().getCTAsPerCGA(); - else if (auto sliceLayout = layout.dyn_cast()) - return getNumCTAs(sliceLayout.getParent()); - else if (auto mmaLayout = layout.dyn_cast()) - CTAsPerCGA = mmaLayout.getCTALayout().getCTAsPerCGA(); - else if (auto dotLayout = layout.dyn_cast()) - return getNumCTAs(dotLayout.getParent()); - else if (auto sharedLayout = layout.dyn_cast()) - CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA(); - else - llvm::report_fatal_error("Unimplemented usage of getNumCTAs"); - return product(CTAsPerCGA); + return product(getCTAsPerCGA(layout)); } bool isaDistributedLayout(Attribute layout) { - return layout.isa() || layout.isa() || + return layout.isa() || layout.isa() || layout.isa(); } @@ -783,13 +569,16 @@ unsigned BlockedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, return product(getElemsPerThread(shape, eltTy)); } SmallVector BlockedEncodingAttr::getCTAsPerCGA() const { - return ::getCTAsPerCGA(*this); + ArrayRef ref = getCTALayout().getCTAsPerCGA(); + return SmallVector(ref.begin(), ref.end()); } SmallVector BlockedEncodingAttr::getCTAOrder() const { - return ::getCTAOrder(*this); + ArrayRef ref = getCTALayout().getCTAOrder(); + return SmallVector(ref.begin(), ref.end()); } SmallVector BlockedEncodingAttr::getCTASplitNum() const { - return ::getCTASplitNum(*this); + ArrayRef ref = getCTALayout().getCTASplitNum(); + return SmallVector(ref.begin(), ref.end()); } SmallVector BlockedEncodingAttr::getWarpsPerCTA() const { return SmallVector(getWarpsPerCTA__().begin(), @@ -811,7 +600,11 @@ SmallVector BlockedEncodingAttr::getSizePerThread() const { } SmallVector BlockedEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - return ::getShapePerCTATile(*this, tensorShape); + SmallVector shape; + for (unsigned d = 0, n = getOrder().size(); d < n; ++d) + shape.push_back(getSizePerThread()[d] * getThreadsPerWarp()[d] * + getWarpsPerCTA()[d]); + return shape; } template @@ -848,32 +641,77 @@ unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef shape, return product(getElemsPerThread(shape, eltTy)); } SmallVector SliceEncodingAttr::getCTASplitNum() const { - return ::getCTASplitNum(*this); + SmallVector res = ::getCTASplitNum(getParent()); + res.erase(res.begin() + getDim()); + return res; } SmallVector SliceEncodingAttr::getCTAOrder() const { - return ::getCTAOrder(*this); + auto parentCTAOrder = ::getCTAOrder(getParent()); + return eraseOrder(parentCTAOrder, getDim()); } SmallVector SliceEncodingAttr::getCTAsPerCGA() const { - return ::getCTAsPerCGA(*this); + auto parentCTAsPerCGA = ::getCTAsPerCGA(getParent()); + if (parentCTAsPerCGA[getDim()] == 1) { + parentCTAsPerCGA.erase(parentCTAsPerCGA.begin() + getDim()); + return parentCTAsPerCGA; + } + /* For getCTAsPerCGA of a slice layout, we have two choices: + * (1) Return CTAsPerCGA of its parent. This is not a perfect solution + * because the rank of the returned CTAsPerCGA does not match the rank of + * tensorShape. + * (2) Get CTAsPerCGA of its parent and erase the sliced dim. This is not a + * perfect solution because the product of the returned CTAsPerCGA might not + * match numCTAs. + * To avoid introducing inconsistencies to the shape and + * layout system, the usage of directly getting CTAsPerCGA of a slice layout + * in which the sliced dim is not 1 is banned. You should always consider + * slice layout as a special case and use getCTAsPerCGA(layout.getParent()) + * in the branch where layout is an instance of SliceEncodingAttr. This is + * inconvenient but safe. + */ + llvm::report_fatal_error( + "getCTAsPerCGA for SliceEncodingAttr is not well-defined"); } SmallVector SliceEncodingAttr::getWarpsPerCTA() const { - return ::getWarpsPerCTA(*this); + auto parent = getParent(); + auto parentWarpsPerCTA = ::getWarpsPerCTA(parent); + assert(parentWarpsPerCTA.size() == 2 || + parentWarpsPerCTA[getDim()] == 1 && + "getWarpsPerCTA only implemented for 2D slice layout or the " + "slice dim must have 1 warp in the parent layout"); + SmallVector warpsPerCTA = parentWarpsPerCTA; + warpsPerCTA.erase(warpsPerCTA.begin() + getDim()); + for (unsigned i = 0; i < warpsPerCTA.size(); i++) + warpsPerCTA[i] *= parentWarpsPerCTA[getDim()]; + return warpsPerCTA; } SmallVector SliceEncodingAttr::getWarpOrder() const { return ::getOrder(*this); } SmallVector SliceEncodingAttr::getThreadsPerWarp() const { - return ::getThreadsPerWarp(*this); + auto parent = getParent(); + auto parentThreadsPerWarp = ::getThreadsPerWarp(parent); + assert(parentThreadsPerWarp.size() == 2 && + "getThreadsPerWarp only implemented for 2D slice layout"); + SmallVector threadsPerWarp = parentThreadsPerWarp; + threadsPerWarp.erase(threadsPerWarp.begin() + getDim()); + for (unsigned i = 0; i < threadsPerWarp.size(); i++) + threadsPerWarp[i] *= parentThreadsPerWarp[getDim()]; + return threadsPerWarp; } SmallVector SliceEncodingAttr::getThreadOrder() const { return ::getOrder(*this); } SmallVector SliceEncodingAttr::getSizePerThread() const { - return ::getSizePerThread(*this); + auto sizePerThread = ::getSizePerThread(getParent()); + sizePerThread.erase(sizePerThread.begin() + getDim()); + return sizePerThread; } SmallVector SliceEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - return ::getShapePerCTATile(*this, tensorShape); + SmallVector shape = ::getShapePerCTATile(getParent(), tensorShape); + shape.erase(shape.begin() + getDim()); + return shape; } SmallVector @@ -979,7 +817,7 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { - if (auto mmaParent = getParent().dyn_cast()) { + if (auto mmaParent = getParent().dyn_cast()) { return mmaParent.getTotalElemsPerThreadForOperands(shape, eltTy, getOpIdx()); } @@ -1012,16 +850,22 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, return 0; } SmallVector DotOperandEncodingAttr::getCTAsPerCGA() const { - return ::getCTAsPerCGA(*this); + return ::getCTAsPerCGA(getParent()); } SmallVector DotOperandEncodingAttr::getCTAOrder() const { - return ::getCTAOrder(*this); + return ::getCTAOrder(getParent()); } SmallVector DotOperandEncodingAttr::getCTASplitNum() const { - return ::getCTASplitNum(*this); + SmallVector res = ::getCTASplitNum(getParent()); + assert(res.size() == 2 && "Invalid dotLayout"); + + // Do not split CTA in K dimension + getOpIdx() == 0 ? res[1] = 1 : res[0] = 1; + return res; } SmallVector DotOperandEncodingAttr::getWarpsPerCTA() const { - return ::getWarpsPerCTA(*this); + llvm::report_fatal_error( + "getWarpsPerCTA not implemented for DotOperandEncodingAttr"); } SmallVector DotOperandEncodingAttr::getWarpOrder() const { return ::getOrder(*this); @@ -1031,7 +875,16 @@ SmallVector DotOperandEncodingAttr::getThreadOrder() const { } SmallVector DotOperandEncodingAttr::getShapePerCTATile( ArrayRef tensorShape) const { - return ::getShapePerCTATile(*this, tensorShape); + auto parentLayout = getParent(); + assert(parentLayout && "DotOperandEncodingAttr must have a parent"); + if (auto parentMmaLayout = parentLayout.dyn_cast()) { + return parentMmaLayout.getShapePerCTATileForDotOperands(tensorShape, + getOpIdx()); + } else { + llvm::report_fatal_error( + "DotOperandEncodingAttr non-MmaEncodingAttr parent not " + "supported yet"); + } } //===----------------------------------------------------------------------===// @@ -1304,13 +1157,16 @@ bool MmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; } bool MmaEncodingAttr::isHopper() const { return getVersionMajor() == 3; } SmallVector MmaEncodingAttr::getCTAsPerCGA() const { - return ::getCTAsPerCGA(*this); + ArrayRef ref = getCTALayout().getCTAsPerCGA(); + return SmallVector(ref.begin(), ref.end()); } SmallVector MmaEncodingAttr::getCTAOrder() const { - return ::getCTAOrder(*this); + ArrayRef ref = getCTALayout().getCTAOrder(); + return SmallVector(ref.begin(), ref.end()); } SmallVector MmaEncodingAttr::getCTASplitNum() const { - return ::getCTASplitNum(*this); + ArrayRef ref = getCTALayout().getCTASplitNum(); + return SmallVector(ref.begin(), ref.end()); } SmallVector MmaEncodingAttr::getWarpsPerCTA() const { return SmallVector(getWarpsPerCTA__().begin(), @@ -1320,17 +1176,48 @@ SmallVector MmaEncodingAttr::getWarpOrder() const { return ::getOrder(*this); } SmallVector MmaEncodingAttr::getThreadsPerWarp() const { - return ::getThreadsPerWarp(*this); + if (isVolta()) + return {4, 8}; + if (isAmpere()) + return {8, 4}; + if (isHopper()) + return {8, 4}; + llvm::report_fatal_error( + "getThreadsPerWarp not implemented for unknown Mma version "); } SmallVector MmaEncodingAttr::getThreadOrder() const { return ::getOrder(*this); } SmallVector MmaEncodingAttr::getSizePerThread() const { - return ::getSizePerThread(*this); + if (isAmpere()) { + return {2, 2}; + } else if (isVolta()) { + return {1, 2}; + } else if (isHopper()) { + auto instrShape = getInstrShape(); + // TODO(thomas): what are those magic numbers? + return SmallVector{instrShape[0] * 4 / 32, instrShape[1] / 4}; + } else { + llvm_unreachable("Unexpected mma version"); + } } SmallVector MmaEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { - return ::getShapePerCTATile(*this, tensorShape); + if (isAmpere()) + return {16 * getWarpsPerCTA()[0], 8 * getWarpsPerCTA()[1]}; + if (isVolta()) { + assert(!tensorShape.empty() && "Volta needs the tensorShape"); + if (tensorShape.size() == 1) // must be SliceEncoding + return {static_cast(tensorShape[0]), + static_cast(tensorShape[0])}; + return {static_cast(tensorShape[0]), + static_cast(tensorShape[1])}; + } + if (isHopper()) { + auto instrShape = getInstrShape(); + return {16 * getWarpsPerCTA()[0], instrShape[1] * getWarpsPerCTA()[1]}; + } + llvm::report_fatal_error("Unexpected MMA layout version found"); } // Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor @@ -1518,10 +1405,20 @@ MmaEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { // DotOperand Encoding //===----------------------------------------------------------------------===// SmallVector DotOperandEncodingAttr::getThreadsPerWarp() const { - return ::getThreadsPerWarp(*this); + llvm::report_fatal_error( + "getThreadsPerWarp not implemented for DotOperandEncodingAttr"); } SmallVector DotOperandEncodingAttr::getSizePerThread() const { - return ::getSizePerThread(*this); + auto parentLayout = getParent(); + assert(parentLayout && "DotOperandEncodingAttr must have a parent"); + if (auto parentMmaLayout = parentLayout.dyn_cast()) { + return parentMmaLayout.getSizePerThreadForOperands(getOpIdx()); + } else { + llvm::report_fatal_error( + "DotOperandEncodingAttr non-MmaEncodingAttr parent not " + "supported yet"); + return {}; + } } Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) { @@ -1644,7 +1541,7 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; AliasResult getAlias(Attribute attr, raw_ostream &os) const override { - if (auto mmaAttr = attr.dyn_cast()) { + if (auto mmaAttr = attr.dyn_cast()) { os << "mma"; return AliasResult::FinalAlias; } else if (auto sharedAttr = attr.dyn_cast()) {