diff --git a/tests/conftest.py b/tests/conftest.py index 16a01f8aa3..5e3a9bd02c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -53,6 +53,35 @@ def _is_torch_fx_available(): ) +# A device-side assert / illegal access poisons the process-wide CUDA context, so +# every later GPU test errors at setup. Abort the session instead of cascading. +_CUDA_FATAL_MARKERS = ( + "device-side assert triggered", + "an illegal memory access was encountered", + "misaligned address", +) +_cuda_context_poisoned = False + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_makereport(item, call): # pylint: disable=unused-argument + outcome = yield + report = outcome.get_result() + global _cuda_context_poisoned # pylint: disable=global-statement + if report.failed and call.excinfo is not None: + if any(marker in str(call.excinfo.value) for marker in _CUDA_FATAL_MARKERS): + _cuda_context_poisoned = True + + +def pytest_runtest_setup(item): + if _cuda_context_poisoned: + item.session.shouldstop = ( + "CUDA context corrupted by an earlier test; aborting to avoid " + "cascading setup errors. Re-run the job." + ) + pytest.skip("CUDA context corrupted by an earlier test; aborting suite.") + + def retry_on_request_exceptions(max_retries=3, delay=1): def decorator(func): @functools.wraps(func)