From 2e9fbdb188400e8f4f89b267176442fc5b0b2e20 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 29 Apr 2026 10:33:06 +0200 Subject: [PATCH 1/4] [CONSAN] Add read before any write check We add a check that a buffer had been written before its first read. This fixes a multicast repro that had been flaky for sometime. Note that this is analogous to the "see the previous write" race conditions we might find if mbarriers are not correct only that this one catches when races when there was no previous write. --- .../TritonInstrument/IR/FunctionBuilder.h | 6 +-- .../TritonInstrument/IR/FunctionBuilder.cpp | 50 +++++++++++++++---- .../Transforms/ConcurrencySanitizer.cpp | 11 ++-- python/test/gluon/test_consan.py | 35 +++++++++++-- 4 files changed, 82 insertions(+), 20 deletions(-) diff --git a/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h b/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h index e383c185fe1b..e4520e52d84e 100644 --- a/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h +++ b/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h @@ -199,13 +199,13 @@ class FunctionBuilder { void createTransferVisibleReadsCall(ImplicitLocOpBuilder &b, Value mbar, uint64_t threadMask, Value pred, MemType memType, Operation *insertPoint); - // verifyWriteVisibility: ensure the thread either sees the latest write or no - // other thread is writing the buffer. + // verifyWriteVisibility: ensure the thread sees the latest write. When + // allowNoWrite is true, also allow rows that have not been written yet. void createVerifyWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, StringRef operandName, Value pred, MemType memType, Operation *insertPoint, - Value recipientCTAs); + Value recipientCTAs, bool allowNoWrite); // verifyReadVisibility: ensure all reads from the buffer are visible to the // thread. void createVerifyReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf, diff --git a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp index db3ace2f1df6..b5b17dfbcf76 100644 --- a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp +++ b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp @@ -2144,7 +2144,7 @@ void FunctionBuilder::createTransferVisibleReadsCall( void FunctionBuilder::createVerifyWriteVisibilityCall( ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, StringRef operandName, Value pred, MemType memType, Operation *insertPoint, - Value recipientCTAs) { + Value recipientCTAs, bool allowNoWrite) { if (auxData.buffers[(int)memType].empty() || auxData.writeVisibility[(int)memType].empty() || (auxData.hasNonTrivialAliasing[(int)memType] && @@ -2166,12 +2166,16 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( std::string message = "Buffer being accessed has outstanding writes."; if (!operandName.empty()) message += " Operand: " + operandName.str(); + std::string uninitializedMessage = "Buffer being read before any write."; + if (!operandName.empty()) + uninitializedMessage += " Operand: " + operandName.str(); auto verifyWriteResultType = cast( writeVisibilityType.cloneWith(std::nullopt, b.getI1Type())); AssertInfo assertInfo{message, verifyWriteResultType}; Type aliasMatrixTypeBase; auto buildVerifyWriteBody = [&writeVisibilityType, &aliasMatrixTypeBase, - verifyWriteResultType](bool useAlias) { + verifyWriteResultType](bool useAlias, + bool allowNoWrite) { return [=](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value bufOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); @@ -2213,14 +2217,25 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( arith::AndIOp::create(fb, bufVisibility, bufferThreadBit); bufferHasVisibility = arith::CmpIOp::create( fb, arith::CmpIPredicate::eq, bufferHasVisibility, bufferThreadBit); - Value writeVisible = - arith::OrIOp::create(fb, noOneIsWriting, bufferHasVisibility); - Value allWritesVisible = reduceAll(fb, writeVisible); + Value result; + if (!allowNoWrite) { + Value one = tti::createConstIntTensor( + fb, fb.getLoc(), 1, cast(buffersEqBuf.getType())); + Value unmatchedRows = arith::XOrIOp::create(fb, buffersEqBuf, one); + Value rowInitialized = arith::XOrIOp::create(fb, noOneIsWriting, one); + Value initializedOrUnmatched = + arith::OrIOp::create(fb, rowInitialized, unmatchedRows); + result = reduceAll(fb, initializedOrUnmatched); + } else { + Value writeVisible = + arith::OrIOp::create(fb, noOneIsWriting, bufferHasVisibility); + result = reduceAll(fb, writeVisible); + } Value vTrue = arith::ConstantOp::create( - fb, allWritesVisible.getType(), fb.getIntegerAttr(fb.getI1Type(), 1)); + fb, result.getType(), fb.getIntegerAttr(fb.getI1Type(), 1)); Value predicatedWriteVisible = - arith::SelectOp::create(fb, pred, allWritesVisible, vTrue); + arith::SelectOp::create(fb, pred, result, vTrue); predicatedWriteVisible = triton::SplatOp::create( fb, verifyWriteResultType, predicatedWriteVisible); triton::ReturnOp::create(fb, predicatedWriteVisible); @@ -2235,18 +2250,35 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( SmallVector args = {bufOffset, lengthVal, pred, threadVal, buffersVal, writeVisibilityVal, recipientCTAs, aliasMatrixVal}; + if (!allowNoWrite) { + AssertInfo initializedAssertInfo{uninitializedMessage, + verifyWriteResultType}; + createCallToCachedFunction( + b, "verify_write_initialized", args, initializedAssertInfo, + {buffersType, writeVisibilityType, aliasMatrixType, + (uint64_t)memType}, + buildVerifyWriteBody(/*useAlias=*/true, /*allowNoWrite=*/false)); + } createCallToCachedFunction( b, "verify_write_visibility", args, assertInfo, {buffersType, writeVisibilityType, aliasMatrixType, (uint64_t)memType}, - buildVerifyWriteBody(/*useAlias=*/true)); + buildVerifyWriteBody(/*useAlias=*/true, /*allowNoWrite=*/true)); } else { SmallVector args = {bufOffset, lengthVal, pred, threadVal, buffersVal, writeVisibilityVal, recipientCTAs}; + if (!allowNoWrite) { + AssertInfo initializedAssertInfo{uninitializedMessage, + verifyWriteResultType}; + createCallToCachedFunction( + b, "verify_write_initialized_noalias", args, initializedAssertInfo, + {buffersType, writeVisibilityType, (uint64_t)memType}, + buildVerifyWriteBody(/*useAlias=*/false, /*allowNoWrite=*/false)); + } createCallToCachedFunction( b, "verify_write_visibility_noalias", args, assertInfo, {buffersType, writeVisibilityType, (uint64_t)memType}, - buildVerifyWriteBody(/*useAlias=*/false)); + buildVerifyWriteBody(/*useAlias=*/false, /*allowNoWrite=*/true)); } } diff --git a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp index 992e41185184..067130a94502 100644 --- a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp @@ -483,7 +483,7 @@ class ConcurrencySanitizerImpl { // is writing to the same buffer. addWriteChecks(b, funcBuilder, op, buf, effect.length, pred, memType, thread, effect.operandName, effectRecipientCTAs, - opInfo->commitKind); + /*allowNoWrite=*/false, opInfo->commitKind); if (opInfo->trackingKind == MemEffectsOpInfo::TrackingKind::Barrier) { funcBuilder.createSetReadVisibilityCall( b, buf, effect.length, getThreadPeersMask(thread), pred, memType, @@ -502,7 +502,7 @@ class ConcurrencySanitizerImpl { // is reading or writing to the same buffer. addWriteChecks(b, funcBuilder, op, buf, effect.length, pred, memType, thread, effect.operandName, effectRecipientCTAs, - opInfo->commitKind); + /*allowNoWrite=*/true, opInfo->commitKind); addReadChecks(b, funcBuilder, op, buf, effect.length, pred, memType, thread, effect.operandName, effectRecipientCTAs, opInfo->commitKind); @@ -578,10 +578,11 @@ class ConcurrencySanitizerImpl { tti::FunctionBuilder &funcBuilder, Operation *op, Value buf, uint32_t length, Value pred, MemType memType, int thread, const std::string &operandName, - Value recipientCTAs, + Value recipientCTAs, bool allowNoWrite, CommitKind::Kind opCommitKind = CommitKind::None) { - funcBuilder.createVerifyWriteVisibilityCall( - b, buf, length, thread, operandName, pred, memType, op, recipientCTAs); + funcBuilder.createVerifyWriteVisibilityCall(b, buf, length, thread, + operandName, pred, memType, op, + recipientCTAs, allowNoWrite); // commit-num-based synchronization is only supported for shared memory if (memType == MemType::SHARED_MEM) { for (const auto &commitKindDesc : diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 5e9baee683cc..bd52c3c4a5ba 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -203,18 +203,47 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr): kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") +def test_local_load_before_first_write(device, run_wrapper, monkeypatch, num_ctas): + if run_wrapper: + result = run_in_process(test_local_load_before_first_write, (device, False, monkeypatch, num_ctas)) + assert_expected_cuda_failure(result.exc) + assert "Buffer being read before any write" in result.driver_stderr_output + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + @gluon.jit + def kernel(output): + block_m: ttgl.constexpr = XBLOCK * ttgl.num_ctas() + cga_layout: ttgl.constexpr = default_cga_layout(ttgl.num_ctas(), 2) + blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout) + smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2, + cga_layout=cga_layout) + smem = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], smem_layout) + val = smem.load(blocked_layout) + offs_m = ttgl.arange(0, block_m, ttgl.SliceLayout(1, blocked_layout))[:, None] + offs_n = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(0, blocked_layout))[None, :] + ttgl.store(output + offs_m * XBLOCK + offs_n, val) + + output = torch.empty((XBLOCK.value * num_ctas, XBLOCK.value), device=device, dtype=torch.float16) + kernel[(1, )](output, num_warps=4, num_ctas=num_ctas) + + @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") @pytest.mark.parametrize("FAILURE", [True, False]) def test_async_tma_multicast_kernel(FAILURE, device, run_wrapper, monkeypatch, num_ctas): if num_ctas == 1: pytest.skip("Need at least 2 CTAs for multicast in this test") - if FAILURE and num_ctas == 4: - pytest.skip("Temporarily disabled: flaky with 4 CTAs when FAILURE=True") if run_wrapper: result = run_in_process(test_async_tma_multicast_kernel, (FAILURE, device, False, monkeypatch, num_ctas)) if FAILURE: assert_expected_cuda_failure(result.exc) - assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output + assert ("Buffer being read before any write" in result.driver_stderr_output + or "Buffer being accessed has outstanding writes" in result.driver_stderr_output) else: assert result.exc is None assert result.driver_stderr_output == "" From daffe6bde6965fd82dd3eb9cfe03d3fd04f63466 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 29 Apr 2026 14:02:39 +0200 Subject: [PATCH 2/4] Initialise tests that were not being initialised. Fix the logic --- .../TritonInstrument/IR/FunctionBuilder.cpp | 18 ++- python/test/gluon/test_consan.py | 113 ++++++++++++++---- 2 files changed, 103 insertions(+), 28 deletions(-) diff --git a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp index b5b17dfbcf76..9a131b5111bc 100644 --- a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp +++ b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp @@ -2219,12 +2219,22 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( fb, arith::CmpIPredicate::eq, bufferHasVisibility, bufferThreadBit); Value result; if (!allowNoWrite) { - Value one = tti::createConstIntTensor( + Value rowOne = tti::createConstIntTensor( fb, fb.getLoc(), 1, cast(buffersEqBuf.getType())); - Value unmatchedRows = arith::XOrIOp::create(fb, buffersEqBuf, one); - Value rowInitialized = arith::XOrIOp::create(fb, noOneIsWriting, one); + Value rowInitialized = + arith::XOrIOp::create(fb, noOneIsWriting, rowOne); + Value initializedRows = + arith::AndIOp::create(fb, rowInitialized, buffersEqBuf); + // Alias rows are alternatives within a CTA, but every selected CTA must + // have at least one initialized row. + Value initializedCTAs = + reduceLastDim(fb, initializedRows); + Value selectedCTAs = reduceLastDim(fb, buffersEqBuf); + Value ctaOne = tti::createConstIntTensor( + fb, fb.getLoc(), 1, cast(selectedCTAs.getType())); + Value unmatchedCTAs = arith::XOrIOp::create(fb, selectedCTAs, ctaOne); Value initializedOrUnmatched = - arith::OrIOp::create(fb, rowInitialized, unmatchedRows); + arith::OrIOp::create(fb, initializedCTAs, unmatchedCTAs); result = reduceAll(fb, initializedOrUnmatched); } else { Value writeVisible = diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index bd52c3c4a5ba..19b4bf66516d 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -687,9 +687,11 @@ def kernel(output_desc, FAILURE: ttgl.constexpr): cga_layout: ttgl.constexpr = default_cga_layout(ttgl.num_ctas(), 2) smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2, cga_layout=cga_layout) - smem = ttgl.allocate_shared_memory(ttgl.float16, [2, block_m, XBLOCK], smem_layout) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout) + smem = ttgl.allocate_shared_memory(ttgl.float16, [2, block_m, XBLOCK], smem_layout) + for i in ttgl.static_range(2): + smem.index(i).store(ttgl.zeros([block_m, XBLOCK], ttgl.float16, blocked_layout)) val = ttgl.full([block_m, XBLOCK], 42, ttgl.float16, blocked_layout) tma.async_copy_shared_to_global(output_desc, [0, 0], smem.index(0)) tma.async_copy_shared_to_global(output_desc, [0, 0], smem.index(1)) @@ -722,7 +724,11 @@ def test_tcgen5_mma(FAILURE, MEM_ACCESS_KIND, TWO_CTAS, device, run_wrapper, mon if MEM_ACCESS_KIND == "tma_cp": # shmem operands are being read by the tcgen05_mma assert "Buffer being accessed has outstanding reads" in result.driver_stderr_output - elif MEM_ACCESS_KIND in ["tmem_load", "tmem_store"]: + elif MEM_ACCESS_KIND == "tmem_load": + # tmem is being written by the tcgen05_mma + assert ("Buffer being accessed has outstanding writes" in result.driver_stderr_output + or "Buffer being read before any write. Operand: B" in result.driver_stderr_output) + elif MEM_ACCESS_KIND == "tmem_store": # tmem is being written by the tcgen05_mma assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output else: @@ -745,19 +751,28 @@ def kernel(input_desc, output_desc, FAILURE: ttgl.constexpr, MEM_ACCESS_KIND: tt cga_layout=mma_cga_layout(ttgl.num_ctas(), 2, TWO_CTAS), two_ctas=TWO_CTAS, ) - smem_a_blocked_layout: ttgl.constexpr = ttgl.BlockedLayout( - size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], - cga_layout=mma_cga_layout(ttgl.num_ctas(), 0, TWO_CTAS)) + smem_a_blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=input_desc.layout.cga_layout) acc_blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=acc_layout.cga_layout) - smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], input_desc.layout) + smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], input_desc.layout, + value=ttgl.zeros([block_m, XBLOCK], ttgl.float16, smem_a_blocked_layout)) + smem_b_layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for([XBLOCK, block_n], ttgl.float16, + cga_layout=mma_cga_layout( + ttgl.num_ctas(), 1, TWO_CTAS)) + smem_b_blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_b_layout.cga_layout) smemB = ttgl.allocate_shared_memory( ttgl.float16, [XBLOCK, block_n], - ttgl.NVMMASharedLayout.get_default_for([XBLOCK, block_n], ttgl.float16, - cga_layout=mma_cga_layout(ttgl.num_ctas(), 1, TWO_CTAS)), + smem_b_layout, + value=ttgl.zeros([XBLOCK, block_n], ttgl.float16, smem_b_blocked_layout), ) + if TWO_CTAS: + ttgl.barrier(cluster=True) mma_bar = mbarrier.allocate_mbarrier() acc = blackwell.allocate_tensor_memory(ttgl.float32, [block_m, block_n], acc_layout) mbarrier.init(mma_bar, count=1) @@ -989,12 +1004,22 @@ def kernel(input_desc, BUF_IDX: ttgl.constexpr, BAR_IDX: ttgl.constexpr): acc_blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=acc_layout.cga_layout) - smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], input_desc.layout) + smem_a_blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=input_desc.layout.cga_layout) + smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], input_desc.layout, + value=ttgl.zeros([block_m, XBLOCK], ttgl.float16, smem_a_blocked_layout)) + smem_b_layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for([XBLOCK, block_n], ttgl.float16, + cga_layout=mma_cga_layout( + ttgl.num_ctas(), 1)) + smem_b_blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_b_layout.cga_layout) smemB = ttgl.allocate_shared_memory( ttgl.float16, [XBLOCK, block_n], - ttgl.NVMMASharedLayout.get_default_for([XBLOCK, block_n], ttgl.float16, - cga_layout=mma_cga_layout(ttgl.num_ctas(), 1)), + smem_b_layout, + value=ttgl.zeros([XBLOCK, block_n], ttgl.float16, smem_b_blocked_layout), ) bar = mbarrier.allocate_mbarrier(batch=4) acc = blackwell.allocate_tensor_memory(ttgl.float32, [2, block_m, block_n], acc_layout) @@ -1352,6 +1377,8 @@ def ws_kernel(output, FAILURE: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [2, block_x], smem_layout) + for i in ttgl.static_range(2): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=2) for i in range(2): mbarrier.init(bar.index(i), count=1) @@ -1406,6 +1433,8 @@ def ws_kernel(output, FAILURE: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [2, block_x], smem_layout) + for i in ttgl.static_range(2): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=2) for i in range(2): mbarrier.init(bar.index(i), count=1) @@ -1469,6 +1498,8 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [3, block_x], smem_layout) + for i in ttgl.static_range(3): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=3) for i in range(3): mbarrier.init(bar.index(i), count=1) @@ -1530,6 +1561,8 @@ def kernel(output, FAILURE: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [3, block_x], smem_layout) + for i in ttgl.static_range(3): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=2) mbarrier.init(bar.index(0), count=2) mbarrier.init(bar.index(1), count=1) @@ -1614,6 +1647,8 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [3, block_x], smem_layout) + for i in ttgl.static_range(3): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=4) for i in range(4): mbarrier.init(bar.index(i), count=1) @@ -1773,6 +1808,8 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [4, block_x], smem_layout) + for i in ttgl.static_range(4): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=4) for i in range(4): mbarrier.init(bar.index(i), count=2) @@ -1847,6 +1884,8 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [3, block_x], smem_layout) + for i in ttgl.static_range(3): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=3) for i in range(3): mbarrier.init(bar.index(i), count=1) @@ -1976,6 +2015,8 @@ def kernel(input, FAILURE: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[block_x], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [2, block_x], smem_layout) + for i in ttgl.static_range(2): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=1) mbarrier.init(bar.index(0), count=1) ttgl.warp_specialize([ @@ -2341,11 +2382,15 @@ def test_aliasing_shared_visibility_outstanding_write(MISSING_BAR, OVERLAP, devi if run_wrapper: result = run_in_process(test_aliasing_shared_visibility_outstanding_write, (MISSING_BAR, OVERLAP, device, False, monkeypatch, num_ctas)) - if MISSING_BAR and OVERLAP: + if not OVERLAP: + assert_expected_cuda_failure(result.exc) + assert "Buffer being read before any write" in result.driver_stderr_output + elif MISSING_BAR: assert result.exc is not None assert_expected_cuda_failure(result.exc) # The race can be reported from either side depending on timing. - assert "Buffer being accessed has outstanding" in result.driver_stderr_output + assert ("Buffer being read before any write" in result.driver_stderr_output + or "Buffer being accessed has outstanding" in result.driver_stderr_output) else: assert result.exc is None assert result.driver_stderr_output == "" @@ -2444,10 +2489,15 @@ def kernel(FAILURE: ttgl.constexpr): blocked_layout_write: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK // 2], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout) + blocked_layout_full: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK * 2], + threads_per_warp=[32, 1], warps_per_cta=[4, 1], + order=[0, 1], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float32, [block_m, XBLOCK], smem_layout) tmem_layout: ttgl.constexpr = blackwell.TensorMemoryLayout([XBLOCK, XBLOCK * 2], col_stride=1, cga_layout=cga_layout) - tmem = blackwell.allocate_tensor_memory(ttgl.float32, [block_m, XBLOCK * 2], tmem_layout) + tmem = blackwell.allocate_tensor_memory( + ttgl.float32, [block_m, XBLOCK * 2], tmem_layout, value=ttgl.zeros([block_m, XBLOCK * 2], ttgl.float32, + blocked_layout_full)) bar = mbarrier.allocate_mbarrier(batch=1) mbarrier.init(bar.index(0), count=1) alias0 = tmem.slice(0, XBLOCK) @@ -2530,20 +2580,29 @@ def kernel(input, MISSING_WAIT: ttgl.constexpr, OVERLAP: ttgl.constexpr): @gluon.jit def async_copy_mma_write_after_read_kernel(a_ptr, BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, BLOCK_K: ttgl.constexpr): + a_smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], ttgl.float16, + cga_layout=mma_cga_layout( + ttgl.num_ctas(), 0)) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], warps_per_cta=[ttgl.num_warps(), 1], order=[0, 1], - cga_layout=mma_cga_layout(ttgl.num_ctas(), 0)) + cga_layout=a_smem_layout.cga_layout) a_smem = ttgl.allocate_shared_memory( ttgl.float16, [BLOCK_M, BLOCK_K], - ttgl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], ttgl.float16, - cga_layout=mma_cga_layout(ttgl.num_ctas(), 0)), + a_smem_layout, + value=ttgl.zeros([BLOCK_M, BLOCK_K], ttgl.float16, blocked_layout), ) + b_smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], ttgl.float16, + cga_layout=mma_cga_layout( + ttgl.num_ctas(), 1)) + blocked_layout_b: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[ttgl.num_warps(), 1], order=[0, 1], + cga_layout=b_smem_layout.cga_layout) b_smem = ttgl.allocate_shared_memory( ttgl.float16, [BLOCK_K, BLOCK_N], - ttgl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], ttgl.float16, - cga_layout=mma_cga_layout(ttgl.num_ctas(), 1)), + b_smem_layout, + value=ttgl.zeros([BLOCK_K, BLOCK_N], ttgl.float16, blocked_layout_b), ) bar = mbarrier.allocate_mbarrier() @@ -2581,17 +2640,23 @@ def test_mma_read_async_copy_write(run_wrapper, monkeypatch, num_ctas): @gluon.jit def load_local_alloc_mma_write_after_read_kernel(a_ptr, K, BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, BLOCK_K: ttgl.constexpr): - blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], - warps_per_cta=[ttgl.num_warps(), 1], order=[0, 1], - cga_layout=mma_cga_layout(ttgl.num_ctas(), 0)) a_smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], ttgl.float16, cga_layout=mma_cga_layout( ttgl.num_ctas(), 0)) + blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[ttgl.num_warps(), 1], order=[0, 1], + cga_layout=a_smem_layout.cga_layout) + b_smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], ttgl.float16, + cga_layout=mma_cga_layout( + ttgl.num_ctas(), 1)) + blocked_layout_b: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[ttgl.num_warps(), 1], order=[0, 1], + cga_layout=b_smem_layout.cga_layout) b_smem = ttgl.allocate_shared_memory( ttgl.float16, [BLOCK_K, BLOCK_N], - ttgl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], ttgl.float16, - cga_layout=mma_cga_layout(ttgl.num_ctas(), 1)), + b_smem_layout, + value=ttgl.zeros([BLOCK_K, BLOCK_N], ttgl.float16, blocked_layout_b), ) bar = mbarrier.allocate_mbarrier() From a93ec2b3284de1a9497dff479320a26f7c6d1a6e Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 29 Apr 2026 15:39:11 +0200 Subject: [PATCH 3/4] fix hopper --- python/test/gluon/test_consan.py | 33 ++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 19b4bf66516d..07aa45c86ea7 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -905,8 +905,16 @@ def kernel(input, FAILURE: ttgl.constexpr): cga_layout=cga_layout_a) smem_layout_b: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2, cga_layout=cga_layout_b) - smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], smem_layout_a) - smemB = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, block_n], smem_layout_b) + smem_a_init_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_layout_a.cga_layout) + smem_b_init_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_layout_b.cga_layout) + smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], smem_layout_a, + value=ttgl.zeros([block_m, XBLOCK], ttgl.float16, smem_a_init_layout)) + smemB = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, block_n], smem_layout_b, + value=ttgl.zeros([XBLOCK, block_n], ttgl.float16, smem_b_init_layout)) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout_a) @@ -952,8 +960,16 @@ def kernel(input, FAILURE: ttgl.constexpr): cga_layout=cga_layout_a) smem_layout_b: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2, cga_layout=cga_layout_b) - smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], smem_layout_a) - smemB = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, block_n], smem_layout_b) + smem_a_init_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_layout_a.cga_layout) + smem_b_init_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_layout_b.cga_layout) + smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], smem_layout_a, + value=ttgl.zeros([block_m, XBLOCK], ttgl.float16, smem_a_init_layout)) + smemB = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, block_n], smem_layout_b, + value=ttgl.zeros([XBLOCK, block_n], ttgl.float16, smem_b_init_layout)) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout_a) @@ -2079,10 +2095,19 @@ def kernel(FAILURE: ttgl.constexpr): cga_layout=cga_layout_b) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout_a) + smem_a_init_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_layout_a.cga_layout) + smem_b_init_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_layout_b.cga_layout) mma_layout: ttgl.constexpr = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 32, 16], cga_layout=cga_layout_c) smemA = ttgl.allocate_shared_memory(ttgl.float16, [2, block_m, XBLOCK], smem_layout_a) smemB = ttgl.allocate_shared_memory(ttgl.float16, [2, XBLOCK, block_n], smem_layout_b) + for i in ttgl.static_range(2): + smemA.index(i).store(ttgl.zeros([block_m, XBLOCK], ttgl.float16, smem_a_init_layout)) + smemB.index(i).store(ttgl.zeros([XBLOCK, block_n], ttgl.float16, smem_b_init_layout)) bar = mbarrier.allocate_mbarrier(batch=1) mbarrier.init(bar.index(0), count=1) ttgl.warp_specialize([ From b053ca1a61f2bc51652621a45af861f8920cd38e Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 29 Apr 2026 16:42:12 +0200 Subject: [PATCH 4/4] more wgmma fixes --- python/test/gluon/test_consan.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 07aa45c86ea7..8c8820b35b68 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -1733,6 +1733,8 @@ def kernel(output, FAILURE: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [3, block_x], smem_layout) + for i in ttgl.static_range(3): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=3) for i in range(3): mbarrier.init(bar.index(i), count=1) @@ -1977,6 +1979,8 @@ def kernel(input, FAILURE: ttgl.constexpr): smem = ttgl.allocate_shared_memory(ttgl.float16, [4, block_x], smem_layout) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[block_x], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) + for i in ttgl.static_range(4): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) ttgl.warp_specialize([ (ws_prog, (input, smem, FAILURE, blocked_layout, 0)), (ws_prog, (input, smem, FAILURE, blocked_layout, 2)),