diff --git a/docs/changelog/2969.bugfix.rst b/docs/changelog/2969.bugfix.rst new file mode 100644 index 000000000..f55daaab2 --- /dev/null +++ b/docs/changelog/2969.bugfix.rst @@ -0,0 +1 @@ +Fix race condition in ``_virtualenv.py`` when file is overwritten during import, preventing ``NameError`` when ``_DISTUTILS_PATCH`` is accessed - by :user:`gracetyy`. diff --git a/src/virtualenv/create/via_global_ref/_virtualenv.py b/src/virtualenv/create/via_global_ref/_virtualenv.py index b61db3079..0d95b28e0 100644 --- a/src/virtualenv/create/via_global_ref/_virtualenv.py +++ b/src/virtualenv/create/via_global_ref/_virtualenv.py @@ -2,10 +2,11 @@ from __future__ import annotations +import contextlib import os import sys -VIRTUALENV_PATCH_FILE = os.path.join(__file__) +VIRTUALENV_PATCH_FILE = os.path.abspath(__file__) def patch_dist(dist): @@ -50,7 +51,14 @@ class _Finder: lock = [] # noqa: RUF012 def find_spec(self, fullname, path, target=None): # noqa: ARG002 - if fullname in _DISTUTILS_PATCH and self.fullname is None: # noqa: PLR1702 + # Guard against race conditions during file rewrite by checking if _DISTUTILS_PATCH is defined. + # This can happen when the file is being overwritten while it's being imported by another process. + # See https://github.com/pypa/virtualenv/issues/2969 for details. + try: + distutils_patch = _DISTUTILS_PATCH + except NameError: + return None + if fullname in distutils_patch and self.fullname is None: # noqa: PLR1702 # initialize lock[0] lazily if len(self.lock) == 0: import threading # noqa: PLC0415 @@ -89,14 +97,26 @@ def find_spec(self, fullname, path, target=None): # noqa: ARG002 @staticmethod def exec_module(old, module): old(module) - if module.__name__ in _DISTUTILS_PATCH: - patch_dist(module) + try: + distutils_patch = _DISTUTILS_PATCH + except NameError: + return + if module.__name__ in distutils_patch: + # patch_dist or its dependencies may not be defined during file rewrite + with contextlib.suppress(NameError): + patch_dist(module) @staticmethod def load_module(old, name): module = old(name) - if module.__name__ in _DISTUTILS_PATCH: - patch_dist(module) + try: + distutils_patch = _DISTUTILS_PATCH + except NameError: + return module + if module.__name__ in distutils_patch: + # patch_dist or its dependencies may not be defined during file rewrite + with contextlib.suppress(NameError): + patch_dist(module) return module diff --git a/tests/integration/test_race_condition_simulation.py b/tests/integration/test_race_condition_simulation.py new file mode 100644 index 000000000..857de9aa5 --- /dev/null +++ b/tests/integration/test_race_condition_simulation.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import importlib.util +import shutil +import sys +from pathlib import Path + + +def test_race_condition_simulation(tmp_path): + """Test that simulates the race condition described in the issue. + + This test creates a temporary directory with _virtualenv.py and _virtualenv.pth, + then simulates the scenario where: + - One process imports and uses the _virtualenv module (simulating marimo) + - Another process overwrites the _virtualenv.py file (simulating uv venv) + + The test verifies that no NameError is raised for _DISTUTILS_PATCH. + """ + # Create the _virtualenv.py file + virtualenv_file = tmp_path / "_virtualenv.py" + source_file = Path(__file__).parents[2] / "src" / "virtualenv" / "create" / "via_global_ref" / "_virtualenv.py" + + shutil.copy(source_file, virtualenv_file) + + # Create the _virtualenv.pth file + pth_file = tmp_path / "_virtualenv.pth" + pth_file.write_text("import _virtualenv", encoding="utf-8") + + # Simulate the race condition by repeatedly importing + errors = [] + for _ in range(5): + # Try to import it + sys.path.insert(0, str(tmp_path)) + try: + if "_virtualenv" in sys.modules: + del sys.modules["_virtualenv"] + + import _virtualenv # noqa: F401, PLC0415 + + # Try to trigger find_spec + try: + importlib.util.find_spec("distutils.dist") + except NameError as e: + if "_DISTUTILS_PATCH" in str(e): + errors.append(str(e)) + finally: + if str(tmp_path) in sys.path: + sys.path.remove(str(tmp_path)) + + assert not errors, f"Race condition detected: {errors}" diff --git a/tests/unit/create/via_global_ref/_test_race_condition_helper.py b/tests/unit/create/via_global_ref/_test_race_condition_helper.py new file mode 100644 index 000000000..8027f17d3 --- /dev/null +++ b/tests/unit/create/via_global_ref/_test_race_condition_helper.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from typing import ClassVar + + +class _Finder: + fullname = None + lock: ClassVar[list] = [] + + def find_spec(self, fullname, path, target=None): # noqa: ARG002 + # This should handle the NameError gracefully + try: + distutils_patch = _DISTUTILS_PATCH + except NameError: + return + if fullname in distutils_patch and self.fullname is None: + return + return + + @staticmethod + def exec_module(old, module): + old(module) + try: + distutils_patch = _DISTUTILS_PATCH + except NameError: + return + if module.__name__ in distutils_patch: + pass # Would call patch_dist(module) + + @staticmethod + def load_module(old, name): + module = old(name) + try: + distutils_patch = _DISTUTILS_PATCH + except NameError: + return module + if module.__name__ in distutils_patch: + pass # Would call patch_dist(module) + return module + + +finder = _Finder() diff --git a/tests/unit/create/via_global_ref/test_race_condition.py b/tests/unit/create/via_global_ref/test_race_condition.py new file mode 100644 index 000000000..3b044167c --- /dev/null +++ b/tests/unit/create/via_global_ref/test_race_condition.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import sys +from pathlib import Path + + +def test_virtualenv_py_race_condition_find_spec(tmp_path): + """Test that _Finder.find_spec handles NameError gracefully when _DISTUTILS_PATCH is not defined.""" + # Create a temporary file with partial _virtualenv.py content (simulating race condition) + venv_file = tmp_path / "_virtualenv_test.py" + + # Write a partial version of _virtualenv.py that has _Finder but not _DISTUTILS_PATCH + # This simulates the state during a race condition where the file is being rewritten + helper_file = Path(__file__).parent / "_test_race_condition_helper.py" + partial_content = helper_file.read_text(encoding="utf-8") + + venv_file.write_text(partial_content, encoding="utf-8") + + sys.path.insert(0, str(tmp_path)) + try: + import _virtualenv_test # noqa: PLC0415 + + finder = _virtualenv_test.finder + + # Try to call find_spec - this should not raise NameError + result = finder.find_spec("distutils.dist", None) + assert result is None, "find_spec should return None when _DISTUTILS_PATCH is not defined" + + # Create a mock module object + class MockModule: + __name__ = "distutils.dist" + + # Try to call exec_module - this should not raise NameError + def mock_old_exec(_x): + pass + + finder.exec_module(mock_old_exec, MockModule()) + + # Try to call load_module - this should not raise NameError + def mock_old_load(_name): + return MockModule() + + result = finder.load_module(mock_old_load, "distutils.dist") + assert result.__name__ == "distutils.dist" + + finally: + sys.path.remove(str(tmp_path)) + if "_virtualenv_test" in sys.modules: + del sys.modules["_virtualenv_test"] + + +def test_virtualenv_py_normal_operation(): + """Test that the fix doesn't break normal operation when _DISTUTILS_PATCH is defined.""" + # Read the actual _virtualenv.py file + virtualenv_py_path = ( + Path(__file__).parent.parent.parent.parent.parent + / "src" + / "virtualenv" + / "create" + / "via_global_ref" + / "_virtualenv.py" + ) + + if not virtualenv_py_path.exists(): + return # Skip if we can't find the file + + content = virtualenv_py_path.read_text(encoding="utf-8") + + # Verify the fix is present + assert "try:" in content + assert "distutils_patch = _DISTUTILS_PATCH" in content + assert "except NameError:" in content + assert "return None" in content or "return" in content