Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
8317097
Add swizzle=0 TCGen5 operand-view memdesc rewrite and lit test
masahi Mar 24, 2026
1939857
cmake fix
masahi Mar 24, 2026
7d1e42c
works
masahi Mar 24, 2026
a86d083
make it work for other dot ops
masahi Mar 24, 2026
d2955e7
fix
masahi Mar 24, 2026
28d35fa
fix
masahi Mar 24, 2026
638c3b0
[TritonGPU] Match swizzle0 operand-view rewrite from local_load sourc…
masahi Mar 24, 2026
3375a12
[TritonGPU] Use source shared encoding for swizzle0 operand-view rewrite
masahi Mar 24, 2026
9f559e9
fix
masahi Mar 25, 2026
390b118
clean
masahi Mar 25, 2026
3782068
simplify
masahi Mar 25, 2026
8707f6d
remove pattern matching against desc load
masahi Mar 25, 2026
5ea9724
upd lit test
masahi Mar 25, 2026
12cb8e0
fix
masahi Mar 28, 2026
07119d3
fix for bw
masahi Mar 28, 2026
746c28a
update bw lit
masahi Mar 28, 2026
1d02e00
update for hop
masahi Mar 28, 2026
be6eb93
upd
masahi Mar 28, 2026
0fa2e71
upd
masahi Mar 28, 2026
5e45dac
clean test
masahi Mar 31, 2026
e7d54f8
refactoring operand update
masahi Mar 31, 2026
3291122
wip
masahi Mar 31, 2026
6637c0d
more
masahi Mar 31, 2026
9dcce40
refactor
masahi Mar 31, 2026
9144860
wip
masahi Mar 31, 2026
da8d60c
fix
masahi Mar 31, 2026
a41052a
more clean
masahi Mar 31, 2026
d3eee96
add comment
masahi Mar 31, 2026
b9b6eb4
remove stale include
masahi Mar 31, 2026
0699532
Merge branch 'main' into tma-mma-swizzle-0
masahi Mar 31, 2026
2cda92b
add comment describing the rewrite pattern
masahi Apr 1, 2026
dcf62c0
minor
masahi Apr 6, 2026
6163ab9
Merge branch 'main' into tma-mma-swizzle-0
masahi Apr 6, 2026
8aec72f
revert cmake change
masahi Apr 6, 2026
fbae09b
update comment to make it more accurate
masahi Apr 6, 2026
4b986f3
Merge branch 'main' into tma-mma-swizzle-0
masahi Apr 8, 2026
e01ce66
Make swizzle0 operand view rewrite sink-driven
masahi Apr 8, 2026
c388478
Clean up sink-driven dot operand rewrite
masahi Apr 8, 2026
b9bb708
Refine sink-driven operand rewrite checks
masahi Apr 8, 2026
1133abd
Generalize dot operand view rewrite naming
masahi Apr 8, 2026
ae6782c
Remove stale swizzle0 host descriptor test
masahi Apr 8, 2026
ffa4f6f
revert unnecessary test change
masahi Apr 8, 2026
9679359
Restore template dispatch for dot operand updates
masahi Apr 8, 2026
e315bf2
Use inferSrcEncoding in dot operand rewrite
masahi Apr 8, 2026
02dcdba
Simplify dot operand rewiring after rewrite
masahi Apr 8, 2026
68fe5ac
Move MMA operand view rewrite into NVIDIA pass
masahi Apr 9, 2026
df2f6f9
Simplify MMA operand view rewrite
masahi Apr 9, 2026
52f2848
precommit
masahi Apr 9, 2026
a77c439
Revert to the old backward inference impl, run the pass before ODE
masahi Apr 9, 2026
6e07bb6
pre commit
masahi Apr 9, 2026
9766e51
Merge branch 'main' into tma-mma-swizzle-0
masahi Apr 9, 2026
e093192
Update descriptor rewrite for new tensordesc type
masahi Apr 9, 2026
4f97dc1
Keep descriptor layouts non-transposed
masahi Apr 9, 2026
3dee2de
Simplify MMA operand view replay steps
masahi Apr 9, 2026
f70af5b
Use DotOpInterface in MMA view rewrite
masahi Apr 9, 2026
87eb143
Move MMA operand view rewrite into ODO
masahi Apr 10, 2026
72859e0
precommit
masahi Apr 10, 2026
3130b82
inline helpers
masahi Apr 10, 2026
6879828
Merge branch 'tma-mma-swizzle-0' into swizzle-0-fix
masahi Apr 15, 2026
a3e56e0
[TritonNvidiaGPU] Avoid fatal ODE layout probes
masahi Apr 15, 2026
102193e
[TritonGPU] Restrict operand-view rewrite to shared_linear
masahi Apr 16, 2026
30f12e1
Merge branch 'main' into swizzle-0-fix
masahi Apr 27, 2026
62b3a08
format
masahi Apr 27, 2026
de8af7a
[TritonGPU] Simplify TMA block shape diagnostics
masahi Apr 28, 2026
2bb6ccd
[TritonGPU] Simplify TMA block shape error helper
masahi Apr 28, 2026
6963198
[TritonGPU] Drop stale TMA helper suffixes
masahi Apr 28, 2026
ed11c36
minor change in LinearLayoutConversions.cpp
masahi Apr 28, 2026
ff09a8a
inline error emit
masahi Apr 28, 2026
b7f1acd
more inline error msg
masahi Apr 28, 2026
0171037
remove tryGetTMABlockShape
masahi Apr 28, 2026
58c3b95
removed tryNvmmaSharedToLinearLayout by adding a safe version of nvmm…
masahi Apr 28, 2026
157f580
Merge branch 'main' into swizzle-0-fix
masahi Apr 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
NVMMASharedEncodingAttr shared,
TMAMode mode,
bool disableSwizzle = false);
FailureOr<LinearLayout>
nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
NVMMASharedEncodingAttr shared, TMAMode mode,
bool disableSwizzle, bool emitErrors);

// Given a linear layout where the input dimensions contain a "block" dimension,
// this method sets the "block" dimension to 0 and removes the corresponding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ class AssignDescriptorMemoryLayouts {
CGAEncodingAttr cgaLayout,
ArrayRef<int64_t> usageShape,
unsigned numCTAs);

protected:
virtual Attribute getCompatibleSharedEncoding(Attribute enc,
ArrayRef<int64_t> shape,
Type elementType) {
return isCompatibleSharedEncoding(enc) ? enc : Attribute();
}

private:
// Override with backend specific implementation
virtual Attribute buildFallbackSharedEncoding(mlir::MLIRContext *,
ArrayRef<int64_t>,
Expand Down
43 changes: 27 additions & 16 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4360,10 +4360,13 @@ getTMABlockShapeIm2Col(ArrayRef<int64_t> shapePerCTA, int elementBitWidth,
// H, W). Supporting pixelsPerColumn > 1024 would require computing offsets
// that depend on input tensor shape and padding, which is non-trivial.
if (blockShape[otherDim] > otherDimMax) {
return emitError() << "im2col mode: pixelsPerColumn dimension "
<< blockShape[otherDim]
<< " exceeds the maximum supported value of "
<< otherDimMax;
if (emitError) {
emitError() << Twine("im2col mode: pixelsPerColumn dimension ") +
Twine(blockShape[otherDim]) +
" exceeds the maximum supported value of " +
Twine(otherDimMax);
}
return failure();
}

// Clamp the contiguous dimension (channelsPerPixel) to max 256
Expand All @@ -4373,12 +4376,16 @@ getTMABlockShapeIm2Col(ArrayRef<int64_t> shapePerCTA, int elementBitWidth,
if (swizzleBytes != 0) {
auto contigDimSize = (8 * swizzleBytes) / elementBitWidth;
if (blockShape[contigDim] < contigDimSize) {
return emitError() << "im2col mode: block shape along the contiguous "
"dimension "
<< contigDim
<< " is too small for the swizzle byte size "
<< swizzleBytes << ", got " << blockShape[contigDim]
<< " but expected at least " << contigDimSize;
if (emitError) {
emitError() << Twine("im2col mode: block shape along the contiguous "
"dimension ") +
Twine(contigDim) +
" is too small for the swizzle byte size " +
Twine(swizzleBytes) + ", got " +
Twine(blockShape[contigDim]) +
" but expected at least " + Twine(contigDimSize);
}
return failure();
}
blockShape[contigDim] = contigDimSize;
}
Expand Down Expand Up @@ -4409,12 +4416,16 @@ getTMABlockShapeTiled(ArrayRef<int64_t> shapePerCTA, int elementBitWidth,
if (swizzleBytes != 0) {
auto contigDimSize = (8 * swizzleBytes) / elementBitWidth;
if (blockShape[contigDim] < contigDimSize) {
return emitError() << "block shape along the contiguous dimension "
<< contigDim
<< " is too small for the swizzle byte size "
<< swizzleBytes << " in an NVMMASharedLayout, got "
<< blockShape[contigDim] << " but expected at least "
<< contigDimSize;
if (emitError) {
emitError() << Twine("block shape along the contiguous dimension ") +
Twine(contigDim) +
" is too small for the swizzle byte size " +
Twine(swizzleBytes) +
" in an NVMMASharedLayout, got " +
Twine(blockShape[contigDim]) +
" but expected at least " + Twine(contigDimSize);
}
return failure();
}
blockShape[contigDim] = contigDimSize;
}
Expand Down
68 changes: 55 additions & 13 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,15 @@ LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
return LinearLayout({{S("offset"), bases2D}}, outDimNames);
}

LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
NVMMASharedEncodingAttr shared,
TMAMode mode, bool disableSwizzle) {
static FailureOr<LinearLayout> buildNvmmaSharedLinearLayout(
ArrayRef<int64_t> shape, NVMMASharedEncodingAttr shared,
ArrayRef<int64_t> tmaShape, bool disableSwizzle, bool emitErrors) {
if (!llvm::all_of(tmaShape, llvm::isPowerOf2_64))
return failure();
MLIRContext *ctx = shared.getContext();
int rank = shape.size();
auto shapePerCTA = getShapePerCTA(shared, shape);
auto kOffset = S("offset");
auto tmaShape =
triton::nvidia_gpu::getTMABlockShape(shared, shapePerCTA,
/*packedSize=*/true, mode);
if (shared.getSwizzlingByteWidth() == 0) {
auto outDimNames = standardOutDimNames(ctx, rank);
LinearLayout layout = LinearLayout::identity1D(tmaShape[rank - 1], kOffset,
Expand Down Expand Up @@ -234,20 +233,23 @@ LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
int packingFactor = shared.getFp4Padded() ? 2 : 1;
if (collapsedTmaShape[1] * packingFactor < tileCols ||
collapsedTmaShape[0] < tileRows) {
llvm::errs() << "Illegal shared layout; expected collapsed shapePerCTA to "
"be at least ["
<< tileRows << ", " << (tileCols / packingFactor)
<< "], collapsedTmaShape: [" << collapsedTmaShape[0] << ", "
<< collapsedTmaShape[1] << "]\n";
llvm::report_fatal_error("Illegal shared layout");
if (emitErrors) {
llvm::errs() << "Illegal shared layout; expected collapsed shapePerCTA "
"to be at least ["
<< tileRows << ", " << (tileCols / packingFactor)
<< "], collapsedTmaShape: [" << collapsedTmaShape[0] << ", "
<< collapsedTmaShape[1] << "]\n";
}
return failure();
}

// Distribute the remaining rows and cols.
auto layout =
ensureLayoutNotSmallerThan(tileLayout, outDimNames, collapsedTmaShape);

// Reshape the layout to the N-D pre-transposed shape per CTA.
SmallVector<int64_t> maybeTransposedTmaShape = tmaShape;
SmallVector<int64_t> maybeTransposedTmaShape(tmaShape.begin(),
tmaShape.end());
if (shared.getTransposed()) {
// Move the outer dim to the inner position.
// TODO: we should move back to using `order` instead of transposed to make
Expand All @@ -256,6 +258,10 @@ LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
maybeTransposedTmaShape.begin() + 1,
maybeTransposedTmaShape.end());
}
// This condition can fail if a layout is speculatively constructed for
// equivalence checking.
if (layout.getTotalOutDimSize() != product(maybeTransposedTmaShape))
return failure();
Comment on lines +261 to +264
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where can this happen exactly? it feels like quite a big issue.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A concrete test case that fails this condition is this one: https://github.com/masahi/triton/blob/58c3b956958f572e1f6bfa3ddbd865c9cac40763/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir#L180-L188

We call buildNvmmaSharedLinearLayout with shape [1, 16, 1, 16] and various candidates nvmma_shared encodings. For some candidates, it seems ensureLayoutNotSmallerThan can return a layout that covers more than [1, 16, 1, 16] output elements.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still get the feeling that there's a better place to catch this one than this late, but sure.

auto reshapedLayout = reshapeLayout(ctx, layout, maybeTransposedTmaShape);

if (shared.getTransposed()) {
Expand All @@ -272,6 +278,42 @@ LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
return combineCtaCgaWithShape(reshapedLayout, shared.getCGALayout(), shape);
}

LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
NVMMASharedEncodingAttr shared,
TMAMode mode, bool disableSwizzle) {
auto layout = nvmmaSharedToLinearLayout(shape, shared, mode, disableSwizzle,
/*emitErrors=*/true);
if (failed(layout))
llvm::report_fatal_error("Illegal shared layout");
Comment on lines +286 to +287
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is fine to keep that along with the emitError

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope the current code after 58c3b95 has addressed this comment

return *layout;
}

FailureOr<LinearLayout>
nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
NVMMASharedEncodingAttr shared, TMAMode mode,
bool disableSwizzle, bool emitErrors) {
auto shapePerCTA = getShapePerCTA(shared, shape);
SmallVector<int64_t> tmaShape;
if (emitErrors) {
tmaShape =
getTMABlockShape(shapePerCTA, shared.getElementBitWidth(),
shared.getSwizzlingByteWidth(), shared.getFp4Padded(),
shared.getTransposed(), /*packedSize=*/true, mode);
} else {
auto maybeTmaShape =
getTMABlockShape(shapePerCTA, shared.getElementBitWidth(),
shared.getSwizzlingByteWidth(), shared.getFp4Padded(),
shared.getTransposed(), /*packedSize=*/true,
/*emitError=*/nullptr, mode);
if (failed(maybeTmaShape))
return failure();
tmaShape = *maybeTmaShape;
}

return buildNvmmaSharedLinearLayout(shape, shared, tmaShape, disableSwizzle,
emitErrors);
}

/// Function to generate lane and warp layout for dot operands.
static LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx,
ArrayRef<unsigned> shape,
Expand Down
46 changes: 31 additions & 15 deletions lib/Dialect/TritonGPU/Transforms/DescriptorMemoryLayouts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,23 +254,36 @@ EncodingInfo AssignDescriptorMemoryLayouts::combineEncodings(

Attribute
AssignDescriptorMemoryLayouts::findLoadEncodingFromUsers(Operation *op) {
auto getCompatibleEncodingForType = [&](Type type) -> Attribute {
if (auto memDescTy = dyn_cast<MemDescType>(type)) {
return getCompatibleSharedEncoding(memDescTy.getEncoding(),
memDescTy.getShape(),
memDescTy.getElementType());
}
if (auto tensorTy = dyn_cast<RankedTensorType>(type)) {
return getCompatibleSharedEncoding(tensorTy.getEncoding(),
tensorTy.getShape(),
tensorTy.getElementType());
}
return {};
};

// Check if there are any desired encodings available on the op
if (auto attr = op->getDiscardableAttr("tt.desired_encoding")) {
if (auto enc = dyn_cast<ttg::SharedEncodingTrait>(attr)) {
if (isCompatibleSharedEncoding(enc))
return enc;
}
if (auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType()))
if (auto compatible = getCompatibleSharedEncoding(
attr, resultTy.getShape(), resultTy.getElementType()))
return compatible;
}
// Ignore multiple users and just pick the first compatible layout
for (auto use : op->getUsers()) {
if (auto alloc = dyn_cast<ttg::LocalAllocOp>(use)) {
auto enc = alloc.getType().getEncoding();
if (isCompatibleSharedEncoding(enc))
return enc;
if (auto compatible = getCompatibleEncodingForType(alloc.getType()))
return compatible;
} else if (auto store = dyn_cast<ttg::LocalStoreOp>(use)) {
auto enc = store.getDst().getType().getEncoding();
if (isCompatibleSharedEncoding(enc))
return enc;
if (auto compatible =
getCompatibleEncodingForType(store.getDst().getType()))
return compatible;
}
}
return {};
Expand Down Expand Up @@ -436,7 +449,9 @@ void AssignDescriptorMemoryLayouts::runOnFunction(FuncOp &func) {
auto ctx = func.getContext();
auto numCTAs = triton::gpu::lookupNumCTAs(func);
for (auto &[desc, einfo] : valueToEncodingInfo) {
auto existingTy = desc.getType().getBlockType();
auto descTy = desc.getType();
auto existingTy =
RankedTensorType::get(descTy.getShape(), descTy.getElementType());
Attribute newEncoding;
if (einfo->desiredEncoding) {
newEncoding = einfo->desiredEncoding;
Expand All @@ -454,10 +469,11 @@ void AssignDescriptorMemoryLayouts::runOnFunction(FuncOp &func) {
SmallVector<Type> resultTys(func.getResultTypes());
for (auto [i, resultTy] : llvm::enumerate(resultTys)) {
if (auto descTy = dyn_cast<TensorDescType>(resultTy)) {
auto encoding =
getFallbackSharedEncoding(descTy.getBlockType(), {}, {}, numCTAs);
resultTys[i] = getTensorDescTypeWithEncoding(
nullptr, descTy.getBlockType(), encoding);
auto existingTy =
RankedTensorType::get(descTy.getShape(), descTy.getElementType());
auto encoding = getFallbackSharedEncoding(existingTy, {}, {}, numCTAs);
resultTys[i] =
getTensorDescTypeWithEncoding(nullptr, existingTy, encoding);
}
}
func.setFunctionType(FunctionType::get(ctx, argTys, resultTys));
Expand Down
Loading
Loading