Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
8317097
Add swizzle=0 TCGen5 operand-view memdesc rewrite and lit test
masahi Mar 24, 2026
1939857
cmake fix
masahi Mar 24, 2026
7d1e42c
works
masahi Mar 24, 2026
a86d083
make it work for other dot ops
masahi Mar 24, 2026
d2955e7
fix
masahi Mar 24, 2026
28d35fa
fix
masahi Mar 24, 2026
638c3b0
[TritonGPU] Match swizzle0 operand-view rewrite from local_load sourc…
masahi Mar 24, 2026
3375a12
[TritonGPU] Use source shared encoding for swizzle0 operand-view rewrite
masahi Mar 24, 2026
9f559e9
fix
masahi Mar 25, 2026
390b118
clean
masahi Mar 25, 2026
3782068
simplify
masahi Mar 25, 2026
8707f6d
remove pattern matching against desc load
masahi Mar 25, 2026
5ea9724
upd lit test
masahi Mar 25, 2026
12cb8e0
fix
masahi Mar 28, 2026
07119d3
fix for bw
masahi Mar 28, 2026
746c28a
update bw lit
masahi Mar 28, 2026
1d02e00
update for hop
masahi Mar 28, 2026
be6eb93
upd
masahi Mar 28, 2026
0fa2e71
upd
masahi Mar 28, 2026
5e45dac
clean test
masahi Mar 31, 2026
e7d54f8
refactoring operand update
masahi Mar 31, 2026
3291122
wip
masahi Mar 31, 2026
6637c0d
more
masahi Mar 31, 2026
9dcce40
refactor
masahi Mar 31, 2026
9144860
wip
masahi Mar 31, 2026
da8d60c
fix
masahi Mar 31, 2026
a41052a
more clean
masahi Mar 31, 2026
d3eee96
add comment
masahi Mar 31, 2026
b9b6eb4
remove stale include
masahi Mar 31, 2026
0699532
Merge branch 'main' into tma-mma-swizzle-0
masahi Mar 31, 2026
2cda92b
add comment describing the rewrite pattern
masahi Apr 1, 2026
dcf62c0
minor
masahi Apr 6, 2026
6163ab9
Merge branch 'main' into tma-mma-swizzle-0
masahi Apr 6, 2026
8aec72f
revert cmake change
masahi Apr 6, 2026
fbae09b
update comment to make it more accurate
masahi Apr 6, 2026
4b986f3
Merge branch 'main' into tma-mma-swizzle-0
masahi Apr 8, 2026
e01ce66
Make swizzle0 operand view rewrite sink-driven
masahi Apr 8, 2026
c388478
Clean up sink-driven dot operand rewrite
masahi Apr 8, 2026
b9bb708
Refine sink-driven operand rewrite checks
masahi Apr 8, 2026
1133abd
Generalize dot operand view rewrite naming
masahi Apr 8, 2026
ae6782c
Remove stale swizzle0 host descriptor test
masahi Apr 8, 2026
ffa4f6f
revert unnecessary test change
masahi Apr 8, 2026
9679359
Restore template dispatch for dot operand updates
masahi Apr 8, 2026
e315bf2
Use inferSrcEncoding in dot operand rewrite
masahi Apr 8, 2026
02dcdba
Simplify dot operand rewiring after rewrite
masahi Apr 8, 2026
68fe5ac
Move MMA operand view rewrite into NVIDIA pass
masahi Apr 9, 2026
df2f6f9
Simplify MMA operand view rewrite
masahi Apr 9, 2026
52f2848
precommit
masahi Apr 9, 2026
a77c439
Revert to the old backward inference impl, run the pass before ODE
masahi Apr 9, 2026
6e07bb6
pre commit
masahi Apr 9, 2026
9766e51
Merge branch 'main' into tma-mma-swizzle-0
masahi Apr 9, 2026
e093192
Update descriptor rewrite for new tensordesc type
masahi Apr 9, 2026
4f97dc1
Keep descriptor layouts non-transposed
masahi Apr 9, 2026
3dee2de
Simplify MMA operand view replay steps
masahi Apr 9, 2026
f70af5b
Use DotOpInterface in MMA view rewrite
masahi Apr 9, 2026
87eb143
Move MMA operand view rewrite into ODO
masahi Apr 10, 2026
72859e0
precommit
masahi Apr 10, 2026
3130b82
inline helpers
masahi Apr 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ class AssignDescriptorMemoryLayouts {
CGAEncodingAttr cgaLayout,
ArrayRef<int64_t> usageShape,
unsigned numCTAs);

protected:
virtual Attribute getCompatibleSharedEncoding(Attribute enc,
ArrayRef<int64_t> shape,
Type elementType) {
return isCompatibleSharedEncoding(enc) ? enc : Attribute();
}

private:
// Override with backend specific implementation
virtual Attribute buildFallbackSharedEncoding(mlir::MLIRContext *,
ArrayRef<int64_t>,
Expand Down
46 changes: 31 additions & 15 deletions lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,23 +250,36 @@ EncodingInfo AssignDescriptorMemoryLayouts::combineEncodings(

Attribute
AssignDescriptorMemoryLayouts::findLoadEncodingFromUsers(Operation *op) {
auto getCompatibleEncodingForType = [&](Type type) -> Attribute {
if (auto memDescTy = dyn_cast<MemDescType>(type)) {
return getCompatibleSharedEncoding(memDescTy.getEncoding(),
memDescTy.getShape(),
memDescTy.getElementType());
}
if (auto tensorTy = dyn_cast<RankedTensorType>(type)) {
return getCompatibleSharedEncoding(tensorTy.getEncoding(),
tensorTy.getShape(),
tensorTy.getElementType());
}
Comment on lines +259 to +263
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is thsi necessary after #9851?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, this is a type of the user (local_alloc), not a tensordesc

return {};
};

// Check if there are any desired encodings available on the op
if (auto attr = op->getDiscardableAttr("tt.desired_encoding")) {
if (auto enc = dyn_cast<ttg::SharedEncodingTrait>(attr)) {
if (isCompatibleSharedEncoding(enc))
return enc;
}
if (auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType()))
if (auto compatible = getCompatibleSharedEncoding(
attr, resultTy.getShape(), resultTy.getElementType()))
return compatible;
}
// Ignore multiple users and just pick the first compatible layout
for (auto use : op->getUsers()) {
if (auto alloc = dyn_cast<ttg::LocalAllocOp>(use)) {
auto enc = alloc.getType().getEncoding();
if (isCompatibleSharedEncoding(enc))
return enc;
if (auto compatible = getCompatibleEncodingForType(alloc.getType()))
return compatible;
} else if (auto store = dyn_cast<ttg::LocalStoreOp>(use)) {
auto enc = store.getDst().getType().getEncoding();
if (isCompatibleSharedEncoding(enc))
return enc;
if (auto compatible =
getCompatibleEncodingForType(store.getDst().getType()))
return compatible;
}
}
return {};
Expand Down Expand Up @@ -442,7 +455,9 @@ void AssignDescriptorMemoryLayouts::runOnFunction(FuncOp &func) {
auto ctx = func.getContext();
auto numCTAs = triton::gpu::lookupNumCTAs(func);
for (auto &[desc, einfo] : valueToEncodingInfo) {
auto existingTy = desc.getType().getBlockType();
auto descTy = desc.getType();
auto existingTy =
RankedTensorType::get(descTy.getShape(), descTy.getElementType());
Attribute newEncoding;
if (einfo->desiredEncoding) {
newEncoding = einfo->desiredEncoding;
Expand All @@ -460,10 +475,11 @@ void AssignDescriptorMemoryLayouts::runOnFunction(FuncOp &func) {
SmallVector<Type> resultTys(func.getResultTypes());
for (auto [i, resultTy] : llvm::enumerate(resultTys)) {
if (auto descTy = dyn_cast<TensorDescType>(resultTy)) {
auto encoding =
getFallbackSharedEncoding(descTy.getBlockType(), {}, {}, numCTAs);
resultTys[i] = getTensorDescTypeWithEncoding(
nullptr, descTy.getBlockType(), encoding);
auto existingTy =
RankedTensorType::get(descTy.getShape(), descTy.getElementType());
auto encoding = getFallbackSharedEncoding(existingTy, {}, {}, numCTAs);
resultTys[i] =
getTensorDescTypeWithEncoding(nullptr, existingTy, encoding);
}
}
func.setFunctionType(FunctionType::get(ctx, argTys, resultTys));
Expand Down
118 changes: 114 additions & 4 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/LayoutUtils.h"
#include "triton/Tools/LinearLayout.h"
#include <memory>
#include <algorithm>
#include <cassert>

namespace mlir::triton::gpu {

Expand Down Expand Up @@ -108,7 +107,10 @@ class FuseTransMMAV3Plus : public OpRewritePattern<LocalAllocOp> {
return failure();

MemDescType allocType = allocOp.getType();
auto allocEncoding = cast<NVMMASharedEncodingAttr>(allocType.getEncoding());
auto allocEncoding =
dyn_cast<NVMMASharedEncodingAttr>(allocType.getEncoding());
if (!allocEncoding)
return failure();
RankedTensorType srcTy = trans.getSrc().getType();

auto ctx = getContext();
Expand Down Expand Up @@ -180,6 +182,113 @@ class ReshapeMemDesc : public OpRewritePattern<LocalAllocOp> {
}
};

// Rewrite
// tt.reshape / tt.trans -> local_alloc -> [memdesc views] -> mma
// into
// local_alloc -> memdesc reshape / trans -> [memdesc views] -> mma
//
// The MMA operand layout is determined by the sink memdesc already feeding the
// dot-like op. This pattern back-propagates that layout through the tensor
// reshape/transpose chain, hoists local_alloc to the base tensor feeding that
// view chain, and replays those tensor views as memdesc reshape/transpose
// ops so the original local_alloc type is preserved.
class RewriteMmaOperandViewsToMemDescForDotOp
: public OpInterfaceRewritePattern<triton::DotOpInterface> {
public:
using OpInterfaceRewritePattern<
triton::DotOpInterface>::OpInterfaceRewritePattern;

LogicalResult matchAndRewrite(triton::DotOpInterface dotOp,
PatternRewriter &rewriter) const override {
if (!isa<triton::nvidia_gpu::TCGen5MMAOp,
triton::nvidia_gpu::TCGen5MMAScaledOp,
triton::nvidia_gpu::WarpGroupDotOp>(dotOp))
return failure();

bool changed = false;

if (rewriteOperand(dotOp.getA(), rewriter).succeeded())
changed = true;

if (rewriteOperand(dotOp.getB(), rewriter).succeeded())
changed = true;

return success(changed);
}

private:
LogicalResult rewriteOperand(Value operand, PatternRewriter &rewriter) const {
if (!isa<MemDescType>(operand.getType()))
return failure();

Value beforeTrailing = operand;
while (auto view = beforeTrailing.getDefiningOp()) {
if (auto reshape = dyn_cast<MemDescReshapeOp>(view)) {
beforeTrailing = reshape.getSrc();
continue;
}
if (auto trans = dyn_cast<MemDescTransOp>(view)) {
beforeTrailing = trans.getSrc();
continue;
}
break;
}

auto localAlloc = beforeTrailing.getDefiningOp<LocalAllocOp>();
if (!localAlloc || !localAlloc.getSrc())
return failure();

Value baseTensor = localAlloc.getSrc();
SmallVector<Operation *> tensorReplaySteps;
MemDescType baseMemTy = localAlloc.getType();
while (auto view = baseTensor.getDefiningOp()) {
if (auto reshape = dyn_cast<triton::ReshapeOp>(view)) {
MemDescType srcTy;
auto inferred = MemDescReshapeOp::inferReturnTypes(
getContext(), reshape.getLoc(), baseMemTy,
reshape.getSrc().getType().getShape(), srcTy);
assert(succeeded(inferred) && "backward memdesc reshape inference "
"must succeed");
(void)inferred;
baseMemTy = srcTy;
} else if (auto trans = dyn_cast<triton::TransOp>(view)) {
Attribute srcEnc = inferSrcEncoding(view, baseMemTy.getEncoding());
if (!srcEnc)
return failure();
baseMemTy = MemDescType::get(
trans.getSrc().getType().getShape(), baseMemTy.getElementType(),
srcEnc, baseMemTy.getMemorySpace(), baseMemTy.getMutableMemory());
} else {
break;
}
tensorReplaySteps.push_back(view);
baseTensor = view->getOperand(0);
}
if (tensorReplaySteps.empty())
return failure();

std::reverse(tensorReplaySteps.begin(), tensorReplaySteps.end());

PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(localAlloc);

Value rewritten = LocalAllocOp::create(rewriter, localAlloc.getLoc(),
baseMemTy, baseTensor);
for (Operation *op : tensorReplaySteps) {
if (auto reshape = dyn_cast<triton::ReshapeOp>(op)) {
rewritten = MemDescReshapeOp::create(rewriter, op->getLoc(), rewritten,
reshape.getType().getShape());
} else {
auto trans = cast<triton::TransOp>(op);
rewritten = MemDescTransOp::create(rewriter, op->getLoc(), rewritten,
trans.getOrder());
}
}
rewriter.replaceOp(localAlloc, rewritten);
return success();
}
};

// Inject TMEM copy instructions into IR to efficiently load blocked scales for
// scaled dot
class UseShmemForScales
Expand Down Expand Up @@ -341,6 +450,7 @@ class TritonGPUOptimizeDotOperandsPass
mlir::RewritePatternSet patterns(context);
patterns.add<SwizzleShmemConvert>(context);
patterns.add<FuseTransMMAV3Plus, ReshapeMemDesc>(context);
patterns.add<RewriteMmaOperandViewsToMemDescForDotOp>(context);
patterns.add<UseShmemForScales>(context);
ConvertLayoutOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsGreedily(m, std::move(patterns))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,60 @@ class NvidiaGPUAssignDescriptorMemoryLayouts
ArrayRef<unsigned> order,
ttg::CGAEncodingAttr cgaLayout,
Type elementType) override;
Attribute getCompatibleSharedEncoding(Attribute enc, ArrayRef<int64_t> shape,
Type elementType) override;
bool isCompatibleSharedEncoding(Attribute enc) override;
};

bool NvidiaGPUAssignDescriptorMemoryLayouts::isCompatibleSharedEncoding(
Attribute enc) {
if (auto nvmma = dyn_cast<ttg::NVMMASharedEncodingAttr>(enc)) {
if (auto nvmma = dyn_cast<ttg::NVMMASharedEncodingAttr>(enc))
return !nvmma.getTransposed();
}
return false;
}

Attribute NvidiaGPUAssignDescriptorMemoryLayouts::getCompatibleSharedEncoding(
Attribute enc, ArrayRef<int64_t> shape, Type elementType) {
if (isCompatibleSharedEncoding(enc))
return enc;

auto sharedLinear = dyn_cast<ttg::SharedLinearEncodingAttr>(enc);
if (!sharedLinear)
return {};

auto *ctx = enc.getContext();
auto cgaLayout = ttg::getCGALayout(sharedLinear);
auto order = ttg::getOrder(sharedLinear, shape);

SmallVector<ttg::NVMMASharedEncodingAttr> preferredCandidates;
// TMA descriptors only support non-transposed layouts. Preserve Triton's
// default shape/order-based choice when it already matches this
// shared_linear layout. The full candidate scan below is only a fallback for
// equivalent non-transposed layouts not selected by the heuristic builder.
for (bool fp4Padded : {false, true}) {
auto preferred = ttg::NVMMASharedEncodingAttr::get(
ctx, shape, order, cgaLayout, elementType, fp4Padded);
preferredCandidates.push_back(preferred);
if (ttg::areLayoutsEquivalent(shape, sharedLinear, preferred))
return preferred;
}

unsigned elementBitWidth = std::max(8u, elementType.getIntOrFloatBitWidth());
for (bool fp4Padded : {false, true}) {
for (unsigned swizzle : {0u, 32u, 64u, 128u}) {
auto candidate = ttg::NVMMASharedEncodingAttr::get(
ctx, swizzle, /*transposed=*/false, elementBitWidth, fp4Padded,
cgaLayout);
if (llvm::is_contained(preferredCandidates, candidate))
continue;
if (ttg::areLayoutsEquivalent(shape, sharedLinear, candidate))
return candidate;
}
}

return {};
}

// Build fallback encoding given shape, order, cga layout and element type
Attribute NvidiaGPUAssignDescriptorMemoryLayouts::buildFallbackSharedEncoding(
mlir::MLIRContext *ctx, ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
Expand Down
Loading
Loading