Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/transform/common/constr_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,18 @@ struct ConstrSet {
}
}

/*! \brief Convert the constraint set to a conjunction (AND) of all
* constraints */
PrimExpr ToConjunction() const {
if (constrs_.empty())
return Bool(true);
PrimExpr result = constrs_[0].ToGenericConstr();
for (size_t i = 1; i < constrs_.size(); ++i) {
result = tir::And(result, constrs_[i].ToGenericConstr());
}
return result;
}

void format(std::ostream &os) const {
os << "ConstrSet(size=" << constrs_.size() << ") {\n";
for (size_t i = 0; i < constrs_.size(); ++i) {
Expand Down
51 changes: 37 additions & 14 deletions src/transform/thread_storage_sync.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1270,23 +1270,46 @@ struct TileLangThreadSyncPlanner : public ConstrVisitor {
}
}
if (has_same_index) {
bool range_is_equal = true;
arith::Analyzer prev_analyzer, curr_analyzer;
prev.cset.Populate(prev_analyzer);
curr.cset.Populate(curr_analyzer);
for (unsigned idx = 0; idx != 3; ++idx) {
Var prev_var = prev.threads[prev.threads.size() + idx - 3]->var;
Var curr_var = curr.threads[curr.threads.size() + idx - 3]->var;
auto prev_bound = prev_analyzer.const_int_bound(prev_var);
auto curr_bound = curr_analyzer.const_int_bound(curr_var);
if (prev_bound->min_value != curr_bound->min_value ||
prev_bound->max_value != curr_bound->max_value) {
range_is_equal = false;
break;
// Use Z3 to check if prev and curr constraints are equivalent.
// If equivalent, the same set of threads execute both accesses, so no
// sync is needed.
//
// Formally, let P(t) denote the predicate for prev's constraint set and
// C(t) denote the predicate for curr's constraint set, where t represents
// the thread indices (threadIdx.x, threadIdx.y, threadIdx.z).
//
// We check bidirectional implication:
// 1. P(t) => C(t): Every thread executing prev also executes curr
// 2. C(t) => P(t): Every thread executing curr also executes prev
//
// If both hold, then P(t) <=> C(t), meaning the exact same set of threads
// execute both accesses. Combined with has_same_index (same buffer index
// expression), this guarantees each thread only accesses locations it
// wrote itself, eliminating cross-thread conflicts.
PrimExpr prev_constr = prev.cset.ToConjunction();
PrimExpr curr_constr = curr.cset.ToConjunction();

arith::Analyzer analyzer;
for (const auto &iv : prev.threads) {
if (iv->dom.defined()) {
analyzer.Bind(iv->var, iv->dom);
}
}
if (range_is_equal)

// Check P => C: ¬P ∨ C
bool prev_implies_curr = analyzer.z3_prover.CanProve(
tir::Or(tir::Not(prev_constr), curr_constr));
// Check C => P: ¬C ∨ P
bool curr_implies_prev = analyzer.z3_prover.CanProve(
tir::Or(tir::Not(curr_constr), prev_constr));

if (prev_implies_curr && curr_implies_prev) {
// If constraints are equivalent, they are not in conflict
return false;
} else {
// If constraints are not equivalent, they are in conflict
return true;
}
}

for (size_t i = 0; i < prev.buffer_indices.size(); i++) {
Expand Down
11 changes: 9 additions & 2 deletions testing/python/issue/test_tilelang_issue_1106.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,15 @@ def test_kernel(a: T.Tensor[(m,), dtype], b: T.Tensor[(m,), dtype]):
def test_issue_1106():
m = 200
kernel = get_kernel(m)
assert "__syncthreads" not in kernel.get_kernel_source()
source = kernel.get_kernel_source()
# Ensure __syncthreads is not inside the for loop
for_start = source.find("for (int i = 0;")
for_end = source.find("__syncthreads")
assert for_end > for_start, "__syncthreads should be after the for loop, not inside it"
# Check that __syncthreads appears after the closing brace of the outer for loop
assert source[for_end - 4 : for_end - 2] == "}\n", "__syncthreads should not be inside any for loop"
Comment on lines +36 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

String-based validation is fragile and lacks robustness checks.

Several concerns with the validation logic:

  1. Missing not-found handling: find() returns -1 when the pattern isn't found. If for_start is -1 and for_end is a positive index, the assertion for_end > for_start passes incorrectly.

  2. Misleading variable name: for_end actually stores the position of __syncthreads, not the end of the for loop.

  3. Brittle slice check: source[for_end - 4 : for_end - 2] == "}\n" depends on exact whitespace/formatting and could break if code generation changes indentation or spacing.

🛠️ Proposed fix with explicit validation
 def test_issue_1106():
     m = 200
     kernel = get_kernel(m)
     source = kernel.get_kernel_source()
     # Ensure __syncthreads is not inside the for loop
     for_start = source.find("for (int i = 0;")
-    for_end = source.find("__syncthreads")
-    assert for_end > for_start, "__syncthreads should be after the for loop, not inside it"
-    # Check that __syncthreads appears after the closing brace of the outer for loop
-    assert source[for_end - 4 : for_end - 2] == "}\n", "__syncthreads should not be inside any for loop"
+    syncthreads_pos = source.find("__syncthreads")
+    assert for_start != -1, "Expected for loop not found in generated source"
+    assert syncthreads_pos != -1, "__syncthreads not found in generated source"
+    assert syncthreads_pos > for_start, "__syncthreads should be after the for loop, not inside it"
+    # Check that __syncthreads appears after the closing brace of the outer for loop
+    preceding_content = source[:syncthreads_pos].rstrip()
+    assert preceding_content.endswith("}"), "__syncthreads should not be inside any for loop"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for_start = source.find("for (int i = 0;")
for_end = source.find("__syncthreads")
assert for_end > for_start, "__syncthreads should be after the for loop, not inside it"
# Check that __syncthreads appears after the closing brace of the outer for loop
assert source[for_end - 4 : for_end - 2] == "}\n", "__syncthreads should not be inside any for loop"
for_start = source.find("for (int i = 0;")
syncthreads_pos = source.find("__syncthreads")
assert for_start != -1, "Expected for loop not found in generated source"
assert syncthreads_pos != -1, "__syncthreads not found in generated source"
assert syncthreads_pos > for_start, "__syncthreads should be after the for loop, not inside it"
# Check that __syncthreads appears after the closing brace of the outer for loop
preceding_content = source[:syncthreads_pos].rstrip()
assert preceding_content.endswith("}"), "__syncthreads should not be inside any for loop"
🤖 Prompt for AI Agents
In `@testing/python/issue/test_tilelang_issue_1106.py` around lines 36 - 40, The
current string-based checks using for_start and for_end are fragile and miss
not-found cases; update the test to first validate both patterns exist (ensure
source.find("for (int i = 0;") and source.find("__syncthreads") return >= 0),
then rename for_end to syncthreads_idx for clarity, locate the opening brace of
the for loop (e.g., find '{' after for_start), compute the matching closing
brace by scanning and counting braces (brace depth) to get loop_close_idx, and
assert syncthreads_idx > loop_close_idx to guarantee __syncthreads is after the
outer loop instead of relying on fixed slices or exact whitespace.



if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
test_issue_1106()
Comment on lines 43 to +45
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Test harness is commented out.

Commenting out tilelang.testing.main() and calling test_issue_1106() directly may prevent this test from being discovered and run by CI/CD test frameworks that rely on the standard harness.

If this is intentional for debugging, consider restoring the harness before merging.

🔧 Suggested fix
 if __name__ == "__main__":
-    # tilelang.testing.main()
-    test_issue_1106()
+    tilelang.testing.main()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
test_issue_1106()
if __name__ == "__main__":
tilelang.testing.main()
🤖 Prompt for AI Agents
In `@testing/python/issue/test_tilelang_issue_1106.py` around lines 43 - 45, The
test harness call has been commented out; restore the standard test entry so CI
discovers this test by uncommenting or re-adding tilelang.testing.main() in the
if __name__ == "__main__": block and avoid calling test_issue_1106() directly;
ensure the file uses the guard with tilelang.testing.main() while leaving direct
test calls removed or conditional for local debugging so the test runner invokes
the suite normally.

23 changes: 22 additions & 1 deletion testing/python/transform/test_tilelang_transform_thread_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from tilelang import tvm as tvm
import tilelang.testing
from tvm.script import tir as T
from tvm import te


def run_passes(func: tvm.tir.PrimFunc):
Expand Down Expand Up @@ -42,6 +41,28 @@ def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32"))
assert "T.tvm_storage_sync" in str(mod)


@tilelang.testing.requires_cuda
def test_sync_if_with_same_index_with_modulo_if():
@T.prim_func(check_well_formed=False)
def func() -> None:
threadIdx_x = T.env_thread("threadIdx.x")
blockIdx_x = T.env_thread("blockIdx.x")
p0 = T.alloc_buffer([1], dtype="float32", scope="local")
result_local = T.alloc_buffer([1], dtype="float32", scope="local")
temp_shared = T.alloc_buffer([32], dtype="float32", scope="shared")
T.launch_thread(blockIdx_x, 1)
T.launch_thread(threadIdx_x, 32)
ty = T.launch_thread("threadIdx.y", 1)
tz = T.launch_thread("threadIdx.z", 1)
result_local[0] = T.float32(0)
if threadIdx_x % 4 == 0:
temp_shared[threadIdx_x] = p0[0]
result_local[0] = temp_shared[threadIdx_x]

mod = run_passes(func)
assert "T.tvm_storage_sync" in str(mod)


@tilelang.testing.requires_cuda
def test_sync_read_thread_id_independent_location():
@T.prim_func
Expand Down
Loading