Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ static bool isValueAvailableInScope(Value value, Region *scope) {
constexpr int64_t kTileM = 8;
constexpr int64_t kTileN = 8;

void createGlobalScratchBarrier(PatternRewriter &rewriter, Location loc) {
ttg::BarrierOp::create(
rewriter, loc, ttg::AddrSpace::GlobalRead | ttg::AddrSpace::GlobalWrite);
Comment on lines +49 to +50
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. this GlobalRead | GlobalWrite does not do anything on nvidia so perhaps better drop it and just use Local?

Copy link
Copy Markdown
Contributor Author

@pawelszczerbuk pawelszczerbuk Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, otherwise it is getting confusing.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, no, we do need these for AMD. This is the common code path

}

enum class UnaryOpId : uint64_t {
Exp = 0,
Log,
Expand Down Expand Up @@ -1319,6 +1324,7 @@ std::optional<scf::ForOp> emitMmaEmulationLoops(
Value accMasked = arith::MulIOp::create(rewriter, loc, accTileI, predInv);
Value outSelI = arith::AddIOp::create(rewriter, loc, outMasked, accMasked);
Value out = unembedToFloat(rewriter, loc, outSelI, accTileTy);
createGlobalScratchBarrier(rewriter, loc);
storeScratchStrided2D(rewriter, loc, dTilePtr, out, accTileTy, dStride);

return mLoop;
Expand Down Expand Up @@ -1612,9 +1618,7 @@ struct DotPattern : public OpRewritePattern<tt::DotOp> {

// Each warp may only store a subset of each tile's rows, so a barrier is
// needed to make all scratch stores visible before the loops read them.
ttg::BarrierOp::create(rewriter, loc,
ttg::AddrSpace::GlobalRead |
ttg::AddrSpace::GlobalWrite);
createGlobalScratchBarrier(rewriter, loc);

auto mLoop = emitMmaEmulationLoops(
rewriter, loc, aPtr, bPtr, dPtr, m, n, k, tileM, tileN, aTileTy,
Expand All @@ -1626,9 +1630,7 @@ struct DotPattern : public OpRewritePattern<tt::DotOp> {

// Same reason: each warp may only write a subset of D's rows in the loop,
// so synchronize before the final load.
ttg::BarrierOp::create(rewriter, loc,
ttg::AddrSpace::GlobalRead |
ttg::AddrSpace::GlobalWrite);
createGlobalScratchBarrier(rewriter, loc);

Value out = loadScratchStrided2D(rewriter, loc, dPtr, cTy, /*stride1=*/m);
if (!out)
Expand Down Expand Up @@ -1749,9 +1751,7 @@ struct DotScaledPattern : public OpRewritePattern<tt::DotScaledOp> {
{1, tileN}, bScaleTy.getElementType(), accLayout);
}

ttg::BarrierOp::create(rewriter, loc,
ttg::AddrSpace::GlobalRead |
ttg::AddrSpace::GlobalWrite);
createGlobalScratchBarrier(rewriter, loc);

auto mLoop = emitMmaEmulationLoops(
rewriter, loc, aPtr, bPtr, dPtr, m, n, k, tileM, tileN, aTileTy,
Expand All @@ -1761,9 +1761,7 @@ struct DotScaledPattern : public OpRewritePattern<tt::DotScaledOp> {
return failure();
rewriter.setInsertionPointAfter(*mLoop);

ttg::BarrierOp::create(rewriter, loc,
ttg::AddrSpace::GlobalRead |
ttg::AddrSpace::GlobalWrite);
createGlobalScratchBarrier(rewriter, loc);

Value out = loadScratchStrided2D(rewriter, loc, dPtr, cTy, /*stride1=*/m);
if (!out)
Expand Down Expand Up @@ -1793,6 +1791,8 @@ struct TMEMLoadPattern : public OpRewritePattern<ttng::TMEMLoadOp> {
if (!result)
return failure();

createGlobalScratchBarrier(rewriter, loc);

if (op.getNumResults() == 1) {
rewriter.replaceOp(op, result);
return success();
Expand Down Expand Up @@ -1827,6 +1827,8 @@ struct TMEMStorePattern : public OpRewritePattern<ttng::TMEMStoreOp> {
if (!createStoreScratchMemory(rewriter, loc, info->ptr, op.getSrc(), srcTy))
return failure();

createGlobalScratchBarrier(rewriter, loc);

if (op.getNumResults() == 0) {
rewriter.eraseOp(op);
return success();
Expand Down Expand Up @@ -1867,6 +1869,8 @@ struct TMEMCopyPattern : public OpRewritePattern<ttng::TMEMCopyOp> {
if (!createStoreScratchMemory(rewriter, loc, info->ptr, srcReg, srcRegTy))
return failure();

createGlobalScratchBarrier(rewriter, loc);

rewriter.eraseOp(op);
return success();
}
Expand Down Expand Up @@ -1961,9 +1965,7 @@ struct TCGen5MMAPattern : public OpRewritePattern<ttng::TCGen5MMAOp> {

// Each warp may only populate a subset of the operand scratch tiles, so
// synchronize before the emulation loops start reading them.
ttg::BarrierOp::create(rewriter, loc,
ttg::AddrSpace::GlobalRead |
ttg::AddrSpace::GlobalWrite);
createGlobalScratchBarrier(rewriter, loc);

auto mLoop = emitMmaEmulationLoops(
rewriter, loc, aScratch->ptr, bScratch->ptr, dInfo->ptr, m, n, k, tileM,
Expand Down Expand Up @@ -2145,9 +2147,7 @@ struct TCGen5MMAScaledPattern

// The operand and scale scratch buffers are written cooperatively, so all
// warps must finish those stores before the emulation loop reads them.
ttg::BarrierOp::create(rewriter, loc,
ttg::AddrSpace::GlobalRead |
ttg::AddrSpace::GlobalWrite);
createGlobalScratchBarrier(rewriter, loc);

auto mLoop = emitMmaEmulationLoops(
rewriter, loc, aScratch->ptr, bScratch->ptr, dInfo->ptr, m, n, k, tileM,
Expand Down
12 changes: 12 additions & 0 deletions test/TritonGPU/nvidia-fpsan.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
// CHECK: ttg.barrier global_read|global_write
// CHECK-NEXT: scf.for
// CHECK: ttg.barrier global_read|global_write
// CHECK: tt.store
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: ttg.barrier global_read|global_write
// CHECK-NEXT: ttng.arrive_barrier
// CHECK-NOT: ttng.tc_gen5_mma
%true = arith.constant true
Expand All @@ -102,6 +106,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
// CHECK: ttg.barrier global_read|global_write
// CHECK-NEXT: scf.for
// CHECK: ttg.barrier global_read|global_write
// CHECK: tt.store
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: ttg.barrier global_read|global_write
// CHECK-NEXT: ttng.arrive_barrier
// CHECK-NOT: ttng.tc_gen5_mma
%false = arith.constant false
Expand Down Expand Up @@ -129,6 +137,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
// CHECK: ttg.barrier global_read|global_write
// CHECK-NEXT: scf.for
// CHECK: ttg.barrier global_read|global_write
// CHECK: tt.store
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: ttg.barrier global_read|global_write
// CHECK-NEXT: ttng.arrive_barrier
// CHECK-NOT: ttng.tc_gen5_mma_scaled
%true = arith.constant true
Expand Down
Loading