[ThreadSync] Use Z3 for constraint equivalence checking#1760
Conversation
Replace const_int_bound based range comparison with Z3 SMT solver for checking thread constraint equivalence in FindConflict.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdds Changes
Sequence Diagram(s)mermaid Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Added comments to improve understanding of the logic determining whether constraints are equivalent or in conflict in the thread storage synchronization process.
|
This is also a fix for #1758 |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@testing/python/issue/test_tilelang_issue_1106.py`:
- Around line 43-45: The test harness call has been commented out; restore the
standard test entry so CI discovers this test by uncommenting or re-adding
tilelang.testing.main() in the if __name__ == "__main__": block and avoid
calling test_issue_1106() directly; ensure the file uses the guard with
tilelang.testing.main() while leaving direct test calls removed or conditional
for local debugging so the test runner invokes the suite normally.
- Around line 36-40: The current string-based checks using for_start and for_end
are fragile and miss not-found cases; update the test to first validate both
patterns exist (ensure source.find("for (int i = 0;") and
source.find("__syncthreads") return >= 0), then rename for_end to
syncthreads_idx for clarity, locate the opening brace of the for loop (e.g.,
find '{' after for_start), compute the matching closing brace by scanning and
counting braces (brace depth) to get loop_close_idx, and assert syncthreads_idx
> loop_close_idx to guarantee __syncthreads is after the outer loop instead of
relying on fixed slices or exact whitespace.
| for_start = source.find("for (int i = 0;") | ||
| for_end = source.find("__syncthreads") | ||
| assert for_end > for_start, "__syncthreads should be after the for loop, not inside it" | ||
| # Check that __syncthreads appears after the closing brace of the outer for loop | ||
| assert source[for_end - 4 : for_end - 2] == "}\n", "__syncthreads should not be inside any for loop" |
There was a problem hiding this comment.
String-based validation is fragile and lacks robustness checks.
Several concerns with the validation logic:
-
Missing not-found handling:
find()returns-1when the pattern isn't found. Iffor_startis-1andfor_endis a positive index, the assertionfor_end > for_startpasses incorrectly. -
Misleading variable name:
for_endactually stores the position of__syncthreads, not the end of the for loop. -
Brittle slice check:
source[for_end - 4 : for_end - 2] == "}\n"depends on exact whitespace/formatting and could break if code generation changes indentation or spacing.
🛠️ Proposed fix with explicit validation
def test_issue_1106():
m = 200
kernel = get_kernel(m)
source = kernel.get_kernel_source()
# Ensure __syncthreads is not inside the for loop
for_start = source.find("for (int i = 0;")
- for_end = source.find("__syncthreads")
- assert for_end > for_start, "__syncthreads should be after the for loop, not inside it"
- # Check that __syncthreads appears after the closing brace of the outer for loop
- assert source[for_end - 4 : for_end - 2] == "}\n", "__syncthreads should not be inside any for loop"
+ syncthreads_pos = source.find("__syncthreads")
+ assert for_start != -1, "Expected for loop not found in generated source"
+ assert syncthreads_pos != -1, "__syncthreads not found in generated source"
+ assert syncthreads_pos > for_start, "__syncthreads should be after the for loop, not inside it"
+ # Check that __syncthreads appears after the closing brace of the outer for loop
+ preceding_content = source[:syncthreads_pos].rstrip()
+ assert preceding_content.endswith("}"), "__syncthreads should not be inside any for loop"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for_start = source.find("for (int i = 0;") | |
| for_end = source.find("__syncthreads") | |
| assert for_end > for_start, "__syncthreads should be after the for loop, not inside it" | |
| # Check that __syncthreads appears after the closing brace of the outer for loop | |
| assert source[for_end - 4 : for_end - 2] == "}\n", "__syncthreads should not be inside any for loop" | |
| for_start = source.find("for (int i = 0;") | |
| syncthreads_pos = source.find("__syncthreads") | |
| assert for_start != -1, "Expected for loop not found in generated source" | |
| assert syncthreads_pos != -1, "__syncthreads not found in generated source" | |
| assert syncthreads_pos > for_start, "__syncthreads should be after the for loop, not inside it" | |
| # Check that __syncthreads appears after the closing brace of the outer for loop | |
| preceding_content = source[:syncthreads_pos].rstrip() | |
| assert preceding_content.endswith("}"), "__syncthreads should not be inside any for loop" |
🤖 Prompt for AI Agents
In `@testing/python/issue/test_tilelang_issue_1106.py` around lines 36 - 40, The
current string-based checks using for_start and for_end are fragile and miss
not-found cases; update the test to first validate both patterns exist (ensure
source.find("for (int i = 0;") and source.find("__syncthreads") return >= 0),
then rename for_end to syncthreads_idx for clarity, locate the opening brace of
the for loop (e.g., find '{' after for_start), compute the matching closing
brace by scanning and counting braces (brace depth) to get loop_close_idx, and
assert syncthreads_idx > loop_close_idx to guarantee __syncthreads is after the
outer loop instead of relying on fixed slices or exact whitespace.
| if __name__ == "__main__": | ||
| tilelang.testing.main() | ||
| # tilelang.testing.main() | ||
| test_issue_1106() |
There was a problem hiding this comment.
Test harness is commented out.
Commenting out tilelang.testing.main() and calling test_issue_1106() directly may prevent this test from being discovered and run by CI/CD test frameworks that rely on the standard harness.
If this is intentional for debugging, consider restoring the harness before merging.
🔧 Suggested fix
if __name__ == "__main__":
- # tilelang.testing.main()
- test_issue_1106()
+ tilelang.testing.main()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if __name__ == "__main__": | |
| tilelang.testing.main() | |
| # tilelang.testing.main() | |
| test_issue_1106() | |
| if __name__ == "__main__": | |
| tilelang.testing.main() |
🤖 Prompt for AI Agents
In `@testing/python/issue/test_tilelang_issue_1106.py` around lines 43 - 45, The
test harness call has been commented out; restore the standard test entry so CI
discovers this test by uncommenting or re-adding tilelang.testing.main() in the
if __name__ == "__main__": block and avoid calling test_issue_1106() directly;
ensure the file uses the guard with tilelang.testing.main() while leaving direct
test calls removed or conditional for local debugging so the test runner invokes
the suite normally.
This PR replaces the
const_int_boundbased range comparison with Z3 SMT solver for checking thread constraint equivalence inFindConflict.The previous implementation used
const_int_boundto compare thread variable ranges, which only checks the min/max bounds without considering additional constraints (e.g.,threadIdx_x % 4 == 0). This led to incorrect equivalence judgments when constraints narrowed the actual executing thread set.Approach
We use Z3 to formally verify constraint set equivalence through bidirectional implication:
Let$P(t)$ denote the predicate for the previous access's constraint set and $C(t)$ denote the predicate for the current access's constraint set, where $t = (\texttt{threadIdx.x}, \texttt{threadIdx.y}, \texttt{threadIdx.z})$ .
We check:
prevalso executescurrcurralso executesprevIf both hold, then$P(t) \Leftrightarrow C(t)$ , meaning the exact same set of threads execute both accesses. Combined with
has_same_index(identical buffer index expressions), this guarantees each thread only accesses memory locations it wrote itself, eliminating cross-thread RAW/WAR conflicts.Implementation
The implication$P \Rightarrow C$ is checked by proving $\neg P \vee C$ using the Z3 prover:
Changes
ToConjunction()method toConstrSetfor converting constraint sets to a single conjunction expressionconst_int_boundrange comparison with Z3-based equivalence checking inFindConflictSummary by CodeRabbit
Refactor
API
Tests
✏️ Tip: You can customize this high-level summary in your review settings.