Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2009,10 +2009,25 @@ class TritonGPURemoveLayoutConversionsPass
}
continue;
}
// TODO: propagate through scf.yield by updating parent op result
// types, scf.for iter_args, and init values to match srcEnc.
if (isa<scf::YieldOp>(user))
// scf.yield passes values through to the parent op's results.
// For ForOp/WhileOp, the parent results are tied to block arguments
// and init operands via loop-carried dependencies — in-place type
// rewriting cannot safely update all of them, so block propagation.
// For IfOp, the results are simple branches with no loop-carried
// deps, so propagation is safe if we also follow the IfOp results.
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
Operation *parent = yieldOp->getParentOp();
if (isa<scf::ForOp, scf::WhileOp>(parent))
return false;
if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
for (Value result : ifOp.getResults()) {
if (isa<RankedTensorType>(result.getType()))
worklist.push_back(result);
}
continue;
}
return false;
}
// Any other user (dot, reduce, another convert, etc.) blocks
// propagation.
return false;
Expand All @@ -2034,6 +2049,7 @@ class TritonGPURemoveLayoutConversionsPass

// Collect all ops that need type rewriting (forward from convert users).
SmallVector<Operation *> opsToRewrite;
SetVector<Operation *> ifOpsToRewrite;
SmallVector<Value> worklist = {dst};
DenseSet<Value> visited;

Expand All @@ -2043,8 +2059,20 @@ class TritonGPURemoveLayoutConversionsPass
continue;
for (OpOperand &use : v.getUses()) {
Operation *user = use.getOwner();
if (isa<LocalStoreOp>(user) || isa<scf::YieldOp>(user))
if (isa<LocalStoreOp>(user))
continue;
// For scf.yield under scf.if, follow through to the IfOp results.
// ForOp/WhileOp yields are blocked by canPropagateSrcEncodingThroughUsers.
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
ifOpsToRewrite.insert(ifOp.getOperation());
for (Value result : ifOp.getResults()) {
if (isa<RankedTensorType>(result.getType()))
worklist.push_back(result);
}
}
continue;
}
opsToRewrite.push_back(user);
for (Value result : user->getResults()) {
if (isa<RankedTensorType>(result.getType()))
Expand Down Expand Up @@ -2116,6 +2144,14 @@ class TritonGPURemoveLayoutConversionsPass
}
}
}
// Rewrite IfOp result types that we propagated through.
for (Operation *op : ifOpsToRewrite) {
for (Value result : op->getResults()) {
if (auto ty = dyn_cast<RankedTensorType>(result.getType())) {
result.setType(ty.cloneWithEncoding(srcEnc));
}
}
}

// Replace all uses of the convert result with the convert source.
dst.replaceAllUsesWith(src);
Expand Down
4 changes: 3 additions & 1 deletion third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,9 @@ def make_ttgir(mod, metadata, opt, capability):
# Budget-aware layout conversion elimination — runs last to ensure
# converts whose scratch would exceed SMEM budget are eliminated
# after all other passes that may introduce layout conversions.
passes.ttgpuir.add_remove_layout_conversions(pm, smem_budget)
# TODO(njriasan): Re-enable once propagateSrcEncodingAndErase handles
# scf::ForOp/WhileOp loop-carried values correctly.
passes.ttgpuir.add_remove_layout_conversions(pm, 0)

pm.run(mod, 'make_ttgir')
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
Expand Down
Loading