From b47f9ee45feb68aa5c9314d8fdad6565082dad96 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 17 May 2026 05:20:21 +0000 Subject: [PATCH 1/3] tests: callback signature drift detector Static AST check that fails fast when a producer in unsloth_zoo (or unsloth) changes the arity of a callback but a consumer callback def still declares the old arity. This was the exact shape of the MLX smoke-test bug PR #5498 fixes -- the trainer's try/except swallowed the TypeError silently and the symptom was a confusing downstream assertion several seconds later. What the detector does: * Producer side: walks every .py and finds classes that own a self.__callbacks list, populated via .append() from an add__callback method, and invoked via `for cb in self.__callbacks: cb(arg1, ..., argN)`. The arity at the call site is the canonical expected arity. * Consumer side: walks every .add__callback(fn) call, resolves fn to a def or lambda in the same file, and asserts arity matches. Consumers that use *args or **kwargs are tolerantly accepted as any arity. * Sources: REPO_ROOT (unsloth) plus UNSLOTH_ZOO_SRC env var (set by the Core workflow once it can be wired in), or sibling ../unsloth-zoo, or the installed wheel. Skips cleanly if no producer pattern found anywhere (the wheel may strip platform-specific submodules like unsloth_zoo/mlx/, so the detector is most useful against a fresh checkout). Validated end-to-end: * Reverted run_real_mlx_smoke.py to its 8-arg shape -- detector raises AssertionError citing exact file:line and the 8 vs 9 drift. * Restored the 9-arg shape -- detector PASSes. * Total runtime ~7 s in pytest. Suggested CI wiring (workflow file change held out of this commit because the pushing PAT lacks `workflow` scope; safe to apply via the GitHub web editor or a maintainer push): ```yaml - name: callback signature drift detector (HARD GATE) env: UNSLOTH_ZOO_SRC: ${{ runner.temp }}/unsloth-zoo run: | python -m pytest -v --tb=short tests/test_callback_signature_drift.py ``` Drop the step into .github/workflows/consolidated-tests-ci.yml right after the existing public-api drift detector step. UNSLOTH_ZOO_SRC reuses the same clone the Core workflow already prepares. --- tests/test_callback_signature_drift.py | 319 +++++++++++++++++++++++++ 1 file changed, 319 insertions(+) create mode 100644 tests/test_callback_signature_drift.py diff --git a/tests/test_callback_signature_drift.py b/tests/test_callback_signature_drift.py new file mode 100644 index 0000000000..c5f047dccf --- /dev/null +++ b/tests/test_callback_signature_drift.py @@ -0,0 +1,319 @@ +"""Static-analysis regression test: callback signature drift. + +Catches the class of bug where a producer (e.g. unsloth_zoo's MLXTrainer) +changes the number of args it passes to a registered callback but consumers +(unsloth tests / source) still declare the old arity. The producer's +``try / except Exception`` typically swallows the resulting TypeError, so +the callback silently never fires and the failure surfaces several seconds +later as a confusing downstream assertion. + +The check is pure AST (no imports of MLX modules etc), so it runs on every +OS / Python version that ships in CI. + +Pattern detected: + * Producer side: a class with ``self.__callbacks`` list, populated + via ``self.__callbacks.append(...)`` from an ``add__callback`` + method, and invoked via ``for cb in self.__callbacks: cb(arg1, ...)``. + The arity at the call site is the canonical expected arity. + * Consumer side: any ``.add__callback(fn)`` call where ``fn`` + resolves to a ``def`` or ``async def`` in the same file. Consumer arity + must equal canonical arity (or be variadic). + +Consumers handled tolerantly: + * ``*args`` / ``**kwargs``: accept any canonical arity. + * Methods (``self.fn``) and unresolved Name targets (imported from another + file): skipped with a note in the failure message rather than asserted. +""" + +from __future__ import annotations + +import ast +import importlib.util +import os +import pathlib +import sys + + +REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent +# Skip noisy paths during file discovery. +SKIP_PARTS = { + ".git", ".out", "temp", "node_modules", "build", "dist", + ".venv", "venv", ".pytest_cache", "__pycache__", + # Frontend tree under studio is JS/TS plus a few stub .py files; not worth walking. + "frontend", +} + + +def _iter_py(root: pathlib.Path): + root = pathlib.Path(root).resolve() + for p in root.rglob("*.py"): + try: + rel_parts = p.resolve().relative_to(root).parts + except ValueError: + rel_parts = p.parts + if any(part.startswith(".") and part not in (".", "..") for part in rel_parts): + continue + if any(part in SKIP_PARTS for part in rel_parts): + continue + yield p + + +# Module-level parse cache so discover_producers + check_registrations only +# pay the parse cost once per file across the whole test run. +_PARSE_CACHE: dict[pathlib.Path, ast.AST | None] = {} + + +def _safe_parse(path: pathlib.Path): + key = path.resolve() + if key in _PARSE_CACHE: + return _PARSE_CACHE[key] + try: + import warnings as _w + with _w.catch_warnings(): + # Suppress SyntaxWarning emitted while parsing third-party files + # that contain invalid escape sequences in regex / docstrings. + _w.simplefilter("ignore", SyntaxWarning) + tree = ast.parse(path.read_text(encoding="utf-8")) + except (SyntaxError, UnicodeDecodeError): + tree = None + _PARSE_CACHE[key] = tree + return tree + + +def _callback_list_attrs_in_class(cls: ast.ClassDef) -> set[str]: + """Find self.__callbacks attributes assigned or appended-to inside cls.""" + found = set() + for node in ast.walk(cls): + # self._x_callbacks = [...] + if isinstance(node, ast.Assign): + for t in node.targets: + if ( + isinstance(t, ast.Attribute) + and isinstance(t.value, ast.Name) + and t.value.id == "self" + and t.attr.startswith("_") + and t.attr.endswith("_callbacks") + ): + found.add(t.attr) + # self._x_callbacks.append(fn) + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and node.func.attr == "append" + and isinstance(node.func.value, ast.Attribute) + and isinstance(node.func.value.value, ast.Name) + and node.func.value.value.id == "self" + and node.func.value.attr.startswith("_") + and node.func.value.attr.endswith("_callbacks") + ): + found.add(node.func.value.attr) + return found + + +def _producer_arities(tree: ast.AST) -> dict[str, int]: + """For each ``for cb in self._x_callbacks: cb(...)`` in the AST, return + {cb_list_attr: max_arity}. Multiple sites take the max so that variadic + branches do not lower the contract. + """ + out: dict[str, int] = {} + for cls in [n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)]: + cb_lists = _callback_list_attrs_in_class(cls) + for cb_list in cb_lists: + for node in ast.walk(cls): + if not isinstance(node, ast.For): + continue + if not ( + isinstance(node.iter, ast.Attribute) + and isinstance(node.iter.value, ast.Name) + and node.iter.value.id == "self" + and node.iter.attr == cb_list + ): + continue + if not isinstance(node.target, ast.Name): + continue + cb_name = node.target.id + for inner in ast.walk(node): + if ( + isinstance(inner, ast.Call) + and isinstance(inner.func, ast.Name) + and inner.func.id == cb_name + ): + arity = len(inner.args) + out[cb_list] = max(out.get(cb_list, 0), arity) + return out + + +def _registration_attr_to_list(attr: str) -> str | None: + """add_step_callback -> _step_callbacks. Returns None if pattern doesn't match.""" + if attr.startswith("add_") and attr.endswith("_callback"): + middle = attr[len("add_") : -len("_callback")] + if middle: + return f"_{middle}_callbacks" + if attr.startswith("register_") and attr.endswith("_callback"): + middle = attr[len("register_") : -len("_callback")] + if middle: + return f"_{middle}_callbacks" + return None + + +def _func_arity(node: ast.AST) -> tuple[int, bool] | None: + """Return (positional_arity, accepts_var_positional). None if not a function def.""" + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)): + return None + args = node.args + arity = len(args.posonlyargs) + len(args.args) + accepts_var = args.vararg is not None + # Bound methods: drop the implicit self if this is a method-style def. + # We can't tell statically whether the def is a method without class + # context, so we conservatively do not subtract self here. The consumer + # check skips bare-Name registrations whose target is a `self.fn` attr + # anyway. + return arity, accepts_var + + +def discover_producers(roots: list[pathlib.Path]) -> dict[str, list[tuple[pathlib.Path, int]]]: + """Walk every .py under each root and return {cb_list_attr: [(file, arity), ...]}.""" + producers: dict[str, list[tuple[pathlib.Path, int]]] = {} + for root in roots: + if not root or not root.exists(): + continue + for src in _iter_py(root): + tree = _safe_parse(src) + if tree is None: + continue + for cb_list, arity in _producer_arities(tree).items(): + producers.setdefault(cb_list, []).append((src, arity)) + return producers + + +def check_registrations(roots: list[pathlib.Path], producers: dict[str, list[tuple[pathlib.Path, int]]]): + """Walk every .py under each root, find .add_*_callback(fn) where fn is a + bare Name resolvable to a def in the same file, and assert its arity + matches the producer's canonical arity. Returns (issues, skipped, ok_count). + """ + issues: list[str] = [] + skipped: list[str] = [] + ok_count = 0 + for root in roots: + if not root or not root.exists(): + continue + for src in _iter_py(root): + tree = _safe_parse(src) + if tree is None: + continue + # All function/lambda defs in this file by name (and by id for lambdas via assignment). + defs_by_name: dict[str, ast.AST] = {} + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + defs_by_name[node.name] = node + if isinstance(node, ast.Assign): + if isinstance(node.value, ast.Lambda) and len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): + defs_by_name[node.targets[0].id] = node.value + # Find .add_*_callback(fn) sites + for call in ast.walk(tree): + if not isinstance(call, ast.Call): + continue + if not isinstance(call.func, ast.Attribute): + continue + cb_list = _registration_attr_to_list(call.func.attr) + if cb_list is None: + continue + if cb_list not in producers: + skipped.append( + f"{src}:{call.lineno}: {call.func.attr}(...) but no producer " + f"defines {cb_list} (third-party API?)" + ) + continue + # Only handle bare-Name registrations; bound methods / partials skipped. + if not (len(call.args) == 1 and isinstance(call.args[0], ast.Name)): + skipped.append( + f"{src}:{call.lineno}: {call.func.attr}(...) registers a " + f"non-Name callback (lambda/method/partial); arity not statically checkable" + ) + continue + cb_name = call.args[0].id + fn = defs_by_name.get(cb_name) + if fn is None: + skipped.append( + f"{src}:{call.lineno}: {call.func.attr}({cb_name}) but {cb_name} " + f"is not defined as a function/lambda in this file (imported?)" + ) + continue + arity_info = _func_arity(fn) + if arity_info is None: + continue + consumer_arity, accepts_var = arity_info + expected_arity = max(a for _, a in producers[cb_list]) + if accepts_var: + ok_count += 1 + continue + if consumer_arity != expected_arity: + issues.append( + f"{src}:{call.lineno}: {cb_name} declared with {consumer_arity} " + f"positional arg(s), but producer calls {cb_list} entries with " + f"{expected_arity} arg(s) " + f"({', '.join(str(p) for p, _ in producers[cb_list])})" + ) + else: + ok_count += 1 + return issues, skipped, ok_count + + +def _zoo_roots() -> list[pathlib.Path]: + """Where to look for unsloth_zoo source. We try, in order: + 1. ``UNSLOTH_ZOO_SRC`` env var (a local git checkout). + 2. ``../unsloth-zoo`` next to this repo (common monorepo-style layout). + 3. The pip-installed package (wheel may strip platform-specific submodules + like ``mlx/``, so this often misses MLX producers). + Every root that exists is scanned; duplicates are fine. + """ + roots: list[pathlib.Path] = [] + env_src = os.environ.get("UNSLOTH_ZOO_SRC") + if env_src: + p = pathlib.Path(env_src).expanduser().resolve() + if p.exists(): + roots.append(p) + sibling = (REPO_ROOT.parent / "unsloth-zoo").resolve() + if sibling.exists(): + roots.append(sibling) + spec = importlib.util.find_spec("unsloth_zoo") + if spec is not None and spec.origin is not None: + # spec.origin -> .../site-packages/unsloth_zoo/__init__.py + # we want the unsloth_zoo dir itself, NOT the site-packages root which + # contains every other installed pkg. + roots.append(pathlib.Path(spec.origin).resolve().parent) + return roots + + +def test_no_callback_signature_drift(): + roots = [REPO_ROOT, *_zoo_roots()] + producers = discover_producers(roots) + if not producers: + import pytest + + pytest.skip( + "no callback producer pattern (self._*_callbacks + cb(...)) found in " + "unsloth or unsloth_zoo. Set UNSLOTH_ZOO_SRC= " + "(the pip wheel strips platform-specific submodules like mlx/) to enable " + "the detector locally." + ) + issues, skipped, ok_count = check_registrations(roots, producers) + msg_parts = [ + f"producers discovered: {len(producers)} ({sorted(producers)})", + f"registrations matched: {ok_count}", + f"registrations skipped: {len(skipped)}", + ] + if issues: + msg_parts.append("") + msg_parts.append("Callback signature drift detected:") + msg_parts.extend(" " + i for i in issues) + raise AssertionError("\n".join(msg_parts)) + if "-v" in sys.argv or "--verbose" in sys.argv: + print("\n".join(msg_parts)) + + +if __name__ == "__main__": + # Allow running directly as a script for fast feedback. + sys.argv.append("-v") + test_no_callback_signature_drift() + print("PASS") From 30e4938cb36f543bd9ee2334c34ea5eec00cd296 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 17 May 2026 05:20:31 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_callback_signature_drift.py | 29 ++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/tests/test_callback_signature_drift.py b/tests/test_callback_signature_drift.py index c5f047dccf..82226c30e6 100644 --- a/tests/test_callback_signature_drift.py +++ b/tests/test_callback_signature_drift.py @@ -37,8 +37,16 @@ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent # Skip noisy paths during file discovery. SKIP_PARTS = { - ".git", ".out", "temp", "node_modules", "build", "dist", - ".venv", "venv", ".pytest_cache", "__pycache__", + ".git", + ".out", + "temp", + "node_modules", + "build", + "dist", + ".venv", + "venv", + ".pytest_cache", + "__pycache__", # Frontend tree under studio is JS/TS plus a few stub .py files; not worth walking. "frontend", } @@ -69,11 +77,12 @@ def _safe_parse(path: pathlib.Path): return _PARSE_CACHE[key] try: import warnings as _w + with _w.catch_warnings(): # Suppress SyntaxWarning emitted while parsing third-party files # that contain invalid escape sequences in regex / docstrings. _w.simplefilter("ignore", SyntaxWarning) - tree = ast.parse(path.read_text(encoding="utf-8")) + tree = ast.parse(path.read_text(encoding = "utf-8")) except (SyntaxError, UnicodeDecodeError): tree = None _PARSE_CACHE[key] = tree @@ -171,7 +180,9 @@ def _func_arity(node: ast.AST) -> tuple[int, bool] | None: return arity, accepts_var -def discover_producers(roots: list[pathlib.Path]) -> dict[str, list[tuple[pathlib.Path, int]]]: +def discover_producers( + roots: list[pathlib.Path], +) -> dict[str, list[tuple[pathlib.Path, int]]]: """Walk every .py under each root and return {cb_list_attr: [(file, arity), ...]}.""" producers: dict[str, list[tuple[pathlib.Path, int]]] = {} for root in roots: @@ -186,7 +197,9 @@ def discover_producers(roots: list[pathlib.Path]) -> dict[str, list[tuple[pathli return producers -def check_registrations(roots: list[pathlib.Path], producers: dict[str, list[tuple[pathlib.Path, int]]]): +def check_registrations( + roots: list[pathlib.Path], producers: dict[str, list[tuple[pathlib.Path, int]]] +): """Walk every .py under each root, find .add_*_callback(fn) where fn is a bare Name resolvable to a def in the same file, and assert its arity matches the producer's canonical arity. Returns (issues, skipped, ok_count). @@ -207,7 +220,11 @@ def check_registrations(roots: list[pathlib.Path], producers: dict[str, list[tup if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): defs_by_name[node.name] = node if isinstance(node, ast.Assign): - if isinstance(node.value, ast.Lambda) and len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): + if ( + isinstance(node.value, ast.Lambda) + and len(node.targets) == 1 + and isinstance(node.targets[0], ast.Name) + ): defs_by_name[node.targets[0].id] = node.value # Find .add_*_callback(fn) sites for call in ast.walk(tree): From 8f8e33d099c60c67cead2e733ae1367527cbdecc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 17 May 2026 08:27:54 +0000 Subject: [PATCH 3/3] ci: wire callback-signature drift detector into Core matrix Drops a 6-line pytest step right after the public-api drift detector, with UNSLOTH_ZOO_SRC pointed at the freshly cloned $RUNNER_TEMP/unsloth-zoo so the detector sees unsloth_zoo/mlx/ (the wheel strips it). Sub-second collection plus ~7 s detector run; fits inside the existing Core matrix budget without a new job. --- .github/workflows/consolidated-tests-ci.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.github/workflows/consolidated-tests-ci.yml b/.github/workflows/consolidated-tests-ci.yml index 6b008d4bb1..d0f60a8902 100644 --- a/.github/workflows/consolidated-tests-ci.yml +++ b/.github/workflows/consolidated-tests-ci.yml @@ -316,6 +316,22 @@ jobs: run: | python -m pytest -v --tb=short tests/test_public_api_surface.py + - name: callback signature drift detector (HARD GATE) + # Catches the MLX-style bug from PR #5498: a producer in + # unsloth_zoo (or unsloth) grows a callback arg, but a consumer + # callback def still declares the old arity. The producer's + # try/except swallows the resulting TypeError and the symptom is + # "callback never fires" -- usually diagnosed downstream as a + # confusing assertion several seconds later. This static AST + # check fails fast at PR time. UNSLOTH_ZOO_SRC points at the + # freshly cloned main so the detector sees platform-specific + # submodules (e.g. unsloth_zoo/mlx/) that the released wheel + # may strip. + env: + UNSLOTH_ZOO_SRC: ${{ runner.temp }}/unsloth-zoo + run: | + python -m pytest -v --tb=short tests/test_callback_signature_drift.py + - name: unsloth Bucket-A — CPU tests not in Repo tests (CPU) # 16 tests across 5 files. They live inside tests/saving/ and # tests/utils/, both of which Repo tests (CPU) excludes via --ignore