Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,6 @@ class AssignDescriptorMemoryLayouts {
CGAEncodingAttr cgaLayout,
ArrayRef<int64_t> usageShape,
unsigned numCTAs);

protected:
virtual Attribute getCompatibleSharedEncoding(Attribute enc,
ArrayRef<int64_t> shape,
Type elementType) {
return isCompatibleSharedEncoding(enc) ? enc : Attribute();
}

private:
// Override with backend specific implementation
virtual Attribute buildFallbackSharedEncoding(mlir::MLIRContext *,
ArrayRef<int64_t>,
Expand Down
46 changes: 15 additions & 31 deletions lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,36 +250,23 @@ EncodingInfo AssignDescriptorMemoryLayouts::combineEncodings(

Attribute
AssignDescriptorMemoryLayouts::findLoadEncodingFromUsers(Operation *op) {
auto getCompatibleEncodingForType = [&](Type type) -> Attribute {
if (auto memDescTy = dyn_cast<MemDescType>(type)) {
return getCompatibleSharedEncoding(memDescTy.getEncoding(),
memDescTy.getShape(),
memDescTy.getElementType());
}
if (auto tensorTy = dyn_cast<RankedTensorType>(type)) {
return getCompatibleSharedEncoding(tensorTy.getEncoding(),
tensorTy.getShape(),
tensorTy.getElementType());
}
return {};
};

// Check if there are any desired encodings available on the op
if (auto attr = op->getDiscardableAttr("tt.desired_encoding")) {
if (auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType()))
if (auto compatible = getCompatibleSharedEncoding(
attr, resultTy.getShape(), resultTy.getElementType()))
return compatible;
if (auto enc = dyn_cast<ttg::SharedEncodingTrait>(attr)) {
if (isCompatibleSharedEncoding(enc))
return enc;
}
}
// Ignore multiple users and just pick the first compatible layout
for (auto use : op->getUsers()) {
if (auto alloc = dyn_cast<ttg::LocalAllocOp>(use)) {
if (auto compatible = getCompatibleEncodingForType(alloc.getType()))
return compatible;
auto enc = alloc.getType().getEncoding();
if (isCompatibleSharedEncoding(enc))
return enc;
} else if (auto store = dyn_cast<ttg::LocalStoreOp>(use)) {
if (auto compatible =
getCompatibleEncodingForType(store.getDst().getType()))
return compatible;
auto enc = store.getDst().getType().getEncoding();
if (isCompatibleSharedEncoding(enc))
return enc;
}
}
return {};
Expand Down Expand Up @@ -455,9 +442,7 @@ void AssignDescriptorMemoryLayouts::runOnFunction(FuncOp &func) {
auto ctx = func.getContext();
auto numCTAs = triton::gpu::lookupNumCTAs(func);
for (auto &[desc, einfo] : valueToEncodingInfo) {
auto descTy = desc.getType();
auto existingTy =
RankedTensorType::get(descTy.getShape(), descTy.getElementType());
auto existingTy = desc.getType().getBlockType();
Attribute newEncoding;
if (einfo->desiredEncoding) {
newEncoding = einfo->desiredEncoding;
Expand All @@ -475,11 +460,10 @@ void AssignDescriptorMemoryLayouts::runOnFunction(FuncOp &func) {
SmallVector<Type> resultTys(func.getResultTypes());
for (auto [i, resultTy] : llvm::enumerate(resultTys)) {
if (auto descTy = dyn_cast<TensorDescType>(resultTy)) {
auto existingTy =
RankedTensorType::get(descTy.getShape(), descTy.getElementType());
auto encoding = getFallbackSharedEncoding(existingTy, {}, {}, numCTAs);
resultTys[i] =
getTensorDescTypeWithEncoding(nullptr, existingTy, encoding);
auto encoding =
getFallbackSharedEncoding(descTy.getBlockType(), {}, {}, numCTAs);
resultTys[i] = getTensorDescTypeWithEncoding(
nullptr, descTy.getBlockType(), encoding);
}
}
func.setFunctionType(FunctionType::get(ctx, argTys, resultTys));
Expand Down
118 changes: 4 additions & 114 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/LayoutUtils.h"
#include "triton/Tools/LinearLayout.h"
#include <algorithm>
#include <cassert>
#include <memory>

namespace mlir::triton::gpu {

Expand Down Expand Up @@ -106,10 +107,7 @@ class FuseTransMMAV3Plus : public OpRewritePattern<LocalAllocOp> {
return failure();

MemDescType allocType = allocOp.getType();
auto allocEncoding =
dyn_cast<NVMMASharedEncodingAttr>(allocType.getEncoding());
if (!allocEncoding)
return failure();
auto allocEncoding = cast<NVMMASharedEncodingAttr>(allocType.getEncoding());
RankedTensorType srcTy = trans.getSrc().getType();

Dialect &dialect = allocEncoding.getDialect();
Expand Down Expand Up @@ -178,113 +176,6 @@ class ReshapeMemDesc : public OpRewritePattern<LocalAllocOp> {
}
};

// Rewrite
// tt.reshape / tt.trans -> local_alloc -> [memdesc views] -> mma
// into
// local_alloc -> memdesc reshape / trans -> [memdesc views] -> mma
//
// The MMA operand layout is determined by the sink memdesc already feeding the
// dot-like op. This pattern back-propagates that layout through the tensor
// reshape/transpose chain, hoists local_alloc to the base tensor feeding that
// view chain, and replays those tensor views as memdesc reshape/transpose
// ops so the original local_alloc type is preserved.
class RewriteMmaOperandViewsToMemDescForDotOp
: public OpInterfaceRewritePattern<triton::DotOpInterface> {
public:
using OpInterfaceRewritePattern<
triton::DotOpInterface>::OpInterfaceRewritePattern;

LogicalResult matchAndRewrite(triton::DotOpInterface dotOp,
PatternRewriter &rewriter) const override {
if (!isa<triton::nvidia_gpu::TCGen5MMAOp,
triton::nvidia_gpu::TCGen5MMAScaledOp,
triton::nvidia_gpu::WarpGroupDotOp>(dotOp))
return failure();

bool changed = false;

if (rewriteOperand(dotOp.getA(), rewriter).succeeded())
changed = true;

if (rewriteOperand(dotOp.getB(), rewriter).succeeded())
changed = true;

return success(changed);
}

private:
LogicalResult rewriteOperand(Value operand, PatternRewriter &rewriter) const {
if (!isa<MemDescType>(operand.getType()))
return failure();

Value beforeTrailing = operand;
while (auto view = beforeTrailing.getDefiningOp()) {
if (auto reshape = dyn_cast<MemDescReshapeOp>(view)) {
beforeTrailing = reshape.getSrc();
continue;
}
if (auto trans = dyn_cast<MemDescTransOp>(view)) {
beforeTrailing = trans.getSrc();
continue;
}
break;
}

auto localAlloc = beforeTrailing.getDefiningOp<LocalAllocOp>();
if (!localAlloc || !localAlloc.getSrc())
return failure();

Value baseTensor = localAlloc.getSrc();
SmallVector<Operation *> tensorReplaySteps;
MemDescType baseMemTy = localAlloc.getType();
while (auto view = baseTensor.getDefiningOp()) {
if (auto reshape = dyn_cast<triton::ReshapeOp>(view)) {
MemDescType srcTy;
auto inferred = MemDescReshapeOp::inferReturnTypes(
getContext(), reshape.getLoc(), baseMemTy,
reshape.getSrc().getType().getShape(), srcTy);
assert(succeeded(inferred) && "backward memdesc reshape inference "
"must succeed");
(void)inferred;
baseMemTy = srcTy;
} else if (auto trans = dyn_cast<triton::TransOp>(view)) {
Attribute srcEnc = inferSrcEncoding(view, baseMemTy.getEncoding());
if (!srcEnc)
return failure();
baseMemTy = MemDescType::get(
trans.getSrc().getType().getShape(), baseMemTy.getElementType(),
srcEnc, baseMemTy.getMemorySpace(), baseMemTy.getMutableMemory());
} else {
break;
}
tensorReplaySteps.push_back(view);
baseTensor = view->getOperand(0);
}
if (tensorReplaySteps.empty())
return failure();

std::reverse(tensorReplaySteps.begin(), tensorReplaySteps.end());

PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(localAlloc);

Value rewritten = LocalAllocOp::create(rewriter, localAlloc.getLoc(),
baseMemTy, baseTensor);
for (Operation *op : tensorReplaySteps) {
if (auto reshape = dyn_cast<triton::ReshapeOp>(op)) {
rewritten = MemDescReshapeOp::create(rewriter, op->getLoc(), rewritten,
reshape.getType().getShape());
} else {
auto trans = cast<triton::TransOp>(op);
rewritten = MemDescTransOp::create(rewriter, op->getLoc(), rewritten,
trans.getOrder());
}
}
rewriter.replaceOp(localAlloc, rewritten);
return success();
}
};

// Inject TMEM copy instructions into IR to efficiently load blocked scales for
// scaled dot
class UseShmemForScales
Expand Down Expand Up @@ -446,7 +337,6 @@ class TritonGPUOptimizeDotOperandsPass
mlir::RewritePatternSet patterns(context);
patterns.add<SwizzleShmemConvert>(context);
patterns.add<FuseTransMMAV3Plus, ReshapeMemDesc>(context);
patterns.add<RewriteMmaOperandViewsToMemDescForDotOp>(context);
patterns.add<UseShmemForScales>(context);
ConvertLayoutOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsGreedily(m, std::move(patterns))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,58 +20,15 @@ class NvidiaGPUAssignDescriptorMemoryLayouts
ArrayRef<unsigned> order,
ttg::CGAEncodingAttr cgaLayout,
Type elementType) override;
Attribute getCompatibleSharedEncoding(Attribute enc, ArrayRef<int64_t> shape,
Type elementType) override;
bool isCompatibleSharedEncoding(Attribute enc) override;
};

bool NvidiaGPUAssignDescriptorMemoryLayouts::isCompatibleSharedEncoding(
Attribute enc) {
if (auto nvmma = dyn_cast<ttg::NVMMASharedEncodingAttr>(enc))
if (auto nvmma = dyn_cast<ttg::NVMMASharedEncodingAttr>(enc)) {
return !nvmma.getTransposed();
return false;
}

Attribute NvidiaGPUAssignDescriptorMemoryLayouts::getCompatibleSharedEncoding(
Attribute enc, ArrayRef<int64_t> shape, Type elementType) {
if (isCompatibleSharedEncoding(enc))
return enc;

auto sharedLinear = dyn_cast<ttg::SharedLinearEncodingAttr>(enc);
if (!sharedLinear)
return {};

auto *ctx = enc.getContext();
auto cgaLayout = ttg::getCGALayout(sharedLinear);
auto order = ttg::getOrder(sharedLinear, shape);

SmallVector<ttg::NVMMASharedEncodingAttr> preferredCandidates;
// TMA descriptors only support non-transposed layouts. Preserve Triton's
// default shape/order-based choice when it already matches this
// shared_linear layout. The full candidate scan below is only a fallback for
// equivalent non-transposed layouts not selected by the heuristic builder.
for (bool fp4Padded : {false, true}) {
auto preferred = ttg::NVMMASharedEncodingAttr::get(
ctx, shape, order, cgaLayout, elementType, fp4Padded);
preferredCandidates.push_back(preferred);
if (ttg::areLayoutsEquivalent(shape, sharedLinear, preferred))
return preferred;
}

unsigned elementBitWidth = std::max(8u, elementType.getIntOrFloatBitWidth());
for (bool fp4Padded : {false, true}) {
for (unsigned swizzle : {0u, 32u, 64u, 128u}) {
auto candidate = ttg::NVMMASharedEncodingAttr::get(
ctx, swizzle, /*transposed=*/false, elementBitWidth, fp4Padded,
cgaLayout);
if (llvm::is_contained(preferredCandidates, candidate))
continue;
if (ttg::areLayoutsEquivalent(shape, sharedLinear, candidate))
return candidate;
}
}

return {};
return false;
}

// Build fallback encoding given shape, order, cga layout and element type
Expand Down
Loading
Loading