Skip to content

Commit 22d8a8d

Browse files
committed
[UnitTest][NVPTX] Avoid cascading failures from CUDA postproc
Prior to this commit, the tests in `test_tir_transform_inject_ptx_async_copy.py` registered the `"tvm_callback_cuda_postproc"` function during pytest collection, and used a global variable to disable its functionality outside of the tests in this file. This had two major issues. First, if any other test also installs a postproc function, these postproc function required by the NVPTX tests would be overwritten. Second, if one of the NTPTX tests fails, the global variable controlling the postproc function would not be reset, causing any subsequent CUDA-related tests to also fail. This commit updates these NVPTX tests to conditionally install the postproc function, to de-register it after the test instead of disabling its functionality, and to de-register it regardless of the test result. This issue was initially found when debugging apache#15103, when a failure in `test_tir_transform_inject_ptx_async_copy.py::test_cp_async_in_if_then_else` caused failures in 32 unrelated tests ([CI link](https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-gpu/detail/PR-15103/7/tests)).
1 parent f14c61f commit 22d8a8d

File tree

1 file changed

+36
-47
lines changed

1 file changed

+36
-47
lines changed

tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py

Lines changed: 36 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
import numpy as np
17+
1818
import tvm
1919
import tvm.testing
2020
from tvm.script import tir as T
2121

22+
import pytest
23+
import numpy as np
24+
2225

2326
def count_cp_async(stmt):
2427
num_alloc = [0]
@@ -351,36 +354,38 @@ def test_inject_async_copy_shared_dyn():
351354
"""
352355

353356

354-
generated_code = ""
355-
support_async = True
357+
@pytest.fixture
358+
def postproc_if_missing_async_support():
359+
arch = tvm.contrib.nvcc.get_target_compute_version()
360+
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
361+
support_async = major >= 8
356362

363+
func_name = "tvm_callback_cuda_postproc"
364+
prev_postproc = None
357365

358-
@tvm.register_func
359-
def tvm_callback_cuda_postproc(code, _):
360-
global generated_code
361-
global support_async
362-
generated_code = code
363-
# return a dummy code so that device < sm80 could build correctly
364366
if not support_async:
365-
ret = ""
366-
for line in code.split("\n"):
367-
ret += line + "\n"
368-
if line.startswith('extern "C" __global__'):
369-
break
370-
ret += "}"
371-
return ret
372-
return code
367+
prev_postproc = tvm.get_global_func(func_name, allow_missing=True)
373368

369+
@tvm.register_func(func_name, override=True)
370+
def tvm_callback_cuda_postproc(code, _):
371+
ret = []
372+
for line in code.split("\n"):
373+
ret.append(line)
374+
ret.append("\n")
375+
if line.startswith('extern "C" __global__') and line.endswith("{"):
376+
break
377+
ret.append("}")
378+
return "".join(ret)
374379

375-
@tvm.testing.requires_cuda
376-
def test_cp_async_in_if_then_else():
377-
global support_async
378-
arch = tvm.contrib.nvcc.get_target_compute_version()
379-
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
380-
if major < 8:
381-
# At least sm80 is required
382-
support_async = False
380+
yield
381+
382+
# Restore previous postproc func to avoid impacting other tests
383+
if prev_postproc is not None:
384+
tvm.register_func(func_name, prev_postproc, override=True)
383385

386+
387+
@tvm.testing.requires_cuda
388+
def test_cp_async_in_if_then_else(postproc_if_missing_async_support):
384389
@T.prim_func
385390
def simple_compute(
386391
A: T.Buffer((16, 14), "float32"),
@@ -421,23 +426,13 @@ def simple_compute(
421426

422427
mod = tvm.IRModule.from_expr(simple_compute)
423428
with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
424-
tvm.build(mod, target="cuda")
429+
built = tvm.build(mod, target="cuda")
430+
generated_code = built.imported_modules[0].get_source()
425431
assert generated_code == expected_cuda_script
426432

427-
if not support_async:
428-
# avoid return dummy code to other tests
429-
support_async = True
430-
431433

432434
@tvm.testing.requires_cuda
433-
def test_vectorize_cp_async_in_if_then_else():
434-
global support_async
435-
arch = tvm.contrib.nvcc.get_target_compute_version()
436-
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
437-
if major < 8:
438-
# At least sm80 is required
439-
support_async = False
440-
435+
def test_vectorize_cp_async_in_if_then_else(postproc_if_missing_async_support):
441436
@T.prim_func
442437
def complex_compute(
443438
A: T.Buffer((2, 16, 16, 1280), "float16"),
@@ -886,17 +881,11 @@ def complex_compute(
886881

887882
mod = tvm.IRModule.from_expr(complex_compute)
888883
with tvm.transform.PassContext(config={"tir.use_async_copy": 1}):
889-
tvm.build(mod, target="cuda")
884+
built = tvm.build(mod, target="cuda")
885+
generated_code = built.imported_modules[0].get_source()
890886
# generated_code must contain " setp.ne.b32 p, %0, 0;"
891887
assert "setp.ne.b32" in generated_code
892888

893-
if not support_async:
894-
# avoid return dummy code to other tests
895-
support_async = True
896-
897889

898890
if __name__ == "__main__":
899-
test_inject_async_copy()
900-
test_inject_async_copy_shared_dyn()
901-
test_cp_async_in_if_then_else()
902-
test_vectorize_cp_async_in_if_then_else()
891+
tvm.testing.main()

0 commit comments

Comments
 (0)