Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ test-distributed: all
test-gluon: all
$(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon
$(PYTEST) --tb=short -vs python/examples/gluon/01-attention-forward.py
$(PYTEST) --tb=short -n $(NUM_PROCS) -vs python/tutorials/gluon

.PHONY: test-regression
test-regression: all
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ inline bool isFp4Padded(Attribute encoding) {
return mmaEnc && mmaEnc.getFp4Padded();
}

SmallVector<Value> translateTMAIndices(OpBuilder &builder, Location loc,
Attribute encoding,
SmallVector<Value> indices);

gpu::CGAEncodingAttr updateCGALayoutForShape(gpu::CGAEncodingAttr cgaLayout,
ArrayRef<int64_t> shape);

Expand Down
24 changes: 10 additions & 14 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,18 +243,14 @@ void createTMAAsyncLoad(scf::ForOp forOp, tt::DescriptorLoadOp loadOp,
Value alloc, Value insertIdx, Value extractIdx,
Value barrier, Operation *waitOp,
CoarseSchedule &schedule) {
return createTMAAsyncCopy(
forOp, loadOp, loadOp.getDesc(), alloc, insertIdx, extractIdx, barrier,
waitOp, schedule,
[&](OpBuilderForStage &builder, Value tmaPtr, Value barrier, Value view,
Value pred) {
auto indices = ttng::translateTMAIndices(
builder, loadOp.getLoc(),
loadOp.getDesc().getType().getBlockType().getEncoding(),
loadOp.getIndices());
ttng::AsyncTMACopyGlobalToLocalOp::create(
builder, loadOp.getLoc(), tmaPtr, indices, barrier, view, pred);
});
return createTMAAsyncCopy(forOp, loadOp, loadOp.getDesc(), alloc, insertIdx,
extractIdx, barrier, waitOp, schedule,
[&](OpBuilderForStage &builder, Value desc,
Value barrier, Value view, Value pred) {
ttng::AsyncTMACopyGlobalToLocalOp::create(
builder, loadOp.getLoc(), desc,
loadOp.getIndices(), barrier, view, pred);
});
}

void createTMAAsyncGather(scf::ForOp forOp, tt::DescriptorGatherOp gatherOp,
Expand All @@ -263,10 +259,10 @@ void createTMAAsyncGather(scf::ForOp forOp, tt::DescriptorGatherOp gatherOp,
CoarseSchedule &schedule) {
return createTMAAsyncCopy(forOp, gatherOp, gatherOp.getDesc(), alloc,
insertIdx, extractIdx, barrier, waitOp, schedule,
[&](OpBuilderForStage &builder, Value tmaPtr,
[&](OpBuilderForStage &builder, Value desc,
Value barrier, Value view, Value pred) {
ttng::AsyncTMAGatherOp::create(
builder, gatherOp.getLoc(), tmaPtr,
builder, gatherOp.getLoc(), desc,
gatherOp.getXOffsets(), gatherOp.getYOffset(),
barrier, view, pred);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,9 @@ static void createTMAAsyncCopy(scf::ForOp forOp, const TMAStore &store,
ttng::FenceAsyncSharedOp::create(builder, loc, false);
auto desc = store.desc;
if (auto storeOp = dyn_cast<tt::DescriptorStoreOp>(store.op)) {
auto indices = ttng::translateTMAIndices(
builder, storeOp.getLoc(),
storeOp.getDesc().getType().getBlockType().getEncoding(),
storeOp.getIndices());
ttng::AsyncTMACopyLocalToGlobalOp::create(builder, loc, desc,
storeOp.getIndices(), alloc);
} else if (auto reduceOp = dyn_cast<tt::DescriptorReduceOp>(store.op)) {
auto indices = ttng::translateTMAIndices(
builder, reduceOp.getLoc(),
reduceOp.getDesc().getType().getBlockType().getEncoding(),
reduceOp.getIndices());
ttng::AsyncTMAReduceOp::create(builder, loc, reduceOp.getKind(), desc,
reduceOp.getIndices(), alloc);
} else {
Expand Down
32 changes: 12 additions & 20 deletions lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,11 @@ class TMALoadLowering : public OpRewritePattern<DescriptorLoadOp> {
LogicalResult matchAndRewrite(DescriptorLoadOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto createLoad = [&](Value tmaPtr, Value barrierAlloc, Value alloc,
auto createLoad = [&](Value desc, Value barrierAlloc, Value alloc,
Value pred) {
auto indices = translateTMAIndices(
rewriter, op.getLoc(),
op.getDesc().getType().getBlockType().getEncoding(), op.getIndices());
triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp::create(
rewriter, op.getLoc(), tmaPtr, indices, barrierAlloc, alloc, pred);
rewriter, op.getLoc(), desc, op.getIndices(), barrierAlloc, alloc,
pred);
};
lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter);
return success();
Expand All @@ -86,10 +84,10 @@ struct TMAGatherLowering : public OpRewritePattern<DescriptorGatherOp> {

LogicalResult matchAndRewrite(DescriptorGatherOp op,
PatternRewriter &rewriter) const override {
auto createLoad = [&](Value tmaPtr, Value barrierAlloc, Value alloc,
auto createLoad = [&](Value desc, Value barrierAlloc, Value alloc,
Value pred) {
triton::nvidia_gpu::AsyncTMAGatherOp::create(
rewriter, op.getLoc(), tmaPtr, op.getXOffsets(), op.getYOffset(),
rewriter, op.getLoc(), desc, op.getXOffsets(), op.getYOffset(),
barrierAlloc, alloc, pred);
};
lowerTMALoad(op, op.getType(), op.getDesc(), createLoad, rewriter);
Expand Down Expand Up @@ -122,12 +120,9 @@ struct TMAStoreLowering : public OpRewritePattern<DescriptorStoreOp> {

LogicalResult matchAndRewrite(DescriptorStoreOp op,
PatternRewriter &rewriter) const override {
auto createStore = [&](Value tmaPtr, Value alloc) {
auto indices = translateTMAIndices(
rewriter, op.getLoc(),
op.getDesc().getType().getBlockType().getEncoding(), op.getIndices());
auto createStore = [&](Value desc, Value alloc) {
triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp::create(
rewriter, op.getLoc(), tmaPtr, indices, alloc);
rewriter, op.getLoc(), desc, op.getIndices(), alloc);
};
lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter);
return success();
Expand All @@ -139,12 +134,9 @@ struct TMAReduceLowering : public OpRewritePattern<DescriptorReduceOp> {

LogicalResult matchAndRewrite(DescriptorReduceOp op,
PatternRewriter &rewriter) const override {
auto createStore = [&](Value tmaPtr, Value alloc) {
auto indices = translateTMAIndices(
rewriter, op.getLoc(),
op.getDesc().getType().getBlockType().getEncoding(), op.getIndices());
auto createStore = [&](Value desc, Value alloc) {
triton::nvidia_gpu::AsyncTMAReduceOp::create(
rewriter, op.getLoc(), op.getKind(), tmaPtr, indices, alloc);
rewriter, op.getLoc(), op.getKind(), desc, op.getIndices(), alloc);
};
lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter);
return success();
Expand All @@ -156,9 +148,9 @@ struct TMAScatterLowering : public OpRewritePattern<DescriptorScatterOp> {

LogicalResult matchAndRewrite(DescriptorScatterOp op,
PatternRewriter &rewriter) const override {
auto createStore = [&](Value tmaPtr, Value alloc) {
triton::nvidia_gpu::AsyncTMAScatterOp::create(rewriter, op.getLoc(),
tmaPtr, op.getXOffsets(),
auto createStore = [&](Value desc, Value alloc) {
triton::nvidia_gpu::AsyncTMAScatterOp::create(rewriter, op.getLoc(), desc,
op.getXOffsets(),
op.getYOffset(), alloc);
};
lowerTMAStore(op, op.getSrc(), op.getDesc(), createStore, rewriter);
Expand Down
10 changes: 0 additions & 10 deletions lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,6 @@ namespace ttg = mlir::triton::gpu;

namespace mlir::triton::nvidia_gpu {

SmallVector<Value> translateTMAIndices(OpBuilder &builder, Location loc,
Attribute encoding,
SmallVector<Value> indices) {
if (isFp4Padded(encoding)) {
auto two = arith::ConstantIntOp::create(builder, loc, 2, 32);
indices.back() = arith::MulIOp::create(builder, loc, indices.back(), two);
}
return indices;
}

ttg::CGAEncodingAttr updateCGALayoutForShape(ttg::CGAEncodingAttr cgaLayout,
ArrayRef<int64_t> shape) {
auto rank = shape.size();
Expand Down
36 changes: 4 additions & 32 deletions python/tutorials/gluon/11-tcgen05-mma-scaled.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,18 +173,6 @@ def simple_mma_scaled_kernel(a_desc, b_desc, c_desc, a_scale_ptr, a_scale_stride
off_k_a = k // A_ELEM_PER_BYTE
off_k_b = k // B_ELEM_PER_BYTE

# When issuing a TMA transaction to TMA tensor descriptors with fp4 padded operands, we need to multiply
# the offset along the contiguous dimension by 2 to account for the padding. This applies to async TMA
# loads, stores, gather, and scatter. Failing to do this can result in illegal instruction errors. If you
# catch the illegal instruction error inside `cuda-gdb`, it may point to the TMA instruction or the
# `mbarrier.wait` on the instruction completion barrier. When breaking on the illegal instruction error,
# you can use `x/i $pc` to print the instruction at the faulting address, and for example use `x/-50i $pc`
# to print the previous 50 instructions.
if a_desc.layout.fp4_padded:
off_k_a *= 2
if b_desc.layout.fp4_padded:
off_k_b *= 2

# Load the A and B tiles.
mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes)
tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], bar, a_smem)
Expand Down Expand Up @@ -495,10 +483,6 @@ def mma_scaled_contig_kernel(a_desc, b_desc, c_desc, a_scale_ptr, b_scale_ptr, V
for k in range(0, K, BLOCK_K):
off_k_a = k // A_ELEM_PER_BYTE
off_k_b = k // B_ELEM_PER_BYTE
if a_desc.layout.fp4_padded:
off_k_a *= 2
if b_desc.layout.fp4_padded:
off_k_b *= 2

mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes)
tma.async_copy_global_to_shared(a_desc, [off_m, off_k_a], bar, a_smem)
Expand Down Expand Up @@ -741,13 +725,9 @@ def mma_scaled_packed_block_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale
for k in range(0, K, BLOCK_K):
off_k_a = k // A_ELEM_PER_BYTE
off_k_b = k // B_ELEM_PER_BYTE
if a_desc.layout.fp4_padded:
off_k_a *= 2
if b_desc.layout.fp4_padded:
off_k_b *= 2
# Index the K subtile along REP_K for each scale.
off_k_a_scale = k // BLOCK_K * A_REP_K
off_k_b_scale = k // BLOCK_K * B_REP_K
off_k_a_scale = (k // BLOCK_K) * A_REP_K
off_k_b_scale = (k // BLOCK_K) * B_REP_K

mbarrier.expect(
bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + a_scale_desc.block_type.nbytes +
Expand Down Expand Up @@ -1029,12 +1009,8 @@ def mma_scaled_tcgen05_copy_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale
for k in range(0, K, BLOCK_K):
off_k_a = k // A_ELEM_PER_BYTE
off_k_b = k // B_ELEM_PER_BYTE
if a_desc.layout.fp4_padded:
off_k_a *= 2
if b_desc.layout.fp4_padded:
off_k_b *= 2
off_k_a_scale = k // BLOCK_K * A_REP_K
off_k_b_scale = k // BLOCK_K * B_REP_K
off_k_a_scale = (k // BLOCK_K) * A_REP_K
off_k_b_scale = (k // BLOCK_K) * B_REP_K

mbarrier.expect(
bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + a_scale_desc.block_type.nbytes +
Expand Down Expand Up @@ -1213,10 +1189,6 @@ def issue_loads(producer, pid_m, pid_n, k, a_desc, b_desc, a_scale_desc, b_scale
off_n_b_scale = pid_n * REP_N
off_k_a = k // A_ELEM_PER_BYTE
off_k_b = k // B_ELEM_PER_BYTE
if a_desc.layout.fp4_padded:
off_k_a *= 2
if b_desc.layout.fp4_padded:
off_k_b *= 2
off_k_a_scale = (k // BLOCK_K) * A_REP_K
off_k_b_scale = (k // BLOCK_K) * B_REP_K

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ Operation *optimizeTMALoads(OpBuilderWithAsyncTaskIds &builder,
builder.setInsertionPoint(tmaLoad);
auto pipelineBuffer = getBufferForPipelineStage(builder, tmaLoad.getType(),
buffer, bufferIdx, true);
// FIXME: translateTMAIndices
copy = builder.createWithAsyncTaskIds<ttng::AsyncTMACopyGlobalToLocalOp>(
loc, tmaLoad.getDesc(), tmaLoad.getIndices(), prodBarrier,
pipelineBuffer, pred);
Expand Down
21 changes: 1 addition & 20 deletions third_party/nvidia/lib/Dialect/NVWS/Transforms/LowerAref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,27 +286,8 @@ getSubViews(ArefValue arefVal, Value stage, Location loc, OpBuilder &rewriter,

void createTMALoad(triton::nvws::DescriptorLoadOp op, PatternRewriter &rewriter,
Value barrierAlloc, Value pred) {
auto indices = translateTMAIndices(
rewriter, op.getLoc(),
op.getDesc().getType().getBlockType().getEncoding(), op.getIndices());
for (auto [newIdx, oldIdx] : llvm::zip(indices, op.getIndices())) {
// translateTMAIndices may create ops, we need to annotated them
if (newIdx != oldIdx) {
auto partitionIds = getPartitionWsTagIds(op);
auto stageCluster = getStageCluster(op);
assignStageCluster(newIdx.getDefiningOp(), partitionIds, stageCluster,
rewriter);
for (auto val : newIdx.getDefiningOp()->getOperands()) {
if (auto op = val.getDefiningOp()) {
if (!hasPartition(op)) {
assignStageCluster(op, partitionIds, stageCluster, rewriter);
}
}
}
}
}
auto newLoadOp = triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp::create(
rewriter, op.getLoc(), op.getDesc(), indices, barrierAlloc,
rewriter, op.getLoc(), op.getDesc(), op.getIndices(), barrierAlloc,
op.getResult(), pred);
assignStageCluster(newLoadOp, getPartitionWsTagIds(op), getStageCluster(op),
rewriter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1412,10 +1412,16 @@ struct AsyncTMACopyGlobalToLocalOpConversion
auto offsets = applyLinearLayout(loc, rewriter, msgToOffset,
{{kMsg, copyIdxVal}, {kBlock, ctaId}});
int operandIdx = 3;
auto encoding = op.getDesc().getType().getBlockType().getEncoding();
bool fp4Padded = nvidia_gpu::isFp4Padded(encoding);
for (int i = 0; i < rank; i++) {
Value coord = adaptor.getCoord()[rank - i - 1];
if (fp4Padded && i == 0) {
coord = b.mul(coord, b.i32_val(2));
}
if (i < offsets.size())
coord = b.add(coord, offsets[offsets.size() - i - 1].second);

operands.push_back(ptxBuilderTMA.newOperand(coord, "r"));
tmaInst += "$" + std::to_string(operandIdx++);
if (i != rank - 1)
Expand Down Expand Up @@ -1496,8 +1502,12 @@ LogicalResult convertTMAStoreLikeOp(Operation *op,

auto offsets = applyLinearLayout(loc, rewriter, msgToOffset,
{{kMsg, copyIdxVal}, {kBlock, ctaId}});
bool fp4Padded = nvidia_gpu::isFp4Padded(srcTy.getEncoding());
for (int i = 0; i < rank; i++) {
Value coord = coords[rank - i - 1];
if (fp4Padded && i == 0) {
coord = b.mul(coord, b.i32_val(2));
}
if (i < offsets.size())
coord = b.add(coord, offsets[offsets.size() - i - 1].second);
operands.push_back(ptxBuilderTMA.newOperand(coord, "r"));
Expand Down Expand Up @@ -1623,8 +1633,11 @@ static LogicalResult iterateGatherScatterIndices(
return op->emitError("memdesc shape must match alloc shape");
// `NVMMASharedEncodingAttr` means the core matrix tiles are placed next to
// each other in shared memory, which lines up with how `gather4` loads data.
if (!isa<NVMMASharedEncodingAttr>(smemType.getEncoding()))
auto enc = dyn_cast<NVMMASharedEncodingAttr>(smemType.getEncoding());
if (!enc)
return op->emitError("requires dst encoding NVMMASharedEncodingAttr");
if (enc.getFp4Padded())
yOffsetValue = b.mul(yOffsetValue, b.i32_val(2));
Type llvmElemTy = typeConverter.convertType(smemType.getElementType());
Type elemPtrTy = ptr_ty(ctx, /*addrspace=*/3);
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, smemObjValue,
Expand Down
Loading