Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ test-distributed: all
test-gluon: all
$(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon
$(PYTEST) --tb=short -vs python/examples/gluon/01-attention-forward.py
$(PYTEST) --tb=short -n $(NUM_PROCS) -vs python/tutorials/gluon

.PHONY: test-regression
test-regression: all
Expand Down
3 changes: 1 addition & 2 deletions include/triton/Analysis/Membar.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ struct AllocationSlice {
private:
std::tuple<Interval<size_t>, const void *, llvm::ArrayRef<int64_t>>
asTuple() const {
return std::make_tuple(allocationInterval, accessTy.getAsOpaquePointer(),
subsliceOffsets);
return {allocationInterval, accessTy.getAsOpaquePointer(), subsliceOffsets};
}
// Offsets from subslice. Empty when offsets are unknown
SmallVector<int64_t> subsliceOffsets;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ inline bool isFp4Padded(Attribute encoding) {
return mmaEnc && mmaEnc.getFp4Padded();
}

SmallVector<Value> translateTMAIndices(OpBuilder &builder, Location loc,
Attribute encoding,
SmallVector<Value> indices);

gpu::CGAEncodingAttr updateCGALayoutForShape(gpu::CGAEncodingAttr cgaLayout,
ArrayRef<int64_t> shape);

Expand Down
14 changes: 5 additions & 9 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,11 @@ void createTMAAsyncLoad(scf::ForOp forOp, tt::DescriptorLoadOp loadOp,
return createTMAAsyncCopy(
forOp, loadOp, loadOp.getDesc(), alloc, insertIdx, extractIdx, barrier,
waitOp, schedule,
[&](OpBuilderForStage &builder, Value tmaPtr, Value barrier, Value view,
[&](OpBuilderForStage &builder, Value desc, Value barrier, Value view,
Value pred) {
auto indices = ttng::translateTMAIndices(
builder, loadOp.getLoc(),
loadOp.getDesc().getType().getBlockType().getEncoding(),
loadOp.getIndices());
ttng::AsyncTMACopyGlobalToLocalOp::create(
builder, loadOp.getLoc(), /*multicastTargets*/ Value(), tmaPtr,
indices, barrier, view, pred);
builder, loadOp.getLoc(), /*multicastTargets*/ Value(), desc,
loadOp.getIndices(), barrier, view, pred);
});
}

Expand All @@ -262,10 +258,10 @@ void createTMAAsyncGather(scf::ForOp forOp, tt::DescriptorGatherOp gatherOp,
CoarseSchedule &schedule) {
return createTMAAsyncCopy(forOp, gatherOp, gatherOp.getDesc(), alloc,
insertIdx, extractIdx, barrier, waitOp, schedule,
[&](OpBuilderForStage &builder, Value tmaPtr,
[&](OpBuilderForStage &builder, Value desc,
Value barrier, Value view, Value pred) {
ttng::AsyncTMAGatherOp::create(
builder, gatherOp.getLoc(), tmaPtr,
builder, gatherOp.getLoc(), desc,
gatherOp.getXOffsets(), gatherOp.getYOffset(),
barrier, view, pred);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,9 @@ static void createTMAAsyncCopy(scf::ForOp forOp, const TMAStore &store,
ttng::FenceAsyncSharedOp::create(builder, loc, false);
auto desc = store.desc;
if (auto storeOp = dyn_cast<tt::DescriptorStoreOp>(store.op)) {
auto indices = ttng::translateTMAIndices(
builder, storeOp.getLoc(),
storeOp.getDesc().getType().getBlockType().getEncoding(),
storeOp.getIndices());
ttng::AsyncTMACopyLocalToGlobalOp::create(builder, loc, desc,
storeOp.getIndices(), alloc);
} else if (auto reduceOp = dyn_cast<tt::DescriptorReduceOp>(store.op)) {
auto indices = ttng::translateTMAIndices(
builder, reduceOp.getLoc(),
reduceOp.getDesc().getType().getBlockType().getEncoding(),
reduceOp.getIndices());
ttng::AsyncTMAReduceOp::create(builder, loc, reduceOp.getKind(), desc,
reduceOp.getIndices(), alloc,
triton::EvictionPolicy::NORMAL);
Expand Down
44 changes: 40 additions & 4 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2009,10 +2009,25 @@ class TritonGPURemoveLayoutConversionsPass
}
continue;
}
// TODO: propagate through scf.yield by updating parent op result
// types, scf.for iter_args, and init values to match srcEnc.
if (isa<scf::YieldOp>(user))
// scf.yield passes values through to the parent op's results.
// For ForOp/WhileOp, the parent results are tied to block arguments
// and init operands via loop-carried dependencies — in-place type
// rewriting cannot safely update all of them, so block propagation.
// For IfOp, the results are simple branches with no loop-carried
// deps, so propagation is safe if we also follow the IfOp results.
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
Operation *parent = yieldOp->getParentOp();
if (isa<scf::ForOp, scf::WhileOp>(parent))
return false;
if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
for (Value result : ifOp.getResults()) {
if (isa<RankedTensorType>(result.getType()))
worklist.push_back(result);
}
continue;
}
return false;
}
// Any other user (dot, reduce, another convert, etc.) blocks
// propagation.
return false;
Expand All @@ -2034,6 +2049,7 @@ class TritonGPURemoveLayoutConversionsPass

// Collect all ops that need type rewriting (forward from convert users).
SmallVector<Operation *> opsToRewrite;
SetVector<Operation *> ifOpsToRewrite;
SmallVector<Value> worklist = {dst};
DenseSet<Value> visited;

Expand All @@ -2043,8 +2059,20 @@ class TritonGPURemoveLayoutConversionsPass
continue;
for (OpOperand &use : v.getUses()) {
Operation *user = use.getOwner();
if (isa<LocalStoreOp>(user) || isa<scf::YieldOp>(user))
if (isa<LocalStoreOp>(user))
continue;
// For scf.yield under scf.if, follow through to the IfOp results.
// ForOp/WhileOp yields are blocked by canPropagateSrcEncodingThroughUsers.
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
ifOpsToRewrite.insert(ifOp.getOperation());
for (Value result : ifOp.getResults()) {
if (isa<RankedTensorType>(result.getType()))
worklist.push_back(result);
}
}
continue;
}
opsToRewrite.push_back(user);
for (Value result : user->getResults()) {
if (isa<RankedTensorType>(result.getType()))
Expand Down Expand Up @@ -2116,6 +2144,14 @@ class TritonGPURemoveLayoutConversionsPass
}
}
}
// Rewrite IfOp result types that we propagated through.
for (Operation *op : ifOpsToRewrite) {
for (Value result : op->getResults()) {
if (auto ty = dyn_cast<RankedTensorType>(result.getType())) {
result.setType(ty.cloneWithEncoding(srcEnc));
}
}
}

// Replace all uses of the convert result with the convert source.
dst.replaceAllUsesWith(src);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,10 @@ static void lowerTMACopy(PartitionBuilder &b, Partition &loadPartition,
Value barrier, Value view) {
Value truePred = b.boolCst(true);
if (auto load = dyn_cast<DescriptorLoadOp>(op)) {
auto indices = ttng::translateTMAIndices(
b, load.getLoc(), load.getDesc().getType().getBlockType().getEncoding(),
load.getIndices());
b.createInto<ttng::AsyncTMACopyGlobalToLocalOp>(
loadPartition, stageCluster,
/*multicastTargets*/ Value(), load.getDesc(), indices, barrier, view,
truePred);
/*multicastTargets*/ Value(), load.getDesc(), load.getIndices(),
barrier, view, truePred);
} else {
auto gather = cast<DescriptorGatherOp>(op);
b.createInto<ttng::AsyncTMAGatherOp>(
Expand Down
35 changes: 12 additions & 23 deletions lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,11 @@ class TMALoadLowering : public OpRewritePattern<DescriptorLoadOp> {
LogicalResult matchAndRewrite(DescriptorLoadOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto createLoad = [&](Value tmaPtr, Value barrierAlloc, Value alloc,
auto createLoad = [&](Value desc, Value barrierAlloc, Value alloc,
Value pred) {
auto indices = translateTMAIndices(
rewriter, op.getLoc(),
op.getDesc().getType().getBlockType().getEncoding(), op.getIndices());
triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp::create(
rewriter, op.getLoc(), /*multicastTargets*/ Value(), tmaPtr, indices,
barrierAlloc, alloc, pred);
rewriter, op.getLoc(), /*multicastTargets*/ Value(), desc,
op.getIndices(), barrierAlloc, alloc, pred);
};
lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter);
return success();
Expand All @@ -87,10 +84,10 @@ struct TMAGatherLowering : public OpRewritePattern<DescriptorGatherOp> {

LogicalResult matchAndRewrite(DescriptorGatherOp op,
PatternRewriter &rewriter) const override {
auto createLoad = [&](Value tmaPtr, Value barrierAlloc, Value alloc,
auto createLoad = [&](Value desc, Value barrierAlloc, Value alloc,
Value pred) {
triton::nvidia_gpu::AsyncTMAGatherOp::create(
rewriter, op.getLoc(), tmaPtr, op.getXOffsets(), op.getYOffset(),
rewriter, op.getLoc(), desc, op.getXOffsets(), op.getYOffset(),
barrierAlloc, alloc, pred);
};
lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter);
Expand Down Expand Up @@ -148,13 +145,9 @@ struct TMAStoreLowering : public OpRewritePattern<DescriptorStoreOp> {

LogicalResult matchAndRewrite(DescriptorStoreOp op,
PatternRewriter &rewriter) const override {
auto createStore = [&](Value tmaPtr, Value alloc) {
auto indices = translateTMAIndices(
rewriter, op.getLoc(),
op.getDesc().getType().getBlockType().getEncoding(), op.getIndices());
auto createStore = [&](Value desc, Value alloc) {
triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp::create(
rewriter, op.getLoc(), tmaPtr, indices, alloc,
triton::EvictionPolicy::NORMAL);
rewriter, op.getLoc(), desc, op.getIndices(), alloc);
};
lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter);
return success();
Expand All @@ -166,13 +159,9 @@ struct TMAReduceLowering : public OpRewritePattern<DescriptorReduceOp> {

LogicalResult matchAndRewrite(DescriptorReduceOp op,
PatternRewriter &rewriter) const override {
auto createStore = [&](Value tmaPtr, Value alloc) {
auto indices = translateTMAIndices(
rewriter, op.getLoc(),
op.getDesc().getType().getBlockType().getEncoding(), op.getIndices());
auto createStore = [&](Value desc, Value alloc) {
triton::nvidia_gpu::AsyncTMAReduceOp::create(
rewriter, op.getLoc(), op.getKind(), tmaPtr, indices, alloc,
triton::EvictionPolicy::NORMAL);
rewriter, op.getLoc(), op.getKind(), desc, op.getIndices(), alloc);
};
lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter);
return success();
Expand All @@ -184,9 +173,9 @@ struct TMAScatterLowering : public OpRewritePattern<DescriptorScatterOp> {

LogicalResult matchAndRewrite(DescriptorScatterOp op,
PatternRewriter &rewriter) const override {
auto createStore = [&](Value tmaPtr, Value alloc) {
triton::nvidia_gpu::AsyncTMAScatterOp::create(rewriter, op.getLoc(),
tmaPtr, op.getXOffsets(),
auto createStore = [&](Value desc, Value alloc) {
triton::nvidia_gpu::AsyncTMAScatterOp::create(rewriter, op.getLoc(), desc,
op.getXOffsets(),
op.getYOffset(), alloc);
};
lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter);
Expand Down
10 changes: 0 additions & 10 deletions lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,6 @@ namespace ttg = mlir::triton::gpu;

namespace mlir::triton::nvidia_gpu {

SmallVector<Value> translateTMAIndices(OpBuilder &builder, Location loc,
Attribute encoding,
SmallVector<Value> indices) {
if (isFp4Padded(encoding)) {
auto two = arith::ConstantIntOp::create(builder, loc, 2, 32);
indices.back() = arith::MulIOp::create(builder, loc, indices.back(), two);
}
return indices;
}

ttg::CGAEncodingAttr updateCGALayoutForShape(ttg::CGAEncodingAttr cgaLayout,
ArrayRef<int64_t> shape) {
auto rank = shape.size();
Expand Down
Loading
Loading