|
24 | 24 | from tvm import topi |
25 | 25 | from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16 |
26 | 26 | from tvm.contrib import utils |
| 27 | +from tvm.script import tir as T |
27 | 28 | import tvm.testing |
28 | 29 | import pytest |
29 | 30 |
|
@@ -1068,5 +1069,50 @@ def check_cuda(n, lanes): |
1068 | 1069 | check_cuda(64, 2) |
1069 | 1070 |
|
1070 | 1071 |
|
| 1072 | +def test_cuda_thread_sync_inside_condition(): |
| 1073 | + @T.prim_func |
| 1074 | + def func1(A: T.Buffer((4, 4), "float32")) -> None: |
| 1075 | + A_shared = T.alloc_buffer((4, 4), "float32", scope="shared") |
| 1076 | + for bx in T.thread_binding(1, "blockIdx.x"): |
| 1077 | + for tx in T.thread_binding(32, "threadIdx.x"): |
| 1078 | + if A[0, 0] > 1.0: |
| 1079 | + for i, j in T.grid(4, 4): |
| 1080 | + A_shared[i, j] = A[i, j] |
| 1081 | + for i, j in T.grid(4, 4): |
| 1082 | + A[i, j] = A_shared[i, j] + 1.0 |
| 1083 | + |
| 1084 | + @T.prim_func |
| 1085 | + def func2(A: T.Buffer((4, 4), "float32")) -> None: |
| 1086 | + A_shared = T.alloc_buffer((4, 4), "float32", scope="shared") |
| 1087 | + for bx in T.thread_binding(1, "blockIdx.x"): |
| 1088 | + for tx in T.thread_binding(32, "threadIdx.x"): |
| 1089 | + if T.tvm_thread_invariant(A[0, 0] > 1.0): |
| 1090 | + for i, j in T.grid(4, 4): |
| 1091 | + A_shared[i, j] = A[i, j] |
| 1092 | + for i, j in T.grid(4, 4): |
| 1093 | + A[i, j] = A_shared[i, j] + 1.0 |
| 1094 | + |
| 1095 | + @T.prim_func |
| 1096 | + def func3(A: T.Buffer((4, 4), "float32")) -> None: |
| 1097 | + A_shared = T.alloc_buffer((4, 4), "float32", scope="shared") |
| 1098 | + for bx in T.thread_binding(1, "blockIdx.x"): |
| 1099 | + for tx in T.thread_binding(32, "threadIdx.x"): |
| 1100 | + while T.tvm_thread_invariant(A[0, 0] > 1.0): |
| 1101 | + for i, j in T.grid(4, 4): |
| 1102 | + A_shared[i, j] = A[i, j] |
| 1103 | + for i, j in T.grid(4, 4): |
| 1104 | + A[i, j] = A_shared[i, j] + 1.0 |
| 1105 | + |
| 1106 | + mod = tvm.IRModule({"main": func1}) |
| 1107 | + with pytest.raises(tvm.error.InternalError): |
| 1108 | + tvm.build(mod, target="cuda") |
| 1109 | + |
| 1110 | + mod = tvm.IRModule({"main": func2}) |
| 1111 | + tvm.build(mod, target="cuda") |
| 1112 | + |
| 1113 | + mod = tvm.IRModule({"main": func3}) |
| 1114 | + tvm.build(mod, target="cuda") |
| 1115 | + |
| 1116 | + |
1071 | 1117 | if __name__ == "__main__": |
1072 | 1118 | tvm.testing.main() |
0 commit comments