Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
765ee49
Tighten MLX VLM training parity diagnostics
mmathew23 May 19, 2026
6fb702a
Match Qwen3-VL rotary precision in MLX
mmathew23 May 19, 2026
5bc745a
Disable Qwen3-VL MLX compile verification
mmathew23 May 19, 2026
1c487a6
Match HF AdamW decay filtering in MLX
mmathew23 May 19, 2026
9a38968
Preserve Qwen3-VL residual dtype in MLX vision block
mmathew23 May 19, 2026
6ec832e
update
mmathew23 May 20, 2026
5895b20
udpate vlm
mmathew23 May 20, 2026
5c82ee2
bring back correct loss curves
mmathew23 May 20, 2026
a93449f
update textdataset
mmathew23 May 20, 2026
dcd0a90
dataset ordering fix, lr fix
mmathew23 May 21, 2026
e16efc0
Merge branch 'main' into explore/mlx
mmathew23 May 21, 2026
b0a83b5
use proportional MLX grad value clipping
mmathew23 May 21, 2026
964be34
cast norm activation output back to original input dtype
mmathew23 May 21, 2026
ca08652
address mlx training review feedback
mmathew23 May 21, 2026
ec34323
fix(mlx): cast custom norm outputs
Lyxot May 21, 2026
f26cf37
feat: auto discover custom norm from model
Lyxot May 21, 2026
a24b8f3
fix(mlx): harden norm output cast discovery
Lyxot May 21, 2026
5465959
fix(mlx): preserve custom norm keyword calls
Lyxot May 21, 2026
7e0bee5
harden mlx custom norm output casting
mmathew23 May 22, 2026
c7a0956
Fix four loose ends for PR #684
May 24, 2026
c3de6bb
Merge remote-tracking branch 'origin/main' into pr-684-head
May 24, 2026
0753b11
Address reviewer round 1 P1 findings on PR #684
May 24, 2026
693b099
Preserve embedder position_ids in _vlm_cce_forward
May 24, 2026
4cb6ca6
Address reviewer round 2 findings on PR #684
May 24, 2026
6374bad
Extend HF parity decoupled weight decay to SGD/Muon/Lion for PR #684
May 24, 2026
d16ef24
Address remaining reviewer round 2 findings on PR #684
May 24, 2026
23751c8
Address reviewer round 3 P1/P2 findings on PR #684
May 24, 2026
8c27832
Restore mask = kwargs.get for 4 patched VLM get_input_embeddings
May 24, 2026
1d0c11e
Materialize multiple epochs of labeled batches when num_epochs>1 for …
May 24, 2026
f601a15
Rename PR-numbered tests and shorten verbose comments
May 25, 2026
f545a00
Add MLX max grad leaf norm clipping
mmathew23 May 26, 2026
6e6ab83
Restore collated VLM position ids for parity
mmathew23 May 26, 2026
7300649
Scope VLM position id override to collated ids
mmathew23 May 26, 2026
830cd37
Preserve returned VLM position ids
mmathew23 May 26, 2026
67b3504
Fix MLX VLM parity masking edge cases
mmathew23 May 26, 2026
35c0710
Route text-only VLM loads through text trainer
mmathew23 May 26, 2026
5f497dc
Match BNB nested NF4 scale quantization
mmathew23 May 26, 2026
af3d68d
Match CUDA VLM resize-min behavior in MLX
mmathew23 May 26, 2026
bf40c71
Match Gemma3 vision post norm epsilon
mmathew23 May 26, 2026
104e41d
Run Gemma3 vision SDPA in fp32 on MLX
mmathew23 May 26, 2026
f6d5f62
Match Gemma3 image feature scaling on MLX
mmathew23 May 26, 2026
e8e6c9b
Use Gemma3 image attention mask in MLX VLM CCE
mmathew23 May 26, 2026
3e2ee98
Clip MLX global grad norms in fp32
mmathew23 May 26, 2026
80b62cf
Match Gemma3 vision fp32 norm and activation math
mmathew23 May 26, 2026
6ba0a7f
Disable Gemma3 MLX training compile pending parity
mmathew23 May 26, 2026
e70e575
Match Gemma3 text RMSNorm fp32 math
mmathew23 May 26, 2026
31457a2
Preserve VLM hidden stack activation dtype
mmathew23 May 27, 2026
9cc6885
Restore Gemma3 MLX training compile qualification
mmathew23 May 27, 2026
0c0e567
Handle quantized CCE layer modes
mmathew23 May 27, 2026
c25c86b
Rename grad-clip test to reflect three-mode scope
May 27, 2026
458e62a
Merge main into explore/mlx; resolve mlx/utils.py conflicts
May 27, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/consolidated-tests-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ jobs:
- name: pytest tests/security (HARD GATE)
run: python -m pytest tests/security -v

- name: pytest tests/test_pr_a_imports + zoo-specific CPU tests
- name: pytest tests/test_mlx_module_exports + zoo-specific CPU tests
# Run as SEPARATE pytest invocation: tests/security/conftest.py installs a
# session-scoped network_blocker autouse fixture that would otherwise block
# test_pypi_version_sync from reaching pypi.org.
continue-on-error: true
run: |
python -m pytest \
tests/test_pr_a_imports.py \
tests/test_mlx_module_exports.py \
tests/test_rl_replacements_cpu.py \
tests/test_temporary_patches_imports.py \
tests/test_zoo_history_regressions.py \
Expand Down
287 changes: 287 additions & 0 deletions tests/test_mlx_batching_and_decay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
from __future__ import annotations

import inspect

import numpy as np
import pytest


mx = pytest.importorskip("mlx.core")
if "mlx_simulation" in str(getattr(mx, "__file__", "")):
pytest.skip("requires real MLX runtime", allow_module_level=True)


def _skip_if_mlx_core_was_replaced():
import mlx.core as current_mx
if current_mx is not mx:
pytest.skip("requires real MLX runtime without mlx_simulation monkeypatch")


class _TinyTokenizer:
pad_token_id = 2
eos_token_id = 2
unk_token_id = -1
image_token_id = 200

def encode(self, text):
return [int(part) for part in str(text).split()]

def convert_tokens_to_ids(self, token):
if isinstance(token, list):
return [self.convert_tokens_to_ids(item) for item in token]
return {"<image>": 200, "<|image_pad|>": 201}.get(token, self.unk_token_id)


class _ContentProcessor:
tokenizer = _TinyTokenizer()
image_processor = object()

def __call__(self, text, **_kwargs):
rows = [[int(item), 200, 2] for item in text]
masks = [[1, 1, 1] for _ in rows]
return {
"input_ids": np.array(rows, dtype=np.int32),
"attention_mask": np.array(masks, dtype=np.int32),
}


def test_vlm_ignore_ids_exclude_pad_even_when_pad_is_eos():
from unsloth_zoo.mlx.utils import _get_vlm_ignore_token_ids

ids = _get_vlm_ignore_token_ids(
processor=_ContentProcessor(),
config={"pad_token_id": 2, "image_token_id": 200},
)

assert 200 in ids
assert 2 not in ids


def test_vlm_label_mask_keeps_in_sequence_pad_eos_token():
from unsloth_zoo.mlx.utils import _apply_vlm_label_masks

batch = {
"input_ids": mx.array([[101, 2, 200, 2]], dtype=mx.int32),
"attention_mask": mx.array([[1, 1, 1, 0]], dtype=mx.int32),
}
out = _apply_vlm_label_masks(
batch,
labels=batch["input_ids"],
ignore_token_ids=[200],
)

assert out.tolist() == [[101, 2, -100, -100]]


def test_manual_weight_decay_accepts_scalar_lr_and_preserves_dtype():
from mlx.utils import tree_flatten
from unsloth_zoo.mlx.trainer import MLXTrainer

class TinyModel:
def __init__(self):
self.params = {
"layer": {
"weight": mx.array([10.0], dtype=mx.bfloat16),
"bias": mx.array([10.0], dtype=mx.bfloat16),
},
"norm": {"weight": mx.array([10.0], dtype=mx.float32)},
}

def trainable_parameters(self):
return self.params

def update(self, updates):
def merge(dst, src):
for key, value in src.items():
if isinstance(value, dict):
merge(dst[key], value)
else:
dst[key] = value
merge(self.params, updates)

class TinyOptimizer:
learning_rate = 0.1

model = TinyModel()
grad = {
"layer": {
"weight": mx.array([1.0], dtype=mx.bfloat16),
"bias": mx.array([1.0], dtype=mx.bfloat16),
},
"norm": {"weight": mx.array([1.0], dtype=mx.float32)},
}
trainer = object.__new__(MLXTrainer)
trainer._manual_weight_decay = 0.1

trainer._apply_manual_weight_decay(model, TinyOptimizer(), grad)
flat = dict(tree_flatten(model.trainable_parameters()))

assert flat["layer.weight"].dtype == mx.bfloat16
assert flat["layer.weight"].item() < 10.0
assert flat["layer.bias"].item() == pytest.approx(10.0)
assert flat["norm.weight"].item() == pytest.approx(10.0)


def test_nf4_dense_zero_group_dequantizes_to_zero_without_epsilon_scale():
_skip_if_mlx_core_was_replaced()
from unsloth_zoo.mlx.loader import _nf4_dense_dequantize_weight

weight = mx.zeros((1, 4), dtype=mx.float32)
out = _nf4_dense_dequantize_weight(weight, group_size=4)

assert out.tolist() == [[0.0, 0.0, 0.0, 0.0]]


def test_ordered_text_batches_raise_clear_error_when_all_rows_drop():
from unsloth_zoo.mlx.utils import create_ordered_batches

with pytest.raises(ValueError, match="no trainable token sequences"):
create_ordered_batches(
dataset=[{"text": "1"}],
tokenizer=_TinyTokenizer(),
batch_size=1,
max_seq_length=1,
dataset_order="sequential",
)


def test_ordered_text_torch_randperm_can_materialize_multiple_epochs():
_skip_if_mlx_core_was_replaced()
from unsloth_zoo.mlx.utils import create_ordered_batches

batches = create_ordered_batches(
dataset=[{"text": f"{i} {i + 10}"} for i in range(5)],
tokenizer=_TinyTokenizer(),
batch_size=1,
max_seq_length=4,
seed=None,
dataset_order="torch_randperm",
num_epochs=2,
)

first_epoch = [int(batch[0, 0].item()) for batch, _lengths, _labels in batches[:5]]
second_epoch = [int(batch[0, 0].item()) for batch, _lengths, _labels in batches[5:]]
assert len(batches) == 10
assert sorted(first_epoch) == [0, 1, 2, 3, 4]
assert sorted(second_epoch) == [0, 1, 2, 3, 4]
assert first_epoch != second_epoch


def test_vlm_torch_randperm_seed_none_and_multi_epoch_batches():
_skip_if_mlx_core_was_replaced()
from unsloth_zoo.mlx.utils import create_vlm_batches

batches = create_vlm_batches(
dataset=[{"text": str(i)} for i in range(5)],
processor=_ContentProcessor(),
config={"image_size": 16, "image_token_id": 200},
batch_size=1,
max_seq_length=8,
seed=None,
dataset_order="torch_randperm",
num_epochs=2,
)

first_epoch = [int(batch["input_ids"][0, 0].item()) for batch in batches[:5]]
second_epoch = [int(batch["input_ids"][0, 0].item()) for batch in batches[5:]]
assert len(batches) == 10
assert sorted(first_epoch) == [0, 1, 2, 3, 4]
assert sorted(second_epoch) == [0, 1, 2, 3, 4]
assert first_epoch != second_epoch


def test_compiler_review_guards_are_present():
import unsloth_zoo.compiler as compiler
import unsloth_zoo.mlx.compile as mlx_compile

compiler_source = inspect.getsource(compiler)
mlx_compile_source = inspect.getsource(mlx_compile)

assert (
'self.loss_function.__name__.endswith("ForCausalLMLoss") '
"and labels is not None and NOT_RETURN_LOGITS"
) in compiler_source
assert '"weight" in norm' not in mlx_compile_source
assert '"bias" in norm' not in mlx_compile_source
assert 'getattr(norm, "weight", None)' in mlx_compile_source


def test_norm_output_cast_discovers_custom_norms_from_loaded_model():
_skip_if_mlx_core_was_replaced()
import mlx.nn as nn

gemma3_text = pytest.importorskip("mlx_lm.models.gemma3_text")
stablelm = pytest.importorskip("mlx_lm.models.stablelm")
fastvlm_vision = pytest.importorskip("mlx_vlm.models.fastvlm.vision")
import unsloth_zoo.mlx.trainer as trainer_mod

class TinyModel(nn.Module):
def __init__(self):
super().__init__()
self.input_layernorm = gemma3_text.RMSNorm(4)
self.q_layernorm = stablelm.LayerNormPerHead(
head_dim=4, num_heads=2, eps=1e-5
)
self.norm = fastvlm_vision.LayerNormChannel(num_features=4)

trainer_mod._set_norm_output_cast_to_input_dtype(False)
model = TinyModel()
cases = [
(model.input_layernorm, mx.ones((2, 4), dtype=mx.bfloat16)),
(
model.q_layernorm,
mx.ones((1, 3, 2, 4), dtype=mx.bfloat16),
),
(
model.norm,
mx.ones((1, 2, 2, 4), dtype=mx.bfloat16),
),
]

norm_classes = trainer_mod._iter_norm_output_cast_classes(model)
for norm, x in cases:
assert type(norm) in norm_classes
raw = norm(x)
assert raw.dtype == mx.float32

try:
trainer_mod._set_norm_output_cast_to_input_dtype(True, model)
for norm, x in cases:
out = norm(x)
assert out.dtype == x.dtype
finally:
trainer_mod._set_norm_output_cast_to_input_dtype(False)


def test_norm_output_cast_does_not_double_patch_inherited_norm_call():
_skip_if_mlx_core_was_replaced()
import mlx.nn as nn
import unsloth_zoo.mlx.trainer as trainer_mod

class CustomRMSNorm(nn.RMSNorm):
pass

class TinyModel(nn.Module):
def __init__(self):
super().__init__()
self.input_layernorm = CustomRMSNorm(4)

trainer_mod._set_norm_output_cast_to_input_dtype(False)
model = TinyModel()
x = mx.ones((2, 4), dtype=mx.bfloat16)

try:
trainer_mod._set_norm_output_cast_to_input_dtype(True, model)
assert nn.RMSNorm in trainer_mod._NORM_OUTPUT_CAST_PATCHED_CLASSES
assert CustomRMSNorm not in trainer_mod._NORM_OUTPUT_CAST_PATCHED_CLASSES
assert model.input_layernorm(x).dtype == x.dtype
finally:
trainer_mod._set_norm_output_cast_to_input_dtype(False)

assert nn.RMSNorm not in trainer_mod._NORM_OUTPUT_CAST_PATCHED_CLASSES
assert CustomRMSNorm not in trainer_mod._NORM_OUTPUT_CAST_PATCHED_CLASSES
assert not getattr(
CustomRMSNorm.__call__,
"_unsloth_norm_output_cast_wrapper",
False,
)
11 changes: 5 additions & 6 deletions tests/test_pr_a_components.py → tests/test_mlx_cce_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

"""
PR-A end-to-end component exercises through the shim.
"""End-to-end MLX CCE kernel exercises through the simulation shim.

Goes one level deeper than test_pr_a_imports.py: actually constructs
inputs and runs the function bodies. Each test focuses on one
critical PR-A code path.
Goes one level deeper than test_mlx_module_exports.py: actually
constructs inputs and runs the function bodies, one critical CCE
code path per test.
"""

from __future__ import annotations
Expand Down Expand Up @@ -202,7 +201,7 @@ def square_vjp(primals, cotangents, outputs):


# ---------------------------------------------------------------------------
# 4. mx.array isinstance contract that PR-A relies on.
# 4. mx.array isinstance contract that MLX trainer relies on.
# ---------------------------------------------------------------------------

def test_torch_tensor_is_mx_array():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

"""
PR-A integration: exercise unsloth_zoo.mlx.loader._dequantize_selected_mlx_modules.
"""Exercise unsloth_zoo.mlx.loader._dequantize_selected_mlx_modules.

Builds a synthetic MLX-style model with one QuantizedLinear submodule,
runs PR-A's dequantize-and-replace helper, verifies the result is
a numerically correct nn.Linear with the dequantized weight.

This is the canonical PR-A code path: load_in_4bit=False (or
selective requantize) walks named_modules, finds QuantizedLinear,
calls mx.dequantize with mode='affine', and swaps in nn.Linear.
runs the dequantize-and-replace helper, and verifies the result is a
numerically correct nn.Linear with the dequantized weight. Mirrors the
load_in_4bit=False / selective requantize path: walks named_modules,
finds QuantizedLinear, calls mx.dequantize with mode='affine', swaps
in nn.Linear.
"""

from __future__ import annotations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

"""
PR-A gated_delta_vjp end-to-end through the shim.
"""gated_delta_vjp end-to-end through the shim.

Exercises:
* mx.custom_function decorator + .vjp registration
Expand All @@ -24,8 +23,8 @@
* .astype(mx.float32) at ~30 sites
* mx.where / mx.expand_dims / mx.zeros_like

If forward + backward both produce finite tensors with the right shapes,
PR-A's VJP path is exercisable on Linux+CUDA.
If forward + backward both produce finite tensors with the right
shapes, the VJP path is exercisable on Linux+CUDA.
"""

from __future__ import annotations
Expand Down
Loading
Loading