Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 11 additions & 67 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,78 +142,22 @@ def _install_device_type_stub(name: str) -> None:


# ---------------------------------------------------------------------------
# Apply the peft + transformers-4.x stub-injection fix before pytest collects
# tests that import peft.utils.transformers_weight_conversion. Production runs
# this via unsloth/_gpu_init.py, but the GPU-free harness above skips full
# package init, so we load just the standalone import-fixes module by path.
# Apply ALL upstream-drift fixes (vllm GuidedDecodingParams alias, triton
# CompiledKernel attr wrap, peft transformers_weight_conversion stub, etc.)
# by triggering ``import unsloth``. Fixes live on ``unsloth/import_fixes.py``
# and apply at unsloth import time. The GPU-free harness above pre-spoofs
# the device-type chain so ``import unsloth`` survives on a CPU-only runner.
# Suites without unsloth installed (e.g. security-only) keep passing --
# the ImportError is swallowed and the drift detectors will surface any
# pathology the missing patches would have hidden.
# ---------------------------------------------------------------------------


def _apply_unsloth_peft_import_fix_for_tests() -> None:
import importlib.util as _ilu

def _apply_upstream_import_fixes_for_tests() -> None:
try:
pkg_spec = _ilu.find_spec("unsloth")
import unsloth # noqa: F401 # runs unsloth/import_fixes.py
except Exception:
return
if pkg_spec is None or not pkg_spec.submodule_search_locations:
return
fix_path = os.path.join(
pkg_spec.submodule_search_locations[0],
"import_fixes.py",
)
if not os.path.exists(fix_path):
return

mod_name = "unsloth.import_fixes"
_installed_skeleton = False
if mod_name in sys.modules:
mod = sys.modules[mod_name]
else:
# Submodule import needs SOME parent ``unsloth`` entry; reuse or
# install a bare skeleton and pop on exit so later ``import unsloth``
# calls hit the real package init.
if "unsloth" not in sys.modules:
pkg = types.ModuleType("unsloth")
pkg.__path__ = list(pkg_spec.submodule_search_locations)
pkg.__spec__ = pkg_spec
pkg.__package__ = "unsloth"
pkg.__file__ = os.path.join(
pkg_spec.submodule_search_locations[0],
"__init__.py",
)
sys.modules["unsloth"] = pkg
_installed_skeleton = True
spec = _ilu.spec_from_file_location(mod_name, fix_path)
if spec is None or spec.loader is None:
if _installed_skeleton:
sys.modules.pop("unsloth", None)
return
mod = _ilu.module_from_spec(spec)
sys.modules[mod_name] = mod
try:
spec.loader.exec_module(mod)
except Exception:
sys.modules.pop(mod_name, None)
if _installed_skeleton:
sys.modules.pop("unsloth", None)
return

fix = getattr(mod, "fix_peft_transformers_weight_conversion_import", None)
if fix is None:
if _installed_skeleton:
sys.modules.pop("unsloth", None)
return
try:
fix()
except Exception:
# Individual fix is internally guarded; don't take pytest down.
pass
Comment on lines 159 to 160
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad, silent exception handler (except Exception: pass) can mask unexpected issues such as syntax errors or logic bugs within the unsloth package initialization. Since the goal is to support environments where unsloth is not installed, catching ModuleNotFoundError specifically and checking the module name is more appropriate. Additionally, per the general rules, logging the exception at a debug level would aid in troubleshooting if the import fails for other reasons.

References
  1. Avoid using broad, silent exception handlers like except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging.
  2. When catching an ImportError for an optional dependency, prefer catching the more specific ModuleNotFoundError and check the module name to avoid suppressing unrelated import errors.

finally:
# Drop scratch skeleton; import_fixes itself stays cached as
# ``unsloth.import_fixes`` without an active parent.
if _installed_skeleton:
sys.modules.pop("unsloth", None)


_apply_unsloth_peft_import_fix_for_tests()
_apply_upstream_import_fixes_for_tests()
23 changes: 17 additions & 6 deletions tests/test_import_fixes_drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,17 +305,28 @@ def test_triton_compiled_kernel_has_num_ctas_and_cluster_dims():
tc = pytest.importorskip("triton.compiler.compiler")

ck_cls = tc.CompiledKernel
# Healthy if class has num_ctas directly; otherwise the fix installs
# at instance __init__ time and we cannot cheaply observe that on CPU.
# Healthy if either: pre-3.6 class attr present, or unsloth wrapped
# ``__init__`` to install num_ctas + cluster_dims per instance (the
# post-3.6 shape ``fix_triton_compiled_kernel_missing_attrs`` lands).
if hasattr(ck_cls, "num_ctas"):
return
init = getattr(ck_cls, "__init__", None)
if init is not None:
code = getattr(init, "__code__", None)
freevars = set(getattr(code, "co_freevars", ()) or ())
co_names = set(getattr(code, "co_names", ()) or ())
if "_orig_init" in freevars or {"num_ctas", "cluster_dims"}.issubset(
co_names
):
return
Comment on lines +315 to +321
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block will raise an AttributeError if code is None. This occurs when __init__ is a built-in method or a slot wrapper (e.g., object.__init__), which do not have a __code__ attribute. A guard should be added to ensure code is not None before accessing co_freevars or co_names.


pytest.fail(
"DRIFT DETECTED: triton.CompiledKernel lacks the `num_ctas` "
"class attribute; fix_triton_compiled_kernel_missing_attrs "
"patches __init__ to inject num_ctas and cluster_dims so "
"torch._inductor.runtime.triton_heuristics.make_launcher "
"stops crashing under torch.compile."
"class attribute AND ``__init__`` has not been wrapped by "
"fix_triton_compiled_kernel_missing_attrs; torch Inductor's "
"``make_launcher`` will crash on the eager "
"``binary.metadata.num_ctas, *binary.metadata.cluster_dims`` "
"unpack under torch.compile."
)


Expand Down
Loading