Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions flashinfer/jit/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,16 @@ def has_flashinfer_cubin() -> bool:
def _get_cubin_dir():
"""
Get the cubin directory path with the following priority:
1. flashinfer-cubin package if installed
2. Environment variable FLASHINFER_CUBIN_DIR
1. Environment variable FLASHINFER_CUBIN_DIR
2. flashinfer-cubin package if installed
3. Default cache directory
"""
# First check if flashinfer-cubin package is installed
# First check environment variable
env_dir = os.getenv("FLASHINFER_CUBIN_DIR")
if env_dir:
return pathlib.Path(env_dir)

# Then check if flashinfer-cubin package is installed
if has_flashinfer_cubin():
import flashinfer_cubin

Expand All @@ -82,11 +87,6 @@ def _get_cubin_dir():

return pathlib.Path(flashinfer_cubin.get_cubin_dir())

# Then check environment variable
env_dir = os.getenv("FLASHINFER_CUBIN_DIR")
if env_dir:
return pathlib.Path(env_dir)

# Fall back to default cache directory
return FLASHINFER_CACHE_DIR / "cubins"

Expand Down
103 changes: 103 additions & 0 deletions tests/test_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Regression tests for _get_cubin_dir() priority — issue #2976.

env.py imports CompilationContext (CUDA deps), so we load it in isolation
with lightweight stubs to keep tests runnable without a GPU.

python -m pytest tests/test_env.py -v --noconftest
"""

import importlib.util
import pathlib
import sys
import types

_REPO_ROOT = pathlib.Path(__file__).resolve().parents[1]


def _load_env_module():
"""Load flashinfer.jit.env with minimal stubs (no CUDA required)."""
stubs = {
"flashinfer": types.ModuleType("flashinfer"),
"flashinfer.jit": types.ModuleType("flashinfer.jit"),
"flashinfer.version": types.ModuleType("flashinfer.version"),
"flashinfer.compilation_context": types.ModuleType(
"flashinfer.compilation_context"
),
}
stubs["flashinfer"].__path__ = [str(_REPO_ROOT / "flashinfer")]
stubs["flashinfer.jit"].__path__ = [str(_REPO_ROOT / "flashinfer" / "jit")]
stubs["flashinfer.version"].__version__ = "0.0.0+test"
stubs["flashinfer.version"].__git_version__ = "test"

class _Stub:
def __init__(self):
self.TARGET_CUDA_ARCHS = set()

stubs["flashinfer.compilation_context"].CompilationContext = _Stub

saved = {k: sys.modules.get(k) for k in stubs}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of _load_env_module does not save and restore the state of sys.modules["flashinfer.jit.env"]. This can lead to test isolation issues if other tests in the same process attempt to import this module, as they will receive the mocked version instead of the real one. It is safer to include the module being tested in the saved dictionary so it can be properly cleaned up in the finally block.

Suggested change
saved = {k: sys.modules.get(k) for k in stubs}
modules_to_save = list(stubs.keys()) + ["flashinfer.jit.env"]
saved = {k: sys.modules.get(k) for k in modules_to_save}

sys.modules.update(stubs)
try:
spec = importlib.util.spec_from_file_location(
"flashinfer.jit.env", str(_REPO_ROOT / "flashinfer" / "jit" / "env.py")
)
mod = importlib.util.module_from_spec(spec)
sys.modules["flashinfer.jit.env"] = mod
spec.loader.exec_module(mod)
return mod
finally:
for k, v in saved.items():
if v is None:
sys.modules.pop(k, None)
else:
sys.modules[k] = v


_env = _load_env_module()
Comment on lines +38 to +56
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Prevent cross-test contamination from sys.modules mutation.

Line 45 installs flashinfer.jit.env into sys.modules, but _load_env_module() never restores any preexisting entry for that key. Because _env is created at import time (Line 56), later tests can accidentally reuse this stub-loaded module.

Proposed fix
 def _load_env_module():
@@
-    saved = {k: sys.modules.get(k) for k in stubs}
+    saved = {k: sys.modules.get(k) for k in stubs}
+    saved_env_module = sys.modules.get("flashinfer.jit.env")
     sys.modules.update(stubs)
     try:
         spec = importlib.util.spec_from_file_location(
             "flashinfer.jit.env", str(_REPO_ROOT / "flashinfer" / "jit" / "env.py")
         )
         mod = importlib.util.module_from_spec(spec)
         sys.modules["flashinfer.jit.env"] = mod
         spec.loader.exec_module(mod)
         return mod
     finally:
+        if saved_env_module is None:
+            sys.modules.pop("flashinfer.jit.env", None)
+        else:
+            sys.modules["flashinfer.jit.env"] = saved_env_module
         for k, v in saved.items():
             if v is None:
                 sys.modules.pop(k, None)
             else:
                 sys.modules[k] = v
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_env.py` around lines 38 - 56, The helper _load_env_module mutates
sys.modules by inserting "flashinfer.jit.env" but never restores any prior
entry, causing cross-test contamination when _env is created at import time;
update _load_env_module to save the original value of "flashinfer.jit.env"
before assigning mod (e.g. original_env = sys.modules.get("flashinfer.jit.env")
or include that key in the saved dict) and in the finally block restore it
exactly as you do for other stubs (if original is None pop the key, else set it
back), ensuring "flashinfer.jit.env" is returned to its pre-call state and
preventing stale module reuse by _env.



def _fake_cubin_pkg(path):
"""Return a stub ``flashinfer_cubin`` module pointing at *path*."""
m = types.ModuleType("flashinfer_cubin")
m.__version__ = "0.0.0+test"
m.get_cubin_dir = lambda: path
return m


# -- priority tests (regression for #2976) ----------------------------------


def test_env_var_overrides_package(monkeypatch, tmp_path):
"""FLASHINFER_CUBIN_DIR must take priority over the installed package."""
env_dir = str(tmp_path / "env_cubins")
pkg_dir = str(tmp_path / "pkg_cubins")
monkeypatch.setenv("FLASHINFER_CUBIN_DIR", env_dir)
monkeypatch.setenv("FLASHINFER_DISABLE_VERSION_CHECK", "1")
monkeypatch.setattr(_env, "has_flashinfer_cubin", lambda: True)
monkeypatch.setitem(sys.modules, "flashinfer_cubin", _fake_cubin_pkg(pkg_dir))
assert _env._get_cubin_dir() == pathlib.Path(env_dir)


def test_package_used_when_no_env_var(monkeypatch, tmp_path):
"""Without the env var, the package path should be returned."""
pkg_dir = str(tmp_path / "pkg_cubins")
monkeypatch.delenv("FLASHINFER_CUBIN_DIR", raising=False)
monkeypatch.setenv("FLASHINFER_DISABLE_VERSION_CHECK", "1")
monkeypatch.setattr(_env, "has_flashinfer_cubin", lambda: True)
monkeypatch.setitem(sys.modules, "flashinfer_cubin", _fake_cubin_pkg(pkg_dir))
assert _env._get_cubin_dir() == pathlib.Path(pkg_dir)


def test_env_var_used_when_no_package(monkeypatch, tmp_path):
"""Env var should work even when the package is not installed."""
env_dir = str(tmp_path / "env_cubins")
monkeypatch.setenv("FLASHINFER_CUBIN_DIR", env_dir)
monkeypatch.setattr(_env, "has_flashinfer_cubin", lambda: False)
assert _env._get_cubin_dir() == pathlib.Path(env_dir)


def test_default_when_nothing_set(monkeypatch):
"""Fall back to the default cache directory."""
monkeypatch.delenv("FLASHINFER_CUBIN_DIR", raising=False)
monkeypatch.setattr(_env, "has_flashinfer_cubin", lambda: False)
assert _env._get_cubin_dir() == _env.FLASHINFER_CACHE_DIR / "cubins"
Loading