diff --git a/tests/test_habana_profiler_unit.py b/tests/test_habana_profiler_unit.py index 646e604866..d334878b5a 100644 --- a/tests/test_habana_profiler_unit.py +++ b/tests/test_habana_profiler_unit.py @@ -15,6 +15,7 @@ import os import shutil +import weakref from unittest.mock import MagicMock import pytest @@ -44,7 +45,7 @@ def patched_profiler(monkeypatch): @pytest.fixture(autouse=True) def cleanup(): shutil.rmtree(PROFILER_OUTPUT_DIR, ignore_errors=True) - HabanaProfile._profilers = [] + HabanaProfile._profilers = weakref.WeakSet() def run_profiling(profiler): @@ -118,8 +119,9 @@ def test_two_profilers_can_run_sequentially(): assert len(os.listdir(PROFILER_OUTPUT_DIR)) == 2 -def test_cannot_start_profiler_when_another_is_running(patched_profiler): +def test_starting_new_profiler_stops_previous(patched_profiler): another_profiler = HabanaProfile(warmup=1, active=1) patched_profiler.start() - with pytest.raises(RuntimeError): - another_profiler.start() + another_profiler.start() + assert not patched_profiler._running + assert another_profiler._running