diff --git a/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h b/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h index e383c185fe1b..e4520e52d84e 100644 --- a/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h +++ b/include/triton/Dialect/TritonInstrument/IR/FunctionBuilder.h @@ -199,13 +199,13 @@ class FunctionBuilder { void createTransferVisibleReadsCall(ImplicitLocOpBuilder &b, Value mbar, uint64_t threadMask, Value pred, MemType memType, Operation *insertPoint); - // verifyWriteVisibility: ensure the thread either sees the latest write or no - // other thread is writing the buffer. + // verifyWriteVisibility: ensure the thread sees the latest write. When + // allowNoWrite is true, also allow rows that have not been written yet. void createVerifyWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, StringRef operandName, Value pred, MemType memType, Operation *insertPoint, - Value recipientCTAs); + Value recipientCTAs, bool allowNoWrite); // verifyReadVisibility: ensure all reads from the buffer are visible to the // thread. void createVerifyReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf, diff --git a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp index db3ace2f1df6..9a131b5111bc 100644 --- a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp +++ b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp @@ -2144,7 +2144,7 @@ void FunctionBuilder::createTransferVisibleReadsCall( void FunctionBuilder::createVerifyWriteVisibilityCall( ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, StringRef operandName, Value pred, MemType memType, Operation *insertPoint, - Value recipientCTAs) { + Value recipientCTAs, bool allowNoWrite) { if (auxData.buffers[(int)memType].empty() || auxData.writeVisibility[(int)memType].empty() || (auxData.hasNonTrivialAliasing[(int)memType] && @@ -2166,12 +2166,16 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( std::string message = "Buffer being accessed has outstanding writes."; if (!operandName.empty()) message += " Operand: " + operandName.str(); + std::string uninitializedMessage = "Buffer being read before any write."; + if (!operandName.empty()) + uninitializedMessage += " Operand: " + operandName.str(); auto verifyWriteResultType = cast( writeVisibilityType.cloneWith(std::nullopt, b.getI1Type())); AssertInfo assertInfo{message, verifyWriteResultType}; Type aliasMatrixTypeBase; auto buildVerifyWriteBody = [&writeVisibilityType, &aliasMatrixTypeBase, - verifyWriteResultType](bool useAlias) { + verifyWriteResultType](bool useAlias, + bool allowNoWrite) { return [=](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value bufOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); @@ -2213,14 +2217,35 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( arith::AndIOp::create(fb, bufVisibility, bufferThreadBit); bufferHasVisibility = arith::CmpIOp::create( fb, arith::CmpIPredicate::eq, bufferHasVisibility, bufferThreadBit); - Value writeVisible = - arith::OrIOp::create(fb, noOneIsWriting, bufferHasVisibility); - Value allWritesVisible = reduceAll(fb, writeVisible); + Value result; + if (!allowNoWrite) { + Value rowOne = tti::createConstIntTensor( + fb, fb.getLoc(), 1, cast(buffersEqBuf.getType())); + Value rowInitialized = + arith::XOrIOp::create(fb, noOneIsWriting, rowOne); + Value initializedRows = + arith::AndIOp::create(fb, rowInitialized, buffersEqBuf); + // Alias rows are alternatives within a CTA, but every selected CTA must + // have at least one initialized row. + Value initializedCTAs = + reduceLastDim(fb, initializedRows); + Value selectedCTAs = reduceLastDim(fb, buffersEqBuf); + Value ctaOne = tti::createConstIntTensor( + fb, fb.getLoc(), 1, cast(selectedCTAs.getType())); + Value unmatchedCTAs = arith::XOrIOp::create(fb, selectedCTAs, ctaOne); + Value initializedOrUnmatched = + arith::OrIOp::create(fb, initializedCTAs, unmatchedCTAs); + result = reduceAll(fb, initializedOrUnmatched); + } else { + Value writeVisible = + arith::OrIOp::create(fb, noOneIsWriting, bufferHasVisibility); + result = reduceAll(fb, writeVisible); + } Value vTrue = arith::ConstantOp::create( - fb, allWritesVisible.getType(), fb.getIntegerAttr(fb.getI1Type(), 1)); + fb, result.getType(), fb.getIntegerAttr(fb.getI1Type(), 1)); Value predicatedWriteVisible = - arith::SelectOp::create(fb, pred, allWritesVisible, vTrue); + arith::SelectOp::create(fb, pred, result, vTrue); predicatedWriteVisible = triton::SplatOp::create( fb, verifyWriteResultType, predicatedWriteVisible); triton::ReturnOp::create(fb, predicatedWriteVisible); @@ -2235,18 +2260,35 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( SmallVector args = {bufOffset, lengthVal, pred, threadVal, buffersVal, writeVisibilityVal, recipientCTAs, aliasMatrixVal}; + if (!allowNoWrite) { + AssertInfo initializedAssertInfo{uninitializedMessage, + verifyWriteResultType}; + createCallToCachedFunction( + b, "verify_write_initialized", args, initializedAssertInfo, + {buffersType, writeVisibilityType, aliasMatrixType, + (uint64_t)memType}, + buildVerifyWriteBody(/*useAlias=*/true, /*allowNoWrite=*/false)); + } createCallToCachedFunction( b, "verify_write_visibility", args, assertInfo, {buffersType, writeVisibilityType, aliasMatrixType, (uint64_t)memType}, - buildVerifyWriteBody(/*useAlias=*/true)); + buildVerifyWriteBody(/*useAlias=*/true, /*allowNoWrite=*/true)); } else { SmallVector args = {bufOffset, lengthVal, pred, threadVal, buffersVal, writeVisibilityVal, recipientCTAs}; + if (!allowNoWrite) { + AssertInfo initializedAssertInfo{uninitializedMessage, + verifyWriteResultType}; + createCallToCachedFunction( + b, "verify_write_initialized_noalias", args, initializedAssertInfo, + {buffersType, writeVisibilityType, (uint64_t)memType}, + buildVerifyWriteBody(/*useAlias=*/false, /*allowNoWrite=*/false)); + } createCallToCachedFunction( b, "verify_write_visibility_noalias", args, assertInfo, {buffersType, writeVisibilityType, (uint64_t)memType}, - buildVerifyWriteBody(/*useAlias=*/false)); + buildVerifyWriteBody(/*useAlias=*/false, /*allowNoWrite=*/true)); } } diff --git a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp index 992e41185184..067130a94502 100644 --- a/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp +++ b/lib/Dialect/TritonInstrument/Transforms/ConcurrencySanitizer.cpp @@ -483,7 +483,7 @@ class ConcurrencySanitizerImpl { // is writing to the same buffer. addWriteChecks(b, funcBuilder, op, buf, effect.length, pred, memType, thread, effect.operandName, effectRecipientCTAs, - opInfo->commitKind); + /*allowNoWrite=*/false, opInfo->commitKind); if (opInfo->trackingKind == MemEffectsOpInfo::TrackingKind::Barrier) { funcBuilder.createSetReadVisibilityCall( b, buf, effect.length, getThreadPeersMask(thread), pred, memType, @@ -502,7 +502,7 @@ class ConcurrencySanitizerImpl { // is reading or writing to the same buffer. addWriteChecks(b, funcBuilder, op, buf, effect.length, pred, memType, thread, effect.operandName, effectRecipientCTAs, - opInfo->commitKind); + /*allowNoWrite=*/true, opInfo->commitKind); addReadChecks(b, funcBuilder, op, buf, effect.length, pred, memType, thread, effect.operandName, effectRecipientCTAs, opInfo->commitKind); @@ -578,10 +578,11 @@ class ConcurrencySanitizerImpl { tti::FunctionBuilder &funcBuilder, Operation *op, Value buf, uint32_t length, Value pred, MemType memType, int thread, const std::string &operandName, - Value recipientCTAs, + Value recipientCTAs, bool allowNoWrite, CommitKind::Kind opCommitKind = CommitKind::None) { - funcBuilder.createVerifyWriteVisibilityCall( - b, buf, length, thread, operandName, pred, memType, op, recipientCTAs); + funcBuilder.createVerifyWriteVisibilityCall(b, buf, length, thread, + operandName, pred, memType, op, + recipientCTAs, allowNoWrite); // commit-num-based synchronization is only supported for shared memory if (memType == MemType::SHARED_MEM) { for (const auto &commitKindDesc : diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 5e9baee683cc..8c8820b35b68 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -203,18 +203,47 @@ def kernel(input_desc, out, FAILURE: ttgl.constexpr): kernel[(1, )](input_desc, output, FAILURE=FAILURE, num_warps=4, num_ctas=num_ctas) +@pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") +def test_local_load_before_first_write(device, run_wrapper, monkeypatch, num_ctas): + if run_wrapper: + result = run_in_process(test_local_load_before_first_write, (device, False, monkeypatch, num_ctas)) + assert_expected_cuda_failure(result.exc) + assert "Buffer being read before any write" in result.driver_stderr_output + return + + monkeypatch.setenv("TRITON_INSTRUMENTATION_MODE", "consan") + monkeypatch.setenv("CUDA_LAUNCH_BLOCKING", "1") + knobs.refresh_knobs() + + @gluon.jit + def kernel(output): + block_m: ttgl.constexpr = XBLOCK * ttgl.num_ctas() + cga_layout: ttgl.constexpr = default_cga_layout(ttgl.num_ctas(), 2) + blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout) + smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2, + cga_layout=cga_layout) + smem = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], smem_layout) + val = smem.load(blocked_layout) + offs_m = ttgl.arange(0, block_m, ttgl.SliceLayout(1, blocked_layout))[:, None] + offs_n = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(0, blocked_layout))[None, :] + ttgl.store(output + offs_m * XBLOCK + offs_n, val) + + output = torch.empty((XBLOCK.value * num_ctas, XBLOCK.value), device=device, dtype=torch.float16) + kernel[(1, )](output, num_warps=4, num_ctas=num_ctas) + + @pytest.mark.skipif(not is_cuda() or torch.cuda.get_device_capability()[0] < 9, reason="Requires hopper or newer") @pytest.mark.parametrize("FAILURE", [True, False]) def test_async_tma_multicast_kernel(FAILURE, device, run_wrapper, monkeypatch, num_ctas): if num_ctas == 1: pytest.skip("Need at least 2 CTAs for multicast in this test") - if FAILURE and num_ctas == 4: - pytest.skip("Temporarily disabled: flaky with 4 CTAs when FAILURE=True") if run_wrapper: result = run_in_process(test_async_tma_multicast_kernel, (FAILURE, device, False, monkeypatch, num_ctas)) if FAILURE: assert_expected_cuda_failure(result.exc) - assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output + assert ("Buffer being read before any write" in result.driver_stderr_output + or "Buffer being accessed has outstanding writes" in result.driver_stderr_output) else: assert result.exc is None assert result.driver_stderr_output == "" @@ -658,9 +687,11 @@ def kernel(output_desc, FAILURE: ttgl.constexpr): cga_layout: ttgl.constexpr = default_cga_layout(ttgl.num_ctas(), 2) smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2, cga_layout=cga_layout) - smem = ttgl.allocate_shared_memory(ttgl.float16, [2, block_m, XBLOCK], smem_layout) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout) + smem = ttgl.allocate_shared_memory(ttgl.float16, [2, block_m, XBLOCK], smem_layout) + for i in ttgl.static_range(2): + smem.index(i).store(ttgl.zeros([block_m, XBLOCK], ttgl.float16, blocked_layout)) val = ttgl.full([block_m, XBLOCK], 42, ttgl.float16, blocked_layout) tma.async_copy_shared_to_global(output_desc, [0, 0], smem.index(0)) tma.async_copy_shared_to_global(output_desc, [0, 0], smem.index(1)) @@ -693,7 +724,11 @@ def test_tcgen5_mma(FAILURE, MEM_ACCESS_KIND, TWO_CTAS, device, run_wrapper, mon if MEM_ACCESS_KIND == "tma_cp": # shmem operands are being read by the tcgen05_mma assert "Buffer being accessed has outstanding reads" in result.driver_stderr_output - elif MEM_ACCESS_KIND in ["tmem_load", "tmem_store"]: + elif MEM_ACCESS_KIND == "tmem_load": + # tmem is being written by the tcgen05_mma + assert ("Buffer being accessed has outstanding writes" in result.driver_stderr_output + or "Buffer being read before any write. Operand: B" in result.driver_stderr_output) + elif MEM_ACCESS_KIND == "tmem_store": # tmem is being written by the tcgen05_mma assert "Buffer being accessed has outstanding writes" in result.driver_stderr_output else: @@ -716,19 +751,28 @@ def kernel(input_desc, output_desc, FAILURE: ttgl.constexpr, MEM_ACCESS_KIND: tt cga_layout=mma_cga_layout(ttgl.num_ctas(), 2, TWO_CTAS), two_ctas=TWO_CTAS, ) - smem_a_blocked_layout: ttgl.constexpr = ttgl.BlockedLayout( - size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], - cga_layout=mma_cga_layout(ttgl.num_ctas(), 0, TWO_CTAS)) + smem_a_blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=input_desc.layout.cga_layout) acc_blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=acc_layout.cga_layout) - smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], input_desc.layout) + smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], input_desc.layout, + value=ttgl.zeros([block_m, XBLOCK], ttgl.float16, smem_a_blocked_layout)) + smem_b_layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for([XBLOCK, block_n], ttgl.float16, + cga_layout=mma_cga_layout( + ttgl.num_ctas(), 1, TWO_CTAS)) + smem_b_blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_b_layout.cga_layout) smemB = ttgl.allocate_shared_memory( ttgl.float16, [XBLOCK, block_n], - ttgl.NVMMASharedLayout.get_default_for([XBLOCK, block_n], ttgl.float16, - cga_layout=mma_cga_layout(ttgl.num_ctas(), 1, TWO_CTAS)), + smem_b_layout, + value=ttgl.zeros([XBLOCK, block_n], ttgl.float16, smem_b_blocked_layout), ) + if TWO_CTAS: + ttgl.barrier(cluster=True) mma_bar = mbarrier.allocate_mbarrier() acc = blackwell.allocate_tensor_memory(ttgl.float32, [block_m, block_n], acc_layout) mbarrier.init(mma_bar, count=1) @@ -861,8 +905,16 @@ def kernel(input, FAILURE: ttgl.constexpr): cga_layout=cga_layout_a) smem_layout_b: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2, cga_layout=cga_layout_b) - smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], smem_layout_a) - smemB = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, block_n], smem_layout_b) + smem_a_init_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_layout_a.cga_layout) + smem_b_init_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_layout_b.cga_layout) + smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], smem_layout_a, + value=ttgl.zeros([block_m, XBLOCK], ttgl.float16, smem_a_init_layout)) + smemB = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, block_n], smem_layout_b, + value=ttgl.zeros([XBLOCK, block_n], ttgl.float16, smem_b_init_layout)) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout_a) @@ -908,8 +960,16 @@ def kernel(input, FAILURE: ttgl.constexpr): cga_layout=cga_layout_a) smem_layout_b: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2, cga_layout=cga_layout_b) - smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], smem_layout_a) - smemB = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, block_n], smem_layout_b) + smem_a_init_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_layout_a.cga_layout) + smem_b_init_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_layout_b.cga_layout) + smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], smem_layout_a, + value=ttgl.zeros([block_m, XBLOCK], ttgl.float16, smem_a_init_layout)) + smemB = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, block_n], smem_layout_b, + value=ttgl.zeros([XBLOCK, block_n], ttgl.float16, smem_b_init_layout)) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout_a) @@ -960,12 +1020,22 @@ def kernel(input_desc, BUF_IDX: ttgl.constexpr, BAR_IDX: ttgl.constexpr): acc_blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=acc_layout.cga_layout) - smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], input_desc.layout) + smem_a_blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=input_desc.layout.cga_layout) + smemA = ttgl.allocate_shared_memory(ttgl.float16, [block_m, XBLOCK], input_desc.layout, + value=ttgl.zeros([block_m, XBLOCK], ttgl.float16, smem_a_blocked_layout)) + smem_b_layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for([XBLOCK, block_n], ttgl.float16, + cga_layout=mma_cga_layout( + ttgl.num_ctas(), 1)) + smem_b_blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_b_layout.cga_layout) smemB = ttgl.allocate_shared_memory( ttgl.float16, [XBLOCK, block_n], - ttgl.NVMMASharedLayout.get_default_for([XBLOCK, block_n], ttgl.float16, - cga_layout=mma_cga_layout(ttgl.num_ctas(), 1)), + smem_b_layout, + value=ttgl.zeros([XBLOCK, block_n], ttgl.float16, smem_b_blocked_layout), ) bar = mbarrier.allocate_mbarrier(batch=4) acc = blackwell.allocate_tensor_memory(ttgl.float32, [2, block_m, block_n], acc_layout) @@ -1323,6 +1393,8 @@ def ws_kernel(output, FAILURE: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [2, block_x], smem_layout) + for i in ttgl.static_range(2): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=2) for i in range(2): mbarrier.init(bar.index(i), count=1) @@ -1377,6 +1449,8 @@ def ws_kernel(output, FAILURE: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [2, block_x], smem_layout) + for i in ttgl.static_range(2): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=2) for i in range(2): mbarrier.init(bar.index(i), count=1) @@ -1440,6 +1514,8 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [3, block_x], smem_layout) + for i in ttgl.static_range(3): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=3) for i in range(3): mbarrier.init(bar.index(i), count=1) @@ -1501,6 +1577,8 @@ def kernel(output, FAILURE: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [3, block_x], smem_layout) + for i in ttgl.static_range(3): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=2) mbarrier.init(bar.index(0), count=2) mbarrier.init(bar.index(1), count=1) @@ -1585,6 +1663,8 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [3, block_x], smem_layout) + for i in ttgl.static_range(3): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=4) for i in range(4): mbarrier.init(bar.index(i), count=1) @@ -1653,6 +1733,8 @@ def kernel(output, FAILURE: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [3, block_x], smem_layout) + for i in ttgl.static_range(3): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=3) for i in range(3): mbarrier.init(bar.index(i), count=1) @@ -1744,6 +1826,8 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [4, block_x], smem_layout) + for i in ttgl.static_range(4): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=4) for i in range(4): mbarrier.init(bar.index(i), count=2) @@ -1818,6 +1902,8 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [3, block_x], smem_layout) + for i in ttgl.static_range(3): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=3) for i in range(3): mbarrier.init(bar.index(i), count=1) @@ -1893,6 +1979,8 @@ def kernel(input, FAILURE: ttgl.constexpr): smem = ttgl.allocate_shared_memory(ttgl.float16, [4, block_x], smem_layout) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[block_x], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) + for i in ttgl.static_range(4): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) ttgl.warp_specialize([ (ws_prog, (input, smem, FAILURE, blocked_layout, 0)), (ws_prog, (input, smem, FAILURE, blocked_layout, 2)), @@ -1947,6 +2035,8 @@ def kernel(input, FAILURE: ttgl.constexpr): blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[block_x], threads_per_warp=[32], warps_per_cta=[4], order=[0], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float16, [2, block_x], smem_layout) + for i in ttgl.static_range(2): + smem.index(i).store(ttgl.zeros([block_x], ttgl.float16, blocked_layout)) bar = mbarrier.allocate_mbarrier(batch=1) mbarrier.init(bar.index(0), count=1) ttgl.warp_specialize([ @@ -2009,10 +2099,19 @@ def kernel(FAILURE: ttgl.constexpr): cga_layout=cga_layout_b) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout_a) + smem_a_init_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_layout_a.cga_layout) + smem_b_init_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[4, 1], order=[0, 1], + cga_layout=smem_layout_b.cga_layout) mma_layout: ttgl.constexpr = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 32, 16], cga_layout=cga_layout_c) smemA = ttgl.allocate_shared_memory(ttgl.float16, [2, block_m, XBLOCK], smem_layout_a) smemB = ttgl.allocate_shared_memory(ttgl.float16, [2, XBLOCK, block_n], smem_layout_b) + for i in ttgl.static_range(2): + smemA.index(i).store(ttgl.zeros([block_m, XBLOCK], ttgl.float16, smem_a_init_layout)) + smemB.index(i).store(ttgl.zeros([XBLOCK, block_n], ttgl.float16, smem_b_init_layout)) bar = mbarrier.allocate_mbarrier(batch=1) mbarrier.init(bar.index(0), count=1) ttgl.warp_specialize([ @@ -2312,11 +2411,15 @@ def test_aliasing_shared_visibility_outstanding_write(MISSING_BAR, OVERLAP, devi if run_wrapper: result = run_in_process(test_aliasing_shared_visibility_outstanding_write, (MISSING_BAR, OVERLAP, device, False, monkeypatch, num_ctas)) - if MISSING_BAR and OVERLAP: + if not OVERLAP: + assert_expected_cuda_failure(result.exc) + assert "Buffer being read before any write" in result.driver_stderr_output + elif MISSING_BAR: assert result.exc is not None assert_expected_cuda_failure(result.exc) # The race can be reported from either side depending on timing. - assert "Buffer being accessed has outstanding" in result.driver_stderr_output + assert ("Buffer being read before any write" in result.driver_stderr_output + or "Buffer being accessed has outstanding" in result.driver_stderr_output) else: assert result.exc is None assert result.driver_stderr_output == "" @@ -2415,10 +2518,15 @@ def kernel(FAILURE: ttgl.constexpr): blocked_layout_write: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK // 2], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1], cga_layout=cga_layout) + blocked_layout_full: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK * 2], + threads_per_warp=[32, 1], warps_per_cta=[4, 1], + order=[0, 1], cga_layout=cga_layout) smem = ttgl.allocate_shared_memory(ttgl.float32, [block_m, XBLOCK], smem_layout) tmem_layout: ttgl.constexpr = blackwell.TensorMemoryLayout([XBLOCK, XBLOCK * 2], col_stride=1, cga_layout=cga_layout) - tmem = blackwell.allocate_tensor_memory(ttgl.float32, [block_m, XBLOCK * 2], tmem_layout) + tmem = blackwell.allocate_tensor_memory( + ttgl.float32, [block_m, XBLOCK * 2], tmem_layout, value=ttgl.zeros([block_m, XBLOCK * 2], ttgl.float32, + blocked_layout_full)) bar = mbarrier.allocate_mbarrier(batch=1) mbarrier.init(bar.index(0), count=1) alias0 = tmem.slice(0, XBLOCK) @@ -2501,20 +2609,29 @@ def kernel(input, MISSING_WAIT: ttgl.constexpr, OVERLAP: ttgl.constexpr): @gluon.jit def async_copy_mma_write_after_read_kernel(a_ptr, BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, BLOCK_K: ttgl.constexpr): + a_smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], ttgl.float16, + cga_layout=mma_cga_layout( + ttgl.num_ctas(), 0)) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], warps_per_cta=[ttgl.num_warps(), 1], order=[0, 1], - cga_layout=mma_cga_layout(ttgl.num_ctas(), 0)) + cga_layout=a_smem_layout.cga_layout) a_smem = ttgl.allocate_shared_memory( ttgl.float16, [BLOCK_M, BLOCK_K], - ttgl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], ttgl.float16, - cga_layout=mma_cga_layout(ttgl.num_ctas(), 0)), + a_smem_layout, + value=ttgl.zeros([BLOCK_M, BLOCK_K], ttgl.float16, blocked_layout), ) + b_smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], ttgl.float16, + cga_layout=mma_cga_layout( + ttgl.num_ctas(), 1)) + blocked_layout_b: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[ttgl.num_warps(), 1], order=[0, 1], + cga_layout=b_smem_layout.cga_layout) b_smem = ttgl.allocate_shared_memory( ttgl.float16, [BLOCK_K, BLOCK_N], - ttgl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], ttgl.float16, - cga_layout=mma_cga_layout(ttgl.num_ctas(), 1)), + b_smem_layout, + value=ttgl.zeros([BLOCK_K, BLOCK_N], ttgl.float16, blocked_layout_b), ) bar = mbarrier.allocate_mbarrier() @@ -2552,17 +2669,23 @@ def test_mma_read_async_copy_write(run_wrapper, monkeypatch, num_ctas): @gluon.jit def load_local_alloc_mma_write_after_read_kernel(a_ptr, K, BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, BLOCK_K: ttgl.constexpr): - blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], - warps_per_cta=[ttgl.num_warps(), 1], order=[0, 1], - cga_layout=mma_cga_layout(ttgl.num_ctas(), 0)) a_smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], ttgl.float16, cga_layout=mma_cga_layout( ttgl.num_ctas(), 0)) + blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[ttgl.num_warps(), 1], order=[0, 1], + cga_layout=a_smem_layout.cga_layout) + b_smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], ttgl.float16, + cga_layout=mma_cga_layout( + ttgl.num_ctas(), 1)) + blocked_layout_b: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[32, 1], + warps_per_cta=[ttgl.num_warps(), 1], order=[0, 1], + cga_layout=b_smem_layout.cga_layout) b_smem = ttgl.allocate_shared_memory( ttgl.float16, [BLOCK_K, BLOCK_N], - ttgl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], ttgl.float16, - cga_layout=mma_cga_layout(ttgl.num_ctas(), 1)), + b_smem_layout, + value=ttgl.zeros([BLOCK_K, BLOCK_N], ttgl.float16, blocked_layout_b), ) bar = mbarrier.allocate_mbarrier()