diff --git a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp index c9068da03e12..047c51b9dca3 100644 --- a/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp @@ -95,6 +95,11 @@ struct ScratchInfo { RankedTensorType tensorType; }; +struct ScratchState { + std::optional canonical; + DenseMap byScope; +}; + class TmemScratchManager { public: ttg::BlockedEncodingAttr getScratchEncoding(PatternRewriter &rewriter, @@ -143,14 +148,19 @@ class TmemScratchManager { } if (auto alloc = memdesc.getDefiningOp()) { - auto it = scratchMap.find(memdesc); - if (it != scratchMap.end()) { - auto itRegion = it->second.find(scope); - if (itRegion != it->second.end()) { - if (itRegion->second.ptr && itRegion->second.ptr.getType()) - return itRegion->second; - it->second.erase(itRegion); - } + ScratchState &state = scratchMap[memdesc]; + auto itRegion = state.byScope.find(scope); + if (itRegion != state.byScope.end()) { + if (itRegion->second.ptr && itRegion->second.ptr.getType()) + return itRegion->second; + state.byScope.erase(itRegion); + } + if (state.canonical) { + Value ptr = + remapToScope(state.canonical->ptr, rewriter, scope, alloc.getLoc()); + ScratchInfo info{ptr, state.canonical->tensorType}; + state.byScope[scope] = info; + return info; } OpBuilder::InsertionGuard guard(rewriter); @@ -176,10 +186,11 @@ class TmemScratchManager { return std::nullopt; } - ptr = remapToScope(ptr, rewriter, scope, loc); + state.canonical = ScratchInfo{ptr, tensorTy}; + ptr = remapToScope(ptr, rewriter, scope, loc); ScratchInfo info{ptr, tensorTy}; - scratchMap[memdesc][scope] = info; + state.byScope[scope] = info; return info; } @@ -303,7 +314,7 @@ class TmemScratchManager { return scope->getArgument(captureIdx); } - DenseMap> scratchMap; + DenseMap scratchMap; }; Value createScratchAndStore(PatternRewriter &rewriter, Location loc, Value val, diff --git a/python/test/gluon/test_fpsan.py b/python/test/gluon/test_fpsan.py index 3b73ab30a99d..8fc432b8b180 100644 --- a/python/test/gluon/test_fpsan.py +++ b/python/test/gluon/test_fpsan.py @@ -1609,6 +1609,53 @@ def kernel(x_ptr, out_ptr): _assert_payload_equal(out, exp_bits) +@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell") +def test_tmem_store_in_warp_specialize_partition_visible_to_parent(device, fresh_knobs): + _require_cuda_backend(device) + + B = 64 + BLOCK = gl.constexpr(B) + + fresh_knobs.compilation.instrumentation_mode = "fpsan" + + @gluon.jit + def store_one_partition(tmem): + reg_layout: gl.constexpr = tmem.get_reg_layout() + one = gl.full((BLOCK, BLOCK), 1.0, gl.float32, reg_layout) + tmem.store(one) + + @gluon.jit + def default_partition(): + pass + + @gluon.jit + def kernel(out_ptr): + layout: gl.constexpr = gl.BlockedLayout([1, 1], [32, 1], [gl.num_warps(), 1], [1, 0]) + offs_m = gl.arange(0, BLOCK, layout=gl.SliceLayout(1, layout))[:, None] + offs_n = gl.arange(0, BLOCK, layout=gl.SliceLayout(0, layout))[None, :] + offs = offs_m * BLOCK + offs_n + + tmem_layout: gl.constexpr = TensorMemoryLayout((BLOCK, BLOCK), col_stride=1) + tmem = allocate_tensor_memory(gl.float32, [BLOCK, BLOCK], layout=tmem_layout) + reg_layout: gl.constexpr = tmem.get_reg_layout() + zero = gl.full((BLOCK, BLOCK), 0.0, gl.float32, reg_layout) + tmem.store(zero) + + gl.warp_specialize([ + (default_partition, ()), + (store_one_partition, (tmem, )), + ], [4], [32]) + + out = tmem.load() + out = gl.convert_layout(out, layout) + gl.store(out_ptr + offs, out) + + out = torch.empty((B, B), device=device, dtype=torch.float32) + kernel[(1, )](out, num_warps=4) + + torch.testing.assert_close(out, torch.ones_like(out), rtol=0, atol=0) + + def test_reduction(device, fresh_knobs): _require_cuda_backend(device)