-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[CONSAN] Add read before any write check #10167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2144,7 +2144,7 @@ void FunctionBuilder::createTransferVisibleReadsCall( | |
| void FunctionBuilder::createVerifyWriteVisibilityCall( | ||
| ImplicitLocOpBuilder &b, Value buf, uint32_t length, int thread, | ||
| StringRef operandName, Value pred, MemType memType, Operation *insertPoint, | ||
| Value recipientCTAs) { | ||
| Value recipientCTAs, bool allowNoWrite) { | ||
| if (auxData.buffers[(int)memType].empty() || | ||
| auxData.writeVisibility[(int)memType].empty() || | ||
| (auxData.hasNonTrivialAliasing[(int)memType] && | ||
|
|
@@ -2166,12 +2166,16 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( | |
| std::string message = "Buffer being accessed has outstanding writes."; | ||
| if (!operandName.empty()) | ||
| message += " Operand: " + operandName.str(); | ||
| std::string uninitializedMessage = "Buffer being read before any write."; | ||
| if (!operandName.empty()) | ||
| uninitializedMessage += " Operand: " + operandName.str(); | ||
| auto verifyWriteResultType = cast<RankedTensorType>( | ||
| writeVisibilityType.cloneWith(std::nullopt, b.getI1Type())); | ||
| AssertInfo assertInfo{message, verifyWriteResultType}; | ||
| Type aliasMatrixTypeBase; | ||
| auto buildVerifyWriteBody = [&writeVisibilityType, &aliasMatrixTypeBase, | ||
| verifyWriteResultType](bool useAlias) { | ||
| verifyWriteResultType](bool useAlias, | ||
| bool allowNoWrite) { | ||
| return [=](ImplicitLocOpBuilder &fb, Block *entryBlock) { | ||
| Value bufOffset = entryBlock->getArgument(0); | ||
| Value lengthVal = entryBlock->getArgument(1); | ||
|
|
@@ -2213,14 +2217,35 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( | |
| arith::AndIOp::create(fb, bufVisibility, bufferThreadBit); | ||
| bufferHasVisibility = arith::CmpIOp::create( | ||
| fb, arith::CmpIPredicate::eq, bufferHasVisibility, bufferThreadBit); | ||
| Value writeVisible = | ||
| arith::OrIOp::create(fb, noOneIsWriting, bufferHasVisibility); | ||
| Value allWritesVisible = reduceAll<arith::AndIOp>(fb, writeVisible); | ||
| Value result; | ||
| if (!allowNoWrite) { | ||
| Value rowOne = tti::createConstIntTensor( | ||
| fb, fb.getLoc(), 1, cast<RankedTensorType>(buffersEqBuf.getType())); | ||
| Value rowInitialized = | ||
| arith::XOrIOp::create(fb, noOneIsWriting, rowOne); | ||
| Value initializedRows = | ||
| arith::AndIOp::create(fb, rowInitialized, buffersEqBuf); | ||
| // Alias rows are alternatives within a CTA, but every selected CTA must | ||
| // have at least one initialized row. | ||
| Value initializedCTAs = | ||
| reduceLastDim<arith::OrIOp>(fb, initializedRows); | ||
| Value selectedCTAs = reduceLastDim<arith::OrIOp>(fb, buffersEqBuf); | ||
| Value ctaOne = tti::createConstIntTensor( | ||
| fb, fb.getLoc(), 1, cast<RankedTensorType>(selectedCTAs.getType())); | ||
| Value unmatchedCTAs = arith::XOrIOp::create(fb, selectedCTAs, ctaOne); | ||
| Value initializedOrUnmatched = | ||
| arith::OrIOp::create(fb, initializedCTAs, unmatchedCTAs); | ||
| result = reduceAll<arith::AndIOp>(fb, initializedOrUnmatched); | ||
| } else { | ||
| Value writeVisible = | ||
| arith::OrIOp::create(fb, noOneIsWriting, bufferHasVisibility); | ||
| result = reduceAll<arith::AndIOp>(fb, writeVisible); | ||
| } | ||
|
|
||
| Value vTrue = arith::ConstantOp::create( | ||
| fb, allWritesVisible.getType(), fb.getIntegerAttr(fb.getI1Type(), 1)); | ||
| fb, result.getType(), fb.getIntegerAttr(fb.getI1Type(), 1)); | ||
| Value predicatedWriteVisible = | ||
| arith::SelectOp::create(fb, pred, allWritesVisible, vTrue); | ||
| arith::SelectOp::create(fb, pred, result, vTrue); | ||
| predicatedWriteVisible = triton::SplatOp::create( | ||
| fb, verifyWriteResultType, predicatedWriteVisible); | ||
| triton::ReturnOp::create(fb, predicatedWriteVisible); | ||
|
|
@@ -2235,18 +2260,35 @@ void FunctionBuilder::createVerifyWriteVisibilityCall( | |
| SmallVector<Value> args = {bufOffset, lengthVal, pred, | ||
| threadVal, buffersVal, writeVisibilityVal, | ||
| recipientCTAs, aliasMatrixVal}; | ||
| if (!allowNoWrite) { | ||
| AssertInfo initializedAssertInfo{uninitializedMessage, | ||
| verifyWriteResultType}; | ||
| createCallToCachedFunction( | ||
| b, "verify_write_initialized", args, initializedAssertInfo, | ||
| {buffersType, writeVisibilityType, aliasMatrixType, | ||
| (uint64_t)memType}, | ||
| buildVerifyWriteBody(/*useAlias=*/true, /*allowNoWrite=*/false)); | ||
| } | ||
| createCallToCachedFunction( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we have to run both "verify_write_initialized" and "verify_write_visibility" when !allowNoWrite? Won't verify_write_initialized already check the visibility?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a minor thing, and it's because these functions just accept one assert, so if we want to have its own nice assert, we have to run both. Should I make it so that we can have several asserts per function you reckon?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One of the problems is that I think "verify_write_initialized" checks also visibility, so in case of buffer not being visible it will emit the "uninitialized" assert, which will be confusing. I think we either need separate function to check for no-write, or multiple possible asserts, but then we also need the inner function to be able to return richer information than just bool
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's fine, unless I misunderstood you. This is already tested in which triggers the visiblity point but not the initialised one.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, re-reading the code I see I was wrong. Thanks! |
||
| b, "verify_write_visibility", args, assertInfo, | ||
| {buffersType, writeVisibilityType, aliasMatrixType, (uint64_t)memType}, | ||
| buildVerifyWriteBody(/*useAlias=*/true)); | ||
| buildVerifyWriteBody(/*useAlias=*/true, /*allowNoWrite=*/true)); | ||
| } else { | ||
| SmallVector<Value> args = {bufOffset, lengthVal, pred, | ||
| threadVal, buffersVal, writeVisibilityVal, | ||
| recipientCTAs}; | ||
| if (!allowNoWrite) { | ||
| AssertInfo initializedAssertInfo{uninitializedMessage, | ||
| verifyWriteResultType}; | ||
| createCallToCachedFunction( | ||
| b, "verify_write_initialized_noalias", args, initializedAssertInfo, | ||
| {buffersType, writeVisibilityType, (uint64_t)memType}, | ||
| buildVerifyWriteBody(/*useAlias=*/false, /*allowNoWrite=*/false)); | ||
| } | ||
| createCallToCachedFunction( | ||
| b, "verify_write_visibility_noalias", args, assertInfo, | ||
| {buffersType, writeVisibilityType, (uint64_t)memType}, | ||
| buildVerifyWriteBody(/*useAlias=*/false)); | ||
| buildVerifyWriteBody(/*useAlias=*/false, /*allowNoWrite=*/true)); | ||
| } | ||
| } | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.