Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
681 changes: 681 additions & 0 deletions tests/test_fused_forward_install.py

Large diffs are not rendered by default.

32 changes: 22 additions & 10 deletions tests/test_upstream_source_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,14 @@ def test_compiler_per_layer_projection_inplace_regex():
def test_compiler_cross_entropy_lm_head_pattern_present():
"""``unsloth_zoo/compiler.py:1508-1525`` (cross_entropy_find_1)
expects ``logits = self.lm_head(hidden_states`` at the head of the
loss block in every ForCausalLM forward."""
loss block in every ForCausalLM forward.

Read on-disk modeling source: the fused-forward installer rewrites
``cls.forward`` at import time, but the upstream pattern compiler.py
pins still lives in the source file."""
pytest.importorskip("transformers")
import importlib
import pathlib
candidate_classes = [
"transformers.models.llama.modeling_llama.LlamaForCausalLM",
"transformers.models.llama4.modeling_llama4.Llama4ForCausalLM",
Expand All @@ -334,12 +339,12 @@ def test_compiler_cross_entropy_lm_head_pattern_present():
mod = importlib.import_module(mod_path)
except ImportError:
continue
cls = getattr(mod, cls_name, None)
if cls is None:
src_file = getattr(mod, "__file__", None)
if not src_file:
continue
try:
src = inspect.getsource(cls.forward)
except (OSError, TypeError):
src = pathlib.Path(src_file).read_text(encoding="utf-8")
except OSError:
continue
if needle in src:
found = True
Expand All @@ -357,9 +362,16 @@ def test_compiler_cross_entropy_lm_head_pattern_present():

def test_compiler_cross_entropy_find_2_loss_function_signature():
"""``unsloth_zoo/compiler.py:1593-1600`` (cross_entropy_find_2) pins
``loss = self.loss_function(...$LOGITS$, $LABELS$, $VOCABSIZE$...)``."""
``loss = self.loss_function(...$LOGITS$, $LABELS$, $VOCABSIZE$...)``.

Read the modeling module's on-disk source directly. The fused-forward
installer (forward_install.py) replaces ``*ForCausalLM.forward`` at
import time, so ``inspect.getsource(cls.forward)`` would return the
rewritten body; the upstream pattern this test pins still lives on
disk untouched."""
pytest.importorskip("transformers")
import importlib
import pathlib
candidate_classes = [
"transformers.models.llama.modeling_llama.LlamaForCausalLM",
"transformers.models.mistral.modeling_mistral.MistralForCausalLM",
Expand All @@ -374,12 +386,12 @@ def test_compiler_cross_entropy_find_2_loss_function_signature():
mod = importlib.import_module(mod_path)
except ImportError:
continue
cls = getattr(mod, cls_name, None)
if cls is None:
src_file = getattr(mod, "__file__", None)
if not src_file:
continue
try:
src = inspect.getsource(cls.forward)
except (OSError, TypeError):
src = pathlib.Path(src_file).read_text(encoding="utf-8")
except OSError:
continue
if needle in src:
return
Expand Down
9 changes: 9 additions & 0 deletions unsloth_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,15 @@ def filter(self, x): return not (self.text in x.getMessage())
from .temporary_patches import (
encode_conversations_with_harmony,
)

# Fused lm_head + cross_entropy auto-installer. On by default; set
# UNSLOTH_FUSED_FORWARD=0 to disable.
try:
from .fused_losses.forward_install import install_modeling_import_hook as _install_fused_forward
_install_fused_forward()
del _install_fused_forward
except Exception:
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
pass
Comment on lines +388 to +393

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The broad try...except Exception: pass block around the fused forward installation can make debugging difficult if the installer fails for unexpected reasons. It is recommended to at least print the exception to aid in troubleshooting, especially for visibility in 'studio' environments, as this is an opt-in feature that users might want to verify.

References
  1. Use print instead of logger.info for messages that must be visible in 'studio' when working with llama.cpp, as logger.info messages may be filtered out.

from .rl_environments import (
check_python_modules,
create_locked_down_function,
Expand Down
9 changes: 9 additions & 0 deletions unsloth_zoo/fused_losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,12 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from .cross_entropy_loss import *
from .forward_adapter import EMPTY_LOGITS, unsloth_fused_lm_head_loss
from .forward_install import (
install_modeling_import_hook,
install_for_module,
install_for_class,
register_canonical,
audit,
is_enabled,
)
Loading
Loading