Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,13 @@ class FunctionBuilder {
void createTransferVisibleReadsCall(ImplicitLocOpBuilder &b, Value mbar,
uint64_t threadMask, Value pred,
MemType memType, Operation *insertPoint);
// verifyWriteVisibility: ensure the thread either sees the latest write or no
// other thread is writing the buffer.
// verifyWriteVisibility: ensure the thread sees the latest write. When
// allowNoWrite is true, also allow rows that have not been written yet.
void createVerifyWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
uint32_t length, int thread,
StringRef operandName, Value pred,
MemType memType, Operation *insertPoint,
Value recipientCTAs);
Value recipientCTAs, bool allowNoWrite);
// verifyReadVisibility: ensure all reads from the buffer are visible to the
// thread.
void createVerifyReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
Expand Down
60 changes: 51 additions & 9 deletions lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2144,7 +2144,7 @@ void FunctionBuilder::createTransferVisibleReadsCall(
void FunctionBuilder::createVerifyWriteVisibilityCall(
ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread,
StringRef operandName, Value pred, MemType memType, Operation *insertPoint,
Value recipientCTAs) {
Value recipientCTAs, bool allowNoWrite) {
if (auxData.buffers[(int)memType].empty() ||
auxData.writeVisibility[(int)memType].empty() ||
(auxData.hasNonTrivialAliasing[(int)memType] &&
Expand All @@ -2166,12 +2166,16 @@ void FunctionBuilder::createVerifyWriteVisibilityCall(
std::string message = "Buffer being accessed has outstanding writes.";
if (!operandName.empty())
message += " Operand: " + operandName.str();
std::string uninitializedMessage = "Buffer being read before any write.";
if (!operandName.empty())
uninitializedMessage += " Operand: " + operandName.str();
auto verifyWriteResultType = cast<RankedTensorType>(
writeVisibilityType.cloneWith(std::nullopt, b.getI1Type()));
AssertInfo assertInfo{message, verifyWriteResultType};
Type aliasMatrixTypeBase;
auto buildVerifyWriteBody = [&writeVisibilityType, &aliasMatrixTypeBase,
verifyWriteResultType](bool useAlias) {
verifyWriteResultType](bool useAlias,
bool allowNoWrite) {
return [=](ImplicitLocOpBuilder &fb, Block *entryBlock) {
Value bufOffset = entryBlock->getArgument(0);
Value lengthVal = entryBlock->getArgument(1);
Expand Down Expand Up @@ -2213,14 +2217,35 @@ void FunctionBuilder::createVerifyWriteVisibilityCall(
arith::AndIOp::create(fb, bufVisibility, bufferThreadBit);
bufferHasVisibility = arith::CmpIOp::create(
fb, arith::CmpIPredicate::eq, bufferHasVisibility, bufferThreadBit);
Value writeVisible =
arith::OrIOp::create(fb, noOneIsWriting, bufferHasVisibility);
Value allWritesVisible = reduceAll<arith::AndIOp>(fb, writeVisible);
Value result;
if (!allowNoWrite) {
Value rowOne = tti::createConstIntTensor(
fb, fb.getLoc(), 1, cast<RankedTensorType>(buffersEqBuf.getType()));
Value rowInitialized =
arith::XOrIOp::create(fb, noOneIsWriting, rowOne);
Value initializedRows =
arith::AndIOp::create(fb, rowInitialized, buffersEqBuf);
// Alias rows are alternatives within a CTA, but every selected CTA must
// have at least one initialized row.
Value initializedCTAs =
reduceLastDim<arith::OrIOp>(fb, initializedRows);
Value selectedCTAs = reduceLastDim<arith::OrIOp>(fb, buffersEqBuf);
Value ctaOne = tti::createConstIntTensor(
fb, fb.getLoc(), 1, cast<RankedTensorType>(selectedCTAs.getType()));
Value unmatchedCTAs = arith::XOrIOp::create(fb, selectedCTAs, ctaOne);
Value initializedOrUnmatched =
arith::OrIOp::create(fb, initializedCTAs, unmatchedCTAs);
result = reduceAll<arith::AndIOp>(fb, initializedOrUnmatched);
} else {
Value writeVisible =
arith::OrIOp::create(fb, noOneIsWriting, bufferHasVisibility);
result = reduceAll<arith::AndIOp>(fb, writeVisible);
}

Value vTrue = arith::ConstantOp::create(
fb, allWritesVisible.getType(), fb.getIntegerAttr(fb.getI1Type(), 1));
fb, result.getType(), fb.getIntegerAttr(fb.getI1Type(), 1));
Value predicatedWriteVisible =
arith::SelectOp::create(fb, pred, allWritesVisible, vTrue);
arith::SelectOp::create(fb, pred, result, vTrue);
predicatedWriteVisible = triton::SplatOp::create(
fb, verifyWriteResultType, predicatedWriteVisible);
triton::ReturnOp::create(fb, predicatedWriteVisible);
Expand All @@ -2235,18 +2260,35 @@ void FunctionBuilder::createVerifyWriteVisibilityCall(
SmallVector<Value> args = {bufOffset, lengthVal, pred,
threadVal, buffersVal, writeVisibilityVal,
recipientCTAs, aliasMatrixVal};
if (!allowNoWrite) {
AssertInfo initializedAssertInfo{uninitializedMessage,
verifyWriteResultType};
createCallToCachedFunction(
b, "verify_write_initialized", args, initializedAssertInfo,
{buffersType, writeVisibilityType, aliasMatrixType,
(uint64_t)memType},
buildVerifyWriteBody(/*useAlias=*/true, /*allowNoWrite=*/false));
Comment thread
lezcano marked this conversation as resolved.
}
createCallToCachedFunction(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have to run both "verify_write_initialized" and "verify_write_visibility" when !allowNoWrite? Won't verify_write_initialized already check the visibility?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a minor thing, and it's because these functions just accept one assert, so if we want to have its own nice assert, we have to run both. Should I make it so that we can have several asserts per function you reckon?

Copy link
Copy Markdown
Contributor

@pawelszczerbuk pawelszczerbuk Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the problems is that I think "verify_write_initialized" checks also visibility, so in case of buffer not being visible it will emit the "uninitialized" assert, which will be confusing. I think we either need separate function to check for no-write, or multiple possible asserts, but then we also need the inner function to be able to return richer information than just bool

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine, unless I misunderstood you.
We have:

 rowInitialized = !noOneIsWriting
 writeVisible = noOneIsWriting || bufferHasVisibility

This is already tested in

python/test/gluon/test_consan.py::test_async_tma_kernel[1ctas-True]

which triggers the visiblity point but not the initialised one.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, re-reading the code I see I was wrong. Thanks!

b, "verify_write_visibility", args, assertInfo,
{buffersType, writeVisibilityType, aliasMatrixType, (uint64_t)memType},
buildVerifyWriteBody(/*useAlias=*/true));
buildVerifyWriteBody(/*useAlias=*/true, /*allowNoWrite=*/true));
} else {
SmallVector<Value> args = {bufOffset, lengthVal, pred,
threadVal, buffersVal, writeVisibilityVal,
recipientCTAs};
if (!allowNoWrite) {
AssertInfo initializedAssertInfo{uninitializedMessage,
verifyWriteResultType};
createCallToCachedFunction(
b, "verify_write_initialized_noalias", args, initializedAssertInfo,
{buffersType, writeVisibilityType, (uint64_t)memType},
buildVerifyWriteBody(/*useAlias=*/false, /*allowNoWrite=*/false));
}
createCallToCachedFunction(
b, "verify_write_visibility_noalias", args, assertInfo,
{buffersType, writeVisibilityType, (uint64_t)memType},
buildVerifyWriteBody(/*useAlias=*/false));
buildVerifyWriteBody(/*useAlias=*/false, /*allowNoWrite=*/true));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ class ConcurrencySanitizerImpl {
// is writing to the same buffer.
addWriteChecks(b, funcBuilder, op, buf, effect.length, pred, memType,
thread, effect.operandName, effectRecipientCTAs,
opInfo->commitKind);
/*allowNoWrite=*/false, opInfo->commitKind);
if (opInfo->trackingKind == MemEffectsOpInfo::TrackingKind::Barrier) {
funcBuilder.createSetReadVisibilityCall(
b, buf, effect.length, getThreadPeersMask(thread), pred, memType,
Expand All @@ -502,7 +502,7 @@ class ConcurrencySanitizerImpl {
// is reading or writing to the same buffer.
addWriteChecks(b, funcBuilder, op, buf, effect.length, pred, memType,
thread, effect.operandName, effectRecipientCTAs,
opInfo->commitKind);
/*allowNoWrite=*/true, opInfo->commitKind);
addReadChecks(b, funcBuilder, op, buf, effect.length, pred, memType,
thread, effect.operandName, effectRecipientCTAs,
opInfo->commitKind);
Expand Down Expand Up @@ -578,10 +578,11 @@ class ConcurrencySanitizerImpl {
tti::FunctionBuilder &funcBuilder, Operation *op,
Value buf, uint32_t length, Value pred, MemType memType,
int thread, const std::string &operandName,
Value recipientCTAs,
Value recipientCTAs, bool allowNoWrite,
CommitKind::Kind opCommitKind = CommitKind::None) {
funcBuilder.createVerifyWriteVisibilityCall(
b, buf, length, thread, operandName, pred, memType, op, recipientCTAs);
funcBuilder.createVerifyWriteVisibilityCall(b, buf, length, thread,
operandName, pred, memType, op,
recipientCTAs, allowNoWrite);
// commit-num-based synchronization is only supported for shared memory
if (memType == MemType::SHARED_MEM) {
for (const auto &commitKindDesc :
Expand Down
Loading
Loading