Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/consolidated-tests-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,17 @@ jobs:
run: |
python -m pytest -v --tb=short tests/test_import_fixes_drift.py

- name: public-api surface drift detectors (9 tests, HARD GATE)
# Companion to test_import_fixes_drift.py: that file catches
# third-party drift; this one catches drift in unsloth's OWN
# public surface (FastLanguageModel / FastVisionModel /
# FastModel + their classmethods + is_bf16_supported). A
# rename here would silently break the unslothai/notebooks tree
# one PR cycle later -- this gate catches it BEFORE the
# breakage reaches users.
run: |
python -m pytest -v --tb=short tests/test_public_api_surface.py

- name: unsloth Bucket-A — CPU tests not in Repo tests (CPU)
# 16 tests across 5 files. They live inside tests/saving/ and
# tests/utils/, both of which Repo tests (CPU) excludes via --ignore
Expand Down
216 changes: 216 additions & 0 deletions tests/test_public_api_surface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Unsloth - 2x faster, 60% less VRAM LLM training and finetuning
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.

"""Public-API surface drift detectors for unsloth itself.

Companion to tests/test_import_fixes_drift.py: that file catches drift
in THIRD-PARTY libraries (transformers / trl / triton / peft / etc.)
that unsloth's import_fixes patches around. This file catches drift in
unsloth's OWN public-surface API -- the top-10 symbols and classmethods
that the unslothai/notebooks tree (and therefore every user on Colab)
calls. If a refactor on this repo renames FastLanguageModel.from_pretrained
or drops one of the documented kwargs, the test fires DRIFT DETECTED
here BEFORE the breakage reaches users.

Call-site counts measured against unslothai/notebooks @ main:
FastLanguageModel.from_pretrained 506
FastLanguageModel.for_inference 370
FastLanguageModel.get_peft_model 304
FastVisionModel.for_inference 183
FastVisionModel.from_pretrained 176
FastVisionModel.get_peft_model 99
FastVisionModel.for_training 60
FastModel.from_pretrained 103
FastModel.get_peft_model 67

Mirrors the unsloth-zoo / unsloth drift-detector skeleton:
``pytest.importorskip("unsloth")`` to gate, assert the healthy upstream
shape, ``pytest.fail("DRIFT DETECTED: ...")`` (never ``pytest.skip``) on
regression so the matrix cell goes red.
"""

from __future__ import annotations

import inspect

import pytest


def _signature_param_names(callable_obj) -> set[str]:
try:
sig = inspect.signature(callable_obj)
except (TypeError, ValueError):
return set()
Comment on lines +48 to +53
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The helper function _signature_param_names is defined but not used anywhere in this test file. It should be removed to keep the codebase clean and avoid maintaining dead code.

return set(sig.parameters)


def _accepts(callable_obj, kwargs: set[str]) -> tuple[bool, set[str]]:
"""True if every name in ``kwargs`` is either a named parameter on
``callable_obj`` OR the callable's signature has a ``**kwargs``
catch-all. Returns (ok, missing_set)."""
try:
sig = inspect.signature(callable_obj)
except (TypeError, ValueError):
Comment on lines +62 to +63
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Returning True, set() when inspect.signature fails is too permissive for a drift detector. If the object's signature cannot be inspected (for example, if the attribute was replaced by a non-callable value like None), the test would silently pass. Returning False and the set of required arguments ensures that such regressions are correctly identified as drift. Consider centralizing this signature validation logic into a helper function to ensure consistency across the codebase.

Suggested change
sig = inspect.signature(callable_obj)
except (TypeError, ValueError):
except (TypeError, ValueError):
return False, kwargs
References
  1. Centralize recurring or complex logical checks into a single helper function and reuse it across the codebase to ensure consistency and simplify maintenance.

return True, set()
params = sig.parameters
has_var_kw = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())
if has_var_kw:
return True, set()
missing = kwargs - set(params)
return (not missing), missing


# ===========================================================================
# FastLanguageModel: the headline class. 506 from_pretrained + 370
# for_inference + 304 get_peft_model call sites across the notebooks.
# ===========================================================================


def test_fast_language_model_class_present():
unsloth = pytest.importorskip("unsloth")
if not hasattr(unsloth, "FastLanguageModel"):
pytest.fail(
"DRIFT DETECTED: unsloth.FastLanguageModel is missing; every "
"LoRA notebook fails at the first import cell."
)


def test_fast_language_model_from_pretrained_kwargs():
"""from_pretrained must accept the canonical kwargs the notebooks pass."""
unsloth = pytest.importorskip("unsloth")
required = {"model_name", "max_seq_length", "dtype", "load_in_4bit"}
ok, missing = _accepts(unsloth.FastLanguageModel.from_pretrained, required)
if not ok:
pytest.fail(
f"DRIFT DETECTED: FastLanguageModel.from_pretrained dropped "
f"kwargs {sorted(missing)}; 506 notebook call sites would "
f"crash with TypeError."
)


def test_fast_language_model_get_peft_model_kwargs():
unsloth = pytest.importorskip("unsloth")
required = {
"r",
"lora_alpha",
"lora_dropout",
"target_modules",
"bias",
"use_gradient_checkpointing",
"random_state",
}
ok, missing = _accepts(unsloth.FastLanguageModel.get_peft_model, required)
if not ok:
pytest.fail(
f"DRIFT DETECTED: FastLanguageModel.get_peft_model dropped "
f"kwargs {sorted(missing)}; 304 notebook call sites would crash."
)


def test_fast_language_model_for_inference_callable():
unsloth = pytest.importorskip("unsloth")
if not callable(getattr(unsloth.FastLanguageModel, "for_inference", None)):
pytest.fail(
"DRIFT DETECTED: FastLanguageModel.for_inference is missing; "
"370 inference-cell call sites would crash."
)


# ===========================================================================
# FastVisionModel: 183 + 176 + 99 + 60 call sites across vision notebooks.
# ===========================================================================


def test_fast_vision_model_class_and_methods():
unsloth = pytest.importorskip("unsloth")
if not hasattr(unsloth, "FastVisionModel"):
pytest.fail(
"DRIFT DETECTED: unsloth.FastVisionModel is missing; every "
"vision fine-tuning notebook fails at import."
)
cls = unsloth.FastVisionModel
missing = [
m
for m in ("from_pretrained", "get_peft_model", "for_inference", "for_training")
if not callable(getattr(cls, m, None))
]
if missing:
pytest.fail(f"DRIFT DETECTED: FastVisionModel is missing methods {missing}.")


def test_fast_vision_model_get_peft_model_vision_kwargs():
"""Vision-specific kwargs the notebooks pass on the vision LoRA path."""
unsloth = pytest.importorskip("unsloth")
required = {
"finetune_vision_layers",
"finetune_language_layers",
"finetune_attention_modules",
"finetune_mlp_modules",
}
ok, missing = _accepts(unsloth.FastVisionModel.get_peft_model, required)
if not ok:
pytest.fail(
f"DRIFT DETECTED: FastVisionModel.get_peft_model dropped "
f"vision kwargs {sorted(missing)}."
)


# ===========================================================================
# FastModel: the modern unified entry point. 103 + 67 call sites.
# ===========================================================================


def test_fast_model_class_and_methods():
unsloth = pytest.importorskip("unsloth")
if not hasattr(unsloth, "FastModel"):
pytest.fail(
"DRIFT DETECTED: unsloth.FastModel is missing; the modern "
"unified entry point used by 100+ notebooks would crash."
)
missing = [
m
for m in ("from_pretrained", "get_peft_model")
if not callable(getattr(unsloth.FastModel, m, None))
]
if missing:
pytest.fail(f"DRIFT DETECTED: FastModel is missing methods {missing}.")


def test_fast_model_from_pretrained_kwargs():
unsloth = pytest.importorskip("unsloth")
required = {"model_name", "max_seq_length", "dtype", "load_in_4bit"}
ok, missing = _accepts(unsloth.FastModel.from_pretrained, required)
if not ok:
pytest.fail(
f"DRIFT DETECTED: FastModel.from_pretrained dropped kwargs "
f"{sorted(missing)}; 103 notebook call sites would crash."
)


# ===========================================================================
# Bf16 helper alias (renamed once already; keep both accepted).
# ===========================================================================


def test_is_bf16_supported_or_alias_callable():
"""48 notebook import sites for is_bf16_supported plus 8 for the
legacy is_bfloat16_supported alias. Either must remain importable."""
unsloth = pytest.importorskip("unsloth")
has_new = callable(getattr(unsloth, "is_bf16_supported", None))
has_old = callable(getattr(unsloth, "is_bfloat16_supported", None))
if not (has_new or has_old):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Require both bf16 aliases to stay importable

When only one alias is exported, this test still passes even though the other import form remains part of the public surface: the docstring cites notebook call sites for both is_bf16_supported and is_bfloat16_supported, and repo-wide search shows tests importing from unsloth import is_bf16_supported. In the scenario where only the legacy name remains (or only the new name remains), those imports crash while this new drift detector stays green, so this should assert both public names are callable unless the removed name is deliberately no longer supported.

Useful? React with 👍 / 👎.

pytest.fail(
"DRIFT DETECTED: neither unsloth.is_bf16_supported nor "
"unsloth.is_bfloat16_supported is callable; dtype probing "
"in 50+ notebooks fails."
)
Loading