Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions tests/test_mlx_runtime_cce_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Unsloth Zoo - Utilities for Unsloth
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# SPDX-License-Identifier: AGPL-3.0-or-later

from __future__ import annotations

import sys

import pytest


mx = pytest.importorskip("mlx.core")
if "mlx_simulation" in str(getattr(mx, "__file__", "")):
pytest.skip("requires real MLX runtime", allow_module_level=True)


def _stable_norm(values):
max_abs = mx.array(0.0, dtype=mx.float32)
for value in values:
value32 = value.astype(mx.float32)
max_abs = mx.maximum(max_abs, mx.max(mx.abs(value32)))

denom = mx.maximum(max_abs, mx.array(1e-30, dtype=mx.float32))
norm_sq = mx.array(0.0, dtype=mx.float32)
for value in values:
scaled = value.astype(mx.float32) / denom
norm_sq = norm_sq + mx.sum(scaled * scaled)
return denom * mx.sqrt(norm_sq)


def _skip_torch_shim():
if any(name.startswith("mlx_simulation") for name in sys.modules):
pytest.skip("requires real MLX runtime")


def test_compiled_runtime_cce_preserves_aux_lse_for_gradients():
_skip_torch_shim()
from unsloth_zoo.mlx.cce import make_chunked_cross_entropy_loss

runtime_cce, _ = make_chunked_cross_entropy_loss(
ignore_index=-100,
chunk_size=32,
)
hidden = (mx.arange(64 * 32, dtype=mx.float32).reshape(64, 32) / 97.0) - 1.0
weight = (mx.arange(128 * 32, dtype=mx.float32).reshape(128, 32) / 113.0) - 1.0
targets = (mx.arange(64, dtype=mx.int32) * 7) % 128
targets = mx.where(mx.arange(64) % 11 == 0, -100, targets)
ntoks = mx.maximum(
mx.sum((targets != -100).astype(mx.float32)),
mx.array(1.0, dtype=mx.float32),
)

def loss_and_grad_norm(h, w):
def loss_fn(hh, ww):
losses = runtime_cce(hh, ww, targets)
return losses.astype(mx.float32).sum() / ntoks

loss, grads = mx.value_and_grad(loss_fn, argnums=(0, 1))(h, w)
return loss, _stable_norm(grads)

eager_loss, eager_norm = loss_and_grad_norm(hidden, weight)
compiled_loss, compiled_norm = mx.compile(loss_and_grad_norm)(hidden, weight)
mx.eval(eager_loss, eager_norm, compiled_loss, compiled_norm)

assert compiled_loss.item() == pytest.approx(eager_loss.item(), rel=1e-5)
assert compiled_norm.item() == pytest.approx(eager_norm.item(), rel=1e-4)


def test_compiled_quantized_runtime_cce_preserves_aux_lse_for_gradients():
_skip_torch_shim()
import mlx.nn as nn

from unsloth_zoo.mlx.cce import make_chunked_cross_entropy_loss

linear = nn.Linear(32, 128, bias=False)
linear.weight = (
mx.arange(128 * 32, dtype=mx.float32).reshape(128, 32) / 113.0
) - 1.0
qlinear = nn.QuantizedLinear.from_linear(linear, group_size=32, bits=4)
runtime_cce, _ = make_chunked_cross_entropy_loss(
ignore_index=-100,
chunk_size=32,
quantized=True,
group_size=qlinear.group_size,
bits=qlinear.bits,
)
hidden = (mx.arange(64 * 32, dtype=mx.float32).reshape(64, 32) / 97.0) - 1.0
targets = (mx.arange(64, dtype=mx.int32) * 7) % 128
targets = mx.where(mx.arange(64) % 11 == 0, -100, targets)
ntoks = mx.maximum(
mx.sum((targets != -100).astype(mx.float32)),
mx.array(1.0, dtype=mx.float32),
)

def loss_and_grad_norm(h):
def loss_fn(hh):
losses = runtime_cce(
hh,
qlinear.weight,
qlinear.scales,
qlinear.biases,
targets,
)
return losses.astype(mx.float32).sum() / ntoks

loss, grad = mx.value_and_grad(loss_fn)(h)
return loss, _stable_norm((grad,))

eager_loss, eager_norm = loss_and_grad_norm(hidden)
compiled_loss, compiled_norm = mx.compile(loss_and_grad_norm)(hidden)
mx.eval(eager_loss, eager_norm, compiled_loss, compiled_norm)

assert compiled_loss.item() == pytest.approx(eager_loss.item(), rel=1e-5)
assert compiled_norm.item() == pytest.approx(eager_norm.item(), rel=1e-4)
8 changes: 4 additions & 4 deletions tests/test_pr_a_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _ce_reference(hidden, weight, targets, ignore_index=-100, softcap=0.0):

def test_cce_forward_pure_python_matches_reference():
"""The kernel-less branch of _forward_chunked_fused_finalize."""
from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize
from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize

torch.manual_seed(0)
n, hidden_dim, vocab = 4, 8, 32
Expand All @@ -72,7 +72,7 @@ def test_cce_forward_pure_python_matches_reference():

def test_cce_with_softcap():
"""Softcap path matches reference too."""
from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize
from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize

torch.manual_seed(0)
n, hidden_dim, vocab = 4, 8, 32
Expand All @@ -93,7 +93,7 @@ def test_cce_with_softcap():

def test_cce_with_ignore_index():
"""Ignore-index zeros out loss for those positions."""
from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize
from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize

torch.manual_seed(0)
n, hidden_dim, vocab = 4, 8, 32
Expand All @@ -114,7 +114,7 @@ def test_cce_with_ignore_index():

def test_cce_chunked_matches_unchunked():
"""Chunked online LSE must equal a single-shot computation."""
from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize
from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize

torch.manual_seed(0)
n, hidden_dim, vocab = 4, 8, 64
Expand Down
91 changes: 78 additions & 13 deletions tests/test_pr_a_deep_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

"""
PR-A deeper component exercises: mlx_trainer, mlx_compile discovery,
PR-A deeper component exercises: trainer, compile discovery,
cce backward, and quantization helpers — beyond just imports.

If a test fails, the failing component identifies the next gap.
Expand All @@ -40,7 +40,7 @@ def _install_shim():
# ---------------------------------------------------------------------------

def test_mlx_training_config_is_dataclass_with_all_fields():
from unsloth_zoo.mlx_trainer import MLXTrainingConfig
from unsloth_zoo.mlx.trainer import MLXTrainingConfig
assert dataclasses.is_dataclass(MLXTrainingConfig)
fields = {f.name for f in dataclasses.fields(MLXTrainingConfig)}
# Required SFT-compat fields
Expand All @@ -67,37 +67,102 @@ def test_mlx_training_config_is_dataclass_with_all_fields():
@pytest.mark.parametrize("optim_name", ["adamw", "adam", "sgd", "adafactor"])
def test_mlx_training_config_each_optim(optim_name):
"""Every PR-A-supported optim string at least constructs cleanly in config."""
from unsloth_zoo.mlx_trainer import MLXTrainingConfig
from unsloth_zoo.mlx.trainer import MLXTrainingConfig
cfg = MLXTrainingConfig(optim=optim_name)
assert cfg.optim == optim_name


def test_trainer_drives_dynamic_lr_outside_optimizer_scheduler():
from unsloth_zoo.mlx.trainer import (
MLXTrainer,
MLXTrainingConfig,
)

trainer = MLXTrainer.__new__(MLXTrainer)
trainer.args = MLXTrainingConfig(
learning_rate=5e-5,
lr_scheduler_type="linear",
warmup_steps=5,
)
schedule = trainer._build_schedule(total_steps=8)
def value_at(step):
value = schedule(step)
return value.item() if hasattr(value, "item") else float(value)

assert value_at(0) > 0.0
assert value_at(4) < trainer.args.learning_rate
assert value_at(5) == pytest.approx(trainer.args.learning_rate)

trainer.model = object()
optimizer = trainer._build_optimizer(total_steps=8)
assert not callable(optimizer.learning_rate)
first_lr = float(optimizer.learning_rate)
trainer._set_optimizer_lr_for_step(optimizer, 1)
second_lr = float(optimizer.learning_rate)
assert second_lr > first_lr


@pytest.mark.parametrize(
("scheduler", "warmup"),
[
("linear", 0),
("linear", 5),
("cosine", 0),
("cosine", 5),
("constant", 0),
("constant", 5),
],
)
def test_scheduler_lr_is_nonzero_for_optimizer_update_steps(scheduler, warmup):
from unsloth_zoo.mlx.trainer import MLXTrainer, MLXTrainingConfig

total_steps = 8
trainer = MLXTrainer.__new__(MLXTrainer)
trainer.args = MLXTrainingConfig(
learning_rate=5e-5,
lr_scheduler_type=scheduler,
warmup_steps=warmup,
)
schedule = trainer._build_schedule(total_steps=total_steps)

if callable(schedule):
raw_values = [schedule(step) for step in range(total_steps)]
else:
raw_values = [schedule] * total_steps
values = [
value.item() if hasattr(value, "item") else float(value)
for value in raw_values
]

assert all(value > 0.0 for value in values)


# ---------------------------------------------------------------------------
# 2. mlx_compile module-level discovery functions return sensible defaults
# 2. compile module-level discovery functions return sensible defaults
# on a host with no real MLX architectures.
# ---------------------------------------------------------------------------

def test_mlx_compile_discovers_no_archs_under_shim():
def test_compile_discovers_no_archs_under_shim():
"""No real mlx_vlm.models.* installed -> empty discovery, not crash."""
import unsloth_zoo.mlx_compile as mc
import unsloth_zoo.mlx.compile as mc
archs = mc.discover_architectures()
assert isinstance(archs, tuple)


def test_mlx_compile_patch_primitives_exist():
import unsloth_zoo.mlx_compile as mc
def test_compile_patch_primitives_exist():
import unsloth_zoo.mlx.compile as mc
primitives = mc.list_compile_patch_primitives()
assert len(primitives) > 0


def test_mlx_compile_protocol_requirements_exist():
import unsloth_zoo.mlx_compile as mc
def test_compile_protocol_requirements_exist():
import unsloth_zoo.mlx.compile as mc
reqs = mc.list_protocol_requirements()
assert len(reqs) > 0


def test_mlx_compile_summarize_qualifications_returns_dict():
import unsloth_zoo.mlx_compile as mc
def test_compile_summarize_qualifications_returns_dict():
import unsloth_zoo.mlx.compile as mc
s = mc.summarize_compile_qualifications()
assert isinstance(s, dict)
assert "architectures" in s
Expand All @@ -109,7 +174,7 @@ def test_mlx_compile_summarize_qualifications_returns_dict():

def test_cce_backward_via_torch_autograd():
"""Build a tiny CCE forward and verify torch.autograd traverses it."""
from unsloth_zoo.mlx_cce.runtime_cce import _forward_chunked_fused_finalize
from unsloth_zoo.mlx.cce.runtime_cce import _forward_chunked_fused_finalize

torch.manual_seed(0)
n, hd, vocab = 4, 8, 32
Expand Down
8 changes: 4 additions & 4 deletions tests/test_pr_a_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

"""
PR-A integration: exercise unsloth_zoo.mlx_loader._dequantize_selected_mlx_modules.
PR-A integration: exercise unsloth_zoo.mlx.loader._dequantize_selected_mlx_modules.

Builds a synthetic MLX-style model with one QuantizedLinear submodule,
runs PR-A's dequantize-and-replace helper, verifies the result is
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_dequantize_selected_mlx_modules_swap():
"""Build a Module with a QuantizedLinear, dequant-replace, verify swap."""
import mlx.core as mx
import mlx.nn as nn
from unsloth_zoo.mlx_loader import _dequantize_selected_mlx_modules
from unsloth_zoo.mlx.loader import _dequantize_selected_mlx_modules

# 4-bit weight: 8 input dims (one group), 2 output dims, group_size=8.
bits, group_size = 4, 8
Expand Down Expand Up @@ -122,7 +122,7 @@ def __init__(self):
def test_dequantize_predicate_filters():
"""Predicate should let some QuantizedLinear modules through unchanged."""
import mlx.nn as nn
from unsloth_zoo.mlx_loader import _dequantize_selected_mlx_modules
from unsloth_zoo.mlx.loader import _dequantize_selected_mlx_modules

bits, group_size = 4, 8
packed, scales, biases = _make_packed_4bit([0]*8, group_size, bits)
Expand All @@ -145,7 +145,7 @@ def __init__(self):

def test_dequantize_no_match_returns_zero():
import mlx.nn as nn
from unsloth_zoo.mlx_loader import _dequantize_selected_mlx_modules
from unsloth_zoo.mlx.loader import _dequantize_selected_mlx_modules

class Empty(nn.Module):
def __init__(self):
Expand Down
Loading
Loading