Skip to content

Commit e2e33dd

Browse files
authored
[Bugfix] Disable SingleEnvThreadVerifier (#16361)
During TensorIR scheduling, the `IterVar`s that represent environment threads may duplicate, i.e. it is legal to have two env threads with the same name tag, which may fail the `SingleEnvThreadVerifier` check during schedule creation. This PR disables this check in this case. In the future, it may be worthwhile to bring it back against post-scheduling TIR.
1 parent 8e67e2a commit e2e33dd

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

src/tir/analysis/verify_well_formed.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,6 @@ bool VerifyWellFormed(const PrimFunc& func, bool assert_mode) {
347347
}
348348

349349
if (!UndefinedVarVerifier::Verify(func, assert_mode)) return false;
350-
if (!SingleEnvThreadVerifier::Verify(func, assert_mode)) return false;
351350

352351
// TODO(Siyuan): add more checks here.
353352
return true;
@@ -364,7 +363,6 @@ bool VerifyWellFormed(const IRModule& mod, bool assert_mode) {
364363
}
365364

366365
if (!UndefinedVarVerifier::Verify(mod, assert_mode)) return false;
367-
if (!SingleEnvThreadVerifier::Verify(mod, assert_mode)) return false;
368366

369367
return true;
370368
}

tests/python/tir-schedule/test_tir_schedule_error.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=missing-function-docstring,missing-module-docstring
18-
import sys
19-
2018
import pytest
19+
2120
import tvm
2221
import tvm.testing
2322
from tvm import tir
@@ -41,6 +40,25 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
4140
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
4241

4342

43+
@T.prim_func
44+
def two_kernels(var_A: T.handle, var_B: T.handle, seq_len: T.int32):
45+
T.func_attr({"tir.noalias": T.bool(True)})
46+
A = T.match_buffer(var_A, (1, seq_len * 8), "int32")
47+
B = T.match_buffer(var_B, (1, seq_len * 8), "int32", align=8)
48+
with T.block("exclusive_scan"):
49+
T.reads()
50+
T.writes()
51+
s8: T.int32 = seq_len * 8
52+
if s8 == 0:
53+
blockIdx_x = T.launch_thread("blockIdx.x", 1)
54+
else:
55+
with T.launch_thread("threadIdx.x", 1024) as threadIdx_x:
56+
blockIdx_x = T.launch_thread("blockIdx.x", T.ceildiv(s8, 1024))
57+
i: T.int32 = blockIdx_x * 1024 + threadIdx_x
58+
if i < s8:
59+
B[i // s8, i % s8] = A[i // s8, i % s8]
60+
61+
4462
# pylint: enable=no-member,invalid-name,unused-variable
4563

4664

@@ -74,5 +92,9 @@ def test_tir_schedule_attribute_error():
7492
sch.non_existent_field()
7593

7694

95+
def test_tir_schedule_two_kernels():
96+
tir.Schedule(two_kernels)
97+
98+
7799
if __name__ == "__main__":
78100
tvm.testing.main()

0 commit comments

Comments
 (0)