diff --git a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp index 784ef92ef176..445ba45c4b5f 100644 --- a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp @@ -45,6 +45,11 @@ static bool isValueAvailableInScope(Value value, Region *scope) { constexpr int64_t kTileM = 8; constexpr int64_t kTileN = 8; +void createGlobalScratchBarrier(PatternRewriter &rewriter, Location loc) { + ttg::BarrierOp::create( + rewriter, loc, ttg::AddrSpace::GlobalRead | ttg::AddrSpace::GlobalWrite); +} + enum class UnaryOpId : uint64_t { Exp = 0, Log, @@ -1319,6 +1324,7 @@ std::optional emitMmaEmulationLoops( Value accMasked = arith::MulIOp::create(rewriter, loc, accTileI, predInv); Value outSelI = arith::AddIOp::create(rewriter, loc, outMasked, accMasked); Value out = unembedToFloat(rewriter, loc, outSelI, accTileTy); + createGlobalScratchBarrier(rewriter, loc); storeScratchStrided2D(rewriter, loc, dTilePtr, out, accTileTy, dStride); return mLoop; @@ -1612,9 +1618,7 @@ struct DotPattern : public OpRewritePattern { // Each warp may only store a subset of each tile's rows, so a barrier is // needed to make all scratch stores visible before the loops read them. - ttg::BarrierOp::create(rewriter, loc, - ttg::AddrSpace::GlobalRead | - ttg::AddrSpace::GlobalWrite); + createGlobalScratchBarrier(rewriter, loc); auto mLoop = emitMmaEmulationLoops( rewriter, loc, aPtr, bPtr, dPtr, m, n, k, tileM, tileN, aTileTy, @@ -1626,9 +1630,7 @@ struct DotPattern : public OpRewritePattern { // Same reason: each warp may only write a subset of D's rows in the loop, // so synchronize before the final load. - ttg::BarrierOp::create(rewriter, loc, - ttg::AddrSpace::GlobalRead | - ttg::AddrSpace::GlobalWrite); + createGlobalScratchBarrier(rewriter, loc); Value out = loadScratchStrided2D(rewriter, loc, dPtr, cTy, /*stride1=*/m); if (!out) @@ -1749,9 +1751,7 @@ struct DotScaledPattern : public OpRewritePattern { {1, tileN}, bScaleTy.getElementType(), accLayout); } - ttg::BarrierOp::create(rewriter, loc, - ttg::AddrSpace::GlobalRead | - ttg::AddrSpace::GlobalWrite); + createGlobalScratchBarrier(rewriter, loc); auto mLoop = emitMmaEmulationLoops( rewriter, loc, aPtr, bPtr, dPtr, m, n, k, tileM, tileN, aTileTy, @@ -1761,9 +1761,7 @@ struct DotScaledPattern : public OpRewritePattern { return failure(); rewriter.setInsertionPointAfter(*mLoop); - ttg::BarrierOp::create(rewriter, loc, - ttg::AddrSpace::GlobalRead | - ttg::AddrSpace::GlobalWrite); + createGlobalScratchBarrier(rewriter, loc); Value out = loadScratchStrided2D(rewriter, loc, dPtr, cTy, /*stride1=*/m); if (!out) @@ -1793,6 +1791,8 @@ struct TMEMLoadPattern : public OpRewritePattern { if (!result) return failure(); + createGlobalScratchBarrier(rewriter, loc); + if (op.getNumResults() == 1) { rewriter.replaceOp(op, result); return success(); @@ -1827,6 +1827,8 @@ struct TMEMStorePattern : public OpRewritePattern { if (!createStoreScratchMemory(rewriter, loc, info->ptr, op.getSrc(), srcTy)) return failure(); + createGlobalScratchBarrier(rewriter, loc); + if (op.getNumResults() == 0) { rewriter.eraseOp(op); return success(); @@ -1867,6 +1869,8 @@ struct TMEMCopyPattern : public OpRewritePattern { if (!createStoreScratchMemory(rewriter, loc, info->ptr, srcReg, srcRegTy)) return failure(); + createGlobalScratchBarrier(rewriter, loc); + rewriter.eraseOp(op); return success(); } @@ -1961,9 +1965,7 @@ struct TCGen5MMAPattern : public OpRewritePattern { // Each warp may only populate a subset of the operand scratch tiles, so // synchronize before the emulation loops start reading them. - ttg::BarrierOp::create(rewriter, loc, - ttg::AddrSpace::GlobalRead | - ttg::AddrSpace::GlobalWrite); + createGlobalScratchBarrier(rewriter, loc); auto mLoop = emitMmaEmulationLoops( rewriter, loc, aScratch->ptr, bScratch->ptr, dInfo->ptr, m, n, k, tileM, @@ -2145,9 +2147,7 @@ struct TCGen5MMAScaledPattern // The operand and scale scratch buffers are written cooperatively, so all // warps must finish those stores before the emulation loop reads them. - ttg::BarrierOp::create(rewriter, loc, - ttg::AddrSpace::GlobalRead | - ttg::AddrSpace::GlobalWrite); + createGlobalScratchBarrier(rewriter, loc); auto mLoop = emitMmaEmulationLoops( rewriter, loc, aScratch->ptr, bScratch->ptr, dInfo->ptr, m, n, k, tileM, diff --git a/test/TritonGPU/nvidia-fpsan.mlir b/test/TritonGPU/nvidia-fpsan.mlir index 4f5ac4382d7d..54ff38886001 100644 --- a/test/TritonGPU/nvidia-fpsan.mlir +++ b/test/TritonGPU/nvidia-fpsan.mlir @@ -77,6 +77,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: ttg.barrier global_read|global_write // CHECK-NEXT: scf.for // CHECK: ttg.barrier global_read|global_write + // CHECK: tt.store + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: ttg.barrier global_read|global_write // CHECK-NEXT: ttng.arrive_barrier // CHECK-NOT: ttng.tc_gen5_mma %true = arith.constant true @@ -102,6 +106,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: ttg.barrier global_read|global_write // CHECK-NEXT: scf.for // CHECK: ttg.barrier global_read|global_write + // CHECK: tt.store + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: ttg.barrier global_read|global_write // CHECK-NEXT: ttng.arrive_barrier // CHECK-NOT: ttng.tc_gen5_mma %false = arith.constant false @@ -129,6 +137,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // CHECK: ttg.barrier global_read|global_write // CHECK-NEXT: scf.for // CHECK: ttg.barrier global_read|global_write + // CHECK: tt.store + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: ttg.barrier global_read|global_write // CHECK-NEXT: ttng.arrive_barrier // CHECK-NOT: ttng.tc_gen5_mma_scaled %true = arith.constant true