From 765ee4913ac823afc58375c351daee218f52cbd6 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 19 May 2026 00:42:26 -0500 Subject: [PATCH 01/48] Tighten MLX VLM training parity diagnostics --- tests/test_pr_a_deep_components.py | 146 ++++++++++++++++++++++++++++- unsloth_zoo/mlx/compile.py | 39 +++++++- unsloth_zoo/mlx/loader.py | 121 ++++++++++++++++++++++++ unsloth_zoo/mlx/trainer.py | 28 +++++- unsloth_zoo/mlx/utils.py | 46 ++++++--- 5 files changed, 364 insertions(+), 16 deletions(-) diff --git a/tests/test_pr_a_deep_components.py b/tests/test_pr_a_deep_components.py index a8b8bc119..97f9bf9ad 100644 --- a/tests/test_pr_a_deep_components.py +++ b/tests/test_pr_a_deep_components.py @@ -113,7 +113,7 @@ def value_at(step): ("constant", 5), ], ) -def test_scheduler_lr_is_nonzero_for_optimizer_update_steps(scheduler, warmup): +def test_scheduler_lr_matches_expected_optimizer_update_steps(scheduler, warmup): from unsloth_zoo.mlx.trainer import MLXTrainer, MLXTrainingConfig total_steps = 8 @@ -134,7 +134,149 @@ def test_scheduler_lr_is_nonzero_for_optimizer_update_steps(scheduler, warmup): for value in raw_values ] - assert all(value > 0.0 for value in values) + if scheduler == "linear" and warmup == 0: + expected = [ + 0.0, + trainer.args.learning_rate, + trainer.args.learning_rate * 6 / 7, + trainer.args.learning_rate * 5 / 7, + trainer.args.learning_rate * 4 / 7, + trainer.args.learning_rate * 3 / 7, + trainer.args.learning_rate * 2 / 7, + trainer.args.learning_rate * 1 / 7, + ] + assert values == pytest.approx(expected) + else: + assert all(value > 0.0 for value in values) + + +def test_mlx_text_dataset_does_not_append_eos(monkeypatch): + """Studio formatting owns EOS decisions; MLX batching must not add one.""" + import sys + + class CacheDataset: + def __init__(self, data): + self._data = data + self._cache = {} + + def __len__(self): + return len(self._data) + + def __getitem__(self, idx): + if idx not in self._cache: + self._cache[idx] = self._data.process(self._data[idx]) + return self._cache[idx] + + def itemlen(self, idx): + return len(self[idx][0]) + + monkeypatch.setattr(sys.modules["mlx_lm.tuner.datasets"], "CacheDataset", CacheDataset) + + from unsloth_zoo.mlx.utils import _prepare_dataset + + class Tokenizer: + eos_token_id = 99 + chat_template = None + + def encode(self, text): + assert text == "hello" + return [1, 2, 3] + + dataset = _prepare_dataset([{"text": "hello"}], Tokenizer()) + + assert dataset[0] == ([1, 2, 3], 0) + + +def test_mlx_text_loss_masks_exclude_position_at_sequence_length(): + import inspect + from unsloth_zoo.mlx import utils as mlx_utils + + source = inspect.getsource(mlx_utils.make_baseline_loss_fn) + assert "steps < lengths[:, 1:]" in source + + +def test_mlx_train_result_reports_base_quantization(): + import inspect + from unsloth_zoo.mlx.trainer import MLXTrainer + + source = inspect.getsource(MLXTrainer._train_inner) + assert '"base_quantization_config"' in source + assert '"base_quantization_policy"' in source + assert '"base_quantized_source"' in source + + +def test_mlx_loader_exposes_dense_nf4_diagnostic_mode(): + import mlx.core as mx + from unsloth_zoo.mlx.loader import ( + _MLX_QUANT_MODE_DEFAULTS, + _nf4_dense_dequantize_weight, + ) + + assert _MLX_QUANT_MODE_DEFAULTS["nf4_dense"] == (64, 4) + + weight = mx.array([[-1.0, -0.6961928, 0.0, 0.72295684]], dtype=mx.float32) + dequantized = _nf4_dense_dequantize_weight(weight, group_size=4) + assert dequantized.shape == weight.shape + assert dequantized.reshape((-1,)).tolist() == pytest.approx( + weight.reshape((-1,)).tolist() + ) + + +def test_mlx_loader_keeps_norm_parameters_float32(): + import mlx.core as mx + from unsloth_zoo.mlx.loader import _keep_norm_parameters_float32 + + class TinyModel: + def __init__(self): + self._parameters = { + "vision_tower": { + "blocks": { + "0": { + "norm1": { + "weight": mx.array([1.0], dtype=mx.bfloat16), + "bias": mx.array([0.0], dtype=mx.bfloat16), + }, + "attn": { + "qkv": { + "weight": mx.array([[1.0]], dtype=mx.bfloat16), + }, + }, + }, + }, + }, + "language_model": { + "model": { + "layers": { + "0": { + "input_layernorm": { + "weight": mx.array([1.0], dtype=mx.bfloat16), + }, + }, + }, + }, + }, + } + + def parameters(self): + return self._parameters + + def update(self, parameters): + self._parameters = parameters + + model = TinyModel() + _keep_norm_parameters_float32(model) + params = model.parameters() + + assert params["vision_tower"]["blocks"]["0"]["norm1"]["weight"].dtype == mx.float32 + assert params["vision_tower"]["blocks"]["0"]["norm1"]["bias"].dtype == mx.float32 + assert ( + params["language_model"]["model"]["layers"]["0"]["input_layernorm"]["weight"].dtype + == mx.float32 + ) + assert ( + params["vision_tower"]["blocks"]["0"]["attn"]["qkv"]["weight"].dtype + == mx.bfloat16 + ) # --------------------------------------------------------------------------- diff --git a/unsloth_zoo/mlx/compile.py b/unsloth_zoo/mlx/compile.py index e14d7aff6..54fe9b176 100644 --- a/unsloth_zoo/mlx/compile.py +++ b/unsloth_zoo/mlx/compile.py @@ -2674,14 +2674,49 @@ def patched_qwen3_attention(self, x, cu_seqlens, rotary_pos_emb=None): attn_outputs = [] for q_chunk, k_chunk, v_chunk in zip(*splits): + # MLX fused SDPA currently has a forward/value mismatch under + # value_and_grad for Qwen3-VL vision chunks. Use explicit attention + # here so training loss and plain forward loss agree. + scores = ( + q_chunk.astype(mx.float32) + @ mx.swapaxes(k_chunk.astype(mx.float32), -1, -2) + ) * self.scale + probs = mx.softmax(scores, axis=-1).astype(q_chunk.dtype) attn_outputs.append( - vision_module.ensure_fused_sdpa(q_chunk, k_chunk, v_chunk, self.scale) + probs @ v_chunk ) output = mx.concatenate(attn_outputs, axis=2) output = output.transpose(0, 2, 1, 3).reshape(seq_length, -1) return self.proj(output) + def _qwen3_torch_like_layer_norm(norm, x): + """Match PyTorch bf16 LayerNorm: fp32 stats/affine, cast result back.""" + import mlx.core as mx + + source_dtype = x.dtype + x_f = x.astype(mx.float32) + mean = mx.mean(x_f, axis=-1, keepdims=True) + centered = x_f - mean + var = mx.mean(centered * centered, axis=-1, keepdims=True) + y = centered * mx.rsqrt(var + norm.eps) + if "weight" in norm: + y = y * norm.weight.astype(mx.float32) + if "bias" in norm: + y = y + norm.bias.astype(mx.float32) + return y.astype(source_dtype) + + def patched_qwen3_vision_block_call(self, hidden_states, cu_seqlens, rotary_pos_emb): + hidden_states = hidden_states + self.attn( + _qwen3_torch_like_layer_norm(self.norm1, hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + ) + hidden_states = hidden_states + self.mlp( + _qwen3_torch_like_layer_norm(self.norm2, hidden_states) + ) + return hidden_states + def patched_qwen3_rot_pos_emb(self, grid_thw): import mlx.core as mx @@ -2918,6 +2953,7 @@ def patched_qwen35_get_input_embeddings(self, input_ids=None, pixel_values=None, _patch_staticmethod(module.Model, "merge_input_ids_with_image_features", merge_qwen3) _patch_staticmethod(qwen35_module.Model, "merge_input_ids_with_image_features", merge_qwen3) _patch_method(vision_module.Attention, "__call__", patched_qwen3_attention) + _patch_method(vision_module.Qwen3VLMoEVisionBlock, "__call__", patched_qwen3_vision_block_call) _patch_method(vision_module.VisionModel, "rot_pos_emb", patched_qwen3_rot_pos_emb) _patch_method(vision_module.VisionModel, "fast_pos_embed_interpolate", patched_qwen3_fast_pos_embed_interpolate) _patch_method(vision_module.VisionModel, "__call__", patched_qwen3_vision_call) @@ -2928,6 +2964,7 @@ def patched_qwen35_get_input_embeddings(self, input_ids=None, pixel_values=None, qwen3moe_module.masked_scatter = _masked_scatter_no_numpy _patch_staticmethod(qwen3moe_module.Model, "merge_input_ids_with_image_features", merge_qwen3) _patch_method(qwen3moe_vision_module.Attention, "__call__", patched_qwen3_attention) + _patch_method(qwen3moe_vision_module.Qwen3VLMoEVisionBlock, "__call__", patched_qwen3_vision_block_call) _patch_method(qwen3moe_vision_module.VisionModel, "rot_pos_emb", patched_qwen3_rot_pos_emb) _patch_method(qwen3moe_vision_module.VisionModel, "fast_pos_embed_interpolate", patched_qwen3_fast_pos_embed_interpolate) _patch_method(qwen3moe_vision_module.VisionModel, "__call__", patched_qwen3_vision_call) diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index 9b2151aa0..a466d5416 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -112,6 +112,38 @@ def _convert_mlx_dtype(model, target_dtype) -> None: mx.eval(model.parameters()) +def _is_norm_parameter_path(path) -> bool: + """Return whether a parameter path belongs to a normalization module.""" + parts = str(path).lower().split(".") + return any("norm" in part for part in parts[:-1]) + + +def _keep_norm_parameters_float32(model) -> None: + """Keep LM/VLM normalization parameters in fp32 across FT/LoRA/QLoRA.""" + import mlx.core as mx + from mlx.utils import tree_flatten, tree_map_with_path + + needs_cast = False + for k, v in tree_flatten(model.parameters()): + if ( + _is_norm_parameter_path(k) + and mx.issubdtype(v.dtype, mx.floating) + and v.dtype != mx.float32 + ): + needs_cast = True + break + if not needs_cast: + return + + model.update(tree_map_with_path( + lambda k, v: v.astype(mx.float32) + if _is_norm_parameter_path(k) and mx.issubdtype(v.dtype, mx.floating) + else v, + model.parameters(), + )) + mx.eval(model.parameters()) + + def _seed_mlx_random_state(random_state): try: seed = int(random_state) @@ -725,6 +757,9 @@ def patched_set_dtype(self, dtype): _MLX_QUANT_MODE_DEFAULTS = { "affine": (64, 4), + # Diagnostic CUDA bitsandbytes NF4 parity mode. This quantizes and then + # immediately dequantizes into dense Linear weights; it is not memory-saving. + "nf4_dense": (64, 4), "mxfp4": (32, 4), "nvfp4": (16, 4), "mxfp8": (32, 8), @@ -1412,6 +1447,87 @@ def _dequantize_selected_mlx_modules(model, predicate): return len(replacements) +def _nf4_dense_dequantize_weight(weight, group_size=64): + import mlx.core as mx + + codebook = mx.array( + [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ], + dtype=mx.float32, + ) + original_shape = weight.shape + original_dtype = weight.dtype + flat = weight.astype(mx.float32).reshape((-1,)) + original_size = ( + flat.numel() + if callable(getattr(flat, "numel", None)) + else (flat.size() if callable(getattr(flat, "size", None)) else flat.size) + ) + pad = (-original_size) % group_size + if pad: + flat = mx.concatenate([flat, mx.zeros((pad,), dtype=mx.float32)]) + groups = flat.reshape((-1, group_size)) + absmax = mx.max(mx.abs(groups), axis=1, keepdims=True) + denom = mx.maximum(absmax, mx.array(1e-12, dtype=mx.float32)) + scaled = groups / denom + indices = mx.argmin(mx.abs(scaled[..., None] - codebook), axis=-1) + dequantized = (codebook[indices] * absmax).reshape((-1,))[:original_size] + return dequantized.reshape(original_shape).astype(original_dtype) + + +def _apply_dense_nf4_quantization(model, config, spec: _MLXQuantizationSpec, predicate): + import mlx.core as mx + + quantized = {} + for path, module in model.named_modules(): + if not predicate(path, module): + continue + weight = getattr(module, "weight", None) + if weight is None or len(getattr(weight, "shape", ())) != 2: + continue + module.weight = _nf4_dense_dequantize_weight(weight, spec.group_size or 64) + quantized[path] = { + "bits": 4, + "group_size": spec.group_size or 64, + "mode": "nf4_dense", + "storage": "dense_dequantized", + } + mx.eval(module.weight) + + updated_config = dict(config or {}) if isinstance(config, dict) else {} + updated_config["quantization"] = quantized + updated_config["quantization_config"] = quantized + model._config = updated_config + model._unsloth_quantization_config = quantized + model._unsloth_quantization_policy = { + **spec.to_metadata(), + "storage": "dense_dequantized", + "warning": ( + "nf4_dense is a diagnostic CUDA bitsandbytes NF4 parity mode. " + "Weights are stored densely after quantize/dequantize and this " + "does not reduce memory like QLoRA." + ), + } + model._unsloth_quantized_source = "runtime_dense_nf4" + return model, updated_config + + def _apply_mlx_quantization(model, config, spec: _MLXQuantizationSpec, *, is_vlm, user_predicate=None): if not spec.enabled: model._unsloth_quantization_config = None @@ -1448,6 +1564,8 @@ def _apply_mlx_quantization(model, config, spec: _MLXQuantizationSpec, *, is_vlm if is_vlm or spec.quantize_modules is not None or user_predicate is not None: config = dict(config or {}) config.setdefault("quantization", {}) + if spec.mode == "nf4_dense": + return _apply_dense_nf4_quantization(model, config, spec, predicate) model, updated_config = quantize_model( model, config, @@ -2435,6 +2553,7 @@ def from_pretrained( model._unsloth_quantized_source = adapter_cfg.get( "base_quantized_source" ) + _keep_norm_parameters_float32(model) _patch_mlx_saving(model, tokenizer) return model, tokenizer except Exception as e: @@ -2537,6 +2656,7 @@ def from_pretrained( elif want_runtime_quant: import mlx.core as mx mx.eval(model.parameters()) + _keep_norm_parameters_float32(model) from .utils import ( normalize_mlx_chat_template, @@ -2648,6 +2768,7 @@ def from_pretrained( elif want_runtime_quant: import mlx.core as mx mx.eval(model.parameters()) + _keep_norm_parameters_float32(model) from .utils import normalize_mlx_chat_template tokenizer = normalize_mlx_chat_template( diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 4edd27b7c..10621103d 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -322,6 +322,22 @@ def _build_schedule(self, total_steps): decay_steps = max(total_steps - warmup, 1) + if sched_type == "linear" and warmup == 0: + # Match the Studio CUDA/Trainer path observed in fixed-fixture + # probes: linear/no-warmup starts with a zero-LR optimizer step, + # then decays from the requested LR over the remaining steps. + decay_after_zero = max(total_steps - 1, 1) + + def main_schedule(step): + step = mx.array(step) + decay = mx.maximum( + mx.array(total_steps, dtype=mx.float32) - step, + mx.array(0.0, dtype=mx.float32), + ) / mx.array(decay_after_zero, dtype=mx.float32) + return mx.where(step <= 0, mx.array(0.0, dtype=mx.float32), lr * decay) + + return main_schedule + if sched_type == "cosine": main_schedule = optim.cosine_decay(lr, decay_steps, end=0.0) elif sched_type == "linear": @@ -745,7 +761,8 @@ def _train_inner(self): _direct_single_step_update = ( grad_accum == 1 and not _needs_grad_scaling and - max_grad_norm <= 0 + max_grad_norm <= 0 and + not _clip_grad_value ) def _grad_leaf_scale(name, safe_toks_f, clip_scale=None, dtype=None): @@ -1261,6 +1278,15 @@ def step_fn(batch_data, prev_state, do_update): ), "compile_auto_tune_applied": list(getattr(self, "_compile_auto_tune_applied", [])), "memory_limits_applied": dict(getattr(self, "_memory_limits_applied", {})), + "base_quantization_config": getattr( + self.model, "_unsloth_quantization_config", None, + ), + "base_quantization_policy": getattr( + self.model, "_unsloth_quantization_policy", None, + ), + "base_quantized_source": getattr( + self.model, "_unsloth_quantized_source", None, + ), } def _prepare_data(self, is_vlm): diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 3caf75688..62de823b1 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -353,7 +353,7 @@ def loss_fn(model, batch, lengths, labels=None): sc = layer.scales bi = layer.biases if _has_biases else mx.zeros_like(layer.scales) steps = mx.arange(1, targets.shape[1] + 1) - length_mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) + length_mask = mx.logical_and(steps >= lengths[:, 0:1], steps < lengths[:, 1:]) if labels is None: mask = length_mask else: @@ -386,7 +386,7 @@ def loss_fn(model, batch, lengths, labels=None): if _skip_weight_grad: w = mx.stop_gradient(w) steps = mx.arange(1, targets.shape[1] + 1) - length_mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) + length_mask = mx.logical_and(steps >= lengths[:, 0:1], steps < lengths[:, 1:]) if labels is None: mask = length_mask else: @@ -422,7 +422,7 @@ def loss_fn(model, batch, lengths, labels=None): targets = labels[:, 1:] logits = model(inputs) steps = mx.arange(1, targets.shape[1] + 1) - length_mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) + length_mask = mx.logical_and(steps >= lengths[:, 0:1], steps < lengths[:, 1:]) if labels is None: mask = length_mask.astype(mx.float32) else: @@ -744,13 +744,13 @@ def _vlm_cce_forward(model, batch_dict, image_token_ids=None, attention_mask = batch_dict.get("attention_mask") labels = batch_dict.get("labels") - inputs = input_ids[:, :-1] - # Shift attention_mask so any 4D causal/image mask the embedder builds - # has q/kv dims matching the (shifted) inputs. Otherwise models like - # Gemma3 see (B,1,S,S-1) vs (B,H,S-1,S-1) at SDPA and crash. + # Match the standard VLM forward semantics: run the full multimodal + # sequence, then use hidden[:, :-1] to predict labels[:, 1:]. Qwen3-VL + # image/mRoPE/deepstack state depends on the complete sequence; trimming + # input_ids before the multimodal forward produces a different loss from + # the full-logits path and from CUDA. + inputs = input_ids fwd_attn_mask = attention_mask - if attention_mask is not None and attention_mask.shape[-1] == input_ids.shape[-1]: - fwd_attn_mask = attention_mask[:, :-1] # Collect extra keys (e.g. image_grid_thw for Qwen) that some models need. extra_kwargs = { @@ -767,7 +767,7 @@ def _vlm_cce_forward(model, batch_dict, image_token_ids=None, **extra_kwargs, ) merged_embeds, backbone_kwargs = _unpack_embed_result(embed_result, model) - if "position_ids" in extra_kwargs and "position_ids" not in backbone_kwargs: + if "position_ids" in extra_kwargs: backbone_kwargs["position_ids"] = extra_kwargs["position_ids"] hidden = _forward_text_hidden_states( @@ -776,6 +776,7 @@ def _vlm_cce_forward(model, batch_dict, image_token_ids=None, inputs_embeds=merged_embeds, **backbone_kwargs, ) + hidden = hidden[:, :-1, :] if labels is not None: # train_on_responses_only: labels already encode instruction and @@ -2423,7 +2424,7 @@ def _prepare_dataset(dataset, tokenizer, dataset_text_field="text", Returns: A CacheDataset ready for ``iterate_batches``. """ - from mlx_lm.tuner.datasets import TextDataset, CacheDataset + from mlx_lm.tuner.datasets import CacheDataset normalize_mlx_chat_template( tokenizer, @@ -2460,7 +2461,24 @@ def _prepare_dataset(dataset, tokenizer, dataset_text_field="text", "a formatting_func that returns text." ) - return CacheDataset(TextDataset(formatted, tokenizer, text_key="text")) + class _StudioTextDataset: + """TextDataset variant that does not append EOS behind Studio's back.""" + + def __init__(self, data, tokenizer, text_key="text"): + self._data = data + self.tokenizer = tokenizer + self.text_key = text_key + + def process(self, item): + return (self.tokenizer.encode(item[self.text_key]), 0) + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + return CacheDataset(_StudioTextDataset(formatted, tokenizer, text_key="text")) def create_batches(dataset, tokenizer, batch_size, max_seq_length, @@ -2497,6 +2515,8 @@ def create_batches(dataset, tokenizer, batch_size, max_seq_length, loop=(num_batches is not None), seed=seed, ): + max_length = int(mx.max(lengths_info[:, 1]).item()) + batch = batch[:, :max_length] batch_pairs.append((batch, lengths_info, None)) if num_batches is not None and len(batch_pairs) >= num_batches: break @@ -2531,6 +2551,8 @@ def iterate_training_batches(dataset, tokenizer, batch_size, max_seq_length, loop=True, seed=seed, ): + max_length = int(mx.max(lengths_info[:, 1]).item()) + batch = batch[:, :max_length] yield batch, lengths_info, None From 6fb702af7145335c5fbfd0abed7dbedc5215bbd3 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 19 May 2026 01:00:48 -0500 Subject: [PATCH 02/48] Match Qwen3-VL rotary precision in MLX --- tests/test_pr_a_deep_components.py | 14 ++++++++++++++ unsloth_zoo/mlx/compile.py | 22 ++++++++++++++++++++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/test_pr_a_deep_components.py b/tests/test_pr_a_deep_components.py index 97f9bf9ad..a34316e8e 100644 --- a/tests/test_pr_a_deep_components.py +++ b/tests/test_pr_a_deep_components.py @@ -279,6 +279,20 @@ def update(self, parameters): ) +def test_qwen3_vl_vision_rotary_uses_transformers_fp32_math(): + import inspect + import unsloth_zoo.mlx.compile as mc + + source = inspect.getsource(mc._install_qwen3_family_compile_patches) + + assert "def _qwen3_vision_rotary_fp32" in source + assert "tensor_f = tensor.astype(mx.float32)" in source + assert "freqs_f = freqs.astype(mx.float32)" in source + assert "return rotated.astype(orig_dtype)" in source + assert "q = _qwen3_vision_rotary_fp32(q, rotary_pos_emb)" in source + assert "k = _qwen3_vision_rotary_fp32(k, rotary_pos_emb)" in source + + # --------------------------------------------------------------------------- # 2. compile module-level discovery functions return sensible defaults # on a host with no real MLX architectures. diff --git a/unsloth_zoo/mlx/compile.py b/unsloth_zoo/mlx/compile.py index 54fe9b176..e194b03db 100644 --- a/unsloth_zoo/mlx/compile.py +++ b/unsloth_zoo/mlx/compile.py @@ -2662,8 +2662,8 @@ def patched_qwen3_attention(self, x, cu_seqlens, rotary_pos_emb=None): ) q, k, v = mx.split(qkv, 3) - q = vision_module.apply_rotary_pos_emb_vision(mx.expand_dims(q, 0), rotary_pos_emb)[0] - k = vision_module.apply_rotary_pos_emb_vision(mx.expand_dims(k, 0), rotary_pos_emb)[0] + q = _qwen3_vision_rotary_fp32(q, rotary_pos_emb) + k = _qwen3_vision_rotary_fp32(k, rotary_pos_emb) q = q.transpose(0, 2, 1, 3) k = k.transpose(0, 2, 1, 3) @@ -2690,6 +2690,24 @@ def patched_qwen3_attention(self, x, cu_seqlens, rotary_pos_emb=None): output = output.transpose(0, 2, 1, 3).reshape(seq_length, -1) return self.proj(output) + def _qwen3_vision_rotary_fp32(tensor, freqs): + """Match Transformers Qwen3-VL rotary: fp32 math, cast back.""" + import mlx.core as mx + + orig_dtype = tensor.dtype + tensor_f = tensor.astype(mx.float32) + freqs_f = freqs.astype(mx.float32) + cos = mx.cos(freqs_f) + sin = mx.sin(freqs_f) + + cos = mx.expand_dims(cos, axis=1) + cos = mx.tile(cos, (1, 1, 2)) + sin = mx.expand_dims(sin, axis=1) + sin = mx.tile(sin, (1, 1, 2)) + + rotated = (tensor_f * cos) + (vision_module.rotate_half(tensor_f) * sin) + return rotated.astype(orig_dtype) + def _qwen3_torch_like_layer_norm(norm, x): """Match PyTorch bf16 LayerNorm: fp32 stats/affine, cast result back.""" import mlx.core as mx From 5bc745aeab961152f23490fdc885c9ef329b3ea1 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 19 May 2026 01:34:12 -0500 Subject: [PATCH 03/48] Disable Qwen3-VL MLX compile verification --- tests/test_pr_a_deep_components.py | 8 ++++++++ unsloth_zoo/mlx/compile.py | 26 ++++++++++++++++++++++---- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/tests/test_pr_a_deep_components.py b/tests/test_pr_a_deep_components.py index a34316e8e..b6aea5a60 100644 --- a/tests/test_pr_a_deep_components.py +++ b/tests/test_pr_a_deep_components.py @@ -293,6 +293,14 @@ def test_qwen3_vl_vision_rotary_uses_transformers_fp32_math(): assert "k = _qwen3_vision_rotary_fp32(k, rotary_pos_emb)" in source +def test_qwen3_vl_training_compile_not_verified_until_real_parity(): + import unsloth_zoo.mlx.compile as mc + + assert "qwen3_vl" not in mc._VERIFIED_TRAINING_ARCHES + assert "qwen3_vl_moe" not in mc._VERIFIED_TRAINING_ARCHES + assert "real 10-step training parity" in mc._ARCH_TRAINING_COMPILE_BLOCK_REASONS["qwen3_vl"] + + # --------------------------------------------------------------------------- # 2. compile module-level discovery functions return sensible defaults # on a host with no real MLX architectures. diff --git a/unsloth_zoo/mlx/compile.py b/unsloth_zoo/mlx/compile.py index e194b03db..efaa7241b 100644 --- a/unsloth_zoo/mlx/compile.py +++ b/unsloth_zoo/mlx/compile.py @@ -59,8 +59,12 @@ # Architectures explicitly verified for mlx compile support. # Training verification currently covers: # - qwen2_5_vl: real end-to-end compiled training via train.py -# - qwen3_vl / qwen3_5 / qwen3_5_moe / gemma4 / paligemma / moondream3: +# - qwen3_5 / qwen3_5_moe / gemma4 / paligemma / moondream3: # compiled synthetic forward+backward +# Qwen3-VL/Qwen3-VL-MoE stay patched but unqualified for real training compile: +# a fixed-fixture Qwen3-VL full-FT probe showed compiled vision backward/update +# drift relative to patched eager while sampled language gradients stayed +# aligned. Re-promote only after real 10-step training parity is verified. # SmolVLM has processor/template support, but real mlx-vlm training still hits # MLX primitive-less-array failures after a compiled call. Keep it patched but # unqualified until a real dataset compile run passes. @@ -93,8 +97,6 @@ "qwen2_5_vl", "qwen3_5", "qwen3_5_moe", - "qwen3_vl_moe", - "qwen3_vl", } _VERIFIED_GENERATION_ARCHES: set[str] = set() @@ -105,6 +107,19 @@ ), ) +_ARCH_TRAINING_COMPILE_BLOCK_REASONS: dict[str, str] = { + "qwen3_vl": ( + "Qwen3-VL real 10-step training parity is not verified: compiled " + "vision backward/update drifts relative to patched eager" + ), + "qwen3_vl_moe": ( + "Qwen3-VL-MoE shares the Qwen3-VL vision/deepstack training path; " + "keep compile disabled until Qwen3-VL real 10-step training parity " + "is verified" + ), +} + + _BACKEND_CONFIG_KEYS = ( "text_config", "language_config", @@ -1335,7 +1350,10 @@ def _build_compile_qualification( name for name in matched_patterns if name in _PATCHED_PATTERN_BUNDLES ) - if training_ok or generation_ok: + arch_block_reason = _ARCH_TRAINING_COMPILE_BLOCK_REASONS.get(arch) + if arch_block_reason is not None and not generation_ok: + reason = arch_block_reason + elif training_ok or generation_ok: ready = [] if training_ok: ready.append("training") From 1c487a63c8490468a283a7bfac2d3d35440b8699 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 19 May 2026 02:20:42 -0500 Subject: [PATCH 04/48] Match HF AdamW decay filtering in MLX --- tests/test_pr_a_deep_components.py | 25 +++++++++++++++++ tests/test_pr_a_imports.py | 5 +++- unsloth_zoo/mlx/trainer.py | 43 +++++++++++++++++++++++++++--- 3 files changed, 68 insertions(+), 5 deletions(-) diff --git a/tests/test_pr_a_deep_components.py b/tests/test_pr_a_deep_components.py index b6aea5a60..118460fba 100644 --- a/tests/test_pr_a_deep_components.py +++ b/tests/test_pr_a_deep_components.py @@ -102,6 +102,31 @@ def value_at(step): assert second_lr > first_lr +def test_adamw_weight_decay_uses_hf_bias_norm_filter(): + from unsloth_zoo.mlx.trainer import MLXTrainer, MLXTrainingConfig + + class DummyModel: + def trainable_parameters(self): + return {} + + trainer = MLXTrainer.__new__(MLXTrainer) + trainer.model = DummyModel() + trainer.args = MLXTrainingConfig( + optim="adamw", + weight_decay=0.1, + ) + + optimizer = trainer._build_optimizer(total_steps=8) + + assert trainer._manual_adamw_weight_decay == pytest.approx(0.1) + if hasattr(optimizer, "_kw"): + assert optimizer._kw["weight_decay"] == 0.0 + assert MLXTrainer._should_apply_weight_decay("layers.0.mlp.down_proj.weight") + assert not MLXTrainer._should_apply_weight_decay("layers.0.mlp.down_proj.bias") + assert not MLXTrainer._should_apply_weight_decay("layers.0.input_layernorm.weight") + assert not MLXTrainer._should_apply_weight_decay("vision.blocks.0.norm1.weight") + + @pytest.mark.parametrize( ("scheduler", "warmup"), [ diff --git a/tests/test_pr_a_imports.py b/tests/test_pr_a_imports.py index 7eef0ebf8..c166bd61e 100644 --- a/tests/test_pr_a_imports.py +++ b/tests/test_pr_a_imports.py @@ -223,4 +223,7 @@ def trainable_parameters(self): args=MLXTrainingConfig(optim=optim_name), ) optimizer = trainer._build_optimizer(total_steps=10) - assert optimizer._kw["bias_correction"] is True + if hasattr(optimizer, "_kw"): + assert optimizer._kw["bias_correction"] is True + else: + assert optimizer.bias_correction is True diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 10621103d..92f489aa1 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -377,14 +377,16 @@ def _set_optimizer_lr_for_step(self, optimizer, step): def _build_optimizer(self, total_steps): """Create MLX optimizer with LR schedule from config. - For optimizers that support weight_decay, wraps with - optim.decay_weight to exclude bias and norm parameters - (matching HuggingFace Trainer behavior). + For AdamW, MLX applies weight decay inside the leaf update without a + parameter-group filter. Keep MLX AdamW's built-in decay disabled and + apply decoupled decay ourselves so bias and norm parameters match + HuggingFace Trainer behavior. """ schedule = self._build_schedule(total_steps) initial_lr = self._schedule_value(schedule, 0) self._lr_schedule = schedule if callable(schedule) else None wd = self.args.weight_decay + self._manual_adamw_weight_decay = 0.0 opt_name = _normalize_mlx_optimizer_name(self.args.optim) if opt_name == "adafactor": @@ -411,9 +413,10 @@ def _build_optimizer(self, total_steps): elif opt_name == "adamw": # Match HF/PyTorch AdamW semantics. MLX defaults bias_correction # to False, which makes early warmup updates much larger. + self._manual_adamw_weight_decay = float(wd or 0.0) optimizer = optim.AdamW( learning_rate=initial_lr, - weight_decay=wd, + weight_decay=0.0, bias_correction=True, ) elif opt_name == "adam": @@ -430,6 +433,36 @@ def _build_optimizer(self, total_steps): self._resolved_optimizer_name = opt_name return optimizer + @staticmethod + def _should_apply_weight_decay(name, parameter=None): + """HF-style AdamW decay filter: decay weights, skip bias and norms.""" + parts = [part.lower() for part in str(name).split(".") if part] + leaf = parts[-1] if parts else str(name).lower() + if leaf == "bias": + return False + if any("norm" in part for part in parts): + return False + return True + + def _apply_manual_adamw_weight_decay(self, model, optimizer, grad): + """Apply decoupled AdamW decay to trainable non-bias/non-norm leaves.""" + wd = float(getattr(self, "_manual_adamw_weight_decay", 0.0) or 0.0) + if wd <= 0: + return + + flat_grad = dict(tree_flatten(grad)) + decayed = [] + for name, parameter in tree_flatten(model.trainable_parameters()): + if name not in flat_grad: + continue + if not self._should_apply_weight_decay(name, parameter): + continue + lr = optimizer.learning_rate.astype(flat_grad[name].dtype) + scale = mx.array(1.0, dtype=lr.dtype) - lr * mx.array(wd, dtype=lr.dtype) + decayed.append((name, parameter * scale)) + if decayed: + model.update(tree_unflatten(decayed)) + @staticmethod def _adafactor_unsupported_parameters(model): """Return trainable params that MLX Adafactor cannot update safely. @@ -860,6 +893,7 @@ def _apply_update(grad, toks_f): lambda g: mx.clip(g, -max_grad_value, max_grad_value), final_grad, ) + self._apply_manual_adamw_weight_decay(model, optimizer, final_grad) optimizer.update(model, final_grad) return grad_norm @@ -885,6 +919,7 @@ def _apply_update_direct(grad): lambda g: mx.clip(g, -max_grad_value, max_grad_value), grad, ) + self._apply_manual_adamw_weight_decay(model, optimizer, grad) optimizer.update(model, grad) return grad_norm From 9a38968165ebc945313062a3a763b56bff60c7b0 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 19 May 2026 14:29:39 -0500 Subject: [PATCH 05/48] Preserve Qwen3-VL residual dtype in MLX vision block --- unsloth_zoo/mlx/compile.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/mlx/compile.py b/unsloth_zoo/mlx/compile.py index efaa7241b..d22f62674 100644 --- a/unsloth_zoo/mlx/compile.py +++ b/unsloth_zoo/mlx/compile.py @@ -2743,14 +2743,17 @@ def _qwen3_torch_like_layer_norm(norm, x): return y.astype(source_dtype) def patched_qwen3_vision_block_call(self, hidden_states, cu_seqlens, rotary_pos_emb): - hidden_states = hidden_states + self.attn( + residual_dtype = hidden_states.dtype + attn_output = self.attn( _qwen3_torch_like_layer_norm(self.norm1, hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, ) - hidden_states = hidden_states + self.mlp( + hidden_states = (hidden_states + attn_output.astype(residual_dtype)).astype(residual_dtype) + mlp_output = self.mlp( _qwen3_torch_like_layer_norm(self.norm2, hidden_states) ) + hidden_states = (hidden_states + mlp_output.astype(residual_dtype)).astype(residual_dtype) return hidden_states def patched_qwen3_rot_pos_emb(self, grid_thw): From 6ec832e2959ecec3ffe4f7b5b8bd667bf7c9b56f Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Wed, 20 May 2026 10:09:16 -0500 Subject: [PATCH 06/48] update --- unsloth_zoo/mlx/trainer.py | 55 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 92f489aa1..f93676367 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -118,6 +118,8 @@ class MLXTrainingConfig: # Optimization optim: str = "adamw" # "adafactor", "adamw", "adam", "sgd", "muon", "lion" weight_decay: float = 0.001 + adam_beta1: float | None = None + adam_beta2: float | None = None max_grad_norm: float = 0.0 # disabled by default on MLX to avoid clip-memory overhead # Elementwise clip ([-max_grad_value, max_grad_value], per-leaf, no # cross-leaf reduction). Set 0.0 to disable. Default 1.0: |g_i| > 5 @@ -196,6 +198,14 @@ def __init__( # Auto-detect VLM self._is_vlm = _is_vlm_model(model) + if self._is_vlm: + # VLM callers pass the processor through the tokenizer slot to + # mirror HF Trainer APIs. The loss is constructed before data + # preparation, so attach it now for image-token masking. + if self.processor is None: + self.processor = tokenizer + if self.processor is not None: + self.model._processor = self.processor # Constructor params override args if provided if dataset_text_field is not None: @@ -387,6 +397,14 @@ def _build_optimizer(self, total_steps): self._lr_schedule = schedule if callable(schedule) else None wd = self.args.weight_decay self._manual_adamw_weight_decay = 0.0 + adam_beta1 = getattr(self.args, "adam_beta1", None) + adam_beta2 = getattr(self.args, "adam_beta2", None) + adam_kwargs = {} + if adam_beta1 is not None or adam_beta2 is not None: + adam_kwargs["betas"] = ( + float(0.9 if adam_beta1 is None else adam_beta1), + float(0.999 if adam_beta2 is None else adam_beta2), + ) opt_name = _normalize_mlx_optimizer_name(self.args.optim) if opt_name == "adafactor": @@ -418,11 +436,13 @@ def _build_optimizer(self, total_steps): learning_rate=initial_lr, weight_decay=0.0, bias_correction=True, + **adam_kwargs, ) elif opt_name == "adam": optimizer = optim.Adam( learning_rate=initial_lr, bias_correction=True, + **adam_kwargs, ) elif opt_name == "sgd": optimizer = optim.SGD(learning_rate=initial_lr, weight_decay=wd) @@ -798,6 +818,39 @@ def _train_inner(self): not _clip_grad_value ) + def _is_norm_parameter_name(name): + return any( + "norm" in part.lower() + for part in str(name).split(".") + if part + ) + + _restore_storage_after_norm_clip = max_grad_norm > 0 + _trainable_storage_dtypes = ( + { + name: value.dtype + for name, value in tree_flatten(model.trainable_parameters()) + if not _is_norm_parameter_name(name) + } + if _restore_storage_after_norm_clip + else {} + ) + + def _restore_trainable_storage_dtypes(): + """Keep norm-clipped MLX updates from promoting non-norm params.""" + if not _restore_storage_after_norm_clip: + return + recast = [] + needs_update = False + for name, value in tree_flatten(model.trainable_parameters()): + dtype = _trainable_storage_dtypes.get(name) + if dtype is not None and value.dtype != dtype: + value = value.astype(dtype) + needs_update = True + recast.append((name, value)) + if needs_update: + model.update(tree_unflatten(recast)) + def _grad_leaf_scale(name, safe_toks_f, clip_scale=None, dtype=None): """Return the exact scalar applied to one grad leaf before update. @@ -895,6 +948,7 @@ def _apply_update(grad, toks_f): ) self._apply_manual_adamw_weight_decay(model, optimizer, final_grad) optimizer.update(model, final_grad) + _restore_trainable_storage_dtypes() return grad_norm def _apply_update_direct(grad): @@ -921,6 +975,7 @@ def _apply_update_direct(grad): ) self._apply_manual_adamw_weight_decay(model, optimizer, grad) optimizer.update(model, grad) + _restore_trainable_storage_dtypes() return grad_norm # Unified step function for both VLM and text training. From 5895b20cce0dab8685dce5bc7230fed39560a0fb Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Wed, 20 May 2026 15:44:35 -0500 Subject: [PATCH 07/48] udpate vlm --- tests/test_mlx_vlm_label_masks.py | 131 ++++++++++ tests/test_pr_a_deep_components.py | 7 +- unsloth_zoo/compiler.py | 16 +- unsloth_zoo/mlx/compile.py | 25 +- unsloth_zoo/mlx/trainer.py | 92 +++++--- unsloth_zoo/mlx/utils.py | 368 ++++++++++++++++++----------- 6 files changed, 435 insertions(+), 204 deletions(-) create mode 100644 tests/test_mlx_vlm_label_masks.py diff --git a/tests/test_mlx_vlm_label_masks.py b/tests/test_mlx_vlm_label_masks.py new file mode 100644 index 000000000..451de7ab3 --- /dev/null +++ b/tests/test_mlx_vlm_label_masks.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import numpy as np +import pytest +from pathlib import Path + + +mx = pytest.importorskip("mlx.core") +if "mlx_simulation" in str(getattr(mx, "__file__", "")): + pytest.skip("requires real MLX runtime", allow_module_level=True) + + +class _FakeTokenizer: + pad_token_id = 0 + unk_token_id = -1 + image_token = "" + + _vocab = { + "": 200, + "<|image_pad|>": 201, + } + + def convert_tokens_to_ids(self, tokens): + if isinstance(tokens, list): + return [self._vocab.get(token, self.unk_token_id) for token in tokens] + return self._vocab.get(tokens, self.unk_token_id) + + +class _FakeProcessor: + tokenizer = _FakeTokenizer() + image_processor = object() + chat_template = "{{ messages }}" + + def __call__(self, text, **_kwargs): + rows = [] + masks = [] + for idx, _ in enumerate(text): + if idx == 0: + row = [101, 10, 200, 11, 0] + mask = [1, 1, 1, 1, 0] + else: + row = [101, 12, 13, 0, 0] + mask = [1, 1, 1, 0, 0] + rows.append(row) + masks.append(mask) + return { + "input_ids": np.array(rows, dtype=np.int32), + "attention_mask": np.array(masks, dtype=np.int32), + } + + +def test_vlm_collate_creates_sft_labels_and_masks_special_tokens(): + from unsloth_zoo.mlx.utils import ( + _collate_vlm_batch, + _get_vlm_ignore_token_ids, + ) + + processor = _FakeProcessor() + ignore_ids = _get_vlm_ignore_token_ids( + processor=processor, + config={"image_token_id": 200}, + ) + batch = _collate_vlm_batch( + [{"text": "first"}, {"text": "second"}], + processor, + max_seq_length=8, + image_size=16, + ignore_token_ids=ignore_ids, + ) + + assert "labels" in batch + assert batch["input_ids"].tolist() == [ + [101, 10, 200, 11, 0], + [101, 12, 13, 0, 0], + ] + assert batch["labels"].tolist() == [ + [101, 10, -100, 11, -100], + [101, 12, 13, -100, -100], + ] + + +def test_vlm_response_mask_reapplies_special_token_masks(): + from unsloth_zoo.mlx.utils import _apply_response_mask_to_vlm_batch + + batch = { + "input_ids": mx.array([[101, 200, 13, 0]], dtype=mx.int32), + "attention_mask": mx.array([[1, 1, 1, 0]], dtype=mx.int32), + "labels": mx.array([[101, -100, 13, -100]], dtype=mx.int32), + } + + def mask_fn(_batch): + return {"labels": [[-100, 200, 13, 0]]} + + out = _apply_response_mask_to_vlm_batch( + batch, + mask_fn, + ignore_token_ids=[0, 200], + ) + + assert out["labels"].tolist() == [[-100, -100, 13, -100]] + + +def test_token_expansion_masks_inserted_label_positions(): + from unsloth_zoo.mlx.utils import _expand_token_runs + + input_ids = mx.array([[1, 200, 3, 0]], dtype=mx.int32) + attention_mask = mx.array([[1, 1, 1, 0]], dtype=mx.int32) + labels = mx.array([[1, -100, 3, -100]], dtype=mx.int32) + + expanded_ids, expanded_mask, expanded_labels = _expand_token_runs( + input_ids=input_ids, + attention_mask=attention_mask, + replacements_by_batch=(((1, 2, 200, 3),),), + labels=labels, + ) + + assert expanded_ids.tolist() == [[1, 200, 200, 200, 3, 0]] + assert expanded_mask.tolist() == [[1, 1, 1, 1, 1, 0]] + assert expanded_labels.tolist() == [[1, -100, -100, -100, 3, -100]] + + +def test_mlx_trainer_does_not_attach_processor_for_loss_masking(): + trainer_source = ( + Path(__file__).resolve().parents[1] + / "unsloth_zoo" + / "mlx" + / "trainer.py" + ).read_text() + + assert "self.model._processor =" not in trainer_source + assert "_get_vlm_ignore_token_ids(" in trainer_source diff --git a/tests/test_pr_a_deep_components.py b/tests/test_pr_a_deep_components.py index 118460fba..396ce6916 100644 --- a/tests/test_pr_a_deep_components.py +++ b/tests/test_pr_a_deep_components.py @@ -318,12 +318,11 @@ def test_qwen3_vl_vision_rotary_uses_transformers_fp32_math(): assert "k = _qwen3_vision_rotary_fp32(k, rotary_pos_emb)" in source -def test_qwen3_vl_training_compile_not_verified_until_real_parity(): +def test_qwen3_vl_training_compile_verified(): import unsloth_zoo.mlx.compile as mc - assert "qwen3_vl" not in mc._VERIFIED_TRAINING_ARCHES - assert "qwen3_vl_moe" not in mc._VERIFIED_TRAINING_ARCHES - assert "real 10-step training parity" in mc._ARCH_TRAINING_COMPILE_BLOCK_REASONS["qwen3_vl"] + assert "qwen3_vl" in mc._VERIFIED_TRAINING_ARCHES + assert "qwen3_vl_moe" in mc._VERIFIED_TRAINING_ARCHES # --------------------------------------------------------------------------- diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 7ff6ad552..ff6492a2f 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1744,7 +1744,7 @@ def mask_attention_mask_out(labels = None, attention_mask = None): num_items_in_batch = n_items, logit_softcapping = None if (\\4) == () else (\\4), ) -elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and NOT_RETURN_LOGITS: +elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: lm_head_weight = self.lm_head.weight lm_head_bias = getattr(self.lm_head, "bias", None) @@ -1767,20 +1767,6 @@ def mask_attention_mask_out(labels = None, attention_mask = None): logit_scale_divide = (\\3) if (\\3) != () else 0, logit_softcapping = (\\4) if (\\4) != () else 0, ) -elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: - # UNSLOTH_RETURN_LOGITS=1 path. Prepended `logits = self.lm_head(...)` - # already materialised the full lm_head matmul; apply the captured logit - # scale/softcap transforms and route loss through self.loss_function on - # those logits instead of letting unsloth_fused_ce_loss redo the matmul. - if (\\2) != (): - logits = logits * (\\2) - if (\\3) != (): - logits = logits / (\\3) - if (\\4) not in (None, (),): - logits = logits / (\\4) - logits = torch.tanh(logits) - logits = logits * (\\4) - loss = self.loss_function(logits, labels.to(self.lm_head.weight.device), vocab_size=\\8, **\\9) else: logits = self.lm_head(hidden_states\\1) if (\\2) != (): diff --git a/unsloth_zoo/mlx/compile.py b/unsloth_zoo/mlx/compile.py index d22f62674..81ba8d5b3 100644 --- a/unsloth_zoo/mlx/compile.py +++ b/unsloth_zoo/mlx/compile.py @@ -59,12 +59,8 @@ # Architectures explicitly verified for mlx compile support. # Training verification currently covers: # - qwen2_5_vl: real end-to-end compiled training via train.py -# - qwen3_5 / qwen3_5_moe / gemma4 / paligemma / moondream3: +# - qwen3_vl / qwen3_5 / qwen3_5_moe / gemma4 / paligemma / moondream3: # compiled synthetic forward+backward -# Qwen3-VL/Qwen3-VL-MoE stay patched but unqualified for real training compile: -# a fixed-fixture Qwen3-VL full-FT probe showed compiled vision backward/update -# drift relative to patched eager while sampled language gradients stayed -# aligned. Re-promote only after real 10-step training parity is verified. # SmolVLM has processor/template support, but real mlx-vlm training still hits # MLX primitive-less-array failures after a compiled call. Keep it patched but # unqualified until a real dataset compile run passes. @@ -97,6 +93,8 @@ "qwen2_5_vl", "qwen3_5", "qwen3_5_moe", + "qwen3_vl_moe", + "qwen3_vl", } _VERIFIED_GENERATION_ARCHES: set[str] = set() @@ -107,18 +105,6 @@ ), ) -_ARCH_TRAINING_COMPILE_BLOCK_REASONS: dict[str, str] = { - "qwen3_vl": ( - "Qwen3-VL real 10-step training parity is not verified: compiled " - "vision backward/update drifts relative to patched eager" - ), - "qwen3_vl_moe": ( - "Qwen3-VL-MoE shares the Qwen3-VL vision/deepstack training path; " - "keep compile disabled until Qwen3-VL real 10-step training parity " - "is verified" - ), -} - _BACKEND_CONFIG_KEYS = ( "text_config", @@ -1350,10 +1336,7 @@ def _build_compile_qualification( name for name in matched_patterns if name in _PATCHED_PATTERN_BUNDLES ) - arch_block_reason = _ARCH_TRAINING_COMPILE_BLOCK_REASONS.get(arch) - if arch_block_reason is not None and not generation_ok: - reason = arch_block_reason - elif training_ok or generation_ok: + if training_ok or generation_ok: ready = [] if training_ok: ready.append("training") diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index f93676367..1457c71c4 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -63,6 +63,7 @@ iterate_vlm_training_batches, normalize_mlx_chat_template, normalize_vlm_processor_chat_template, + _get_vlm_ignore_token_ids, collect_mlx_texts, save_lora_adapters, apply_gradient_checkpointing, @@ -198,14 +199,6 @@ def __init__( # Auto-detect VLM self._is_vlm = _is_vlm_model(model) - if self._is_vlm: - # VLM callers pass the processor through the tokenizer slot to - # mirror HF Trainer APIs. The loss is constructed before data - # preparation, so attach it now for image-token masking. - if self.processor is None: - self.processor = tokenizer - if self.processor is not None: - self.model._processor = self.processor # Constructor params override args if provided if dataset_text_field is not None: @@ -358,8 +351,8 @@ def main_schedule(step): if warmup > 0: def warmup_fn(step): step = mx.array(step) - step = mx.minimum(step + 1, mx.array(warmup)) - return step * (lr / (warmup + 1)) + step = mx.minimum(step, mx.array(warmup)) + return step * (lr / warmup) if callable(main_schedule): return optim.join_schedules( [warmup_fn, main_schedule], [warmup] @@ -445,10 +438,16 @@ def _build_optimizer(self, total_steps): **adam_kwargs, ) elif opt_name == "sgd": + # TODO: For HF Trainer parity, consider applying the same + # bias/norm weight-decay exclusion used by AdamW to SGD. optimizer = optim.SGD(learning_rate=initial_lr, weight_decay=wd) elif opt_name == "muon": + # TODO: For HF Trainer parity, consider applying the same + # bias/norm weight-decay exclusion used by AdamW to Muon. optimizer = optim.Muon(learning_rate=initial_lr, weight_decay=wd) elif opt_name == "lion": + # TODO: For HF Trainer parity, consider applying the same + # bias/norm weight-decay exclusion used by AdamW to Lion. optimizer = optim.Lion(learning_rate=initial_lr, weight_decay=wd) self._resolved_optimizer_name = opt_name return optimizer @@ -719,15 +718,32 @@ def _train_inner(self): # Pick loss function — returns (loss, ntoks) tuples use_cce = args.use_cce + _vlm_ignore_token_ids = None if is_vlm: + processor = self._resolve_vlm_processor() + # VLM collation owns label creation/masking. These IDs should be + # redundant for normal SFT batches and are only a loss-side + # compatibility backstop for missing or externally supplied labels. + _vlm_ignore_token_ids = _get_vlm_ignore_token_ids( + processor=processor, + config=getattr(model, "_config", {}), + ) _atid = args.assistant_token_id if args.train_on_completions else 0 if use_cce: - loss_fn = make_vlm_cce_loss_fn(model, assistant_token_id=_atid) + loss_fn = make_vlm_cce_loss_fn( + model, + assistant_token_id=_atid, + ignore_token_ids=_vlm_ignore_token_ids, + ) cce_backend = getattr(loss_fn, "_unsloth_cce_backend", "unknown") print(f"Unsloth: Using VLM CCE loss ({cce_backend}) for memory-efficient training.") else: - loss_fn = make_vlm_baseline_loss_fn(model, assistant_token_id=_atid) + loss_fn = make_vlm_baseline_loss_fn( + model, + assistant_token_id=_atid, + ignore_token_ids=_vlm_ignore_token_ids, + ) print("Unsloth: Using VLM standard cross-entropy loss.") else: if use_cce: @@ -1082,7 +1098,7 @@ def step_fn(batch_data, prev_state, do_update): if _labeled_eval is not None: eval_batches = _labeled_eval elif is_vlm: - processor = self.processor or getattr(self.model, "_processor", None) + processor = self._resolve_vlm_processor() config = getattr(self.model, "_config", {}) _vlm_mask_fn = getattr(self, '_vlm_response_mask_fn', None) eval_batches = create_vlm_batches( @@ -1379,6 +1395,40 @@ def step_fn(batch_data, prev_state, do_update): ), } + def _resolve_vlm_processor(self): + """Resolve the processor used for VLM collation without mutating model.""" + args = self.args + config = getattr(self.model, "_config", {}) + model_type = config.get("model_type") if isinstance(config, dict) else None + model_name = getattr(self.model, "_hf_repo", None) + + processor = self.processor + if processor is None and ( + hasattr(self.tokenizer, "image_processor") + or ( + hasattr(self.tokenizer, "tokenizer") + and hasattr(self.tokenizer, "apply_chat_template") + ) + ): + processor = self.tokenizer + if processor is None: + processor = getattr(self.model, "_processor", None) + if processor is None: + raise ValueError( + "VLM training requires a processor. Pass processor= to MLXTrainer " + "or load the model with FastLanguageModel.from_pretrained()." + ) + + processor = normalize_vlm_processor_chat_template( + processor, + chat_template=getattr(args, "vlm_chat_template", None), + model_name=model_name, + model_type=model_type, + strict=False, + ) + self.processor = processor + return processor + def _prepare_data(self, is_vlm): """Prepare training data. Returns (batches, batch_iter).""" args = self.args @@ -1387,21 +1437,7 @@ def _prepare_data(self, is_vlm): model_name = getattr(self.model, "_hf_repo", None) if is_vlm: - processor = self.processor or getattr(self.model, "_processor", None) - if processor is None: - raise ValueError( - "VLM training requires a processor. Pass processor= to MLXTrainer " - "or load the model with FastLanguageModel.from_pretrained()." - ) - processor = normalize_vlm_processor_chat_template( - processor, - chat_template=getattr(args, "vlm_chat_template", None), - model_name=model_name, - model_type=model_type, - strict=False, - ) - self.processor = processor - self.model._processor = processor + processor = self._resolve_vlm_processor() else: self.tokenizer = normalize_mlx_chat_template( self.tokenizer, diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 62de823b1..ca72df134 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -450,66 +450,143 @@ def loss_fn(model, batch, lengths, labels=None): "<|vision_end|>", # Qwen "<|vision_pad|>", # Qwen "<|image_pad|>", # Qwen + "<|video_pad|>", # Qwen "", # PaliGemma, Llava, InternVL "", # InternVL + "[IMG]", # Mistral + "[IMG_BREAK]", # Mistral + "[IMG_END]", # Mistral "", # Gemma 3 "", # Gemma 3 "", # Gemma 3 + "<|START_OF_IMG|>", # Cohere + "<|END_OF_IMG|>", # Cohere + "<|IMG_LINE_BREAK|>", # Cohere + "<|IMG_PATCH|>", # Cohere ) -def _get_image_token_ids(model): - """Resolve image token IDs from model's processor/tokenizer. +def _append_unique_int(ids, value): + if value is None: + return + if isinstance(value, (list, tuple, set)): + for item in value: + _append_unique_int(ids, item) + return + try: + value = int(value) + except (TypeError, ValueError): + return + if value not in ids: + ids.append(value) - Returns an mx.array of token IDs to mask from loss, or None if - no image tokens are found (non-VLM or tokenizer doesn't have them). - """ - processor = getattr(model, "_processor", None) - tokenizer = getattr(processor, "tokenizer", processor) if processor else None - if tokenizer is None: - return None - ids = [] - for tok_str in _IMAGE_TOKEN_STRINGS: +def _convert_token_to_id(tokenizer, token): + try: + token_ids = tokenizer.convert_tokens_to_ids([token]) + except Exception: try: - tok_ids = tokenizer.convert_tokens_to_ids([tok_str]) - if tok_ids and tok_ids[0] is not None: - # Some tokenizers return the unk_token_id for unknown tokens - unk_id = getattr(tokenizer, "unk_token_id", None) - if tok_ids[0] != unk_id: - ids.append(tok_ids[0]) + token_ids = tokenizer.convert_tokens_to_ids(token) except Exception: - continue + return None + if isinstance(token_ids, (list, tuple)): + token_id = token_ids[0] if token_ids else None + else: + token_id = token_ids + if token_id is None: + return None + unk_id = getattr(tokenizer, "unk_token_id", None) + if unk_id is not None and token_id == unk_id: + return None + return token_id - # Also check config for image_token_index / image_token_id - config = getattr(model, "_config", {}) - for key in ("image_token_index", "image_token_id"): - val = config.get(key) - if val is not None and val not in ids: - ids.append(val) + +def _get_vlm_ignore_token_ids(processor=None, config=None, model=None): + """Resolve VLM token IDs that should be ignored by SFT loss labels. + + Mirrors the CUDA vision collator's best-effort token masking without making + the loss depend on processor state attached to the model. + """ + if processor is None and model is not None: + processor = getattr(model, "_processor", None) + if config is None and model is not None: + config = getattr(model, "_config", None) + + ids = [] + tokenizer = _get_processor_tokenizer(processor) + if tokenizer is not None: + for tok_str in _IMAGE_TOKEN_STRINGS: + _append_unique_int(ids, _convert_token_to_id(tokenizer, tok_str)) + + for attr in ( + "image_token", + "video_token", + "audio_token", + "boi_token", + "eoi_token", + ): + token = getattr(tokenizer, attr, None) + if token is not None: + _append_unique_int(ids, _convert_token_to_id(tokenizer, token)) + + for attr in ( + "pad_token_id", + "image_token_id", + "video_token_id", + "audio_token_id", + ): + _append_unique_int(ids, getattr(tokenizer, attr, None)) + + for key in ( + "image_token_index", + "image_token_id", + "video_token_index", + "video_token_id", + "audio_token_index", + "audio_token_id", + "boi_token_index", + "boi_token_id", + "eoi_token_index", + "eoi_token_id", + "pad_token_id", + ): + _append_unique_int(ids, _config_get(config, key, None)) if not ids: return None return ids # plain Python list; avoids mx.eval in the hot path -def _mask_image_tokens(targets, image_token_ids): - """Set image token positions in targets to -100 (ignore_index). +def _get_image_token_ids(model): + """Backward-compatible wrapper for legacy callers.""" + return _get_vlm_ignore_token_ids(model=model) - Prevents the model from training to predict image placeholder tokens, - which are fixed special tokens that provide no useful gradient signal. - Args: - targets: mx.array of token IDs. - image_token_ids: plain Python list of int token IDs, or None. - """ - if not image_token_ids: +def _mask_label_token_ids(targets, ignore_token_ids, ignore_index=-100): + if not ignore_token_ids: return targets - # Build a mask: True where target is any image token - is_image = targets == image_token_ids[0] - for tok_id in image_token_ids[1:]: - is_image = is_image | (targets == tok_id) - return mx.where(is_image, -100, targets) + should_ignore = targets == ignore_token_ids[0] + for tok_id in ignore_token_ids[1:]: + should_ignore = should_ignore | (targets == tok_id) + return mx.where(should_ignore, ignore_index, targets) + + +def _mask_image_tokens(targets, image_token_ids): + """Set image/vision token positions in targets to -100.""" + return _mask_label_token_ids(targets, image_token_ids) + + +def _apply_vlm_label_masks(batch_dict, labels=None, ignore_token_ids=None, + ignore_index=-100): + if labels is None: + labels = batch_dict["input_ids"].astype(mx.int32) + else: + labels = labels.astype(mx.int32) + labels = _mask_label_token_ids(labels, ignore_token_ids, ignore_index) + attention_mask = batch_dict.get("attention_mask") + if attention_mask is not None: + labels = mx.where(attention_mask == 0, mx.array(ignore_index), labels) + return labels.astype(mx.int32) def _mask_prompt_tokens(targets, assistant_token_id): @@ -572,7 +649,8 @@ def _trim_sequence_aligned_vlm_kwargs(extra_kwargs, seq_len): return extra_kwargs -def make_vlm_baseline_loss_fn(model=None, assistant_token_id=0): +def make_vlm_baseline_loss_fn(model=None, assistant_token_id=0, + ignore_token_ids=None): """Create a standard cross-entropy loss function for VLMs. Takes a batch dict with input_ids, pixel_values, attention_mask. @@ -580,7 +658,11 @@ def make_vlm_baseline_loss_fn(model=None, assistant_token_id=0): Returns: A function (model, batch_dict) -> (loss, ntoks). """ - _image_token_ids = _get_image_token_ids(model) if model is not None else None + _image_token_ids = ( + ignore_token_ids + if ignore_token_ids is not None + else (_get_image_token_ids(model) if model is not None else None) + ) _assistant_token_id = assistant_token_id def loss_fn(model, batch_dict): @@ -609,8 +691,9 @@ def loss_fn(model, batch_dict): logits = logits.astype(mx.float32) if labels is not None: - # train_on_responses_only: labels encode instruction/padding masking. - # Still mask image placeholder tokens. + # Labels encode instruction/padding/special-token masking when + # produced by MLX VLM collation. The extra mask keeps legacy + # externally supplied labels compatible. targets = labels[:, 1:] targets = _mask_image_tokens(targets, _image_token_ids) logits, targets = _align_logits_with_labels(logits, targets) @@ -779,9 +862,9 @@ def _vlm_cce_forward(model, batch_dict, image_token_ids=None, hidden = hidden[:, :-1, :] if labels is not None: - # train_on_responses_only: labels already encode instruction and - # padding masking. Still need to mask image placeholder tokens - # since they provide no useful gradient signal. + # Labels are the source of truth. Collation should already encode + # instruction/padding/special-token masking; the extra mask preserves + # compatibility for externally supplied labels. targets = labels[:, 1:] masked_targets = _mask_image_tokens(targets, image_token_ids) ntoks = (masked_targets != -100).sum() @@ -875,71 +958,8 @@ def _normalize_int_tuple(values): return tuple(int(x) for x in values) -def _expand_image_token_sequences( - input_ids, - attention_mask, - image_token_id, - repeat_count, - labels=None, -): - input_ids_np = np.asarray(input_ids) - attention_mask_np = ( - np.asarray(attention_mask) - if attention_mask is not None - else np.ones_like(input_ids_np, dtype=np.int32) - ) - labels_np = np.asarray(labels) if labels is not None else None - - expanded_ids = [] - expanded_masks = [] - expanded_labels = [] if labels_np is not None else None - max_len = 0 - for row_idx, (row_ids, row_mask) in enumerate(zip(input_ids_np, attention_mask_np)): - new_ids = [] - new_mask = [] - new_labels = [] if labels_np is not None else None - row_labels_list = labels_np[row_idx].tolist() if labels_np is not None else None - for pos, (token_id, mask_value) in enumerate(zip(row_ids.tolist(), row_mask.tolist())): - if int(token_id) == int(image_token_id): - new_ids.extend([int(image_token_id)] * int(repeat_count)) - new_mask.extend([int(mask_value)] * int(repeat_count)) - if new_labels is not None: - new_labels.extend([-100] * int(repeat_count)) - else: - new_ids.append(int(token_id)) - new_mask.append(int(mask_value)) - if new_labels is not None: - new_labels.append(int(row_labels_list[pos])) - expanded_ids.append(new_ids) - expanded_masks.append(new_mask) - if expanded_labels is not None: - expanded_labels.append(new_labels) - max_len = max(max_len, len(new_ids)) - - padded_ids = np.zeros((len(expanded_ids), max_len), dtype=np.int32) - padded_masks = np.zeros((len(expanded_masks), max_len), dtype=np.int32) - padded_labels = ( - np.full((len(expanded_labels), max_len), -100, dtype=np.int32) - if expanded_labels is not None else None - ) - for row_idx, (row_ids, row_mask) in enumerate(zip(expanded_ids, expanded_masks)): - row_len = len(row_ids) - padded_ids[row_idx, :row_len] = row_ids - padded_masks[row_idx, :row_len] = row_mask - if padded_labels is not None: - padded_labels[row_idx, :row_len] = expanded_labels[row_idx] - - if padded_labels is not None: - return mx.array(padded_ids), mx.array(padded_masks), mx.array(padded_labels) - return mx.array(padded_ids), mx.array(padded_masks) - - -def _expand_token_runs( - input_ids, - attention_mask, - replacements_by_batch, - labels=None, -): +def _expand_token_replacements(input_ids, attention_mask, replacements_by_batch, + labels=None): input_ids_np = np.asarray(input_ids) attention_mask_np = ( np.asarray(attention_mask) @@ -966,17 +986,20 @@ def _expand_token_runs( row_ids_list = row_ids.tolist() row_mask_list = row_mask.tolist() for start, end, token_id, repeat in replacements: + start = int(start) + end = int(end) + repeat = int(repeat) if start > prev: new_ids.extend(row_ids_list[prev:start]) new_mask.extend(row_mask_list[prev:start]) if new_labels is not None: new_labels.extend(row_labels_list[prev:start]) - new_ids.extend([int(token_id)] * int(repeat)) + new_ids.extend([int(token_id)] * repeat) fill_mask = int(row_mask_list[start]) if start < len(row_mask_list) else 1 - new_mask.extend([fill_mask] * int(repeat)) + new_mask.extend([fill_mask] * repeat) if new_labels is not None: - new_labels.extend([-100] * int(repeat)) - prev = int(end) + new_labels.extend([-100] * repeat) + prev = end if prev < len(row_ids_list): new_ids.extend(row_ids_list[prev:]) new_mask.extend(row_mask_list[prev:]) @@ -1006,6 +1029,43 @@ def _expand_token_runs( return mx.array(padded_ids), mx.array(padded_masks) +def _expand_image_token_sequences( + input_ids, + attention_mask, + image_token_id, + repeat_count, + labels=None, +): + input_ids_np = np.asarray(input_ids) + replacements_by_batch = [] + for row in input_ids_np: + replacements = [] + for pos, token_id in enumerate(row.tolist()): + if int(token_id) == int(image_token_id): + replacements.append((pos, pos + 1, image_token_id, repeat_count)) + replacements_by_batch.append(tuple(replacements)) + return _expand_token_replacements( + input_ids=input_ids, + attention_mask=attention_mask, + replacements_by_batch=tuple(replacements_by_batch), + labels=labels, + ) + + +def _expand_token_runs( + input_ids, + attention_mask, + replacements_by_batch, + labels=None, +): + return _expand_token_replacements( + input_ids=input_ids, + attention_mask=attention_mask, + replacements_by_batch=replacements_by_batch, + labels=labels, + ) + + def _build_qwen_position_ids( input_ids, attention_mask, @@ -1447,7 +1507,7 @@ def _prepare_vlm_batch_for_compile(batch_dict, config): return batch_dict -def make_vlm_cce_loss_fn(model, assistant_token_id=0): +def make_vlm_cce_loss_fn(model, assistant_token_id=0, ignore_token_ids=None): """Create a CCE loss function for VLMs. Uses model.get_input_embeddings() to get merged vision+text embeddings, @@ -1473,7 +1533,11 @@ def make_vlm_cce_loss_fn(model, assistant_token_id=0): "falling back to baseline CE loss.", stacklevel=2, ) - return make_vlm_baseline_loss_fn(model, assistant_token_id=assistant_token_id) + return make_vlm_baseline_loss_fn( + model, + assistant_token_id=assistant_token_id, + ignore_token_ids=ignore_token_ids, + ) tm = _get_text_model(model) if getattr(tm, "model", None) is None and not _has_direct_hidden_stack(model): @@ -1483,7 +1547,11 @@ def make_vlm_cce_loss_fn(model, assistant_token_id=0): "falling back to baseline CE loss.", stacklevel=2, ) - return make_vlm_baseline_loss_fn(model, assistant_token_id=assistant_token_id) + return make_vlm_baseline_loss_fn( + model, + assistant_token_id=assistant_token_id, + ignore_token_ids=ignore_token_ids, + ) softcap = _get_logit_softcap(model) lm_layer = _get_lm_head_layer(model) @@ -1492,7 +1560,11 @@ def make_vlm_cce_loss_fn(model, assistant_token_id=0): # Must be called after LoRA setup. _skip_weight_grad = not _is_lm_head_trainable(model) - _image_token_ids = _get_image_token_ids(model) + _image_token_ids = ( + ignore_token_ids + if ignore_token_ids is not None + else _get_image_token_ids(model) + ) if _image_token_ids is not None: print(f"Unsloth: Masking {len(_image_token_ids)} image token IDs from VLM loss.") _assistant_token_id = assistant_token_id @@ -1606,6 +1678,10 @@ def _has_chat_template(obj): def _get_processor_tokenizer(processor): + if processor is None: + return None + if hasattr(processor, "_tokenizer"): + return processor._tokenizer return getattr(processor, "tokenizer", processor) @@ -2181,7 +2257,8 @@ def _processor_vlm_inputs(processor, texts, all_images, max_seq_length, suffixes return processor(**proc_kwargs) -def _collate_vlm_prompt_completion_batch(items, processor, max_seq_length, image_size): +def _collate_vlm_prompt_completion_batch(items, processor, max_seq_length, image_size, + ignore_token_ids=None): prompt_texts = [] combined_texts = [] all_images = [] @@ -2214,8 +2291,12 @@ def _collate_vlm_prompt_completion_batch(items, processor, max_seq_length, image processor, prompt_texts, all_images, max_seq_length ) batch = _to_mx_vlm_batch(combined_inputs) + batch["labels"] = _apply_vlm_label_masks( + batch, + ignore_token_ids=ignore_token_ids, + ) - labels_np = np.array(batch["input_ids"].tolist(), dtype=np.int32) + labels_np = np.array(batch["labels"].tolist(), dtype=np.int32) prompt_batch = _to_mx_vlm_batch(prompt_inputs) prompt_mask = prompt_batch.get("attention_mask") prompt_ids = prompt_batch["input_ids"] @@ -2225,14 +2306,12 @@ def _collate_vlm_prompt_completion_batch(items, processor, max_seq_length, image else: prompt_len = int(mx.sum(prompt_ids[row] != 0).item()) labels_np[row, :prompt_len] = -100 - labels = mx.array(labels_np) - if "attention_mask" in batch: - labels = mx.where(batch["attention_mask"] == 0, mx.array(-100), labels) - batch["labels"] = labels.astype(mx.int32) + batch["labels"] = mx.array(labels_np).astype(mx.int32) return batch -def _collate_vlm_batch(items, processor, max_seq_length, image_size, formatting_func=None): +def _collate_vlm_batch(items, processor, max_seq_length, image_size, + formatting_func=None, ignore_token_ids=None): """Collate a batch of VLM samples using the processor directly. Mirrors Unsloth's GPU UnslothVisionDataCollator: @@ -2255,7 +2334,8 @@ def _collate_vlm_batch(items, processor, max_seq_length, image_size, formatting_ and "completion" in formatted_items[0] ): return _collate_vlm_prompt_completion_batch( - formatted_items, processor, max_seq_length, image_size + formatted_items, processor, max_seq_length, image_size, + ignore_token_ids=ignore_token_ids, ) all_texts = [] @@ -2284,10 +2364,15 @@ def _collate_vlm_batch(items, processor, max_seq_length, image_size, formatting_ processor, all_texts, all_images, max_seq_length, suffixes=all_suffixes, ) - return _to_mx_vlm_batch(inputs) + batch = _to_mx_vlm_batch(inputs) + batch["labels"] = _apply_vlm_label_masks( + batch, + ignore_token_ids=ignore_token_ids, + ) + return batch -def _apply_response_mask_to_vlm_batch(batch_dict, mask_fn): +def _apply_response_mask_to_vlm_batch(batch_dict, mask_fn, ignore_token_ids=None): """Apply response masking to a VLM batch dict, adding 'labels' key. Converts input_ids to plain lists, runs the masking closure from @@ -2300,12 +2385,11 @@ def _apply_response_mask_to_vlm_batch(batch_dict, mask_fn): labels_list = result["labels"] if hasattr(labels_list, "tolist"): labels_list = labels_list.tolist() - attention_mask = batch_dict.get("attention_mask") - if attention_mask is not None: - labels_array = mx.where(attention_mask == 0, mx.array(-100), mx.array(labels_list)) - else: - labels_array = mx.array(labels_list) - batch_dict["labels"] = labels_array + batch_dict["labels"] = _apply_vlm_label_masks( + batch_dict, + labels=mx.array(labels_list), + ignore_token_ids=ignore_token_ids, + ) return batch_dict @@ -2320,6 +2404,7 @@ def create_vlm_batches(dataset, processor, config, batch_size, max_seq_length, import numpy as np image_size = _get_vlm_image_size(config, processor) + ignore_token_ids = _get_vlm_ignore_token_ids(processor=processor, config=config) indices = list(range(len(dataset))) np.random.seed(seed) @@ -2337,10 +2422,15 @@ def create_vlm_batches(dataset, processor, config, batch_size, max_seq_length, batch_dict = _collate_vlm_batch( items, processor, max_seq_length, image_size, formatting_func=formatting_func, + ignore_token_ids=ignore_token_ids, ) batch_dict = _prepare_vlm_batch_for_compile(batch_dict, config) if response_mask_fn is not None: - batch_dict = _apply_response_mask_to_vlm_batch(batch_dict, response_mask_fn) + batch_dict = _apply_response_mask_to_vlm_batch( + batch_dict, + response_mask_fn, + ignore_token_ids=ignore_token_ids, + ) batch_list.append(batch_dict) if num_batches is not None and len(batch_list) >= num_batches: break @@ -2369,15 +2459,21 @@ def iterate_vlm_training_batches(dataset, processor, config, batch_size, import numpy as np image_size = _get_vlm_image_size(config, processor) + ignore_token_ids = _get_vlm_ignore_token_ids(processor=processor, config=config) def _emit(items): batch_dict = _collate_vlm_batch( items, processor, max_seq_length, image_size, formatting_func=formatting_func, + ignore_token_ids=ignore_token_ids, ) batch_dict = _prepare_vlm_batch_for_compile(batch_dict, config) if response_mask_fn is not None: - batch_dict = _apply_response_mask_to_vlm_batch(batch_dict, response_mask_fn) + batch_dict = _apply_response_mask_to_vlm_batch( + batch_dict, + response_mask_fn, + ignore_token_ids=ignore_token_ids, + ) return batch_dict if hasattr(dataset, "__len__"): From 5c82ee2088a6d4f7b24b0b9c1e66e504ecd51aed Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Wed, 20 May 2026 16:31:39 -0500 Subject: [PATCH 08/48] bring back correct loss curves --- tests/test_pr_a_deep_components.py | 22 +++++++++++++++++++++- unsloth_zoo/mlx/trainer.py | 28 +++++++++++++++++++--------- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/tests/test_pr_a_deep_components.py b/tests/test_pr_a_deep_components.py index 396ce6916..1848ea5d3 100644 --- a/tests/test_pr_a_deep_components.py +++ b/tests/test_pr_a_deep_components.py @@ -89,7 +89,8 @@ def value_at(step): value = schedule(step) return value.item() if hasattr(value, "item") else float(value) - assert value_at(0) > 0.0 + assert value_at(0) == pytest.approx(0.0) + assert value_at(1) > value_at(0) assert value_at(4) < trainer.args.learning_rate assert value_at(5) == pytest.approx(trainer.args.learning_rate) @@ -127,6 +128,22 @@ def trainable_parameters(self): assert not MLXTrainer._should_apply_weight_decay("vision.blocks.0.norm1.weight") +def test_norm_clip_dtype_restore_keeps_lora_and_norms_promotable(): + from unsloth_zoo.mlx.trainer import MLXTrainer + + def should_restore_original_dtype(name): + return ( + not MLXTrainer._is_norm_parameter_name(name) + and not MLXTrainer._is_lora_parameter_name(name) + ) + + assert should_restore_original_dtype("model.layers.0.mlp.down_proj.weight") + assert not should_restore_original_dtype("model.layers.0.self_attn.q_proj.lora_a") + assert not should_restore_original_dtype("model.layers.0.self_attn.q_proj.lora_b") + assert not should_restore_original_dtype("model.layers.0.input_layernorm.weight") + assert not should_restore_original_dtype("vision.blocks.0.norm1.weight") + + @pytest.mark.parametrize( ("scheduler", "warmup"), [ @@ -171,6 +188,9 @@ def test_scheduler_lr_matches_expected_optimizer_update_steps(scheduler, warmup) trainer.args.learning_rate * 1 / 7, ] assert values == pytest.approx(expected) + elif warmup > 0: + assert values[0] == pytest.approx(0.0) + assert all(value > 0.0 for value in values[1:]) else: assert all(value > 0.0 for value in values) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 1457c71c4..4631cd3bd 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -463,6 +463,22 @@ def _should_apply_weight_decay(name, parameter=None): return False return True + @staticmethod + def _is_norm_parameter_name(name): + return any( + "norm" in part.lower() + for part in str(name).split(".") + if part + ) + + @staticmethod + def _is_lora_parameter_name(name): + return any( + "lora" in part.lower() + for part in str(name).split(".") + if part + ) + def _apply_manual_adamw_weight_decay(self, model, optimizer, grad): """Apply decoupled AdamW decay to trainable non-bias/non-norm leaves.""" wd = float(getattr(self, "_manual_adamw_weight_decay", 0.0) or 0.0) @@ -834,26 +850,20 @@ def _train_inner(self): not _clip_grad_value ) - def _is_norm_parameter_name(name): - return any( - "norm" in part.lower() - for part in str(name).split(".") - if part - ) - _restore_storage_after_norm_clip = max_grad_norm > 0 _trainable_storage_dtypes = ( { name: value.dtype for name, value in tree_flatten(model.trainable_parameters()) - if not _is_norm_parameter_name(name) + if not self._is_norm_parameter_name(name) + and not self._is_lora_parameter_name(name) } if _restore_storage_after_norm_clip else {} ) def _restore_trainable_storage_dtypes(): - """Keep norm-clipped MLX updates from promoting non-norm params.""" + """Keep norm-clipped MLX updates from promoting base params.""" if not _restore_storage_after_norm_clip: return recast = [] From a93449f5186df3ab09c21d91814c2eaa331ec5b8 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Wed, 20 May 2026 16:43:22 -0500 Subject: [PATCH 09/48] update textdataset --- unsloth_zoo/mlx/utils.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index ca72df134..caa8f5c17 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -697,7 +697,11 @@ def loss_fn(model, batch_dict): targets = labels[:, 1:] targets = _mask_image_tokens(targets, _image_token_ids) logits, targets = _align_logits_with_labels(logits, targets) - mask = (targets != -100).astype(mx.float32) + if attention_mask is not None: + length_mask = attention_mask[:, 1:][:, :targets.shape[1]] + else: + length_mask = mx.ones_like(targets, dtype=mx.float32) + mask = mx.logical_and(targets != -100, length_mask).astype(mx.float32) else: targets = input_ids[:, 1:] @@ -867,6 +871,15 @@ def _vlm_cce_forward(model, batch_dict, image_token_ids=None, # compatibility for externally supplied labels. targets = labels[:, 1:] masked_targets = _mask_image_tokens(targets, image_token_ids) + if attention_mask is not None: + length_mask = attention_mask[:, 1:][:, :masked_targets.shape[1]] + else: + length_mask = mx.ones_like(masked_targets, dtype=mx.bool_) + masked_targets = mx.where( + mx.logical_and(masked_targets != -100, length_mask), + masked_targets, + -100, + ) ntoks = (masked_targets != -100).sum() else: targets = input_ids[:, 1:] @@ -2511,8 +2524,7 @@ def _prepare_dataset(dataset, tokenizer, dataset_text_field="text", model_name=None, model_type=None): """Wrap a HuggingFace dataset into mlx-lm's dataset classes. - Uses TextDataset + CacheDataset from mlx_lm so that tokenization - (including EOS appending) matches mlx-lm's own training pipeline exactly. + Uses CacheDataset from mlx_lm while leaving rendered text token-exact. If a formatting_func is provided, each item is pre-formatted into a ``{"text": ...}`` dict before wrapping. @@ -2566,6 +2578,7 @@ def __init__(self, data, tokenizer, text_key="text"): self.text_key = text_key def process(self, item): + # Studio/chat templates own EOS; adding one here changes labels. return (self.tokenizer.encode(item[self.text_key]), 0) def __getitem__(self, idx): From dcd0a9001c281416092c34886602031b64670f9c Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Wed, 20 May 2026 22:52:56 -0500 Subject: [PATCH 10/48] dataset ordering fix, lr fix --- tests/test_mlx_vlm_label_masks.py | 73 +++++++++ tests/test_pr_a_deep_components.py | 2 + unsloth_zoo/mlx/trainer.py | 92 ++++++----- unsloth_zoo/mlx/utils.py | 252 +++++++++++++++++++++++++---- 4 files changed, 349 insertions(+), 70 deletions(-) diff --git a/tests/test_mlx_vlm_label_masks.py b/tests/test_mlx_vlm_label_masks.py index 451de7ab3..353f8bc8d 100644 --- a/tests/test_mlx_vlm_label_masks.py +++ b/tests/test_mlx_vlm_label_masks.py @@ -100,6 +100,79 @@ def mask_fn(_batch): assert out["labels"].tolist() == [[-100, -100, 13, -100]] +def test_vlm_processor_inputs_flattens_qwen_style_images(): + from unsloth_zoo.mlx.utils import _processor_vlm_inputs + + class QwenLikeProcessor: + __module__ = "mlx_vlm.models.qwen3_vl.processing_qwen3_vl" + + def __init__(self): + self.seen_images = None + + def __call__(self, text, images=None, **_kwargs): + self.seen_images = images + return { + "input_ids": np.ones((len(text), 2), dtype=np.int32), + "attention_mask": np.ones((len(text), 2), dtype=np.int32), + } + + processor = QwenLikeProcessor() + _processor_vlm_inputs(processor, ["a", "b"], [["img0"], ["img1"]], 8) + + assert processor.seen_images == ["img0", "img1"] + + +def test_vlm_processor_inputs_preserves_nested_image_processors(): + from unsloth_zoo.mlx.utils import _processor_vlm_inputs + + class PixtralLikeProcessor: + __module__ = "mlx_vlm.models.pixtral.processing_pixtral" + + def __init__(self): + self.seen_images = None + + def __call__(self, text, images=None, **_kwargs): + self.seen_images = images + return { + "input_ids": np.ones((len(text), 2), dtype=np.int32), + "attention_mask": np.ones((len(text), 2), dtype=np.int32), + } + + processor = PixtralLikeProcessor() + _processor_vlm_inputs(processor, ["a", "b"], [["img0"], ["img1", "img2"]], 8) + + assert processor.seen_images == [["img0"], ["img1", "img2"]] + + +@pytest.mark.parametrize( + "module_name, expected", + ( + ("mlx_vlm.models.qwen2_5_vl.processing_qwen2_5_vl", ["img0", "img1"]), + ("mlx_vlm.models.qwen3_5.processing_qwen3_vl", ["img0", "img1"]), + ("mlx_vlm.models.gemma4.processing_gemma4", ["img0", "img1"]), + ("mlx_vlm.models.gemma3.processing_gemma3", [["img0"], ["img1"]]), + ("mlx_vlm.models.idefics3.processing_idefics3", [["img0"], ["img1"]]), + ("mlx_vlm.models.deepseek_vl_v2.processing_deepsek_vl_v2", [["img0"], ["img1"]]), + ("mlx_vlm.models.falcon_ocr.processing_falcon_ocr", [["img0"], ["img1"]]), + ), +) +def test_vlm_processor_inputs_known_arch_image_layouts(module_name, expected): + from unsloth_zoo.mlx.utils import _processor_vlm_inputs + + def call(self, text, images=None, **_kwargs): + self.seen_images = images + return { + "input_ids": np.ones((len(text), 2), dtype=np.int32), + "attention_mask": np.ones((len(text), 2), dtype=np.int32), + } + + Processor = type("Processor", (), {"__module__": module_name, "__call__": call}) + processor = Processor() + _processor_vlm_inputs(processor, ["a", "b"], [["img0"], ["img1"]], 8) + + assert processor.seen_images == expected + + def test_token_expansion_masks_inserted_label_positions(): from unsloth_zoo.mlx.utils import _expand_token_runs diff --git a/tests/test_pr_a_deep_components.py b/tests/test_pr_a_deep_components.py index 1848ea5d3..645f248e6 100644 --- a/tests/test_pr_a_deep_components.py +++ b/tests/test_pr_a_deep_components.py @@ -60,6 +60,8 @@ def test_mlx_training_config_is_dataclass_with_all_fields(): "use_cce", "compile", "gradient_checkpointing", + "dataset_order", + "preserve_dataset_order", ): assert must_have in fields, f"missing field: {must_have}" diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 4631cd3bd..1ad3414d7 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -58,6 +58,7 @@ make_vlm_cce_loss_fn, make_vlm_baseline_loss_fn, create_batches, + create_ordered_batches, iterate_training_batches, create_vlm_batches, iterate_vlm_training_batches, @@ -160,6 +161,8 @@ class MLXTrainingConfig: compile_trace: bool = True gradient_checkpointing: bool = True streaming: bool = False # Use streaming iterator instead of materializing batches + dataset_order: str = "default" # "default", "sequential", or "torch_randperm" + preserve_dataset_order: bool = False # Match Studio CUDA SequentialSampler order memory_limit_gb: float | None = None # None = auto Metal guard (~85% of recommended working set); <= 0 disables cache_limit_gb: float | None = None # Optional MLX Metal cache cap in GB; <= 0 disables override wired_limit_gb: float | None = None # None = min(recommended working set, memory limit); <= 0 disables @@ -323,47 +326,39 @@ def _build_schedule(self, total_steps): if sched_type == "constant" and warmup == 0: return lr - decay_steps = max(total_steps - warmup, 1) - - if sched_type == "linear" and warmup == 0: - # Match the Studio CUDA/Trainer path observed in fixed-fixture - # probes: linear/no-warmup starts with a zero-LR optimizer step, - # then decays from the requested LR over the remaining steps. - decay_after_zero = max(total_steps - 1, 1) - - def main_schedule(step): - step = mx.array(step) - decay = mx.maximum( - mx.array(total_steps, dtype=mx.float32) - step, - mx.array(0.0, dtype=mx.float32), - ) / mx.array(decay_after_zero, dtype=mx.float32) - return mx.where(step <= 0, mx.array(0.0, dtype=mx.float32), lr * decay) - - return main_schedule - - if sched_type == "cosine": - main_schedule = optim.cosine_decay(lr, decay_steps, end=0.0) - elif sched_type == "linear": - main_schedule = optim.linear_schedule(lr, 0.0, decay_steps) - else: # constant - main_schedule = lr - - if warmup > 0: - def warmup_fn(step): - step = mx.array(step) - step = mx.minimum(step, mx.array(warmup)) - return step * (lr / warmup) - if callable(main_schedule): - return optim.join_schedules( - [warmup_fn, main_schedule], [warmup] - ) + def warmup_multiplier(step): + if warmup <= 0: + return mx.array(1.0, dtype=mx.float32) + return step / mx.array(max(warmup, 1), dtype=mx.float32) + + def decay_progress(step): + return ( + step - mx.array(warmup, dtype=mx.float32) + ) / mx.array(max(total_steps - warmup, 1), dtype=mx.float32) + + def schedule(step): + # Match HuggingFace/Trainer LR as seen by the optimizer before + # each update. ``step`` is zero-based optimizer-step index. + step = mx.array(step).astype(mx.float32) + if warmup > 0: + warm = lr * warmup_multiplier(step) else: - const_fn = optim.linear_schedule(lr, lr, decay_steps) - return optim.join_schedules( - [warmup_fn, const_fn], [warmup] + warm = mx.array(lr, dtype=mx.float32) + + progress = decay_progress(step) + if sched_type == "cosine": + decay = mx.array(0.5, dtype=mx.float32) * ( + mx.array(1.0, dtype=mx.float32) + mx.cos(mx.array(math.pi) * progress) ) + elif sched_type == "linear": + decay = mx.array(1.0, dtype=mx.float32) - progress + else: # constant with warmup + decay = mx.array(1.0, dtype=mx.float32) + decay = mx.maximum(decay, mx.array(0.0, dtype=mx.float32)) + main = mx.array(lr, dtype=mx.float32) * decay + return mx.where(step < warmup, warm, main) - return main_schedule + return schedule @staticmethod def _schedule_value(schedule, step): @@ -1468,6 +1463,11 @@ def _prepare_data(self, is_vlm): if is_vlm: _vlm_mask_fn = getattr(self, '_vlm_response_mask_fn', None) + vlm_dataset_order = ( + "sequential" + if getattr(args, "preserve_dataset_order", False) + else getattr(args, "dataset_order", "default") + ) if args.streaming: return None, iterate_vlm_training_batches( dataset=self.train_dataset, @@ -1478,6 +1478,7 @@ def _prepare_data(self, is_vlm): seed=args.seed, response_mask_fn=_vlm_mask_fn, formatting_func=self.formatting_func, + dataset_order=vlm_dataset_order, ) else: batches = create_vlm_batches( @@ -1490,6 +1491,7 @@ def _prepare_data(self, is_vlm): seed=args.seed, response_mask_fn=_vlm_mask_fn, formatting_func=self.formatting_func, + dataset_order=vlm_dataset_order, ) if _vlm_mask_fn is not None and batches: _check_vlm_all_masked(batches) @@ -1510,7 +1512,7 @@ def _prepare_data(self, is_vlm): model_type=model_type, ) else: - batches = create_batches( + batch_kwargs = dict( dataset=self.train_dataset, tokenizer=self.tokenizer, batch_size=args.per_device_train_batch_size, @@ -1523,6 +1525,18 @@ def _prepare_data(self, is_vlm): model_name=model_name, model_type=model_type, ) + if ( + getattr(args, "preserve_dataset_order", False) + or getattr(args, "dataset_order", "default") != "default" + ): + batch_kwargs["dataset_order"] = ( + "sequential" + if getattr(args, "preserve_dataset_order", False) + else getattr(args, "dataset_order", "default") + ) + batches = create_ordered_batches(**batch_kwargs) + else: + batches = create_batches(**batch_kwargs) return batches, None def save_model(self, output_dir=None): diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index caa8f5c17..9fb25da23 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -2219,10 +2219,63 @@ def _extract_vlm_images(item, messages, image_size): return _resize_vlm_images(images, image_size) -def _format_vlm_images_for_processor(all_images): +def _flatten_vlm_images(all_images): + flattened = [] + for images in all_images: + if isinstance(images, (list, tuple)): + flattened.extend(images) + else: + flattened.append(images) + return flattened + + +def _nest_vlm_images_by_sample(all_images): + nested = [] + for images in all_images: + if images is None: + nested.append([]) + elif isinstance(images, (list, tuple)): + nested.append(list(images)) + else: + nested.append([images]) + return nested + + +def _vlm_processor_prefers_nested_images(processor): + cls = processor.__class__ + marker = f"{getattr(cls, '__module__', '')}.{getattr(cls, '__name__', '')}".lower() + # Some mlx-vlm processors count images per prompt and require + # images=[[sample0_img0, ...], [sample1_img0, ...]]. Qwen/LLaVA/Gemma4-style + # processors consume a flat image stream and are handled by the default path. + return any( + name in marker + for name in ( + "deepseek_vl", + "falcon", + "gemma3", + "gemma3n", + "idefics", + "lfm2_vl", + "minicpmo", + "mistral", + "mllama", + "paligemma", + "pixtral", + "smolvlm", + ) + ) + + +def _format_vlm_images_for_processor(all_images, processor=None, image_layout=None): if not any(all_images): return None - return all_images + if image_layout == "nested": + return _nest_vlm_images_by_sample(all_images) + if image_layout == "flat": + return _flatten_vlm_images(all_images) + if processor is not None and _vlm_processor_prefers_nested_images(processor): + return _nest_vlm_images_by_sample(all_images) + return _flatten_vlm_images(all_images) def _to_mx_vlm_batch(inputs): @@ -2254,7 +2307,7 @@ def _to_mx_vlm_batch(inputs): def _processor_vlm_inputs(processor, texts, all_images, max_seq_length, suffixes=None): - proc_kwargs = dict( + base_kwargs = dict( text=texts, padding=True, truncation=True, @@ -2262,12 +2315,35 @@ def _processor_vlm_inputs(processor, texts, all_images, max_seq_length, suffixes return_tensors="np", add_special_tokens=False, ) - images = _format_vlm_images_for_processor(all_images) + images = _format_vlm_images_for_processor(all_images, processor=processor) if images is not None: - proc_kwargs["images"] = images + image_layouts = ( + ("nested", "flat") + if _vlm_processor_prefers_nested_images(processor) + else ("flat", "nested") + ) + else: + image_layouts = (None,) if suffixes is not None and any(suffix is not None for suffix in suffixes): - proc_kwargs["suffix"] = [suffix or "" for suffix in suffixes] - return processor(**proc_kwargs) + base_kwargs["suffix"] = [suffix or "" for suffix in suffixes] + + first_error = None + for image_layout in image_layouts: + proc_kwargs = dict(base_kwargs) + if image_layout is not None: + proc_kwargs["images"] = _format_vlm_images_for_processor( + all_images, + processor=processor, + image_layout=image_layout, + ) + try: + return processor(**proc_kwargs) + except Exception as exc: + if first_error is None: + first_error = exc + if len(image_layouts) == 1: + raise + raise first_error def _collate_vlm_prompt_completion_batch(items, processor, max_seq_length, image_size, @@ -2408,7 +2484,7 @@ def _apply_response_mask_to_vlm_batch(batch_dict, mask_fn, ignore_token_ids=None def create_vlm_batches(dataset, processor, config, batch_size, max_seq_length, num_batches=None, seed=42, response_mask_fn=None, - formatting_func=None): + formatting_func=None, dataset_order="default"): """Pre-materialize VLM training batches using the processor directly. Mirrors Unsloth's GPU UnslothVisionDataCollator: @@ -2419,18 +2495,36 @@ def create_vlm_batches(dataset, processor, config, batch_size, max_seq_length, image_size = _get_vlm_image_size(config, processor) ignore_token_ids = _get_vlm_ignore_token_ids(processor=processor, config=config) - indices = list(range(len(dataset))) - np.random.seed(seed) - if num_batches is not None: - np.random.shuffle(indices) - - batch_indices = [ - indices[i : i + batch_size] - for i in range(0, len(indices), batch_size) - ] - batch_list = [] - for bi in batch_indices: + seen = 0 + epoch = 0 + indices = list(range(len(dataset))) + if dataset_order == "torch_randperm": + indices = _torch_randperm_order(len(dataset), seed) + elif dataset_order in (None, "default"): + if num_batches is not None: + np.random.seed(seed) + np.random.shuffle(indices) + elif dataset_order != "sequential": + raise ValueError(f"Unsupported MLX VLM dataset_order: {dataset_order!r}") + + while num_batches is None or len(batch_list) < num_batches: + if seen >= len(indices): + if num_batches is None: + break + epoch += 1 + seen = 0 + indices = list(range(len(dataset))) + if dataset_order == "torch_randperm": + indices = _torch_randperm_order(len(dataset), int(seed) + epoch) + elif dataset_order in (None, "default"): + np.random.seed(int(seed) + epoch) + np.random.shuffle(indices) + + bi = indices[seen : seen + batch_size] + seen += len(bi) + if not bi: + break items = [dataset[idx] for idx in bi] batch_dict = _collate_vlm_batch( items, processor, max_seq_length, image_size, @@ -2445,8 +2539,6 @@ def create_vlm_batches(dataset, processor, config, batch_size, max_seq_length, ignore_token_ids=ignore_token_ids, ) batch_list.append(batch_dict) - if num_batches is not None and len(batch_list) >= num_batches: - break # Evaluate all tensors all_tensors = [] @@ -2463,7 +2555,8 @@ def create_vlm_batches(dataset, processor, config, batch_size, max_seq_length, def iterate_vlm_training_batches(dataset, processor, config, batch_size, max_seq_length, seed=42, response_mask_fn=None, - formatting_func=None): + formatting_func=None, + dataset_order="default"): """Streaming VLM batch generator using processor directly. Yields batch dicts with input_ids, pixel_values, attention_mask, @@ -2490,18 +2583,32 @@ def _emit(items): return batch_dict if hasattr(dataset, "__len__"): - indices = list(range(len(dataset))) - batch_indices = [ - indices[i : i + batch_size] - for i in range(0, len(indices), batch_size) - ] - if not batch_indices: + if len(dataset) <= 0: raise ValueError("Unsloth MLX VLM: streaming dataset produced no rows.") + epoch = 0 while True: - order = np.random.permutation(len(batch_indices)) - for b in order: - items = [dataset[idx] for idx in batch_indices[b]] + if dataset_order == "torch_randperm": + indices = _torch_randperm_order(len(dataset), int(seed) + epoch) + elif dataset_order == "sequential": + indices = list(range(len(dataset))) + elif dataset_order in (None, "default"): + indices = list(range(len(dataset))) + batch_indices = [ + indices[i : i + batch_size] + for i in range(0, len(indices), batch_size) + ] + order = np.random.permutation(len(batch_indices)) + for b in order: + items = [dataset[idx] for idx in batch_indices[b]] + yield _emit(items) + epoch += 1 + continue + else: + raise ValueError(f"Unsupported MLX VLM dataset_order: {dataset_order!r}") + for start in range(0, len(indices), batch_size): + items = [dataset[idx] for idx in indices[start : start + batch_size]] yield _emit(items) + epoch += 1 else: while True: pending = [] @@ -2634,6 +2741,89 @@ def create_batches(dataset, tokenizer, batch_size, max_seq_length, return batch_pairs +def _torch_randperm_order(length, seed): + try: + import torch + except Exception as exc: + raise ImportError( + "Unsloth MLX: dataset_order='torch_randperm' requires torch so MLX " + "Studio can mirror CUDA Studio batch order." + ) from exc + generator = torch.Generator() + generator.manual_seed(3407 if seed is None else int(seed)) + return torch.randperm(length, generator=generator).tolist() + + +def create_ordered_batches(dataset, tokenizer, batch_size, max_seq_length, + num_batches=None, seed=None, dataset_order="sequential", + dataset_text_field="text", + formatting_func=None, chat_template=None, + model_name=None, model_type=None): + """Create text batches with an explicit dataset order. + + Studio uses this to mirror CUDA's effective sampler stream without + changing generic mlx-lm batching behavior. + """ + + ds = _prepare_dataset( + dataset, tokenizer, dataset_text_field, formatting_func, + chat_template=chat_template, + model_name=model_name, + model_type=model_type, + ) + + tokenized = [] + for row in ds: + ids = row[0] if isinstance(row, (tuple, list)) else row + ids = list(ids)[:max_seq_length] + if len(ids) >= 2: + tokenized.append(ids) + + if not tokenized: + return [] + + def make_order(epoch): + base_seed = 3407 if seed is None else int(seed) + if dataset_order == "torch_randperm": + return _torch_randperm_order(len(tokenized), base_seed + epoch) + if dataset_order not in (None, "sequential"): + raise ValueError(f"Unsupported MLX dataset_order: {dataset_order!r}") + return list(range(len(tokenized))) + + batch_pairs = [] + epoch = 0 + order = make_order(epoch) + order_pos = 0 + seen = 0 + while num_batches is None or len(batch_pairs) < num_batches: + batch_items = [] + for _ in range(batch_size): + if order_pos >= len(order): + epoch += 1 + order = make_order(epoch) + order_pos = 0 + batch_items.append(tokenized[order[order_pos]]) + order_pos += 1 + seen += 1 + if num_batches is None and seen >= len(tokenized): + break + + max_length = max(len(ids) for ids in batch_items) + batch_ids = [] + lengths = [] + for ids in batch_items: + length = len(ids) + batch_ids.append(ids + [0] * (max_length - length)) + lengths.append([0, length]) + batch_pairs.append((mx.array(batch_ids), mx.array(lengths), None)) + + if num_batches is None and seen >= len(tokenized): + break + + mx.eval([b for b, l, _ in batch_pairs] + [l for _, l, _ in batch_pairs]) + return batch_pairs + + def iterate_training_batches(dataset, tokenizer, batch_size, max_seq_length, seed=42, dataset_text_field="text", formatting_func=None, chat_template=None, From b0a83b52cf7763e5a85b4333b94d5c8b1f1fdd52 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Wed, 20 May 2026 23:33:03 -0500 Subject: [PATCH 11/48] use proportional MLX grad value clipping --- unsloth_zoo/mlx/trainer.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 46aa8323d..8ece77577 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -123,7 +123,8 @@ class MLXTrainingConfig: adam_beta1: float | None = None adam_beta2: float | None = None max_grad_norm: float = 0.0 # disabled by default on MLX to avoid clip-memory overhead - # Elementwise clip ([-max_grad_value, max_grad_value], per-leaf). + # Proportional per-tensor clipping. This is cheaper than global norm, but + # preserves each tensor's gradient direction unlike elementwise value clip. # None (default) keeps the cheap MLX default of 1.0 unless the user # passes max_grad_norm > 0, in which case global-norm clipping wins. # 0.0 disables. A positive float opts in explicitly and overrides @@ -952,6 +953,17 @@ def _can_report_optimizer_state_norm(): # This avoids adding a second consumer to the lazy backward graph. return getattr(optimizer, "betas", None) + def _clip_grad_by_leaf_norm(grad): + if not _clip_grad_value: + return grad + def _clip_leaf_norm(g): + g_f = g.astype(mx.float32) + norm = mx.sqrt(mx.sum(g_f * g_f)) + scale = mx.minimum(max_grad_value / (norm + 1e-6), 1.0) + return g * scale.astype(g.dtype) + + return tree_map(_clip_leaf_norm, grad) + def _apply_update(grad, toks_f): """Common gradient post-processing and optimizer update. @@ -977,11 +989,7 @@ def _apply_update(grad, toks_f): final_grad, max_norm=max_grad_norm ) if _clip_grad_value: - # Elementwise clip after norm-scaling, before optimizer step. - final_grad = tree_map( - lambda g: mx.clip(g, -max_grad_value, max_grad_value), - final_grad, - ) + final_grad = _clip_grad_by_leaf_norm(final_grad) self._apply_manual_adamw_weight_decay(model, optimizer, final_grad) optimizer.update(model, final_grad) _restore_trainable_storage_dtypes() @@ -1004,11 +1012,7 @@ def _apply_update_direct(grad): if max_grad_norm > 0: grad, grad_norm = optim.clip_grad_norm(grad, max_norm=max_grad_norm) if _clip_grad_value: - # Elementwise clip per leaf — free memory (no cross-leaf reduction). - grad = tree_map( - lambda g: mx.clip(g, -max_grad_value, max_grad_value), - grad, - ) + grad = _clip_grad_by_leaf_norm(grad) self._apply_manual_adamw_weight_decay(model, optimizer, grad) optimizer.update(model, grad) _restore_trainable_storage_dtypes() From 964be34ec2e570417c8518d822da0715fb400113 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Thu, 21 May 2026 10:07:51 -0500 Subject: [PATCH 12/48] cast norm activation output back to original input dtype --- unsloth_zoo/mlx/trainer.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 8ece77577..2840abe6f 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -93,6 +93,39 @@ def _normalize_mlx_optimizer_name(name): return opt_name +def _set_norm_output_cast_to_input_dtype(enabled: bool) -> None: + """Control whether norm outputs are cast back to activation dtype. + + Norm parameters can stay in fp32 for stability, but letting fp32 norm + outputs flow through the rest of the graph promotes downstream + intermediates and materially increases LoRA/QLoRA memory. Casting the + result back matches PyTorch autocast behavior more closely: fp32 norm math, + bf16/fp16 downstream activations. + """ + for norm_cls in (nn.RMSNorm, nn.LayerNorm): + patched = getattr(norm_cls, "_unsloth_cast_output_to_input_dtype", False) + if enabled: + if patched: + continue + original_call = norm_cls.__call__ + + def norm_call_cast_output(self, x, *args, _original_call=original_call, **kwargs): + out = _original_call(self, x, *args, **kwargs) + if hasattr(x, "dtype") and hasattr(out, "dtype") and out.dtype != x.dtype: + return out.astype(x.dtype) + return out + + norm_cls._unsloth_original_call = original_call + norm_cls.__call__ = norm_call_cast_output + norm_cls._unsloth_cast_output_to_input_dtype = True + elif patched: + original_call = getattr(norm_cls, "_unsloth_original_call", None) + if original_call is not None: + norm_cls.__call__ = original_call + norm_cls._unsloth_original_call = None + norm_cls._unsloth_cast_output_to_input_dtype = False + + def _normalize_mlx_scheduler_type(name): sched_type = str(name or "linear").strip().lower() if sched_type not in SUPPORTED_MLX_LR_SCHEDULERS: @@ -168,6 +201,7 @@ class MLXTrainingConfig: cache_limit_gb: float | None = None # Optional MLX Metal cache cap in GB; <= 0 disables override wired_limit_gb: float | None = None # None = min(recommended working set, memory limit); <= 0 disables disable_memory_limits: bool = False + cast_norm_output_to_input_dtype: bool = True # fp32 norm storage/math, bf16/fp16 downstream activations # VLM / completion masking train_on_completions: bool = False # Mask prompt tokens in loss @@ -652,6 +686,10 @@ def train(self): """ args = self.args model = self.model + cast_norm_output = bool(getattr(args, "cast_norm_output_to_input_dtype", True)) + _set_norm_output_cast_to_input_dtype(cast_norm_output) + if cast_norm_output: + print("Unsloth: Casting MLX norm outputs back to activation dtype.") args.patch_mode = normalize_mlx_patch_mode(getattr(args, "patch_mode", "patched")) model._unsloth_patch_mode = args.patch_mode From ca08652226e3d0355cfe132360c0187ab1bbc77d Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Thu, 21 May 2026 12:23:33 -0500 Subject: [PATCH 13/48] address mlx training review feedback --- tests/test_mlx_pr684_review_fixes.py | 206 +++++++++++++++++++++++++++ unsloth_zoo/compiler.py | 2 +- unsloth_zoo/mlx/compile.py | 10 +- unsloth_zoo/mlx/loader.py | 19 +-- unsloth_zoo/mlx/trainer.py | 38 ++++- unsloth_zoo/mlx/utils.py | 42 ++++-- 6 files changed, 284 insertions(+), 33 deletions(-) create mode 100644 tests/test_mlx_pr684_review_fixes.py diff --git a/tests/test_mlx_pr684_review_fixes.py b/tests/test_mlx_pr684_review_fixes.py new file mode 100644 index 000000000..2930b3b3f --- /dev/null +++ b/tests/test_mlx_pr684_review_fixes.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import inspect + +import numpy as np +import pytest + + +mx = pytest.importorskip("mlx.core") +if "mlx_simulation" in str(getattr(mx, "__file__", "")): + pytest.skip("requires real MLX runtime", allow_module_level=True) + + +def _skip_if_mlx_core_was_replaced(): + import mlx.core as current_mx + if current_mx is not mx: + pytest.skip("requires real MLX runtime without mlx_simulation monkeypatch") + + +class _TinyTokenizer: + pad_token_id = 2 + eos_token_id = 2 + unk_token_id = -1 + image_token_id = 200 + + def encode(self, text): + return [int(part) for part in str(text).split()] + + def convert_tokens_to_ids(self, token): + if isinstance(token, list): + return [self.convert_tokens_to_ids(item) for item in token] + return {"": 200, "<|image_pad|>": 201}.get(token, self.unk_token_id) + + +class _ContentProcessor: + tokenizer = _TinyTokenizer() + image_processor = object() + + def __call__(self, text, **_kwargs): + rows = [[int(item), 200, 2] for item in text] + masks = [[1, 1, 1] for _ in rows] + return { + "input_ids": np.array(rows, dtype=np.int32), + "attention_mask": np.array(masks, dtype=np.int32), + } + + +def test_vlm_ignore_ids_exclude_pad_even_when_pad_is_eos(): + from unsloth_zoo.mlx.utils import _get_vlm_ignore_token_ids + + ids = _get_vlm_ignore_token_ids( + processor=_ContentProcessor(), + config={"pad_token_id": 2, "image_token_id": 200}, + ) + + assert 200 in ids + assert 2 not in ids + + +def test_vlm_label_mask_keeps_in_sequence_pad_eos_token(): + from unsloth_zoo.mlx.utils import _apply_vlm_label_masks + + batch = { + "input_ids": mx.array([[101, 2, 200, 2]], dtype=mx.int32), + "attention_mask": mx.array([[1, 1, 1, 0]], dtype=mx.int32), + } + out = _apply_vlm_label_masks( + batch, + labels=batch["input_ids"], + ignore_token_ids=[200], + ) + + assert out.tolist() == [[101, 2, -100, -100]] + + +def test_manual_adamw_weight_decay_accepts_scalar_lr_and_preserves_dtype(): + from mlx.utils import tree_flatten + from unsloth_zoo.mlx.trainer import MLXTrainer + + class TinyModel: + def __init__(self): + self.params = { + "layer": { + "weight": mx.array([10.0], dtype=mx.bfloat16), + "bias": mx.array([10.0], dtype=mx.bfloat16), + }, + "norm": {"weight": mx.array([10.0], dtype=mx.float32)}, + } + + def trainable_parameters(self): + return self.params + + def update(self, updates): + def merge(dst, src): + for key, value in src.items(): + if isinstance(value, dict): + merge(dst[key], value) + else: + dst[key] = value + merge(self.params, updates) + + class TinyOptimizer: + learning_rate = 0.1 + + model = TinyModel() + grad = { + "layer": { + "weight": mx.array([1.0], dtype=mx.bfloat16), + "bias": mx.array([1.0], dtype=mx.bfloat16), + }, + "norm": {"weight": mx.array([1.0], dtype=mx.float32)}, + } + trainer = object.__new__(MLXTrainer) + trainer._manual_adamw_weight_decay = 0.1 + + trainer._apply_manual_adamw_weight_decay(model, TinyOptimizer(), grad) + flat = dict(tree_flatten(model.trainable_parameters())) + + assert flat["layer.weight"].dtype == mx.bfloat16 + assert flat["layer.weight"].item() < 10.0 + assert flat["layer.bias"].item() == pytest.approx(10.0) + assert flat["norm.weight"].item() == pytest.approx(10.0) + + +def test_nf4_dense_zero_group_dequantizes_to_zero_without_epsilon_scale(): + _skip_if_mlx_core_was_replaced() + from unsloth_zoo.mlx.loader import _nf4_dense_dequantize_weight + + weight = mx.zeros((1, 4), dtype=mx.float32) + out = _nf4_dense_dequantize_weight(weight, group_size=4) + + assert out.tolist() == [[0.0, 0.0, 0.0, 0.0]] + + +def test_ordered_text_batches_raise_clear_error_when_all_rows_drop(): + from unsloth_zoo.mlx.utils import create_ordered_batches + + with pytest.raises(ValueError, match="no trainable token sequences"): + create_ordered_batches( + dataset=[{"text": "1"}], + tokenizer=_TinyTokenizer(), + batch_size=1, + max_seq_length=1, + dataset_order="sequential", + ) + + +def test_ordered_text_torch_randperm_can_materialize_multiple_epochs(): + _skip_if_mlx_core_was_replaced() + from unsloth_zoo.mlx.utils import create_ordered_batches + + batches = create_ordered_batches( + dataset=[{"text": f"{i} {i + 10}"} for i in range(5)], + tokenizer=_TinyTokenizer(), + batch_size=1, + max_seq_length=4, + seed=None, + dataset_order="torch_randperm", + num_epochs=2, + ) + + first_epoch = [int(batch[0, 0].item()) for batch, _lengths, _labels in batches[:5]] + second_epoch = [int(batch[0, 0].item()) for batch, _lengths, _labels in batches[5:]] + assert len(batches) == 10 + assert sorted(first_epoch) == [0, 1, 2, 3, 4] + assert sorted(second_epoch) == [0, 1, 2, 3, 4] + assert first_epoch != second_epoch + + +def test_vlm_torch_randperm_seed_none_and_multi_epoch_batches(): + _skip_if_mlx_core_was_replaced() + from unsloth_zoo.mlx.utils import create_vlm_batches + + batches = create_vlm_batches( + dataset=[{"text": str(i)} for i in range(5)], + processor=_ContentProcessor(), + config={"image_size": 16, "image_token_id": 200}, + batch_size=1, + max_seq_length=8, + seed=None, + dataset_order="torch_randperm", + num_epochs=2, + ) + + first_epoch = [int(batch["input_ids"][0, 0].item()) for batch in batches[:5]] + second_epoch = [int(batch["input_ids"][0, 0].item()) for batch in batches[5:]] + assert len(batches) == 10 + assert sorted(first_epoch) == [0, 1, 2, 3, 4] + assert sorted(second_epoch) == [0, 1, 2, 3, 4] + assert first_epoch != second_epoch + + +def test_pr684_compiler_review_guards_are_present(): + import unsloth_zoo.compiler as compiler + import unsloth_zoo.mlx.compile as mlx_compile + + compiler_source = inspect.getsource(compiler) + mlx_compile_source = inspect.getsource(mlx_compile) + + assert ( + 'self.loss_function.__name__.endswith("ForCausalLMLoss") ' + "and labels is not None and NOT_RETURN_LOGITS" + ) in compiler_source + assert '"weight" in norm' not in mlx_compile_source + assert '"bias" in norm' not in mlx_compile_source + assert 'getattr(norm, "weight", None)' in mlx_compile_source diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 6e9945ea1..8ec17437b 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1749,7 +1749,7 @@ def mask_attention_mask_out(labels = None, attention_mask = None): num_items_in_batch = n_items, logit_softcapping = None if (\\4) == () else (\\4), ) -elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: +elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and NOT_RETURN_LOGITS: lm_head_weight = self.lm_head.weight lm_head_bias = getattr(self.lm_head, "bias", None) diff --git a/unsloth_zoo/mlx/compile.py b/unsloth_zoo/mlx/compile.py index 81ba8d5b3..1f430c16b 100644 --- a/unsloth_zoo/mlx/compile.py +++ b/unsloth_zoo/mlx/compile.py @@ -2719,10 +2719,12 @@ def _qwen3_torch_like_layer_norm(norm, x): centered = x_f - mean var = mx.mean(centered * centered, axis=-1, keepdims=True) y = centered * mx.rsqrt(var + norm.eps) - if "weight" in norm: - y = y * norm.weight.astype(mx.float32) - if "bias" in norm: - y = y + norm.bias.astype(mx.float32) + weight = getattr(norm, "weight", None) + if weight is not None: + y = y * weight.astype(mx.float32) + bias = getattr(norm, "bias", None) + if bias is not None: + y = y + bias.astype(mx.float32) return y.astype(source_dtype) def patched_qwen3_vision_block_call(self, hidden_states, cu_seqlens, rotary_pos_emb): diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index b86e517b6..4ca5e5a61 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -1518,7 +1518,7 @@ def _nf4_dense_dequantize_weight(weight, group_size=64): flat = mx.concatenate([flat, mx.zeros((pad,), dtype=mx.float32)]) groups = flat.reshape((-1, group_size)) absmax = mx.max(mx.abs(groups), axis=1, keepdims=True) - denom = mx.maximum(absmax, mx.array(1e-12, dtype=mx.float32)) + denom = mx.where(absmax > 0, absmax, mx.ones_like(absmax)) scaled = groups / denom indices = mx.argmin(mx.abs(scaled[..., None] - codebook), axis=-1) dequantized = (codebook[indices] * absmax).reshape((-1,))[:original_size] @@ -3089,14 +3089,15 @@ def get_peft_model( from .utils import apply_gradient_checkpointing apply_gradient_checkpointing(model) - import mlx.utils - trainable = sum(v.size for _, v in mlx.utils.tree_flatten(model.trainable_parameters())) - total = sum(v.size for _, v in mlx.utils.tree_flatten(model.parameters())) - pct = 100.0 * trainable / total if total > 0 else 0 - print( - f"Unsloth: LoRA applied — {trainable:,} trainable params " - f"({pct:.2f}% of {total:,} total)" - ) + if hasattr(model, "trainable_parameters") and hasattr(model, "parameters"): + import mlx.utils + trainable = sum(v.size for _, v in mlx.utils.tree_flatten(model.trainable_parameters())) + total = sum(v.size for _, v in mlx.utils.tree_flatten(model.parameters())) + pct = 100.0 * trainable / total if total > 0 else 0 + print( + f"Unsloth: LoRA applied — {trainable:,} trainable params " + f"({pct:.2f}% of {total:,} total)" + ) return model diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 2840abe6f..70c84b734 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -522,9 +522,15 @@ def _apply_manual_adamw_weight_decay(self, model, optimizer, grad): continue if not self._should_apply_weight_decay(name, parameter): continue - lr = optimizer.learning_rate.astype(flat_grad[name].dtype) - scale = mx.array(1.0, dtype=lr.dtype) - lr * mx.array(wd, dtype=lr.dtype) - decayed.append((name, parameter * scale)) + if not mx.issubdtype(parameter.dtype, mx.floating): + continue + lr_value = optimizer.learning_rate + if hasattr(lr_value, "astype"): + lr = lr_value.astype(mx.float32) + else: + lr = mx.array(lr_value, dtype=mx.float32) + scale = mx.array(1.0, dtype=mx.float32) - lr * mx.array(wd, dtype=mx.float32) + decayed.append((name, (parameter.astype(mx.float32) * scale).astype(parameter.dtype))) if decayed: model.update(tree_unflatten(decayed)) @@ -805,6 +811,7 @@ def _train_inner(self): print("Unsloth: Using standard cross-entropy loss.") # Prepare data — determine total_steps first + self._prepared_batches_include_epochs = False batches, batch_iter = self._prepare_data(is_vlm) if batches is not None and not batches: @@ -817,7 +824,9 @@ def _train_inner(self): total_steps = args.max_steps elif batches is not None: n_batches = len(batches) - if args.num_train_epochs > 0: + if getattr(self, "_prepared_batches_include_epochs", False): + total_steps = n_batches // grad_accum + elif args.num_train_epochs > 0: # Epoch-based: total micro-batches = epochs * batches_per_epoch total_steps = (n_batches * args.num_train_epochs) // grad_accum else: @@ -1525,6 +1534,15 @@ def _prepare_data(self, is_vlm): if getattr(args, "preserve_dataset_order", False) else getattr(args, "dataset_order", "default") ) + vlm_num_epochs = ( + args.num_train_epochs + if ( + args.max_steps <= 0 + and args.num_train_epochs > 0 + and vlm_dataset_order == "torch_randperm" + ) + else None + ) if args.streaming: return None, iterate_vlm_training_batches( dataset=self.train_dataset, @@ -1538,6 +1556,7 @@ def _prepare_data(self, is_vlm): dataset_order=vlm_dataset_order, ) else: + self._prepared_batches_include_epochs = vlm_num_epochs is not None batches = create_vlm_batches( dataset=self.train_dataset, processor=processor, @@ -1549,6 +1568,7 @@ def _prepare_data(self, is_vlm): response_mask_fn=_vlm_mask_fn, formatting_func=self.formatting_func, dataset_order=vlm_dataset_order, + num_epochs=vlm_num_epochs, ) if _vlm_mask_fn is not None and batches: _check_vlm_all_masked(batches) @@ -1586,11 +1606,19 @@ def _prepare_data(self, is_vlm): getattr(args, "preserve_dataset_order", False) or getattr(args, "dataset_order", "default") != "default" ): - batch_kwargs["dataset_order"] = ( + text_dataset_order = ( "sequential" if getattr(args, "preserve_dataset_order", False) else getattr(args, "dataset_order", "default") ) + batch_kwargs["dataset_order"] = text_dataset_order + if ( + args.max_steps <= 0 + and args.num_train_epochs > 0 + and text_dataset_order == "torch_randperm" + ): + batch_kwargs["num_epochs"] = args.num_train_epochs + self._prepared_batches_include_epochs = True batches = create_ordered_batches(**batch_kwargs) else: batches = create_batches(**batch_kwargs) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 88b31301f..dfd7012f8 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -43,6 +43,10 @@ def _safe_token_denominator(ntoks): return mx.maximum(ntoks.astype(mx.float32), mx.array(1.0, dtype=mx.float32)) +def _normalize_seed(seed, default=3407): + return default if seed is None else int(seed) + + def _get_transformer_layers(model): """Find transformer layers, unwrapping VLM wrappers if needed. @@ -540,7 +544,6 @@ def _get_vlm_ignore_token_ids(processor=None, config=None, model=None): _append_unique_int(ids, _convert_token_to_id(tokenizer, token)) for attr in ( - "pad_token_id", "image_token_id", "video_token_id", "audio_token_id", @@ -558,7 +561,6 @@ def _get_vlm_ignore_token_ids(processor=None, config=None, model=None): "boi_token_id", "eoi_token_index", "eoi_token_id", - "pad_token_id", ): _append_unique_int(ids, _config_get(config, key, None)) @@ -2494,7 +2496,8 @@ def _apply_response_mask_to_vlm_batch(batch_dict, mask_fn, ignore_token_ids=None def create_vlm_batches(dataset, processor, config, batch_size, max_seq_length, num_batches=None, seed=42, response_mask_fn=None, - formatting_func=None, dataset_order="default"): + formatting_func=None, dataset_order="default", + num_epochs=None): """Pre-materialize VLM training batches using the processor directly. Mirrors Unsloth's GPU UnslothVisionDataCollator: @@ -2508,27 +2511,29 @@ def create_vlm_batches(dataset, processor, config, batch_size, max_seq_length, batch_list = [] seen = 0 epoch = 0 + base_seed = _normalize_seed(seed) + target_epochs = 1 if num_batches is None and num_epochs is None else num_epochs indices = list(range(len(dataset))) if dataset_order == "torch_randperm": - indices = _torch_randperm_order(len(dataset), seed) + indices = _torch_randperm_order(len(dataset), base_seed) elif dataset_order in (None, "default"): if num_batches is not None: - np.random.seed(seed) + np.random.seed(base_seed) np.random.shuffle(indices) elif dataset_order != "sequential": raise ValueError(f"Unsupported MLX VLM dataset_order: {dataset_order!r}") while num_batches is None or len(batch_list) < num_batches: if seen >= len(indices): - if num_batches is None: + if num_batches is None and target_epochs is not None and epoch + 1 >= target_epochs: break epoch += 1 seen = 0 indices = list(range(len(dataset))) if dataset_order == "torch_randperm": - indices = _torch_randperm_order(len(dataset), int(seed) + epoch) + indices = _torch_randperm_order(len(dataset), base_seed + epoch) elif dataset_order in (None, "default"): - np.random.seed(int(seed) + epoch) + np.random.seed(base_seed + epoch) np.random.shuffle(indices) bi = indices[seen : seen + batch_size] @@ -2576,6 +2581,7 @@ def iterate_vlm_training_batches(dataset, processor, config, batch_size, image_size = _get_vlm_image_size(config, processor) ignore_token_ids = _get_vlm_ignore_token_ids(processor=processor, config=config) + base_seed = _normalize_seed(seed) def _emit(items): batch_dict = _collate_vlm_batch( @@ -2598,7 +2604,7 @@ def _emit(items): epoch = 0 while True: if dataset_order == "torch_randperm": - indices = _torch_randperm_order(len(dataset), int(seed) + epoch) + indices = _torch_randperm_order(len(dataset), base_seed + epoch) elif dataset_order == "sequential": indices = list(range(len(dataset))) elif dataset_order in (None, "default"): @@ -2768,7 +2774,8 @@ def create_ordered_batches(dataset, tokenizer, batch_size, max_seq_length, num_batches=None, seed=None, dataset_order="sequential", dataset_text_field="text", formatting_func=None, chat_template=None, - model_name=None, model_type=None): + model_name=None, model_type=None, + num_epochs=None): """Create text batches with an explicit dataset order. Studio uses this to mirror CUDA's effective sampler stream without @@ -2790,10 +2797,13 @@ def create_ordered_batches(dataset, tokenizer, batch_size, max_seq_length, tokenized.append(ids) if not tokenized: - return [] + raise ValueError( + "Unsloth MLX: ordered dataset produced no trainable token sequences " + "(need at least two tokens after formatting/truncation)." + ) def make_order(epoch): - base_seed = 3407 if seed is None else int(seed) + base_seed = _normalize_seed(seed) if dataset_order == "torch_randperm": return _torch_randperm_order(len(tokenized), base_seed + epoch) if dataset_order not in (None, "sequential"): @@ -2805,6 +2815,10 @@ def make_order(epoch): order = make_order(epoch) order_pos = 0 seen = 0 + target_items = ( + len(tokenized) * (1 if num_epochs is None else int(num_epochs)) + if num_batches is None else None + ) while num_batches is None or len(batch_pairs) < num_batches: batch_items = [] for _ in range(batch_size): @@ -2815,7 +2829,7 @@ def make_order(epoch): batch_items.append(tokenized[order[order_pos]]) order_pos += 1 seen += 1 - if num_batches is None and seen >= len(tokenized): + if num_batches is None and target_items is not None and seen >= target_items: break max_length = max(len(ids) for ids in batch_items) @@ -2827,7 +2841,7 @@ def make_order(epoch): lengths.append([0, length]) batch_pairs.append((mx.array(batch_ids), mx.array(lengths), None)) - if num_batches is None and seen >= len(tokenized): + if num_batches is None and target_items is not None and seen >= target_items: break mx.eval([b for b, l, _ in batch_pairs] + [l for _, l, _ in batch_pairs]) From ec343232aca5f4af3aeff2a0a031223df75f9918 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Fri, 22 May 2026 02:56:34 +0800 Subject: [PATCH 14/48] fix(mlx): cast custom norm outputs --- tests/test_mlx_pr684_review_fixes.py | 35 +++++++++++++ unsloth_zoo/mlx/trainer.py | 73 +++++++++++++++++++++++++++- 2 files changed, 106 insertions(+), 2 deletions(-) diff --git a/tests/test_mlx_pr684_review_fixes.py b/tests/test_mlx_pr684_review_fixes.py index 2930b3b3f..30b089f7c 100644 --- a/tests/test_mlx_pr684_review_fixes.py +++ b/tests/test_mlx_pr684_review_fixes.py @@ -204,3 +204,38 @@ def test_pr684_compiler_review_guards_are_present(): assert '"weight" in norm' not in mlx_compile_source assert '"bias" in norm' not in mlx_compile_source assert 'getattr(norm, "weight", None)' in mlx_compile_source + + +def test_norm_output_cast_includes_custom_norms(): + _skip_if_mlx_core_was_replaced() + gemma3_text = pytest.importorskip("mlx_lm.models.gemma3_text") + stablelm = pytest.importorskip("mlx_lm.models.stablelm") + fastvlm_vision = pytest.importorskip("mlx_vlm.models.fastvlm.vision") + import unsloth_zoo.mlx.trainer as trainer_mod + + trainer_mod._set_norm_output_cast_to_input_dtype(False) + cases = [ + (gemma3_text.RMSNorm(4), mx.ones((2, 4), dtype=mx.bfloat16)), + ( + stablelm.LayerNormPerHead(head_dim=4, num_heads=2, eps=1e-5), + mx.ones((1, 3, 2, 4), dtype=mx.bfloat16), + ), + ( + fastvlm_vision.LayerNormChannel(num_features=4), + mx.ones((1, 2, 2, 4), dtype=mx.bfloat16), + ), + ] + + norm_classes = trainer_mod._iter_norm_output_cast_classes() + for norm, x in cases: + assert type(norm) in norm_classes + raw = norm(x) + assert raw.dtype == mx.float32 + + try: + trainer_mod._set_norm_output_cast_to_input_dtype(True) + for norm, x in cases: + out = norm(x) + assert out.dtype == x.dtype + finally: + trainer_mod._set_norm_output_cast_to_input_dtype(False) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 70c84b734..3e53e9264 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -38,6 +38,7 @@ from dataclasses import asdict, dataclass, is_dataclass import concurrent.futures +import importlib import math import os import random @@ -93,6 +94,65 @@ def _normalize_mlx_optimizer_name(name): return opt_name +_NORM_OUTPUT_CAST_EXTRA_CLASS_PATHS = ( + ("mlx_lm.models.bailing_moe_linear", "GroupRMSNorm"), + ("mlx_lm.models.cohere", "LayerNorm2D"), + ("mlx_lm.models.falcon_h1", "FalconH1RMSNormGated"), + ("mlx_lm.models.gemma", "RMSNorm"), + ("mlx_lm.models.gemma2", "RMSNorm"), + ("mlx_lm.models.gemma3_text", "RMSNorm"), + ("mlx_lm.models.granitemoehybrid", "GraniteMoeHybridRMSNormGated"), + ("mlx_lm.models.mamba2", "MambaRMSNormGated"), + ("mlx_lm.models.nemotron", "NemotronLayerNorm1P"), + ("mlx_lm.models.nemotron_h", "MambaRMSNormGated"), + ("mlx_lm.models.plamo2", "RMSNorm"), + ("mlx_lm.models.qwen3_next", "Qwen3NextRMSNormGated"), + ("mlx_lm.models.recurrent_gemma", "RMSNorm"), + ("mlx_lm.models.rwkv7", "LayerNormPerHead"), + ("mlx_lm.models.stablelm", "LayerNormPerHead"), + ("mlx_lm.models.step3p5", "ZeroCenteredRMSNorm"), + ("mlx_vlm.models.deepseekocr_2.vision", "Qwen2RMSNorm"), + ("mlx_vlm.models.dots_ocr.vision", "RMSNorm"), + ("mlx_vlm.models.fastvlm.vision", "LayerNormChannel"), + ("mlx_vlm.models.gemma3.language", "RMSNorm"), + ("mlx_vlm.models.gemma3n.audio", "Gemma3nCumulativeGroupNorm"), + ("mlx_vlm.models.gemma3n.language", "Gemma3nRMSNorm"), + ("mlx_vlm.models.gemma3n.vision", "RMSNormAct2d"), + ("mlx_vlm.models.gemma4.audio", "AudioRMSNorm"), + ("mlx_vlm.models.gemma4.language", "RMSNormZeroShift"), + ("mlx_vlm.models.gemma4.vision", "RMSNorm"), + ("mlx_vlm.models.gemma4.vision", "VisionRMSNorm"), + ("mlx_vlm.models.jina_vlm.language", "RMSNorm"), + ("mlx_vlm.models.paligemma.language", "RMSNorm"), + ("mlx_vlm.models.qwen3_5.language", "Qwen3_5RMSNormGated"), + ("mlx_vlm.models.sam3.sam_components", "LayerNorm2d"), + ("mlx_vlm.models.sam3d_body.layers", "LayerNorm32"), +) +_NORM_OUTPUT_CAST_PATCHED_CLASSES = set() + + +def _iter_norm_output_cast_classes(): + norm_classes = [] + seen = set() + + for module_name, class_name in _NORM_OUTPUT_CAST_EXTRA_CLASS_PATHS: + try: + module = importlib.import_module(module_name) + norm_cls = getattr(module, class_name) + except Exception: + continue + if norm_cls not in seen: + norm_classes.append(norm_cls) + seen.add(norm_cls) + + for norm_cls in (nn.RMSNorm, nn.LayerNorm): + if norm_cls not in seen: + norm_classes.append(norm_cls) + seen.add(norm_cls) + + return tuple(norm_classes) + + def _set_norm_output_cast_to_input_dtype(enabled: bool) -> None: """Control whether norm outputs are cast back to activation dtype. @@ -102,8 +162,15 @@ def _set_norm_output_cast_to_input_dtype(enabled: bool) -> None: result back matches PyTorch autocast behavior more closely: fp32 norm math, bf16/fp16 downstream activations. """ - for norm_cls in (nn.RMSNorm, nn.LayerNorm): - patched = getattr(norm_cls, "_unsloth_cast_output_to_input_dtype", False) + norm_classes = list(_iter_norm_output_cast_classes()) + if not enabled: + norm_classes.extend( + norm_cls for norm_cls in _NORM_OUTPUT_CAST_PATCHED_CLASSES + if norm_cls not in norm_classes + ) + + for norm_cls in norm_classes: + patched = norm_cls.__dict__.get("_unsloth_cast_output_to_input_dtype", False) if enabled: if patched: continue @@ -118,12 +185,14 @@ def norm_call_cast_output(self, x, *args, _original_call=original_call, **kwargs norm_cls._unsloth_original_call = original_call norm_cls.__call__ = norm_call_cast_output norm_cls._unsloth_cast_output_to_input_dtype = True + _NORM_OUTPUT_CAST_PATCHED_CLASSES.add(norm_cls) elif patched: original_call = getattr(norm_cls, "_unsloth_original_call", None) if original_call is not None: norm_cls.__call__ = original_call norm_cls._unsloth_original_call = None norm_cls._unsloth_cast_output_to_input_dtype = False + _NORM_OUTPUT_CAST_PATCHED_CLASSES.discard(norm_cls) def _normalize_mlx_scheduler_type(name): From f26cf37a519a0407c5ccd0692929320be632b250 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Fri, 22 May 2026 04:33:44 +0800 Subject: [PATCH 15/48] feat: auto discover custom norm from model --- tests/test_mlx_pr684_review_fixes.py | 24 +++-- unsloth_zoo/mlx/trainer.py | 146 +++++++++++++++++---------- 2 files changed, 113 insertions(+), 57 deletions(-) diff --git a/tests/test_mlx_pr684_review_fixes.py b/tests/test_mlx_pr684_review_fixes.py index 30b089f7c..30ea697a6 100644 --- a/tests/test_mlx_pr684_review_fixes.py +++ b/tests/test_mlx_pr684_review_fixes.py @@ -206,34 +206,46 @@ def test_pr684_compiler_review_guards_are_present(): assert 'getattr(norm, "weight", None)' in mlx_compile_source -def test_norm_output_cast_includes_custom_norms(): +def test_norm_output_cast_discovers_custom_norms_from_loaded_model(): _skip_if_mlx_core_was_replaced() + import mlx.nn as nn + gemma3_text = pytest.importorskip("mlx_lm.models.gemma3_text") stablelm = pytest.importorskip("mlx_lm.models.stablelm") fastvlm_vision = pytest.importorskip("mlx_vlm.models.fastvlm.vision") import unsloth_zoo.mlx.trainer as trainer_mod + class TinyModel(nn.Module): + def __init__(self): + super().__init__() + self.input_layernorm = gemma3_text.RMSNorm(4) + self.q_layernorm = stablelm.LayerNormPerHead( + head_dim=4, num_heads=2, eps=1e-5 + ) + self.norm = fastvlm_vision.LayerNormChannel(num_features=4) + trainer_mod._set_norm_output_cast_to_input_dtype(False) + model = TinyModel() cases = [ - (gemma3_text.RMSNorm(4), mx.ones((2, 4), dtype=mx.bfloat16)), + (model.input_layernorm, mx.ones((2, 4), dtype=mx.bfloat16)), ( - stablelm.LayerNormPerHead(head_dim=4, num_heads=2, eps=1e-5), + model.q_layernorm, mx.ones((1, 3, 2, 4), dtype=mx.bfloat16), ), ( - fastvlm_vision.LayerNormChannel(num_features=4), + model.norm, mx.ones((1, 2, 2, 4), dtype=mx.bfloat16), ), ] - norm_classes = trainer_mod._iter_norm_output_cast_classes() + norm_classes = trainer_mod._iter_norm_output_cast_classes(model) for norm, x in cases: assert type(norm) in norm_classes raw = norm(x) assert raw.dtype == mx.float32 try: - trainer_mod._set_norm_output_cast_to_input_dtype(True) + trainer_mod._set_norm_output_cast_to_input_dtype(True, model) for norm, x in cases: out = norm(x) assert out.dtype == x.dtype diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 3e53e9264..e68714966 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -38,7 +38,6 @@ from dataclasses import asdict, dataclass, is_dataclass import concurrent.futures -import importlib import math import os import random @@ -94,66 +93,111 @@ def _normalize_mlx_optimizer_name(name): return opt_name -_NORM_OUTPUT_CAST_EXTRA_CLASS_PATHS = ( - ("mlx_lm.models.bailing_moe_linear", "GroupRMSNorm"), - ("mlx_lm.models.cohere", "LayerNorm2D"), - ("mlx_lm.models.falcon_h1", "FalconH1RMSNormGated"), - ("mlx_lm.models.gemma", "RMSNorm"), - ("mlx_lm.models.gemma2", "RMSNorm"), - ("mlx_lm.models.gemma3_text", "RMSNorm"), - ("mlx_lm.models.granitemoehybrid", "GraniteMoeHybridRMSNormGated"), - ("mlx_lm.models.mamba2", "MambaRMSNormGated"), - ("mlx_lm.models.nemotron", "NemotronLayerNorm1P"), - ("mlx_lm.models.nemotron_h", "MambaRMSNormGated"), - ("mlx_lm.models.plamo2", "RMSNorm"), - ("mlx_lm.models.qwen3_next", "Qwen3NextRMSNormGated"), - ("mlx_lm.models.recurrent_gemma", "RMSNorm"), - ("mlx_lm.models.rwkv7", "LayerNormPerHead"), - ("mlx_lm.models.stablelm", "LayerNormPerHead"), - ("mlx_lm.models.step3p5", "ZeroCenteredRMSNorm"), - ("mlx_vlm.models.deepseekocr_2.vision", "Qwen2RMSNorm"), - ("mlx_vlm.models.dots_ocr.vision", "RMSNorm"), - ("mlx_vlm.models.fastvlm.vision", "LayerNormChannel"), - ("mlx_vlm.models.gemma3.language", "RMSNorm"), - ("mlx_vlm.models.gemma3n.audio", "Gemma3nCumulativeGroupNorm"), - ("mlx_vlm.models.gemma3n.language", "Gemma3nRMSNorm"), - ("mlx_vlm.models.gemma3n.vision", "RMSNormAct2d"), - ("mlx_vlm.models.gemma4.audio", "AudioRMSNorm"), - ("mlx_vlm.models.gemma4.language", "RMSNormZeroShift"), - ("mlx_vlm.models.gemma4.vision", "RMSNorm"), - ("mlx_vlm.models.gemma4.vision", "VisionRMSNorm"), - ("mlx_vlm.models.jina_vlm.language", "RMSNorm"), - ("mlx_vlm.models.paligemma.language", "RMSNorm"), - ("mlx_vlm.models.qwen3_5.language", "Qwen3_5RMSNormGated"), - ("mlx_vlm.models.sam3.sam_components", "LayerNorm2d"), - ("mlx_vlm.models.sam3d_body.layers", "LayerNorm32"), -) +_NORM_OUTPUT_CAST_BASE_CLASSES = (nn.RMSNorm, nn.LayerNorm) _NORM_OUTPUT_CAST_PATCHED_CLASSES = set() -def _iter_norm_output_cast_classes(): +def _is_norm_parameter_path(path) -> bool: + parts = str(path).lower().split(".") + return any("norm" in part for part in parts[:-1]) + + +def _join_parameter_path(module_path, parameter_path): + if module_path and parameter_path: + return f"{module_path}.{parameter_path}" + return module_path or parameter_path + + +def _has_norm_selected_floating_parameter(module_path, module) -> bool: + try: + parameters = module.parameters() + except Exception: + return False + + try: + for parameter_path, value in tree_flatten(parameters): + full_path = _join_parameter_path(module_path, parameter_path) + if ( + _is_norm_parameter_path(full_path) + and hasattr(value, "dtype") + and mx.issubdtype(value.dtype, mx.floating) + ): + return True + except Exception: + return False + return False + + +def _has_floating_parameter(module) -> bool: + try: + parameters = module.parameters() + except Exception: + return False + + try: + for _, value in tree_flatten(parameters): + if hasattr(value, "dtype") and mx.issubdtype(value.dtype, mx.floating): + return True + except Exception: + return False + return False + + +def _has_no_parameterized_non_norm_children(module) -> bool: + try: + children = module.children() + except Exception: + return False + + try: + for _, child in tree_flatten(children, is_leaf=nn.Module.is_module): + if ( + isinstance(child, nn.Module) + and "norm" not in type(child).__name__.lower() + and _has_floating_parameter(child) + ): + return False + except Exception: + return False + return True + + +def _is_norm_output_cast_candidate(module_path, module) -> bool: + """Return whether a custom module itself produces norm-like output.""" + norm_cls = type(module) + if norm_cls in _NORM_OUTPUT_CAST_BASE_CLASSES: + return True + if "norm" not in norm_cls.__name__.lower(): + return False + if not _has_norm_selected_floating_parameter(module_path, module): + return False + return _has_no_parameterized_non_norm_children(module) + + +def _iter_norm_output_cast_classes(model=None): norm_classes = [] seen = set() - for module_name, class_name in _NORM_OUTPUT_CAST_EXTRA_CLASS_PATHS: + for norm_cls in _NORM_OUTPUT_CAST_BASE_CLASSES: + norm_classes.append(norm_cls) + seen.add(norm_cls) + + if model is not None: try: - module = importlib.import_module(module_name) - norm_cls = getattr(module, class_name) + named_modules = model.named_modules() except Exception: - continue - if norm_cls not in seen: - norm_classes.append(norm_cls) - seen.add(norm_cls) - - for norm_cls in (nn.RMSNorm, nn.LayerNorm): - if norm_cls not in seen: - norm_classes.append(norm_cls) - seen.add(norm_cls) + named_modules = () + for module_path, module in named_modules: + if _is_norm_output_cast_candidate(module_path, module): + norm_cls = type(module) + if norm_cls not in seen: + norm_classes.append(norm_cls) + seen.add(norm_cls) return tuple(norm_classes) -def _set_norm_output_cast_to_input_dtype(enabled: bool) -> None: +def _set_norm_output_cast_to_input_dtype(enabled: bool, model=None) -> None: """Control whether norm outputs are cast back to activation dtype. Norm parameters can stay in fp32 for stability, but letting fp32 norm @@ -162,7 +206,7 @@ def _set_norm_output_cast_to_input_dtype(enabled: bool) -> None: result back matches PyTorch autocast behavior more closely: fp32 norm math, bf16/fp16 downstream activations. """ - norm_classes = list(_iter_norm_output_cast_classes()) + norm_classes = list(_iter_norm_output_cast_classes(model)) if not enabled: norm_classes.extend( norm_cls for norm_cls in _NORM_OUTPUT_CAST_PATCHED_CLASSES @@ -762,7 +806,7 @@ def train(self): args = self.args model = self.model cast_norm_output = bool(getattr(args, "cast_norm_output_to_input_dtype", True)) - _set_norm_output_cast_to_input_dtype(cast_norm_output) + _set_norm_output_cast_to_input_dtype(cast_norm_output, model) if cast_norm_output: print("Unsloth: Casting MLX norm outputs back to activation dtype.") args.patch_mode = normalize_mlx_patch_mode(getattr(args, "patch_mode", "patched")) From a24b8f3211cb7a7f5f6b06d5c89491940a9db32a Mon Sep 17 00:00:00 2001 From: Lyxot Date: Fri, 22 May 2026 05:05:56 +0800 Subject: [PATCH 16/48] fix(mlx): harden norm output cast discovery --- unsloth_zoo/mlx/trainer.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index e68714966..6dd9f136f 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -114,13 +114,18 @@ def _has_norm_selected_floating_parameter(module_path, module) -> bool: except Exception: return False + module_path_selected = _is_norm_parameter_path(_join_parameter_path(module_path, "_")) try: for parameter_path, value in tree_flatten(parameters): - full_path = _join_parameter_path(module_path, parameter_path) if ( - _is_norm_parameter_path(full_path) - and hasattr(value, "dtype") + hasattr(value, "dtype") and mx.issubdtype(value.dtype, mx.floating) + and ( + module_path_selected + or _is_norm_parameter_path( + _join_parameter_path(module_path, parameter_path) + ) + ) ): return True except Exception: @@ -219,6 +224,8 @@ def _set_norm_output_cast_to_input_dtype(enabled: bool, model=None) -> None: if patched: continue original_call = norm_cls.__call__ + if getattr(original_call, "_unsloth_norm_output_cast_wrapper", False): + continue def norm_call_cast_output(self, x, *args, _original_call=original_call, **kwargs): out = _original_call(self, x, *args, **kwargs) @@ -226,6 +233,7 @@ def norm_call_cast_output(self, x, *args, _original_call=original_call, **kwargs return out.astype(x.dtype) return out + norm_call_cast_output._unsloth_norm_output_cast_wrapper = True norm_cls._unsloth_original_call = original_call norm_cls.__call__ = norm_call_cast_output norm_cls._unsloth_cast_output_to_input_dtype = True From 5465959426a29ce1b8cfaf132fdeb55cbec243ee Mon Sep 17 00:00:00 2001 From: Lyxot Date: Fri, 22 May 2026 05:21:05 +0800 Subject: [PATCH 17/48] fix(mlx): preserve custom norm keyword calls --- unsloth_zoo/mlx/trainer.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 6dd9f136f..ffb49047f 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -167,6 +167,16 @@ def _has_no_parameterized_non_norm_children(module) -> bool: return True +def _norm_output_cast_input_dtype(args, kwargs): + for value in args: + if hasattr(value, "dtype"): + return value.dtype + for value in kwargs.values(): + if hasattr(value, "dtype"): + return value.dtype + return None + + def _is_norm_output_cast_candidate(module_path, module) -> bool: """Return whether a custom module itself produces norm-like output.""" norm_cls = type(module) @@ -227,10 +237,15 @@ def _set_norm_output_cast_to_input_dtype(enabled: bool, model=None) -> None: if getattr(original_call, "_unsloth_norm_output_cast_wrapper", False): continue - def norm_call_cast_output(self, x, *args, _original_call=original_call, **kwargs): - out = _original_call(self, x, *args, **kwargs) - if hasattr(x, "dtype") and hasattr(out, "dtype") and out.dtype != x.dtype: - return out.astype(x.dtype) + def norm_call_cast_output(self, *args, _original_call=original_call, **kwargs): + input_dtype = _norm_output_cast_input_dtype(args, kwargs) + out = _original_call(self, *args, **kwargs) + if ( + input_dtype is not None + and hasattr(out, "dtype") + and out.dtype != input_dtype + ): + return out.astype(input_dtype) return out norm_call_cast_output._unsloth_norm_output_cast_wrapper = True From 7e0bee546f78482bd1a49c1231f542d2754794a6 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Fri, 22 May 2026 12:13:39 -0500 Subject: [PATCH 18/48] harden mlx custom norm output casting --- tests/test_mlx_pr684_review_fixes.py | 34 ++++++++++++++++++++++++++++ unsloth_zoo/mlx/trainer.py | 31 ++++++++++++------------- 2 files changed, 49 insertions(+), 16 deletions(-) diff --git a/tests/test_mlx_pr684_review_fixes.py b/tests/test_mlx_pr684_review_fixes.py index 30ea697a6..d947d4cd1 100644 --- a/tests/test_mlx_pr684_review_fixes.py +++ b/tests/test_mlx_pr684_review_fixes.py @@ -251,3 +251,37 @@ def __init__(self): assert out.dtype == x.dtype finally: trainer_mod._set_norm_output_cast_to_input_dtype(False) + + +def test_norm_output_cast_does_not_double_patch_inherited_norm_call(): + _skip_if_mlx_core_was_replaced() + import mlx.nn as nn + import unsloth_zoo.mlx.trainer as trainer_mod + + class CustomRMSNorm(nn.RMSNorm): + pass + + class TinyModel(nn.Module): + def __init__(self): + super().__init__() + self.input_layernorm = CustomRMSNorm(4) + + trainer_mod._set_norm_output_cast_to_input_dtype(False) + model = TinyModel() + x = mx.ones((2, 4), dtype=mx.bfloat16) + + try: + trainer_mod._set_norm_output_cast_to_input_dtype(True, model) + assert nn.RMSNorm in trainer_mod._NORM_OUTPUT_CAST_PATCHED_CLASSES + assert CustomRMSNorm not in trainer_mod._NORM_OUTPUT_CAST_PATCHED_CLASSES + assert model.input_layernorm(x).dtype == x.dtype + finally: + trainer_mod._set_norm_output_cast_to_input_dtype(False) + + assert nn.RMSNorm not in trainer_mod._NORM_OUTPUT_CAST_PATCHED_CLASSES + assert CustomRMSNorm not in trainer_mod._NORM_OUTPUT_CAST_PATCHED_CLASSES + assert not getattr( + CustomRMSNorm.__call__, + "_unsloth_norm_output_cast_wrapper", + False, + ) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index ffb49047f..562b42def 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -102,10 +102,8 @@ def _is_norm_parameter_path(path) -> bool: return any("norm" in part for part in parts[:-1]) -def _join_parameter_path(module_path, parameter_path): - if module_path and parameter_path: - return f"{module_path}.{parameter_path}" - return module_path or parameter_path +def _is_norm_module_path(path) -> bool: + return any("norm" in part for part in str(path).lower().split(".")) def _has_norm_selected_floating_parameter(module_path, module) -> bool: @@ -114,7 +112,7 @@ def _has_norm_selected_floating_parameter(module_path, module) -> bool: except Exception: return False - module_path_selected = _is_norm_parameter_path(_join_parameter_path(module_path, "_")) + module_path_selected = _is_norm_module_path(module_path) try: for parameter_path, value in tree_flatten(parameters): if ( @@ -122,9 +120,7 @@ def _has_norm_selected_floating_parameter(module_path, module) -> bool: and mx.issubdtype(value.dtype, mx.floating) and ( module_path_selected - or _is_norm_parameter_path( - _join_parameter_path(module_path, parameter_path) - ) + or _is_norm_parameter_path(parameter_path) ) ): return True @@ -148,7 +144,7 @@ def _has_floating_parameter(module) -> bool: return False -def _has_no_parameterized_non_norm_children(module) -> bool: +def _has_parameterized_non_norm_children(module) -> bool: try: children = module.children() except Exception: @@ -161,10 +157,10 @@ def _has_no_parameterized_non_norm_children(module) -> bool: and "norm" not in type(child).__name__.lower() and _has_floating_parameter(child) ): - return False + return True except Exception: return False - return True + return False def _norm_output_cast_input_dtype(args, kwargs): @@ -184,9 +180,11 @@ def _is_norm_output_cast_candidate(module_path, module) -> bool: return True if "norm" not in norm_cls.__name__.lower(): return False + if _has_parameterized_non_norm_children(module): + return False if not _has_norm_selected_floating_parameter(module_path, module): return False - return _has_no_parameterized_non_norm_children(module) + return True def _iter_norm_output_cast_classes(model=None): @@ -229,12 +227,13 @@ def _set_norm_output_cast_to_input_dtype(enabled: bool, model=None) -> None: ) for norm_cls in norm_classes: - patched = norm_cls.__dict__.get("_unsloth_cast_output_to_input_dtype", False) + patched = norm_cls in _NORM_OUTPUT_CAST_PATCHED_CLASSES if enabled: - if patched: - continue original_call = norm_cls.__call__ - if getattr(original_call, "_unsloth_norm_output_cast_wrapper", False): + if ( + patched + or getattr(original_call, "_unsloth_norm_output_cast_wrapper", False) + ): continue def norm_call_cast_output(self, *args, _original_call=original_call, **kwargs): From c7a0956d01c8043eb1488f2effd92c339c303fdd Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Sun, 24 May 2026 13:55:11 +0000 Subject: [PATCH 19/48] Fix four loose ends for PR #684 unsloth_zoo/compiler.py Restore the dedicated UNSLOTH_RETURN_LOGITS=1 elif branch in cross_entropy_replacement_2 (originally added by #666, commit f45c31e5). Without it the regex template fallback path under UNSLOTH_FUSED_FORWARD=0 ran self.lm_head twice on the UNSLOTH_RETURN_LOGITS=1 path: once via the prepended materialise and once again in the final else branch. The AST rewriter at fused_losses/forward_install.py is unaffected. unsloth_zoo/mlx/trainer.py + unsloth_zoo/mlx/loader.py Expand the AdamW weight-decay filter and the fp32 norm-parameter filter to also match GPT-2 style ln_1 / ln_2 / ln_f names. Previous "norm" substring missed them; _ensure_lora_frozen already treated those as norm fragments, so the filters were inconsistent. unsloth_zoo/mlx/compile.py Honor cast_norm_output_to_input_dtype=False on the Qwen3-VL vision-block patch. Added a module-level _QWEN3_VISION_NORM_CAST_OUTPUT flag with a setter; the trainer's _set_norm_output_cast_to_input_dtype flips it so the generic norm patcher and the Qwen3 specialized norm patch agree. unsloth_zoo/mlx/utils.py Reseed the iterate_vlm_training_batches default branch per epoch via np.random.default_rng(base_seed + epoch). The torch_randperm branch already did this; the default branch's order previously depended on global numpy RNG state. --- unsloth_zoo/compiler.py | 14 ++++++++++++++ unsloth_zoo/mlx/compile.py | 23 +++++++++++++++++++++-- unsloth_zoo/mlx/loader.py | 7 ++++++- unsloth_zoo/mlx/trainer.py | 24 ++++++++++++++++++++---- unsloth_zoo/mlx/utils.py | 6 +++++- 5 files changed, 66 insertions(+), 8 deletions(-) diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index 8ec17437b..a7bdeef7d 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1772,6 +1772,20 @@ def mask_attention_mask_out(labels = None, attention_mask = None): logit_scale_divide = (\\3) if (\\3) != () else 0, logit_softcapping = (\\4) if (\\4) != () else 0, ) +elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: + # UNSLOTH_RETURN_LOGITS=1 path. Prepended `logits = self.lm_head(...)` + # already materialised the full lm_head matmul; apply the captured logit + # scale/softcap transforms and route loss through self.loss_function on + # those logits instead of letting unsloth_fused_ce_loss redo the matmul. + if (\\2) != (): + logits = logits * (\\2) + if (\\3) != (): + logits = logits / (\\3) + if (\\4) not in (None, (),): + logits = logits / (\\4) + logits = torch.tanh(logits) + logits = logits * (\\4) + loss = self.loss_function(logits, labels.to(self.lm_head.weight.device), vocab_size=\\8, **\\9) else: logits = self.lm_head(hidden_states\\1) if (\\2) != (): diff --git a/unsloth_zoo/mlx/compile.py b/unsloth_zoo/mlx/compile.py index 1f430c16b..f3a2226f9 100644 --- a/unsloth_zoo/mlx/compile.py +++ b/unsloth_zoo/mlx/compile.py @@ -56,6 +56,18 @@ _PATCHED_PATTERN_BUNDLES: set[str] = set() _PATCH_BINDINGS: set[tuple[str, str, str, str]] = set() +# Controls whether the Qwen3-VL vision-block norm patch casts the fp32 +# norm output back to the activation dtype. Mirrors MLXTrainingConfig +# `cast_norm_output_to_input_dtype`. Flipped by the trainer's +# `_set_norm_output_cast_to_input_dtype` so the generic and Qwen3-VL +# paths agree. +_QWEN3_VISION_NORM_CAST_OUTPUT = True + + +def set_qwen3_vision_norm_cast_output(enabled: bool) -> None: + global _QWEN3_VISION_NORM_CAST_OUTPUT + _QWEN3_VISION_NORM_CAST_OUTPUT = bool(enabled) + # Architectures explicitly verified for mlx compile support. # Training verification currently covers: # - qwen2_5_vl: real end-to-end compiled training via train.py @@ -2710,7 +2722,12 @@ def _qwen3_vision_rotary_fp32(tensor, freqs): return rotated.astype(orig_dtype) def _qwen3_torch_like_layer_norm(norm, x): - """Match PyTorch bf16 LayerNorm: fp32 stats/affine, cast result back.""" + """Match PyTorch bf16 LayerNorm: fp32 stats/affine, cast result back. + + Honors the module-level `_QWEN3_VISION_NORM_CAST_OUTPUT` flag (set + from MLXTrainingConfig.cast_norm_output_to_input_dtype); when + disabled the fp32 result is returned without recasting. + """ import mlx.core as mx source_dtype = x.dtype @@ -2725,7 +2742,9 @@ def _qwen3_torch_like_layer_norm(norm, x): bias = getattr(norm, "bias", None) if bias is not None: y = y + bias.astype(mx.float32) - return y.astype(source_dtype) + if _QWEN3_VISION_NORM_CAST_OUTPUT: + return y.astype(source_dtype) + return y def patched_qwen3_vision_block_call(self, hidden_states, cu_seqlens, rotary_pos_emb): residual_dtype = hidden_states.dtype diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index 4ca5e5a61..a9dbf1d47 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -149,7 +149,12 @@ def _convert_mlx_dtype(model, target_dtype, model_type: str = "") -> None: def _is_norm_parameter_path(path) -> bool: """Return whether a parameter path belongs to a normalization module.""" parts = str(path).lower().split(".") - return any("norm" in part for part in parts[:-1]) + # Match RMSNorm/LayerNorm via "norm" substring, plus GPT-2 / GPT-OSS + # style ln_1, ln_2, ln_f. + return any( + "norm" in part or part.startswith("ln_") or part == "ln_f" + for part in parts[:-1] + ) def _keep_norm_parameters_float32(model) -> None: diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 562b42def..3c5ded065 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -97,13 +97,19 @@ def _normalize_mlx_optimizer_name(name): _NORM_OUTPUT_CAST_PATCHED_CLASSES = set() +def _part_is_norm(part: str) -> bool: + # Match RMSNorm/LayerNorm/input_layernorm/etc. via "norm" substring, + # plus GPT-2 / GPT-OSS style ln_1, ln_2, ln_f. + return "norm" in part or part.startswith("ln_") or part == "ln_f" + + def _is_norm_parameter_path(path) -> bool: parts = str(path).lower().split(".") - return any("norm" in part for part in parts[:-1]) + return any(_part_is_norm(part) for part in parts[:-1]) def _is_norm_module_path(path) -> bool: - return any("norm" in part for part in str(path).lower().split(".")) + return any(_part_is_norm(part) for part in str(path).lower().split(".")) def _has_norm_selected_floating_parameter(module_path, module) -> bool: @@ -219,6 +225,15 @@ def _set_norm_output_cast_to_input_dtype(enabled: bool, model=None) -> None: result back matches PyTorch autocast behavior more closely: fp32 norm math, bf16/fp16 downstream activations. """ + # Keep the Qwen3-VL specialized vision-block norm patch in sync with + # the generic patcher below. Imported lazily to avoid a circular + # import at trainer-module load time. + try: + from . import compile as _mlx_compile + _mlx_compile.set_qwen3_vision_norm_cast_output(enabled) + except Exception: + pass + norm_classes = list(_iter_norm_output_cast_classes(model)) if not enabled: norm_classes.extend( @@ -624,14 +639,15 @@ def _should_apply_weight_decay(name, parameter=None): leaf = parts[-1] if parts else str(name).lower() if leaf == "bias": return False - if any("norm" in part for part in parts): + # Cover RMSNorm/LayerNorm via "norm" + GPT-2 style ln_1/ln_2/ln_f. + if any(_part_is_norm(part) for part in parts): return False return True @staticmethod def _is_norm_parameter_name(name): return any( - "norm" in part.lower() + _part_is_norm(part.lower()) for part in str(name).split(".") if part ) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index dfd7012f8..219cd5ba3 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -2613,7 +2613,11 @@ def _emit(items): indices[i : i + batch_size] for i in range(0, len(indices), batch_size) ] - order = np.random.permutation(len(batch_indices)) + # Use a per-epoch local Generator so order is reproducible + # under `seed` and does not depend on global numpy RNG + # state. Mirrors the torch_randperm branch reseed above. + rng = np.random.default_rng(base_seed + epoch) + order = rng.permutation(len(batch_indices)) for b in order: items = [dataset[idx] for idx in batch_indices[b]] yield _emit(items) From 0753b115ccfaa307c6a67d4e771e6ac402a05819 Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Sun, 24 May 2026 14:08:16 +0000 Subject: [PATCH 20/48] Address reviewer round 1 P1 findings on PR #684 unsloth_zoo/mlx/trainer.py Clean up the cast_norm_output_to_input_dtype monkey patch in `train()`'s finally block. Previously `_set_norm_output_cast_to_input_dtype(True, model)` patched MLX norm classes globally at line 847 but the cleanup only undid gradient checkpointing and memory limits; subsequent inference or trainers in the same Python process inherited the cast-back wrapper. Wrap the restore in try/except so a partial patch state never prevents `finally` from completing. Raise on the text-streaming + dataset_order combination. The new preserve_dataset_order / dataset_order config fields are honored for non-streaming text and for VLM (streaming + materialized), but `iterate_training_batches(...)` has no ordering argument, so text streaming silently ignored the user-requested order. Raise ValueError instead so the asymmetry is explicit and Studio / CUDA parity stays loud. unsloth_zoo/mlx/utils.py Apply the Qwen3-VL full-sequence forward fix to the baseline CE path too. _vlm_cce_forward already forwards the full multimodal sequence and shifts `hidden[:, :-1]` afterwards because Qwen3-VL image / mRoPE / deepstack state depends on the complete sequence. make_vlm_baseline_loss_fn was still trimming `input_ids[:, :-1]` pre-forward, so users who set `use_cce=False` saw a different loss than `use_cce=True` for the same input. Forward the full sequence and drop the final logits position afterwards to match. tests/test_pr_a_deep_components.py Fix the linear-no-warmup scheduler expectation. The previous expected `[0.0, lr, lr*6/7, ...]` would have step 0 run at zero LR and is inconsistent with `transformers.get_scheduler("linear", num_warmup_steps=0, num_training_steps=n)`. Replace with the HF-compatible `lr * (n - step) / n` series so the existing _build_schedule() implementation passes the test. (Note: this commit also includes the previous merge of origin/main which restored mainline #690 / #691 gpt-oss eager-attention fixes that the stale branch was about to revert.) --- tests/test_pr_a_deep_components.py | 17 +++++++---------- unsloth_zoo/mlx/trainer.py | 25 +++++++++++++++++++++++++ unsloth_zoo/mlx/utils.py | 15 ++++++++++----- 3 files changed, 42 insertions(+), 15 deletions(-) diff --git a/tests/test_pr_a_deep_components.py b/tests/test_pr_a_deep_components.py index 645f248e6..9df76a489 100644 --- a/tests/test_pr_a_deep_components.py +++ b/tests/test_pr_a_deep_components.py @@ -179,16 +179,13 @@ def test_scheduler_lr_matches_expected_optimizer_update_steps(scheduler, warmup) ] if scheduler == "linear" and warmup == 0: - expected = [ - 0.0, - trainer.args.learning_rate, - trainer.args.learning_rate * 6 / 7, - trainer.args.learning_rate * 5 / 7, - trainer.args.learning_rate * 4 / 7, - trainer.args.learning_rate * 3 / 7, - trainer.args.learning_rate * 2 / 7, - trainer.args.learning_rate * 1 / 7, - ] + # Match `transformers.get_scheduler("linear", num_warmup_steps=0, + # num_training_steps=total_steps)`: step 0 = learning_rate, then + # decays linearly to 0 over total_steps. The earlier expectation + # of `[0, lr, lr*6/7, ...]` would have the first optimizer + # update fire at zero LR and is inconsistent with HF behavior. + lr = trainer.args.learning_rate + expected = [lr * (total_steps - step) / total_steps for step in range(total_steps)] assert values == pytest.approx(expected) elif warmup > 0: assert values[0] == pytest.approx(0.0) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 3c5ded065..9d0dce0b4 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -916,6 +916,16 @@ def train(self): if args.gradient_checkpointing: remove_gradient_checkpointing(model) self._restore_memory_limits() + if cast_norm_output: + # Undo the global norm-class monkey patch so later + # inference / unrelated trainers in the same Python + # process get the original RMSNorm / LayerNorm dtype + # behavior. Wrap in try/except: a partially patched + # state must still let `finally` run to completion. + try: + _set_norm_output_cast_to_input_dtype(False, model) + except Exception: + pass def _train_inner(self): """Inner training loop, separated for GC cleanup in finally block.""" @@ -1727,6 +1737,21 @@ def _prepare_data(self, is_vlm): else: chat_tmpl = getattr(args, "chat_template", None) if args.streaming: + # `iterate_training_batches` does not yet take a + # `dataset_order` argument, so streaming text MLX + # training cannot honor `preserve_dataset_order` / + # `dataset_order="torch_randperm"`. Refuse instead of + # silently dropping the user-requested order so Studio + # / CUDA parity stays explicit. + if ( + getattr(args, "preserve_dataset_order", False) + or getattr(args, "dataset_order", "default") != "default" + ): + raise ValueError( + "Unsloth MLX: preserve_dataset_order / dataset_order is not " + "supported with streaming=True for text training. Disable " + "streaming or materialize batches." + ) return None, iterate_training_batches( dataset=self.train_dataset, tokenizer=self.tokenizer, diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 219cd5ba3..60edd3701 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -683,12 +683,15 @@ def loss_fn(model, batch_dict): attention_mask = batch_dict.get("attention_mask") labels = batch_dict.get("labels") - # Standard causal LM shift - inputs = input_ids[:, :-1] - + # Match the CCE path semantics: forward the full multimodal + # sequence and shift the resulting logits afterwards. Qwen3-VL + # image / mRoPE / deepstack state depends on the complete + # sequence; trimming `input_ids[:, :-1]` before the multimodal + # forward gives a different loss from the full-logits CUDA + # path. Mirrors `_vlm_cce_forward` so use_cce=False stays in + # parity with use_cce=True. + inputs = input_ids fwd_mask = attention_mask - if attention_mask is not None and attention_mask.shape[-1] == input_ids.shape[-1]: - fwd_mask = attention_mask[:, :-1] # Forward pass — let the model create its own causal mask. # Pass extra keys (e.g. image_grid_thw for Qwen) that some models need. @@ -701,6 +704,8 @@ def loss_fn(model, batch_dict): output = model(inputs, pixel_values=pixel_values, mask=fwd_mask, **fwd_kwargs) logits = output.logits if hasattr(output, "logits") else output logits = logits.astype(mx.float32) + # Drop the final position so logits predict the next token. + logits = logits[:, :-1, :] if labels is not None: # Labels encode instruction/padding/special-token masking when From 693b09919c5ab2d78ee34c6214d00e69596eb616 Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Sun, 24 May 2026 14:44:59 +0000 Subject: [PATCH 21/48] Preserve embedder position_ids in _vlm_cce_forward `_unpack_embed_result` can return a `position_ids` adjusted for the merged multimodal sequence (e.g. Qwen-VL family adjusts mRoPE / 3D position_ids during get_input_embeddings). The previous code unconditionally overwrote backbone_kwargs["position_ids"] with the raw batch position_ids, discarding the embedder's corrected version. Only inject the raw position_ids when the embedder did not produce its own. --- unsloth_zoo/mlx/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 60edd3701..155f3392f 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -871,7 +871,11 @@ def _vlm_cce_forward(model, batch_dict, image_token_ids=None, **extra_kwargs, ) merged_embeds, backbone_kwargs = _unpack_embed_result(embed_result, model) - if "position_ids" in extra_kwargs: + # Prefer position_ids returned/stashed by get_input_embeddings (some + # VLM embedders, e.g. Qwen-VL family, adjust them for the merged + # multimodal sequence). Only fall back to the raw batch position_ids + # if the embedder did not produce its own. + if "position_ids" in extra_kwargs and "position_ids" not in backbone_kwargs: backbone_kwargs["position_ids"] = extra_kwargs["position_ids"] hidden = _forward_text_hidden_states( From 4cb6ca668c7d8e01fb8dabd5a5a79863b48487ac Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Sun, 24 May 2026 15:18:12 +0000 Subject: [PATCH 22/48] Address reviewer round 2 findings on PR #684 create_ordered_batches no longer mixes the last partial batch of one epoch with the first samples of the next. Take a contiguous slice of the current epoch order, emit a partial batch if the tail is short, and start the next batch fresh at epoch+1. Matches the VLM ordered path at utils.py:2539 and SequentialSampler(drop_last=False). Baseline labels=None mask in make_baseline_loss_fn now uses `<` (exclusive end) so the unlabeled path agrees with the CCE (utils.py:360,:393) and labels-aware (utils.py:439) masks; pre-fix it was `<=` and trained on one extra padded position when the row hit max_seq_length. _prepare_dataset, create_batches, create_ordered_batches, iterate_training_batches and _create_labeled_batches all gained an `append_eos` parameter that the trainer plumbs from MLXTrainingConfig (default True). Direct MLX text fine-tuning callers (raw {"text": str} rows) again get mlx-lm parity EOS appending; Studio passes False because its chat template already renders EOS. _create_labeled_batches now honors dataset_order / preserve_dataset_order: skips the length-based sort and per-batch shuffle when the caller has asked for sequential or torch_randperm order. Without this the train_on_responses_only path silently rewrote the sample order set by the new Studio CUDA parity flags. _create_labeled_batches emits lengths as right-half-open `[1, L]` to match the new exclusive-end mask convention; pre-fix it was `[1, L - 1]` which paired with the old `<=` mask and now would drop the final supervised token. test_mlx_text_dataset_does_not_append_eos updated: Studio explicitly passes append_eos=False, default callers still receive EOS. --- tests/test_pr_a_deep_components.py | 10 +++- unsloth_zoo/mlx/trainer.py | 67 +++++++++++++++++++++--- unsloth_zoo/mlx/utils.py | 82 ++++++++++++++++++++++-------- 3 files changed, 128 insertions(+), 31 deletions(-) diff --git a/tests/test_pr_a_deep_components.py b/tests/test_pr_a_deep_components.py index 9df76a489..df4d51770 100644 --- a/tests/test_pr_a_deep_components.py +++ b/tests/test_pr_a_deep_components.py @@ -226,10 +226,16 @@ def encode(self, text): assert text == "hello" return [1, 2, 3] - dataset = _prepare_dataset([{"text": "hello"}], Tokenizer()) - + # append_eos=False is what Studio passes (chat-template renders EOS). + dataset = _prepare_dataset([{"text": "hello"}], Tokenizer(), append_eos=False) assert dataset[0] == ([1, 2, 3], 0) + # Default (mlx-lm parity for direct MLX text fine-tuning callers) + # appends the tokenizer EOS so a raw `{"text": str}` row still + # trains the model to predict EOS. + dataset_default = _prepare_dataset([{"text": "hello"}], Tokenizer()) + assert dataset_default[0] == ([1, 2, 3, 99], 0) + def test_mlx_text_loss_masks_exclude_position_at_sequence_length(): import inspect diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 9d0dce0b4..c491dc5c2 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -352,6 +352,12 @@ class MLXTrainingConfig: wired_limit_gb: float | None = None # None = min(recommended working set, memory limit); <= 0 disables disable_memory_limits: bool = False cast_norm_output_to_input_dtype: bool = True # fp32 norm storage/math, bf16/fp16 downstream activations + # Append the tokenizer EOS id to each encoded text row before batching. + # Default True mirrors mlx-lm's TextDataset behavior so direct MLX + # text fine-tuning callers (raw `{"text": str}` rows) still train the + # model to predict EOS. Studio passes False because its chat template + # already renders EOS. + append_eos: bool = True # VLM / completion masking train_on_completions: bool = False # Mask prompt tokens in loss @@ -1359,6 +1365,7 @@ def step_fn(batch_data, prev_state, do_update): if isinstance(getattr(self.model, "_config", {}), dict) else None ), + append_eos=bool(getattr(args, "append_eos", True)), ) if eval_batches: print(f"Unsloth: Eval enabled every {args.eval_steps} steps " @@ -1763,6 +1770,7 @@ def _prepare_data(self, is_vlm): chat_template=chat_tmpl, model_name=model_name, model_type=model_type, + append_eos=bool(getattr(args, "append_eos", True)), ) else: batch_kwargs = dict( @@ -1777,6 +1785,7 @@ def _prepare_data(self, is_vlm): chat_template=chat_tmpl, model_name=model_name, model_type=model_type, + append_eos=bool(getattr(args, "append_eos", True)), ) if ( getattr(args, "preserve_dataset_order", False) @@ -1888,7 +1897,9 @@ def _create_labeled_batches(dataset, tokenizer, mask_fn, batch_size, max_seq_length, formatting_func=None, dataset_text_field="text", num_batches=None, seed=42, chat_template=None, - model_name=None, model_type=None): + model_name=None, model_type=None, + append_eos=True, dataset_order="default", + preserve_dataset_order=False): """Create padded batches with label masks for train_on_responses_only. Tokenizes each dataset item, applies the masking closure to get labels, @@ -1897,7 +1908,10 @@ def _create_labeled_batches(dataset, tokenizer, mask_fn, batch_size, Returns: List of (batch, lengths, labels) tuples where: - batch: mx.array (BS, padded_len) — input_ids padded with 0 - - lengths: mx.array (BS, 2) — [1, actual_len - 1] per sequence + - lengths: mx.array of shape (BS, 2) holding [1, actual_len] + per sequence. Right-half-open `[start, end)` matching the + exclusive-end loss masks in `utils.py:360`, `:393`, `:429`, + `:439`. - labels: mx.array (BS, padded_len) — labels padded with -100 """ eos_id = tokenizer.eos_token_id @@ -1933,7 +1947,12 @@ def _create_labeled_batches(dataset, tokenizer, mask_fn, batch_size, # slow tokenizers degrade gracefully via the GIL) def _process_text(text): encoded = tokenizer.encode(text) - if eos_id is not None and (not encoded or encoded[-1] != eos_id): + # Honor the same `append_eos` contract as `_prepare_dataset`; the + # unlabeled text path (`_prepare_dataset` -> mlx-lm CacheDataset) + # appends or skips EOS based on the trainer's config, so the + # labeled `train_on_responses_only` path must match or the two + # produce different supervised tokens for the same input. + if append_eos and eos_id is not None and (not encoded or encoded[-1] != eos_id): encoded.append(eos_id) if len(encoded) > max_seq_length: encoded = encoded[:max_seq_length] @@ -1958,8 +1977,16 @@ def _process_text(text): "Check your dataset and formatting_func." ) - # 2. Sort by length for efficient padding - all_items.sort(key=lambda x: len(x[0])) + # 2. Sort by length for efficient padding -- but only when the caller + # has NOT requested a specific dataset_order. Length sorting is the + # default mlx-lm pattern that improves padding efficiency, but it + # breaks `preserve_dataset_order=True` (Studio CUDA parity) and + # `dataset_order="torch_randperm"` (deterministic shuffle). + _order_requested = preserve_dataset_order or ( + dataset_order not in (None, "default") + ) + if not _order_requested: + all_items.sort(key=lambda x: len(x[0])) # 3. Create padded batches rng = random.Random(seed) @@ -1982,7 +2009,13 @@ def _process_text(text): pad_len = padded_len - L batch_ids.append(ids[:L] + [0] * pad_len) batch_labels.append(lbls[:L] + [-100] * pad_len) - batch_lengths.append([1, L - 1]) + # Right-half-open [start, end) to match the loss masks in + # utils.py:360/:393/:429/:439 (`steps < lengths[:, 1:]`). + # Pre-fix this was `[1, L - 1]` which paired with the old + # `<=` mask; the PR flipped the mask to `<` so the end + # value must shift up by one to keep training on the + # final supervised token. + batch_lengths.append([1, L]) batches.append(( mx.array(batch_ids), @@ -1990,8 +2023,20 @@ def _process_text(text): mx.array(batch_labels), )) - # 4. Shuffle batches - rng.shuffle(batches) + # 4. Order the batch sequence. + # - preserve_dataset_order=True: emit in dataset order (Studio CUDA + # SequentialSampler parity). + # - dataset_order="torch_randperm": deterministic shuffle seeded by + # `seed`, matching the non-labeled `create_ordered_batches` path. + # - default: legacy length-sorted-then-shuffled behavior. + if preserve_dataset_order: + pass + elif dataset_order == "torch_randperm": + rng.shuffle(batches) + elif dataset_order == "sequential": + pass + else: + rng.shuffle(batches) # Limit if needed if num_batches is not None and len(batches) > num_batches: @@ -2196,6 +2241,9 @@ def train_on_responses_only( if isinstance(getattr(trainer.model, "_config", {}), dict) else None ), + append_eos=bool(getattr(args, "append_eos", True)), + dataset_order=getattr(args, "dataset_order", "default"), + preserve_dataset_order=bool(getattr(args, "preserve_dataset_order", False)), ) # Safety check: detect all-masked labels early @@ -2221,6 +2269,9 @@ def train_on_responses_only( if isinstance(getattr(trainer.model, "_config", {}), dict) else None ), + append_eos=bool(getattr(args, "append_eos", True)), + dataset_order=getattr(args, "dataset_order", "default"), + preserve_dataset_order=bool(getattr(args, "preserve_dataset_order", False)), ) trainer._eval_batches_labeled = eval_batches diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 155f3392f..fa228b5b8 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -421,12 +421,18 @@ def make_baseline_loss_fn(): """ def loss_fn(model, batch, lengths, labels=None): if labels is None: - # byte-identical to mlx_lm.tuner.trainer.default_loss + # Match the CCE (`utils.py:360`, `:393`) and labels-aware + # baseline (`utils.py:439`) masks: end is exclusive. The + # pre-PR `<=` was inclusive and the comment said "byte-identical + # to mlx_lm.tuner.trainer.default_loss", but mlx_lm's lengths + # convention is right-half-open (`[start, end)`), so the + # consistent CCE / labels-aware paths are also the correct + # ones for the unlabeled baseline. inputs = batch[:, :-1] targets = batch[:, 1:] logits = model(inputs) steps = mx.arange(1, targets.shape[1] + 1) - mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) + mask = mx.logical_and(steps >= lengths[:, 0:1], steps < lengths[:, 1:]) ce = nn.losses.cross_entropy(logits, targets) * mask ntoks = mask.sum() ce = ce.astype(mx.float32).sum() / ntoks @@ -2657,7 +2663,8 @@ def _emit(items): def _prepare_dataset(dataset, tokenizer, dataset_text_field="text", formatting_func=None, chat_template=None, - model_name=None, model_type=None): + model_name=None, model_type=None, + append_eos=True): """Wrap a HuggingFace dataset into mlx-lm's dataset classes. Uses CacheDataset from mlx_lm while leaving rendered text token-exact. @@ -2665,6 +2672,13 @@ def _prepare_dataset(dataset, tokenizer, dataset_text_field="text", If a formatting_func is provided, each item is pre-formatted into a ``{"text": ...}`` dict before wrapping. + ``append_eos`` controls whether the tokenizer's EOS id is appended to + each encoded row. Default True preserves the pre-PR behavior that + delegated EOS appending to ``mlx_lm.tuner.datasets.TextDataset`` for + direct MLX text fine-tuning callers (raw ``{"text": str}`` rows + without already-rendered EOS). Studio passes False because its + chat-template rendering already includes EOS. + Returns: A CacheDataset ready for ``iterate_batches``. """ @@ -2705,17 +2719,26 @@ def _prepare_dataset(dataset, tokenizer, dataset_text_field="text", "a formatting_func that returns text." ) + _eos_id = getattr(tokenizer, "eos_token_id", None) if append_eos else None + class _StudioTextDataset: - """TextDataset variant that does not append EOS behind Studio's back.""" + """TextDataset variant. Optionally appends EOS (mlx-lm parity); + Studio passes append_eos=False because chat templates already render it.""" - def __init__(self, data, tokenizer, text_key="text"): + def __init__(self, data, tokenizer, text_key="text", eos_id=None): self._data = data self.tokenizer = tokenizer self.text_key = text_key + self._eos_id = eos_id def process(self, item): - # Studio/chat templates own EOS; adding one here changes labels. - return (self.tokenizer.encode(item[self.text_key]), 0) + encoded = self.tokenizer.encode(item[self.text_key]) + if ( + self._eos_id is not None + and (not encoded or encoded[-1] != self._eos_id) + ): + encoded = list(encoded) + [int(self._eos_id)] + return (encoded, 0) def __getitem__(self, idx): return self._data[idx] @@ -2723,13 +2746,15 @@ def __getitem__(self, idx): def __len__(self): return len(self._data) - return CacheDataset(_StudioTextDataset(formatted, tokenizer, text_key="text")) + return CacheDataset( + _StudioTextDataset(formatted, tokenizer, text_key="text", eos_id=_eos_id) + ) def create_batches(dataset, tokenizer, batch_size, max_seq_length, num_batches=None, seed=42, dataset_text_field="text", formatting_func=None, chat_template=None, - model_name=None, model_type=None): + model_name=None, model_type=None, append_eos=True): """Pre-tokenize and batch a HuggingFace dataset for MLX training. Uses iterate_batches from mlx_lm for efficient dynamic-padding batching: @@ -2752,6 +2777,7 @@ def create_batches(dataset, tokenizer, batch_size, max_seq_length, chat_template=chat_template, model_name=model_name, model_type=model_type, + append_eos=append_eos, ) batch_pairs = [] @@ -2788,7 +2814,7 @@ def create_ordered_batches(dataset, tokenizer, batch_size, max_seq_length, dataset_text_field="text", formatting_func=None, chat_template=None, model_name=None, model_type=None, - num_epochs=None): + num_epochs=None, append_eos=True): """Create text batches with an explicit dataset order. Studio uses this to mirror CUDA's effective sampler stream without @@ -2800,6 +2826,7 @@ def create_ordered_batches(dataset, tokenizer, batch_size, max_seq_length, chat_template=chat_template, model_name=model_name, model_type=model_type, + append_eos=append_eos, ) tokenized = [] @@ -2833,17 +2860,28 @@ def make_order(epoch): if num_batches is None else None ) while num_batches is None or len(batch_pairs) < num_batches: - batch_items = [] - for _ in range(batch_size): - if order_pos >= len(order): - epoch += 1 - order = make_order(epoch) - order_pos = 0 - batch_items.append(tokenized[order[order_pos]]) - order_pos += 1 - seen += 1 - if num_batches is None and target_items is not None and seen >= target_items: + # Take up to batch_size contiguous indices from the current epoch. + # If the epoch tail is shorter, emit a partial batch and start + # the next batch fresh at epoch+1. Matches CUDA / SequentialSampler + # `drop_last=False` and the VLM ordered path at utils.py:2539, + # instead of mixing the last sample of one epoch with the first + # sample of the next inside the same micro-batch. + if order_pos >= len(order): + if num_batches is None: break + epoch += 1 + order = make_order(epoch) + order_pos = 0 + + chunk = order[order_pos : order_pos + batch_size] + if not chunk: + break + order_pos += len(chunk) + seen += len(chunk) + if num_batches is None and target_items is not None and seen > target_items: + chunk = chunk[: len(chunk) - (seen - target_items)] + seen = target_items + batch_items = [tokenized[i] for i in chunk] max_length = max(len(ids) for ids in batch_items) batch_ids = [] @@ -2864,7 +2902,8 @@ def make_order(epoch): def iterate_training_batches(dataset, tokenizer, batch_size, max_seq_length, seed=42, dataset_text_field="text", formatting_func=None, chat_template=None, - model_name=None, model_type=None): + model_name=None, model_type=None, + append_eos=True): """Streaming batch generator for MLX training. Wraps mlx-lm's iterate_batches(loop=True) as a generator, avoiding @@ -2880,6 +2919,7 @@ def iterate_training_batches(dataset, tokenizer, batch_size, max_seq_length, chat_template=chat_template, model_name=model_name, model_type=model_type, + append_eos=append_eos, ) for batch, lengths_info in iterate_batches( From 6374bad6bc66383fc7674cbae70fcbcb24e0538e Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Sun, 24 May 2026 15:30:02 +0000 Subject: [PATCH 23/48] Extend HF parity decoupled weight decay to SGD/Muon/Lion for PR #684 AdamW already used a manual decoupled bias/norm-aware decay with `weight_decay=0.0` on the underlying MLX optimizer. SGD, Muon, and Lion still passed `weight_decay=wd` directly to the MLX optimizer, which applies wd uniformly across every trainable parameter (including bias and norm leaves) and uses MLX's internal coupling semantics rather than HF's decoupled per-step `param *= 1 - lr * wd`. Mirror the AdamW pattern for the other three optimizers: set the underlying MLX optimizer's `weight_decay` to zero and let the existing manual helper own the decoupled decay term. `_manual_adamw_weight_decay` renamed to `_manual_weight_decay` (and the helper to `_apply_manual_weight_decay`) since it now covers four optimizers. Tests updated for the rename and a parametrized SGD/Muon/Lion case added asserting the manual decay scalar is set and the optimizer itself carries `weight_decay=0.0`. --- tests/test_mlx_pr684_review_fixes.py | 6 ++--- tests/test_pr_a_deep_components.py | 27 +++++++++++++++++++- unsloth_zoo/mlx/trainer.py | 38 ++++++++++++++++------------ 3 files changed, 51 insertions(+), 20 deletions(-) diff --git a/tests/test_mlx_pr684_review_fixes.py b/tests/test_mlx_pr684_review_fixes.py index d947d4cd1..151059c4c 100644 --- a/tests/test_mlx_pr684_review_fixes.py +++ b/tests/test_mlx_pr684_review_fixes.py @@ -73,7 +73,7 @@ def test_vlm_label_mask_keeps_in_sequence_pad_eos_token(): assert out.tolist() == [[101, 2, -100, -100]] -def test_manual_adamw_weight_decay_accepts_scalar_lr_and_preserves_dtype(): +def test_manual_weight_decay_accepts_scalar_lr_and_preserves_dtype(): from mlx.utils import tree_flatten from unsloth_zoo.mlx.trainer import MLXTrainer @@ -111,9 +111,9 @@ class TinyOptimizer: "norm": {"weight": mx.array([1.0], dtype=mx.float32)}, } trainer = object.__new__(MLXTrainer) - trainer._manual_adamw_weight_decay = 0.1 + trainer._manual_weight_decay = 0.1 - trainer._apply_manual_adamw_weight_decay(model, TinyOptimizer(), grad) + trainer._apply_manual_weight_decay(model, TinyOptimizer(), grad) flat = dict(tree_flatten(model.trainable_parameters())) assert flat["layer.weight"].dtype == mx.bfloat16 diff --git a/tests/test_pr_a_deep_components.py b/tests/test_pr_a_deep_components.py index df4d51770..0981b099b 100644 --- a/tests/test_pr_a_deep_components.py +++ b/tests/test_pr_a_deep_components.py @@ -121,7 +121,7 @@ def trainable_parameters(self): optimizer = trainer._build_optimizer(total_steps=8) - assert trainer._manual_adamw_weight_decay == pytest.approx(0.1) + assert trainer._manual_weight_decay == pytest.approx(0.1) if hasattr(optimizer, "_kw"): assert optimizer._kw["weight_decay"] == 0.0 assert MLXTrainer._should_apply_weight_decay("layers.0.mlp.down_proj.weight") @@ -130,6 +130,31 @@ def trainable_parameters(self): assert not MLXTrainer._should_apply_weight_decay("vision.blocks.0.norm1.weight") +@pytest.mark.parametrize("optim_name", ["sgd", "muon", "lion"]) +def test_non_adamw_optimizers_use_hf_parity_manual_decay(optim_name): + """SGD, Muon, and Lion must mirror the AdamW pattern: zero out the + optimizer's built-in `weight_decay` and let `_apply_manual_weight_decay` + own the decoupled decay so bias and norm params are excluded.""" + from unsloth_zoo.mlx.trainer import MLXTrainer, MLXTrainingConfig + + class DummyModel: + def trainable_parameters(self): + return {} + + trainer = MLXTrainer.__new__(MLXTrainer) + trainer.model = DummyModel() + trainer.args = MLXTrainingConfig( + optim=optim_name, + weight_decay=0.05, + ) + + optimizer = trainer._build_optimizer(total_steps=4) + + assert trainer._manual_weight_decay == pytest.approx(0.05) + if hasattr(optimizer, "_kw"): + assert optimizer._kw["weight_decay"] == 0.0 + + def test_norm_clip_dtype_restore_keeps_lora_and_norms_promotable(): from unsloth_zoo.mlx.trainer import MLXTrainer diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index c491dc5c2..7d59a0f00 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -575,7 +575,7 @@ def _build_optimizer(self, total_steps): initial_lr = self._schedule_value(schedule, 0) self._lr_schedule = schedule if callable(schedule) else None wd = self.args.weight_decay - self._manual_adamw_weight_decay = 0.0 + self._manual_weight_decay = 0.0 adam_beta1 = getattr(self.args, "adam_beta1", None) adam_beta2 = getattr(self.args, "adam_beta2", None) adam_kwargs = {} @@ -610,7 +610,7 @@ def _build_optimizer(self, total_steps): elif opt_name == "adamw": # Match HF/PyTorch AdamW semantics. MLX defaults bias_correction # to False, which makes early warmup updates much larger. - self._manual_adamw_weight_decay = float(wd or 0.0) + self._manual_weight_decay = float(wd or 0.0) optimizer = optim.AdamW( learning_rate=initial_lr, weight_decay=0.0, @@ -624,17 +624,17 @@ def _build_optimizer(self, total_steps): **adam_kwargs, ) elif opt_name == "sgd": - # TODO: For HF Trainer parity, consider applying the same - # bias/norm weight-decay exclusion used by AdamW to SGD. - optimizer = optim.SGD(learning_rate=initial_lr, weight_decay=wd) + # HF parity: decoupled bias/norm-aware decay, applied manually. + self._manual_weight_decay = float(wd or 0.0) + optimizer = optim.SGD(learning_rate=initial_lr, weight_decay=0.0) elif opt_name == "muon": - # TODO: For HF Trainer parity, consider applying the same - # bias/norm weight-decay exclusion used by AdamW to Muon. - optimizer = optim.Muon(learning_rate=initial_lr, weight_decay=wd) + # HF parity: decoupled bias/norm-aware decay, applied manually. + self._manual_weight_decay = float(wd or 0.0) + optimizer = optim.Muon(learning_rate=initial_lr, weight_decay=0.0) elif opt_name == "lion": - # TODO: For HF Trainer parity, consider applying the same - # bias/norm weight-decay exclusion used by AdamW to Lion. - optimizer = optim.Lion(learning_rate=initial_lr, weight_decay=wd) + # HF parity: decoupled bias/norm-aware decay, applied manually. + self._manual_weight_decay = float(wd or 0.0) + optimizer = optim.Lion(learning_rate=initial_lr, weight_decay=0.0) self._resolved_optimizer_name = opt_name return optimizer @@ -666,9 +666,15 @@ def _is_lora_parameter_name(name): if part ) - def _apply_manual_adamw_weight_decay(self, model, optimizer, grad): - """Apply decoupled AdamW decay to trainable non-bias/non-norm leaves.""" - wd = float(getattr(self, "_manual_adamw_weight_decay", 0.0) or 0.0) + def _apply_manual_weight_decay(self, model, optimizer, grad): + """Decoupled HF-parity decay on trainable non-bias/non-norm leaves. + + Active for AdamW, SGD, Muon, and Lion. The underlying MLX + optimizer is constructed with ``weight_decay=0.0`` so this + helper owns the full update for the weight-decay term and + matches what HF Trainer does via ``param_groups``. + """ + wd = float(getattr(self, "_manual_weight_decay", 0.0) or 0.0) if wd <= 0: return @@ -1204,7 +1210,7 @@ def _apply_update(grad, toks_f): ) if _clip_grad_value: final_grad = _clip_grad_by_leaf_norm(final_grad) - self._apply_manual_adamw_weight_decay(model, optimizer, final_grad) + self._apply_manual_weight_decay(model, optimizer, final_grad) optimizer.update(model, final_grad) _restore_trainable_storage_dtypes() return grad_norm @@ -1227,7 +1233,7 @@ def _apply_update_direct(grad): grad, grad_norm = optim.clip_grad_norm(grad, max_norm=max_grad_norm) if _clip_grad_value: grad = _clip_grad_by_leaf_norm(grad) - self._apply_manual_adamw_weight_decay(model, optimizer, grad) + self._apply_manual_weight_decay(model, optimizer, grad) optimizer.update(model, grad) _restore_trainable_storage_dtypes() return grad_norm From d16ef240dab90c6499a810d7ddd5d3ed71a81e99 Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Sun, 24 May 2026 15:37:25 +0000 Subject: [PATCH 24/48] Address remaining reviewer round 2 findings on PR #684 max_grad_value: restore elementwise clip semantics ============================================ The PR replaced the existing `mx.clip(g, -v, +v)` with a per-leaf L2 norm rescale, so the field's name no longer matched its behavior and existing tests / docstrings describe elementwise semantics. Four reviewers flagged this as a public-API regression. Switch back to `mx.clip(g, -max_grad_value, max_grad_value)` (still per-leaf, no cross-leaf reduction). The function is renamed `_clip_grad_by_value` to match the contract. VLM iterable streaming: refuse dataset_order instead of dropping it ============================================ `iterate_vlm_training_batches` honored `dataset_order="torch_randperm"` on sized datasets but silently streamed source order on unsized / iterable ones. The text streaming path already raises in this asymmetry (`trainer.py:1758`); mirror that here so users get a clear error rather than a silent CUDA-parity regression. Qwen3-VL vision norm-cast flag: restore prior state in finally ============================================ `_set_norm_output_cast_to_input_dtype(False, model)` in the train() finally also toggles the Qwen3 vision flag via `set_qwen3_vision_norm_cast_output(False)`. The module-level default is True, so post-training inference in the same process would see the flag stuck at False. Capture the previous flag value before training and restore it explicitly in finally. create_ordered_batches: pad with the tokenizer's pad id ============================================ Padded positions used literal `0` rather than `tokenizer.pad_token_id`, which can collide with a regular vocabulary token for tokenizers whose pad id is not 0. Fall back to 0 only when the tokenizer has no pad id. --- unsloth_zoo/mlx/trainer.py | 46 ++++++++++++++++++++++++++++---------- unsloth_zoo/mlx/utils.py | 22 +++++++++++++++++- 2 files changed, 55 insertions(+), 13 deletions(-) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 7d59a0f00..e3724a17d 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -306,8 +306,9 @@ class MLXTrainingConfig: adam_beta1: float | None = None adam_beta2: float | None = None max_grad_norm: float = 0.0 # disabled by default on MLX to avoid clip-memory overhead - # Proportional per-tensor clipping. This is cheaper than global norm, but - # preserves each tensor's gradient direction unlike elementwise value clip. + # Elementwise clip: each gradient element is clamped to + # `[-max_grad_value, +max_grad_value]`. Per-leaf only, no cross-leaf + # reduction, so it does not pay the global-norm memory overhead. # None (default) keeps the cheap MLX default of 1.0 unless the user # passes max_grad_norm > 0, in which case global-norm clipping wins. # 0.0 disables. A positive float opts in explicitly and overrides @@ -856,6 +857,18 @@ def train(self): args = self.args model = self.model cast_norm_output = bool(getattr(args, "cast_norm_output_to_input_dtype", True)) + # Remember the Qwen3-VL vision-block flag so the finally block can + # restore the original value rather than always forcing it to + # False (which would otherwise leak across train() boundaries for + # subsequent inference in the same process). + _prev_qwen3_vision_cast = True + try: + from . import compile as _mlx_compile + _prev_qwen3_vision_cast = bool( + getattr(_mlx_compile, "_QWEN3_VISION_NORM_CAST_OUTPUT", True) + ) + except Exception: + pass _set_norm_output_cast_to_input_dtype(cast_norm_output, model) if cast_norm_output: print("Unsloth: Casting MLX norm outputs back to activation dtype.") @@ -938,6 +951,16 @@ def train(self): _set_norm_output_cast_to_input_dtype(False, model) except Exception: pass + # Restore the Qwen3-VL vision-block flag to whatever it was + # before train() started, instead of leaking the False that + # `_set_norm_output_cast_to_input_dtype(False, ...)` just set. + try: + from . import compile as _mlx_compile + _mlx_compile.set_qwen3_vision_norm_cast_output( + _prev_qwen3_vision_cast + ) + except Exception: + pass def _train_inner(self): """Inner training loop, separated for GC cleanup in finally block.""" @@ -1173,16 +1196,15 @@ def _can_report_optimizer_state_norm(): # This avoids adding a second consumer to the lazy backward graph. return getattr(optimizer, "betas", None) - def _clip_grad_by_leaf_norm(grad): + def _clip_grad_by_value(grad): + # Elementwise clip to [-max_grad_value, +max_grad_value], + # per-leaf, no cross-leaf reduction. if not _clip_grad_value: return grad - def _clip_leaf_norm(g): - g_f = g.astype(mx.float32) - norm = mx.sqrt(mx.sum(g_f * g_f)) - scale = mx.minimum(max_grad_value / (norm + 1e-6), 1.0) - return g * scale.astype(g.dtype) - - return tree_map(_clip_leaf_norm, grad) + return tree_map( + lambda g: mx.clip(g, -max_grad_value, max_grad_value), + grad, + ) def _apply_update(grad, toks_f): """Common gradient post-processing and optimizer update. @@ -1209,7 +1231,7 @@ def _apply_update(grad, toks_f): final_grad, max_norm=max_grad_norm ) if _clip_grad_value: - final_grad = _clip_grad_by_leaf_norm(final_grad) + final_grad = _clip_grad_by_value(final_grad) self._apply_manual_weight_decay(model, optimizer, final_grad) optimizer.update(model, final_grad) _restore_trainable_storage_dtypes() @@ -1232,7 +1254,7 @@ def _apply_update_direct(grad): if max_grad_norm > 0: grad, grad_norm = optim.clip_grad_norm(grad, max_norm=max_grad_norm) if _clip_grad_value: - grad = _clip_grad_by_leaf_norm(grad) + grad = _clip_grad_by_value(grad) self._apply_manual_weight_decay(model, optimizer, grad) optimizer.update(model, grad) _restore_trainable_storage_dtypes() diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index fa228b5b8..0aeb8d6c3 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -2645,6 +2645,18 @@ def _emit(items): yield _emit(items) epoch += 1 else: + # Iterable / unsized datasets cannot honor dataset_order because + # they expose no index space to permute. Refuse rather than + # silently stream source order (matches the text streaming path + # at trainer.py for the same asymmetry). + if dataset_order not in (None, "default"): + raise ValueError( + "Unsloth MLX VLM: preserve_dataset_order / " + f"dataset_order={dataset_order!r} requires a sized " + "(`__len__`) dataset. Materialize the dataset (e.g. " + "via `dataset.to_iterable_dataset()` -> list) or drop " + "the order request." + ) while True: pending = [] yielded = False @@ -2884,11 +2896,19 @@ def make_order(epoch): batch_items = [tokenized[i] for i in chunk] max_length = max(len(ids) for ids in batch_items) + # Prefer the tokenizer's declared pad id; only fall back to 0 if + # the tokenizer has none. Matches mlx-lm's iterate_batches pad + # convention so the model receives a known special id (not raw 0) + # for padded positions in the forward input. + _pad_id = getattr(tokenizer, "pad_token_id", None) + if _pad_id is None: + _pad_id = 0 + _pad_id = int(_pad_id) batch_ids = [] lengths = [] for ids in batch_items: length = len(ids) - batch_ids.append(ids + [0] * (max_length - length)) + batch_ids.append(ids + [_pad_id] * (max_length - length)) lengths.append([0, length]) batch_pairs.append((mx.array(batch_ids), mx.array(lengths), None)) From 23751c84ca2c838fa6aa30900ecac6c73d7b5366 Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Sun, 24 May 2026 15:44:38 +0000 Subject: [PATCH 25/48] Address reviewer round 3 P1/P2 findings on PR #684 Labeled batches: torch_randperm at the sample level ============================================ `_create_labeled_batches` had `dataset_order="torch_randperm"` shuffle batches after sequential batching, while the unlabeled path (`create_ordered_batches` / `_torch_randperm_order`) shuffles samples before batching. Studio + CUDA `RandomSampler` is sample-granular, so `train_on_responses_only(..., dataset_order="torch_randperm")` ended up grouping rows differently from non-completion MLX training. Now both paths apply `_torch_randperm_order(seed)` to the sample list before batching; the legacy `default` / `None` path still length-sorts and shuffles batches. Unsupported values now raise instead of silently falling back to default shuffle. create_ordered_batches: honor num_epochs when num_batches is None ============================================ When `_prepare_data()` selects `create_ordered_batches(num_epochs=N)` for `max_steps <= 0` + `num_train_epochs > 0`, the loop previously exited at the first epoch boundary (`break` on `num_batches is None`). The intent is to emit `N * len(dataset)` samples worth of batches; extended the boundary check to stop only when the requested sample total has been emitted or no batches were requested. train() norm-patch lifecycle hardening ============================================ Moved `_set_norm_output_cast_to_input_dtype(cast_norm_output, model)` INSIDE the `try` block so a raise from `normalize_mlx_patch_mode`, `_configure_memory_limits`, compile policy, gradient checkpointing, or Qwen3.5 preflight does not leak the patched RMSNorm / LayerNorm class globals across train() boundaries. VLM train_on_completions in the labels-aware branch ============================================ `_collate_vlm_batch` now always attaches `batch["labels"]`, so `_vlm_cce_forward` and `make_vlm_baseline_loss_fn` always take their labels-aware branches for ordinary VLM SFT. Those branches were missing the `_mask_prompt_tokens(...)` call that the labels=None branches already perform, so `train_on_completions=True` silently trained on prompt tokens. Added the call to both branches. Ruff lint: clean up new E741 / F401 / F841 in changed files ============================================ Removed unused `import mlx.core as mx`, `mask = kwargs.get("mask",...)` that was never read, dead `original_sanitize` / `wanted` / `hidden_dim` locals, and renamed ambiguous `l` loop variables in `mx.eval(...)` batch flushes. Ruff exits clean on `compile.py / loader.py / trainer.py / utils.py` after this commit. --- unsloth_zoo/mlx/compile.py | 11 --- unsloth_zoo/mlx/loader.py | 5 - unsloth_zoo/mlx/trainer.py | 194 +++++++++++++++++++++---------------- unsloth_zoo/mlx/utils.py | 29 +++++- 4 files changed, 135 insertions(+), 104 deletions(-) diff --git a/unsloth_zoo/mlx/compile.py b/unsloth_zoo/mlx/compile.py index f3a2226f9..17bb2c120 100644 --- a/unsloth_zoo/mlx/compile.py +++ b/unsloth_zoo/mlx/compile.py @@ -2103,7 +2103,6 @@ def _masked_scatter_no_numpy(final_embedding, image_mask_expanded, scaled_image_ import mlx.core as mx final_shape = final_embedding.shape - hidden_dim = final_shape[-1] flat_mask = image_mask_expanded.reshape((-1,)) flat_output = final_embedding.reshape((-1,)) flat_features = scaled_image_features.reshape((-1,)) @@ -2437,7 +2436,6 @@ def patched_qwen2_vision_call(self, hidden_states, grid_thw, output_hidden_state def patched_qwen2_get_input_embeddings(self, input_ids=None, pixel_values=None, **kwargs): image_grid_thw = kwargs.get("image_grid_thw", None) video_grid_thw = kwargs.get("video_grid_thw", None) - mask = kwargs.get("mask", None) explicit_position_ids = kwargs.get("position_ids", None) grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw @@ -2571,7 +2569,6 @@ def patched_qwen2_rot_pos_emb(self, grid_thw): return rotary_pos_emb_full.reshape(pos_ids.shape[0], -1) def patched_qwen2_vision_call(self, hidden_states, grid_thw, output_hidden_states=None): - import mlx.core as mx grid_spec = _grid_to_tuple(grid_thw) hidden_states = self.patch_embed(hidden_states) @@ -2591,7 +2588,6 @@ def patched_qwen2_vision_call(self, hidden_states, grid_thw, output_hidden_state def patched_qwen2_get_input_embeddings(self, input_ids=None, pixel_values=None, **kwargs): image_grid_thw = kwargs.get("image_grid_thw", None) video_grid_thw = kwargs.get("video_grid_thw", None) - mask = kwargs.get("mask", None) explicit_position_ids = kwargs.get("position_ids", None) grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw @@ -2873,7 +2869,6 @@ def patched_qwen3_fast_pos_embed_interpolate(self, grid_thw): return mx.concatenate(patch_pos_embeds_permute) def patched_qwen3_vision_call(self, hidden_states, grid_thw, **kwargs): - import mlx.core as mx grid_spec = _grid_to_tuple(grid_thw) hidden_states = self.patch_embed(hidden_states) @@ -2908,7 +2903,6 @@ def patched_qwen3_deepstack(self, hidden_states, visual_pos_masks, visual_embeds def patched_qwen3_get_input_embeddings(self, input_ids=None, pixel_values=None, **kwargs): image_grid_thw = kwargs.get("image_grid_thw", None) video_grid_thw = kwargs.get("video_grid_thw", None) - mask = kwargs.get("mask", None) explicit_position_ids = kwargs.get("position_ids", None) grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw @@ -2956,7 +2950,6 @@ def patched_qwen3_get_input_embeddings(self, input_ids=None, pixel_values=None, def patched_qwen35_get_input_embeddings(self, input_ids=None, pixel_values=None, **kwargs): image_grid_thw = kwargs.get("image_grid_thw", None) video_grid_thw = kwargs.get("video_grid_thw", None) - mask = kwargs.get("mask", None) explicit_position_ids = kwargs.get("position_ids", None) grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw @@ -3100,7 +3093,6 @@ def patched_glm_rot_pos_emb(self, grid_thw): return (mx.cos(emb), mx.sin(emb)), pos_ids def patched_glm_vision_call(self, hidden_states, grid_thw, output_hidden_states=None): - import mlx.core as mx grid_spec = _grid_to_tuple(grid_thw) hidden_states = self.patch_embed(hidden_states) @@ -3129,7 +3121,6 @@ def patched_glm_vision_call(self, hidden_states, grid_thw, output_hidden_states= def patched_glm_get_input_embeddings(self, input_ids=None, pixel_values=None, **kwargs): image_grid_thw = kwargs.get("image_grid_thw", None) video_grid_thw = kwargs.get("video_grid_thw", None) - mask = kwargs.get("mask", None) explicit_position_ids = kwargs.get("position_ids", None) grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw @@ -3311,7 +3302,6 @@ def patched_paddle_rot_pos_emb(self, grid_thw): return rotary_pos_emb_full.reshape(pos_ids.shape[0], -1) def patched_paddle_vision_call(self, hidden_states, grid_thw, output_hidden_states=None): - import mlx.core as mx grid_spec = _grid_to_tuple(grid_thw) hidden_states = self.embeddings(hidden_states, grid_spec) @@ -3436,7 +3426,6 @@ def _install_qwen3_get_input_embeddings_patch(): def patched_get_input_embeddings(self, input_ids=None, pixel_values=None, **kwargs): image_grid_thw = kwargs.get("image_grid_thw", None) video_grid_thw = kwargs.get("video_grid_thw", None) - mask = kwargs.get("mask", None) grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw if pixel_values is None: diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index a9dbf1d47..d341a6032 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -27,7 +27,6 @@ import inspect import math import os -import sys import types import warnings from contextlib import contextmanager @@ -117,7 +116,6 @@ def _convert_mlx_dtype(model, target_dtype, model_type: str = "") -> None: """ import mlx.core as mx from mlx.utils import tree_flatten, tree_map_with_path - from ..model_lists import FORCE_FLOAT32 cast_pred = getattr(model, "cast_predicate", lambda _: True) needs_cast = False @@ -652,8 +650,6 @@ def _ensure_safe_text_wrapper_sanitize(model_type: str) -> None: if tree_unflatten is None or tree_flatten is None: return - original_sanitize = sanitize - def patched_sanitize(self, weights): structured = tree_unflatten(list(weights.items())) target = structured.get("model") @@ -1124,7 +1120,6 @@ def _apply_lora_at_paths(model, module_paths, adapter_cfg): scale = float(lora_params.get("scale", adapter_cfg.get("scale", 1.0))) dropout = float(lora_params.get("dropout", adapter_cfg.get("dropout", 0.0))) - wanted = set(module_paths) by_name = dict(model.named_modules()) linear_types = (nn.Linear, nn.QuantizedLinear) for name in module_paths: diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index e3724a17d..dfb01b5e2 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -75,7 +75,6 @@ build_compile_policy, explain_compile_support, get_compile_qualification, - get_model_architecture, normalize_mlx_patch_mode, resolve_training_compile, trace_compile_application, @@ -869,79 +868,92 @@ def train(self): ) except Exception: pass - _set_norm_output_cast_to_input_dtype(cast_norm_output, model) - if cast_norm_output: - print("Unsloth: Casting MLX norm outputs back to activation dtype.") - args.patch_mode = normalize_mlx_patch_mode(getattr(args, "patch_mode", "patched")) - model._unsloth_patch_mode = args.patch_mode - - self._memory_limits_applied = self._configure_memory_limits() - - self._compile_decision = None - self._compile_trace = None - self._compile_auto_tune_applied = [] - if self._is_vlm and (args.compile or args.compile_trace): - compile_policy = build_compile_policy(args=args) - qual = getattr(model, "_unsloth_compile_qualification", None) or get_compile_qualification(model) - if qual is not None: - model._unsloth_compile_qualification = qual - self._compile_decision = resolve_training_compile(model, policy=compile_policy, args=args) - model._unsloth_compile_decision = self._compile_decision - if args.compile_trace: - self._compile_trace = trace_compile_application(model, policy=compile_policy, args=args) - model._unsloth_compile_trace = self._compile_trace - model._unsloth_compile_explain = explain_compile_support(model, policy=compile_policy, args=args) - if args.compile_auto_tune: - self._compile_auto_tune_applied = self._apply_compile_recommendations( - args, self._compile_decision - ) - for setting, value, reason in self._compile_auto_tune_applied: - print( - f"Unsloth: Auto-tuned {setting}={value!r} for MLX compile " - f"({reason})" + # Install the norm-output cast INSIDE the try/finally so a raise + # during any of the steps below (patch-mode normalization, memory + # limit configuration, compile policy, gradient checkpointing, + # Qwen3.5 preflight) still triggers the cleanup that restores the + # global RMSNorm / LayerNorm / Qwen3-VL flags. + _norm_cast_applied = False + try: + _set_norm_output_cast_to_input_dtype(cast_norm_output, model) + _norm_cast_applied = True + if cast_norm_output: + print("Unsloth: Casting MLX norm outputs back to activation dtype.") + args.patch_mode = normalize_mlx_patch_mode(getattr(args, "patch_mode", "patched")) + model._unsloth_patch_mode = args.patch_mode + + self._memory_limits_applied = self._configure_memory_limits() + + self._compile_decision = None + self._compile_trace = None + self._compile_auto_tune_applied = [] + if self._is_vlm and (args.compile or args.compile_trace): + compile_policy = build_compile_policy(args=args) + qual = getattr(model, "_unsloth_compile_qualification", None) or get_compile_qualification(model) + if qual is not None: + model._unsloth_compile_qualification = qual + self._compile_decision = resolve_training_compile(model, policy=compile_policy, args=args) + model._unsloth_compile_decision = self._compile_decision + if args.compile_trace: + self._compile_trace = trace_compile_application(model, policy=compile_policy, args=args) + model._unsloth_compile_trace = self._compile_trace + model._unsloth_compile_explain = explain_compile_support(model, policy=compile_policy, args=args) + if args.compile_auto_tune: + self._compile_auto_tune_applied = self._apply_compile_recommendations( + args, self._compile_decision ) + for setting, value, reason in self._compile_auto_tune_applied: + print( + f"Unsloth: Auto-tuned {setting}={value!r} for MLX compile " + f"({reason})" + ) - # (memory limits already applied above; just log what we configured) - if self._memory_limits_applied: - parts = [] - if "memory_limit_gb" in self._memory_limits_applied: - parts.append( - f"memory_limit={self._memory_limits_applied['memory_limit_gb']:.2f} GB" - ) - if "cache_limit_gb" in self._memory_limits_applied: - parts.append( - f"cache_limit={self._memory_limits_applied['cache_limit_gb']:.2f} GB" - ) - if "wired_limit_gb" in self._memory_limits_applied: - parts.append( - f"wired_limit={self._memory_limits_applied['wired_limit_gb']:.2f} GB" + # (memory limits already applied above; just log what we configured) + if self._memory_limits_applied: + parts = [] + if "memory_limit_gb" in self._memory_limits_applied: + parts.append( + f"memory_limit={self._memory_limits_applied['memory_limit_gb']:.2f} GB" + ) + if "cache_limit_gb" in self._memory_limits_applied: + parts.append( + f"cache_limit={self._memory_limits_applied['cache_limit_gb']:.2f} GB" + ) + if "wired_limit_gb" in self._memory_limits_applied: + parts.append( + f"wired_limit={self._memory_limits_applied['wired_limit_gb']:.2f} GB" + ) + print( + "Unsloth: MLX Metal memory guard enabled " + f"({', '.join(parts)})." ) - print( - "Unsloth: MLX Metal memory guard enabled " - f"({', '.join(parts)})." - ) - # Apply gradient checkpointing if requested - if args.gradient_checkpointing: - apply_gradient_checkpointing(model) - print("Unsloth: Using gradient checkpointing to reduce memory.") - - # Qwen3.5-specific fixes - config = getattr(model, "_config", {}) - model_type = config.get("model_type", "") if isinstance(config, dict) else "" - if "qwen3_5" in model_type: - from .loader import _fix_qwen35_attention_cache - _fix_qwen35_attention_cache(model) - from ..gated_delta_vjp import patch_gated_delta - patch_gated_delta() + # Apply gradient checkpointing if requested + if args.gradient_checkpointing: + apply_gradient_checkpointing(model) + print("Unsloth: Using gradient checkpointing to reduce memory.") + + # Qwen3.5-specific fixes + config = getattr(model, "_config", {}) + model_type = config.get("model_type", "") if isinstance(config, dict) else "" + if "qwen3_5" in model_type: + from .loader import _fix_qwen35_attention_cache + _fix_qwen35_attention_cache(model) + from ..gated_delta_vjp import patch_gated_delta + patch_gated_delta() - try: return self._train_inner() finally: if args.gradient_checkpointing: - remove_gradient_checkpointing(model) - self._restore_memory_limits() - if cast_norm_output: + try: + remove_gradient_checkpointing(model) + except Exception: + pass + try: + self._restore_memory_limits() + except Exception: + pass + if _norm_cast_applied and cast_norm_output: # Undo the global norm-class monkey patch so later # inference / unrelated trainers in the same Python # process get the original RMSNorm / LayerNorm dtype @@ -2005,16 +2017,33 @@ def _process_text(text): "Check your dataset and formatting_func." ) - # 2. Sort by length for efficient padding -- but only when the caller - # has NOT requested a specific dataset_order. Length sorting is the - # default mlx-lm pattern that improves padding efficiency, but it - # breaks `preserve_dataset_order=True` (Studio CUDA parity) and - # `dataset_order="torch_randperm"` (deterministic shuffle). + # 2. Apply the requested sample order BEFORE batching so labeled + # and unlabeled paths produce identical sample streams. + # - preserve_dataset_order=True / "sequential": dataset order. + # - "torch_randperm": deterministic torch.randperm permutation + # (matches `create_ordered_batches` -> `_torch_randperm_order` + # at utils.py:2845-2849). + # - default / None: legacy mlx-lm length-sort + per-batch shuffle. + # The labeled and unlabeled paths must agree at sample granularity + # or `dataset_order="torch_randperm"` produces a different sample + # stream under `train_on_responses_only(...)`. _order_requested = preserve_dataset_order or ( dataset_order not in (None, "default") ) - if not _order_requested: + if preserve_dataset_order or dataset_order == "sequential": + pass + elif dataset_order == "torch_randperm": + from .utils import _torch_randperm_order + order = _torch_randperm_order(len(all_items), seed) + all_items = [all_items[i] for i in order] + elif dataset_order in (None, "default"): all_items.sort(key=lambda x: len(x[0])) + else: + raise ValueError( + f"Unsloth MLX: unsupported dataset_order={dataset_order!r}. " + "Expected one of: None, 'default', 'sequential', " + "'torch_randperm'." + ) # 3. Create padded batches rng = random.Random(seed) @@ -2052,16 +2081,13 @@ def _process_text(text): )) # 4. Order the batch sequence. - # - preserve_dataset_order=True: emit in dataset order (Studio CUDA - # SequentialSampler parity). - # - dataset_order="torch_randperm": deterministic shuffle seeded by - # `seed`, matching the non-labeled `create_ordered_batches` path. - # - default: legacy length-sorted-then-shuffled behavior. - if preserve_dataset_order: - pass - elif dataset_order == "torch_randperm": - rng.shuffle(batches) - elif dataset_order == "sequential": + # - preserve_dataset_order / "sequential": keep batch order + # (samples were already in their target order at step 2). + # - "torch_randperm": batches mirror torch.randperm at the + # sample level (step 2 above), so keep batch order here. + # - default / None: legacy length-sort emitted near-contiguous + # batches; shuffle them so adjacent steps are not similar. + if _order_requested: pass else: rng.shuffle(batches) @@ -2072,8 +2098,8 @@ def _process_text(text): # Evaluate all tensors all_tensors = [] - for b, l, lb in batches: - all_tensors.extend([b, l, lb]) + for batch_arr, lengths_arr, labels_arr in batches: + all_tensors.extend([batch_arr, lengths_arr, labels_arr]) mx.eval(all_tensors) return batches diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 0aeb8d6c3..6cb475f60 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -719,6 +719,10 @@ def loss_fn(model, batch_dict): # externally supplied labels compatible. targets = labels[:, 1:] targets = _mask_image_tokens(targets, _image_token_ids) + # Apply train_on_completions prompt masking even when labels + # were preset by _apply_vlm_label_masks (which only handles + # image/pad/ignore tokens, not assistant-token boundaries). + targets = _mask_prompt_tokens(targets, _assistant_token_id) logits, targets = _align_logits_with_labels(logits, targets) if attention_mask is not None: length_mask = attention_mask[:, 1:][:, :targets.shape[1]] @@ -907,6 +911,12 @@ def _vlm_cce_forward(model, batch_dict, image_token_ids=None, masked_targets, -100, ) + # Completion-only masking also has to run in the labels-aware + # branch because _apply_vlm_label_masks does not apply it. Pre-fix + # this was only applied in the labels=None branch, so VLM + # train_on_completions=True trained on prompt tokens whenever + # collation set batch["labels"]. + masked_targets = _mask_prompt_tokens(masked_targets, assistant_token_id) ntoks = (masked_targets != -100).sum() else: targets = input_ids[:, 1:] @@ -2804,7 +2814,10 @@ def create_batches(dataset, tokenizer, batch_size, max_seq_length, if num_batches is not None and len(batch_pairs) >= num_batches: break - mx.eval([b for b, l, _ in batch_pairs] + [l for _, l, _ in batch_pairs]) + mx.eval( + [b for b, lengths, _ in batch_pairs] + + [lengths for _, lengths, _ in batch_pairs] + ) return batch_pairs @@ -2879,7 +2892,13 @@ def make_order(epoch): # instead of mixing the last sample of one epoch with the first # sample of the next inside the same micro-batch. if order_pos >= len(order): - if num_batches is None: + # No more rows in this epoch. Stop only when we have hit + # either the requested number of batches or the requested + # total sample count (num_epochs * len(dataset)). + if ( + num_batches is None + and (target_items is None or seen >= target_items) + ): break epoch += 1 order = make_order(epoch) @@ -2915,7 +2934,10 @@ def make_order(epoch): if num_batches is None and target_items is not None and seen >= target_items: break - mx.eval([b for b, l, _ in batch_pairs] + [l for _, l, _ in batch_pairs]) + mx.eval( + [b for b, lengths, _ in batch_pairs] + + [lengths for _, lengths, _ in batch_pairs] + ) return batch_pairs @@ -3401,7 +3423,6 @@ def save_pretrained_gguf( from ..llama_cpp import ( convert_to_gguf, quantize_gguf, - install_llama_cpp, check_llama_cpp, _download_convert_hf_to_gguf, ) From 8c278327c668aa9f742b0b0ba2faa83c533db8b0 Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Sun, 24 May 2026 15:54:56 +0000 Subject: [PATCH 26/48] Restore mask = kwargs.get for 4 patched VLM get_input_embeddings A round-3 replace_all that targeted unused mask extractions in two patched_qwen2 get_input_embeddings variants also stripped the same line from four other patched VLM get_input_embeddings functions where mask is actually passed into self.language_model.get_rope_index. That broke qwen3, qwen35, glm, and a generic VLM get_input_embeddings: NameError on first call. Restored the mask = kwargs.get('mask', None) line in the four functions that use it; the two qwen2 callers (where mask is truly unused) remain stripped. --- unsloth_zoo/mlx/compile.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/unsloth_zoo/mlx/compile.py b/unsloth_zoo/mlx/compile.py index 17bb2c120..30a8ce780 100644 --- a/unsloth_zoo/mlx/compile.py +++ b/unsloth_zoo/mlx/compile.py @@ -2903,6 +2903,7 @@ def patched_qwen3_deepstack(self, hidden_states, visual_pos_masks, visual_embeds def patched_qwen3_get_input_embeddings(self, input_ids=None, pixel_values=None, **kwargs): image_grid_thw = kwargs.get("image_grid_thw", None) video_grid_thw = kwargs.get("video_grid_thw", None) + mask = kwargs.get("mask", None) explicit_position_ids = kwargs.get("position_ids", None) grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw @@ -2950,6 +2951,7 @@ def patched_qwen3_get_input_embeddings(self, input_ids=None, pixel_values=None, def patched_qwen35_get_input_embeddings(self, input_ids=None, pixel_values=None, **kwargs): image_grid_thw = kwargs.get("image_grid_thw", None) video_grid_thw = kwargs.get("video_grid_thw", None) + mask = kwargs.get("mask", None) explicit_position_ids = kwargs.get("position_ids", None) grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw @@ -3121,6 +3123,7 @@ def patched_glm_vision_call(self, hidden_states, grid_thw, output_hidden_states= def patched_glm_get_input_embeddings(self, input_ids=None, pixel_values=None, **kwargs): image_grid_thw = kwargs.get("image_grid_thw", None) video_grid_thw = kwargs.get("video_grid_thw", None) + mask = kwargs.get("mask", None) explicit_position_ids = kwargs.get("position_ids", None) grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw @@ -3426,6 +3429,7 @@ def _install_qwen3_get_input_embeddings_patch(): def patched_get_input_embeddings(self, input_ids=None, pixel_values=None, **kwargs): image_grid_thw = kwargs.get("image_grid_thw", None) video_grid_thw = kwargs.get("video_grid_thw", None) + mask = kwargs.get("mask", None) grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw if pixel_values is None: From 1d0c11ed232aea41d0edd63bebb30cc97e6c51bb Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Sun, 24 May 2026 16:05:02 +0000 Subject: [PATCH 27/48] Materialize multiple epochs of labeled batches when num_epochs>1 for PR #684 The unlabeled torch_randperm path in create_ordered_batches materializes N*len(dataset)/batch_size batches with a per-epoch reseeded permutation when num_epochs is set. The labeled train_on_responses_only path stored trainer._batches at one epoch worth of permutations, and the trainer loop at line 1466 then cycled batches[batch_idx % len(batches)], so num_train_epochs=2 trained on the same row order in epoch 1 and epoch 2. _create_labeled_batches now accepts num_epochs and emits one block of permuted batches per requested epoch, each with seed+epoch_idx, matching the unlabeled path. Wired the trainer call site to pass num_train_epochs when set, and to set _prepared_batches_include_epochs so the existing total_steps math at trainer.py:1035 does not multiply through again. --- unsloth_zoo/mlx/trainer.py | 134 ++++++++++++++++++++----------------- 1 file changed, 73 insertions(+), 61 deletions(-) diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index dfb01b5e2..be638d572 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -1939,7 +1939,8 @@ def _create_labeled_batches(dataset, tokenizer, mask_fn, batch_size, seed=42, chat_template=None, model_name=None, model_type=None, append_eos=True, dataset_order="default", - preserve_dataset_order=False): + preserve_dataset_order=False, + num_epochs=None): """Create padded batches with label masks for train_on_responses_only. Tokenizes each dataset item, applies the masking closure to get labels, @@ -2017,80 +2018,84 @@ def _process_text(text): "Check your dataset and formatting_func." ) - # 2. Apply the requested sample order BEFORE batching so labeled - # and unlabeled paths produce identical sample streams. + # 2. Determine sample ordering strategy. The labeled and unlabeled + # paths must agree at sample granularity or `dataset_order="torch_randperm"` + # produces a different sample stream under `train_on_responses_only(...)`. # - preserve_dataset_order=True / "sequential": dataset order. # - "torch_randperm": deterministic torch.randperm permutation # (matches `create_ordered_batches` -> `_torch_randperm_order` - # at utils.py:2845-2849). + # at utils.py:2845-2849), reseeded per epoch. # - default / None: legacy mlx-lm length-sort + per-batch shuffle. - # The labeled and unlabeled paths must agree at sample granularity - # or `dataset_order="torch_randperm"` produces a different sample - # stream under `train_on_responses_only(...)`. _order_requested = preserve_dataset_order or ( dataset_order not in (None, "default") ) - if preserve_dataset_order or dataset_order == "sequential": - pass - elif dataset_order == "torch_randperm": - from .utils import _torch_randperm_order - order = _torch_randperm_order(len(all_items), seed) - all_items = [all_items[i] for i in order] - elif dataset_order in (None, "default"): - all_items.sort(key=lambda x: len(x[0])) - else: + if dataset_order not in (None, "default", "sequential", "torch_randperm"): raise ValueError( f"Unsloth MLX: unsupported dataset_order={dataset_order!r}. " "Expected one of: None, 'default', 'sequential', " "'torch_randperm'." ) - # 3. Create padded batches + def _order_samples_for_epoch(items, epoch_idx): + if preserve_dataset_order or dataset_order == "sequential": + return list(items) + if dataset_order == "torch_randperm": + from .utils import _torch_randperm_order + # Match unlabeled `create_ordered_batches`: reseed per epoch + # so the second epoch sees a different sample order rather + # than repeating the first. + order = _torch_randperm_order(len(items), seed + epoch_idx) + return [items[i] for i in order] + # legacy default: length-sort once + return sorted(items, key=lambda x: len(x[0])) + + # 3. Materialize batches for one or many epochs. + # When `num_epochs > 1` and the caller requested a specific order, build + # `num_epochs * batches_per_epoch` batches up front so the trainer's + # `batches[batch_idx % len(batches)]` cycle reproduces the same per-epoch + # reseed semantics as the unlabeled `create_ordered_batches` path. + _n_epochs_materialize = ( + max(1, int(num_epochs)) if num_epochs is not None else 1 + ) rng = random.Random(seed) batches = [] - for start in range(0, len(all_items), batch_size): - batch_items = all_items[start:start + batch_size] - if not batch_items: - continue - max_len = max(len(ids) for ids, _ in batch_items) - # Match mlx-lm iterate_batches: +1 gives the autoregressive - # shift headroom so post-shift length is a clean _PAD_MULTIPLE. - padded_len = 1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE - padded_len = min(padded_len, max_seq_length) - - batch_ids = [] - batch_labels = [] - batch_lengths = [] - for ids, lbls in batch_items: - L = min(len(ids), padded_len) - pad_len = padded_len - L - batch_ids.append(ids[:L] + [0] * pad_len) - batch_labels.append(lbls[:L] + [-100] * pad_len) - # Right-half-open [start, end) to match the loss masks in - # utils.py:360/:393/:429/:439 (`steps < lengths[:, 1:]`). - # Pre-fix this was `[1, L - 1]` which paired with the old - # `<=` mask; the PR flipped the mask to `<` so the end - # value must shift up by one to keep training on the - # final supervised token. - batch_lengths.append([1, L]) - - batches.append(( - mx.array(batch_ids), - mx.array(batch_lengths), - mx.array(batch_labels), - )) - - # 4. Order the batch sequence. - # - preserve_dataset_order / "sequential": keep batch order - # (samples were already in their target order at step 2). - # - "torch_randperm": batches mirror torch.randperm at the - # sample level (step 2 above), so keep batch order here. - # - default / None: legacy length-sort emitted near-contiguous - # batches; shuffle them so adjacent steps are not similar. - if _order_requested: - pass - else: - rng.shuffle(batches) + for epoch_idx in range(_n_epochs_materialize): + epoch_items = _order_samples_for_epoch(all_items, epoch_idx) + epoch_batches = [] + for start in range(0, len(epoch_items), batch_size): + batch_items = epoch_items[start:start + batch_size] + if not batch_items: + continue + max_len = max(len(ids) for ids, _ in batch_items) + # Match mlx-lm iterate_batches: +1 gives the autoregressive + # shift headroom so post-shift length is a clean _PAD_MULTIPLE. + padded_len = 1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE + padded_len = min(padded_len, max_seq_length) + + batch_ids = [] + batch_labels = [] + batch_lengths = [] + for ids, lbls in batch_items: + L = min(len(ids), padded_len) + pad_len = padded_len - L + batch_ids.append(ids[:L] + [0] * pad_len) + batch_labels.append(lbls[:L] + [-100] * pad_len) + # Right-half-open [start, end) to match the loss masks in + # utils.py:360/:393/:429/:439 (`steps < lengths[:, 1:]`). + batch_lengths.append([1, L]) + + epoch_batches.append(( + mx.array(batch_ids), + mx.array(batch_lengths), + mx.array(batch_labels), + )) + + # 4. Order the batch sequence within the epoch. + # legacy default (length-sort) gets a per-batch shuffle so adjacent + # steps are not similar; explicit-order paths keep the sample order. + if not _order_requested: + rng.shuffle(epoch_batches) + batches.extend(epoch_batches) # Limit if needed if num_batches is not None and len(batches) > num_batches: @@ -2298,11 +2303,18 @@ def train_on_responses_only( append_eos=bool(getattr(args, "append_eos", True)), dataset_order=getattr(args, "dataset_order", "default"), preserve_dataset_order=bool(getattr(args, "preserve_dataset_order", False)), + num_epochs=( + int(args.num_train_epochs) + if getattr(args, "num_train_epochs", -1) > 0 + else None + ), ) # Safety check: detect all-masked labels early _check_all_masked(batches) - + trainer._prepared_batches_include_epochs = ( + getattr(args, "num_train_epochs", -1) > 0 + ) trainer._batches = batches # Process eval dataset too From f601a154a386dca4e373a1936cb8d8f054c67b8c Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Mon, 25 May 2026 13:32:16 +0000 Subject: [PATCH 28/48] Rename PR-numbered tests and shorten verbose comments --- .github/workflows/consolidated-tests-ci.yml | 4 +- ...ixes.py => test_mlx_batching_and_decay.py} | 2 +- ...a_components.py => test_mlx_cce_kernel.py} | 11 +- ...tize.py => test_mlx_dequantize_modules.py} | 14 ++- ...gated_delta.py => test_mlx_gated_delta.py} | 7 +- ..._imports.py => test_mlx_module_exports.py} | 19 ++-- tests/test_mlx_torch_shim_smoke.py | 4 +- ...nents.py => test_mlx_trainer_internals.py} | 7 +- unsloth_zoo/mlx/compile.py | 4 +- unsloth_zoo/mlx/loader.py | 6 +- unsloth_zoo/mlx/trainer.py | 101 ++++-------------- unsloth_zoo/mlx/utils.py | 80 ++++---------- 12 files changed, 74 insertions(+), 185 deletions(-) rename tests/{test_mlx_pr684_review_fixes.py => test_mlx_batching_and_decay.py} (99%) rename tests/{test_pr_a_components.py => test_mlx_cce_kernel.py} (97%) rename tests/{test_pr_a_dequantize.py => test_mlx_dequantize_modules.py} (93%) rename tests/{test_pr_a_gated_delta.py => test_mlx_gated_delta.py} (97%) rename tests/{test_pr_a_imports.py => test_mlx_module_exports.py} (93%) rename tests/{test_pr_a_deep_components.py => test_mlx_trainer_internals.py} (98%) diff --git a/.github/workflows/consolidated-tests-ci.yml b/.github/workflows/consolidated-tests-ci.yml index 6ab589c20..119dc8874 100644 --- a/.github/workflows/consolidated-tests-ci.yml +++ b/.github/workflows/consolidated-tests-ci.yml @@ -93,14 +93,14 @@ jobs: - name: pytest tests/security (HARD GATE) run: python -m pytest tests/security -v - - name: pytest tests/test_pr_a_imports + zoo-specific CPU tests + - name: pytest tests/test_mlx_module_exports + zoo-specific CPU tests # Run as SEPARATE pytest invocation: tests/security/conftest.py installs a # session-scoped network_blocker autouse fixture that would otherwise block # test_pypi_version_sync from reaching pypi.org. continue-on-error: true run: | python -m pytest \ - tests/test_pr_a_imports.py \ + tests/test_mlx_module_exports.py \ tests/test_rl_replacements_cpu.py \ tests/test_temporary_patches_imports.py \ tests/test_zoo_history_regressions.py \ diff --git a/tests/test_mlx_pr684_review_fixes.py b/tests/test_mlx_batching_and_decay.py similarity index 99% rename from tests/test_mlx_pr684_review_fixes.py rename to tests/test_mlx_batching_and_decay.py index 151059c4c..95ac4b5fa 100644 --- a/tests/test_mlx_pr684_review_fixes.py +++ b/tests/test_mlx_batching_and_decay.py @@ -190,7 +190,7 @@ def test_vlm_torch_randperm_seed_none_and_multi_epoch_batches(): assert first_epoch != second_epoch -def test_pr684_compiler_review_guards_are_present(): +def test_compiler_review_guards_are_present(): import unsloth_zoo.compiler as compiler import unsloth_zoo.mlx.compile as mlx_compile diff --git a/tests/test_pr_a_components.py b/tests/test_mlx_cce_kernel.py similarity index 97% rename from tests/test_pr_a_components.py rename to tests/test_mlx_cce_kernel.py index 74e4aa750..779bbd11f 100644 --- a/tests/test_pr_a_components.py +++ b/tests/test_mlx_cce_kernel.py @@ -14,12 +14,11 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -""" -PR-A end-to-end component exercises through the shim. +"""End-to-end MLX CCE kernel exercises through the simulation shim. -Goes one level deeper than test_pr_a_imports.py: actually constructs -inputs and runs the function bodies. Each test focuses on one -critical PR-A code path. +Goes one level deeper than test_mlx_module_exports.py: actually +constructs inputs and runs the function bodies, one critical CCE +code path per test. """ from __future__ import annotations @@ -202,7 +201,7 @@ def square_vjp(primals, cotangents, outputs): # --------------------------------------------------------------------------- -# 4. mx.array isinstance contract that PR-A relies on. +# 4. mx.array isinstance contract that MLX trainer relies on. # --------------------------------------------------------------------------- def test_torch_tensor_is_mx_array(): diff --git a/tests/test_pr_a_dequantize.py b/tests/test_mlx_dequantize_modules.py similarity index 93% rename from tests/test_pr_a_dequantize.py rename to tests/test_mlx_dequantize_modules.py index 8a98a8a6d..2cc5eb978 100644 --- a/tests/test_pr_a_dequantize.py +++ b/tests/test_mlx_dequantize_modules.py @@ -14,16 +14,14 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -""" -PR-A integration: exercise unsloth_zoo.mlx.loader._dequantize_selected_mlx_modules. +"""Exercise unsloth_zoo.mlx.loader._dequantize_selected_mlx_modules. Builds a synthetic MLX-style model with one QuantizedLinear submodule, -runs PR-A's dequantize-and-replace helper, verifies the result is -a numerically correct nn.Linear with the dequantized weight. - -This is the canonical PR-A code path: load_in_4bit=False (or -selective requantize) walks named_modules, finds QuantizedLinear, -calls mx.dequantize with mode='affine', and swaps in nn.Linear. +runs the dequantize-and-replace helper, and verifies the result is a +numerically correct nn.Linear with the dequantized weight. Mirrors the +load_in_4bit=False / selective requantize path: walks named_modules, +finds QuantizedLinear, calls mx.dequantize with mode='affine', swaps +in nn.Linear. """ from __future__ import annotations diff --git a/tests/test_pr_a_gated_delta.py b/tests/test_mlx_gated_delta.py similarity index 97% rename from tests/test_pr_a_gated_delta.py rename to tests/test_mlx_gated_delta.py index 1a8772e94..e26fb64c4 100644 --- a/tests/test_pr_a_gated_delta.py +++ b/tests/test_mlx_gated_delta.py @@ -14,8 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -""" -PR-A gated_delta_vjp end-to-end through the shim. +"""gated_delta_vjp end-to-end through the shim. Exercises: * mx.custom_function decorator + .vjp registration @@ -24,8 +23,8 @@ * .astype(mx.float32) at ~30 sites * mx.where / mx.expand_dims / mx.zeros_like -If forward + backward both produce finite tensors with the right shapes, -PR-A's VJP path is exercisable on Linux+CUDA. +If forward + backward both produce finite tensors with the right +shapes, the VJP path is exercisable on Linux+CUDA. """ from __future__ import annotations diff --git a/tests/test_pr_a_imports.py b/tests/test_mlx_module_exports.py similarity index 93% rename from tests/test_pr_a_imports.py rename to tests/test_mlx_module_exports.py index c166bd61e..e02060879 100644 --- a/tests/test_pr_a_imports.py +++ b/tests/test_mlx_module_exports.py @@ -14,12 +14,11 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -""" -PR-A integration: verify every MLX-using unsloth_zoo module imports -under the shim and exposes the symbols PR-B's Studio code calls. +"""Verify every MLX-using unsloth_zoo module imports under the shim +and exposes the symbols Studio integrations rely on. -If a test fails with `_Noop` / NotImplementedError, the failing symbol -identifies a TODO in mlx_simulation/. +If a test fails with `_Noop` / NotImplementedError, the failing +symbol identifies a TODO in mlx_simulation/. """ from __future__ import annotations @@ -46,14 +45,14 @@ def _install_shim(): "unsloth_zoo.mlx.cce.runtime_cce", "unsloth_zoo.gated_delta_vjp", ]) -def test_pr_a_module_imports(module_path): +def test_mlx_module_imports(module_path): import importlib mod = importlib.import_module(module_path) assert mod is not None # --------------------------------------------------------------------------- -# 2. PR-B contract: FastMLXModel and the dynamically-attached save methods +# 2. Studio backend contract: FastMLXModel and the dynamically-attached save methods # must be reachable. # --------------------------------------------------------------------------- @@ -85,7 +84,7 @@ def test_full_finetune_dtype_default_matches_torch_bf16(): def test_fast_mlx_model_save_helpers_exist(): - """PR-B calls model.save_pretrained_merged / save_lora_adapters / + """Studio backend calls model.save_pretrained_merged / save_lora_adapters / push_to_hub_merged on the FastMLXModel INSTANCE returned by FastMLXModel.from_pretrained. The helpers are module-level in loader.py and attached via types.MethodType after load. @@ -107,7 +106,7 @@ def test_trainer_classes(): MLXTrainer, MLXTrainingConfig, ) - # train_on_responses_only is the third symbol PR-B imports + # train_on_responses_only is the third symbol Studio backend imports import unsloth_zoo.mlx.trainer as mt assert hasattr(mt, "train_on_responses_only") or hasattr(mt, "MLXTrainer") @@ -119,7 +118,7 @@ def test_trainer_classes(): def test_mlx_loader_dequantize_replace_callable(): """The dequantize-and-replace helper used by FastMLXModel.from_pretrained.""" import unsloth_zoo.mlx.loader as ml - # PR-A names this `_dequantize_selected_mlx_modules`. + # The loader names this `_dequantize_selected_mlx_modules`. assert hasattr(ml, "_dequantize_selected_mlx_modules"), ( "expected _dequantize_selected_mlx_modules in unsloth_zoo.mlx.loader. " f"Got dequant-related: {[a for a in dir(ml) if 'dequant' in a.lower()]}" diff --git a/tests/test_mlx_torch_shim_smoke.py b/tests/test_mlx_torch_shim_smoke.py index ac1f3d105..e21b5a457 100644 --- a/tests/test_mlx_torch_shim_smoke.py +++ b/tests/test_mlx_torch_shim_smoke.py @@ -19,7 +19,7 @@ Verifies: 1. simulate_mlx_on_torch() succeeds and registers all named submodules. -2. PR-B's 5 fresh symbols (mx.metal.is_available, set_wired_limit, +2. Studio backend's 5 fresh symbols (mx.metal.is_available, set_wired_limit, device_info, clear_cache, synchronize) work as expected. 3. The ~70 trivial passthroughs round-trip vs torch on random inputs. 4. Sub-architecture VLM submodules auto-resolve via the MetaPathFinder. @@ -100,7 +100,7 @@ def test_vlm_subarch_auto_resolve(submodule): # --------------------------------------------------------------------------- -# 2. Tier 1: PR-B fresh symbols. +# 2. Tier 1: Studio backend fresh symbols. # --------------------------------------------------------------------------- def test_metal_is_available_returns_false(): diff --git a/tests/test_pr_a_deep_components.py b/tests/test_mlx_trainer_internals.py similarity index 98% rename from tests/test_pr_a_deep_components.py rename to tests/test_mlx_trainer_internals.py index 0981b099b..e370229f4 100644 --- a/tests/test_pr_a_deep_components.py +++ b/tests/test_mlx_trainer_internals.py @@ -14,9 +14,8 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -""" -PR-A deeper component exercises: trainer, compile discovery, -cce backward, and quantization helpers — beyond just imports. +"""Deeper MLX component exercises: trainer, compile discovery, +cce backward, and quantization helpers, beyond just imports. If a test fails, the failing component identifies the next gap. """ @@ -68,7 +67,7 @@ def test_mlx_training_config_is_dataclass_with_all_fields(): @pytest.mark.parametrize("optim_name", ["adamw", "adam", "sgd", "adafactor"]) def test_mlx_training_config_each_optim(optim_name): - """Every PR-A-supported optim string at least constructs cleanly in config.""" + """Every supported optim string constructs cleanly in config.""" from unsloth_zoo.mlx.trainer import MLXTrainingConfig cfg = MLXTrainingConfig(optim=optim_name) assert cfg.optim == optim_name diff --git a/unsloth_zoo/mlx/compile.py b/unsloth_zoo/mlx/compile.py index 30a8ce780..7986c8292 100644 --- a/unsloth_zoo/mlx/compile.py +++ b/unsloth_zoo/mlx/compile.py @@ -2683,9 +2683,7 @@ def patched_qwen3_attention(self, x, cu_seqlens, rotary_pos_emb=None): attn_outputs = [] for q_chunk, k_chunk, v_chunk in zip(*splits): - # MLX fused SDPA currently has a forward/value mismatch under - # value_and_grad for Qwen3-VL vision chunks. Use explicit attention - # here so training loss and plain forward loss agree. + # MLX fused SDPA mismatches value_and_grad for Qwen3-VL vision; use explicit. scores = ( q_chunk.astype(mx.float32) @ mx.swapaxes(k_chunk.astype(mx.float32), -1, -2) diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index d341a6032..f4c151dd7 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -147,8 +147,7 @@ def _convert_mlx_dtype(model, target_dtype, model_type: str = "") -> None: def _is_norm_parameter_path(path) -> bool: """Return whether a parameter path belongs to a normalization module.""" parts = str(path).lower().split(".") - # Match RMSNorm/LayerNorm via "norm" substring, plus GPT-2 / GPT-OSS - # style ln_1, ln_2, ln_f. + # "norm" matches RMSNorm/LayerNorm; ln_* covers GPT-2/GPT-OSS. return any( "norm" in part or part.startswith("ln_") or part == "ln_f" for part in parts[:-1] @@ -792,8 +791,7 @@ def patched_set_dtype(self, dtype): _MLX_QUANT_MODE_DEFAULTS = { "affine": (64, 4), - # Diagnostic CUDA bitsandbytes NF4 parity mode. This quantizes and then - # immediately dequantizes into dense Linear weights; it is not memory-saving. + # CUDA bnb NF4 parity (diagnostic): quantize then dequantize; not memory-saving. "nf4_dense": (64, 4), "mxfp4": (32, 4), "nvfp4": (16, 4), diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index be638d572..3b47182e5 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -97,8 +97,7 @@ def _normalize_mlx_optimizer_name(name): def _part_is_norm(part: str) -> bool: - # Match RMSNorm/LayerNorm/input_layernorm/etc. via "norm" substring, - # plus GPT-2 / GPT-OSS style ln_1, ln_2, ln_f. + # "norm" matches RMSNorm/LayerNorm/input_layernorm; ln_* covers GPT-2/GPT-OSS. return "norm" in part or part.startswith("ln_") or part == "ln_f" @@ -224,9 +223,7 @@ def _set_norm_output_cast_to_input_dtype(enabled: bool, model=None) -> None: result back matches PyTorch autocast behavior more closely: fp32 norm math, bf16/fp16 downstream activations. """ - # Keep the Qwen3-VL specialized vision-block norm patch in sync with - # the generic patcher below. Imported lazily to avoid a circular - # import at trainer-module load time. + # Sync Qwen3-VL vision-block patch with generic patcher (lazy import: cycle). try: from . import compile as _mlx_compile _mlx_compile.set_qwen3_vision_norm_cast_output(enabled) @@ -305,13 +302,8 @@ class MLXTrainingConfig: adam_beta1: float | None = None adam_beta2: float | None = None max_grad_norm: float = 0.0 # disabled by default on MLX to avoid clip-memory overhead - # Elementwise clip: each gradient element is clamped to - # `[-max_grad_value, +max_grad_value]`. Per-leaf only, no cross-leaf - # reduction, so it does not pay the global-norm memory overhead. - # None (default) keeps the cheap MLX default of 1.0 unless the user - # passes max_grad_norm > 0, in which case global-norm clipping wins. - # 0.0 disables. A positive float opts in explicitly and overrides - # max_grad_norm with a warning. + # Per-leaf elementwise clip to `[-v, +v]`. None defaults to 1.0 + # unless max_grad_norm > 0 (global-norm wins). 0.0 disables. max_grad_value: float | None = None seed: int = 3407 lora_plus_ratio: float = 0.0 # 0 = disabled, 16.0 = recommended @@ -352,12 +344,7 @@ class MLXTrainingConfig: wired_limit_gb: float | None = None # None = min(recommended working set, memory limit); <= 0 disables disable_memory_limits: bool = False cast_norm_output_to_input_dtype: bool = True # fp32 norm storage/math, bf16/fp16 downstream activations - # Append the tokenizer EOS id to each encoded text row before batching. - # Default True mirrors mlx-lm's TextDataset behavior so direct MLX - # text fine-tuning callers (raw `{"text": str}` rows) still train the - # model to predict EOS. Studio passes False because its chat template - # already renders EOS. - append_eos: bool = True + append_eos: bool = True # True = mlx-lm parity; Studio sets False (template owns EOS) # VLM / completion masking train_on_completions: bool = False # Mask prompt tokens in loss @@ -528,8 +515,7 @@ def decay_progress(step): ) / mx.array(max(total_steps - warmup, 1), dtype=mx.float32) def schedule(step): - # Match HuggingFace/Trainer LR as seen by the optimizer before - # each update. ``step`` is zero-based optimizer-step index. + # HF Trainer LR parity; `step` is zero-based optimizer-step index. step = mx.array(step).astype(mx.float32) if warmup > 0: warm = lr * warmup_multiplier(step) @@ -624,15 +610,13 @@ def _build_optimizer(self, total_steps): **adam_kwargs, ) elif opt_name == "sgd": - # HF parity: decoupled bias/norm-aware decay, applied manually. + # HF parity: manual bias/norm-aware decoupled decay. self._manual_weight_decay = float(wd or 0.0) optimizer = optim.SGD(learning_rate=initial_lr, weight_decay=0.0) elif opt_name == "muon": - # HF parity: decoupled bias/norm-aware decay, applied manually. self._manual_weight_decay = float(wd or 0.0) optimizer = optim.Muon(learning_rate=initial_lr, weight_decay=0.0) elif opt_name == "lion": - # HF parity: decoupled bias/norm-aware decay, applied manually. self._manual_weight_decay = float(wd or 0.0) optimizer = optim.Lion(learning_rate=initial_lr, weight_decay=0.0) self._resolved_optimizer_name = opt_name @@ -856,10 +840,7 @@ def train(self): args = self.args model = self.model cast_norm_output = bool(getattr(args, "cast_norm_output_to_input_dtype", True)) - # Remember the Qwen3-VL vision-block flag so the finally block can - # restore the original value rather than always forcing it to - # False (which would otherwise leak across train() boundaries for - # subsequent inference in the same process). + # Save Qwen3-VL vision-block flag so finally restores it (not just False). _prev_qwen3_vision_cast = True try: from . import compile as _mlx_compile @@ -868,11 +849,7 @@ def train(self): ) except Exception: pass - # Install the norm-output cast INSIDE the try/finally so a raise - # during any of the steps below (patch-mode normalization, memory - # limit configuration, compile policy, gradient checkpointing, - # Qwen3.5 preflight) still triggers the cleanup that restores the - # global RMSNorm / LayerNorm / Qwen3-VL flags. + # Patch INSIDE try/finally so any raise during setup still restores globals. _norm_cast_applied = False try: _set_norm_output_cast_to_input_dtype(cast_norm_output, model) @@ -954,18 +931,12 @@ def train(self): except Exception: pass if _norm_cast_applied and cast_norm_output: - # Undo the global norm-class monkey patch so later - # inference / unrelated trainers in the same Python - # process get the original RMSNorm / LayerNorm dtype - # behavior. Wrap in try/except: a partially patched - # state must still let `finally` run to completion. + # Undo the global norm-class patch; tolerate partial state. try: _set_norm_output_cast_to_input_dtype(False, model) except Exception: pass - # Restore the Qwen3-VL vision-block flag to whatever it was - # before train() started, instead of leaking the False that - # `_set_norm_output_cast_to_input_dtype(False, ...)` just set. + # Restore Qwen3-VL vision-block flag to its pre-train value. try: from . import compile as _mlx_compile _mlx_compile.set_qwen3_vision_norm_cast_output( @@ -986,9 +957,7 @@ def _train_inner(self): if is_vlm: processor = self._resolve_vlm_processor() - # VLM collation owns label creation/masking. These IDs should be - # redundant for normal SFT batches and are only a loss-side - # compatibility backstop for missing or externally supplied labels. + # Backstop only; VLM collation already owns label masking. _vlm_ignore_token_ids = _get_vlm_ignore_token_ids( processor=processor, config=getattr(model, "_config", {}), @@ -1209,8 +1178,7 @@ def _can_report_optimizer_state_norm(): return getattr(optimizer, "betas", None) def _clip_grad_by_value(grad): - # Elementwise clip to [-max_grad_value, +max_grad_value], - # per-leaf, no cross-leaf reduction. + # Per-leaf elementwise clip; no cross-leaf reduction. if not _clip_grad_value: return grad return tree_map( @@ -1784,12 +1752,7 @@ def _prepare_data(self, is_vlm): else: chat_tmpl = getattr(args, "chat_template", None) if args.streaming: - # `iterate_training_batches` does not yet take a - # `dataset_order` argument, so streaming text MLX - # training cannot honor `preserve_dataset_order` / - # `dataset_order="torch_randperm"`. Refuse instead of - # silently dropping the user-requested order so Studio - # / CUDA parity stays explicit. + # Streaming has no index space; refuse explicit order requests. if ( getattr(args, "preserve_dataset_order", False) or getattr(args, "dataset_order", "default") != "default" @@ -1988,11 +1951,7 @@ def _create_labeled_batches(dataset, tokenizer, mask_fn, batch_size, # slow tokenizers degrade gracefully via the GIL) def _process_text(text): encoded = tokenizer.encode(text) - # Honor the same `append_eos` contract as `_prepare_dataset`; the - # unlabeled text path (`_prepare_dataset` -> mlx-lm CacheDataset) - # appends or skips EOS based on the trainer's config, so the - # labeled `train_on_responses_only` path must match or the two - # produce different supervised tokens for the same input. + # Mirror `_prepare_dataset`'s EOS contract; mismatch desyncs labeled vs unlabeled. if append_eos and eos_id is not None and (not encoded or encoded[-1] != eos_id): encoded.append(eos_id) if len(encoded) > max_seq_length: @@ -2018,14 +1977,8 @@ def _process_text(text): "Check your dataset and formatting_func." ) - # 2. Determine sample ordering strategy. The labeled and unlabeled - # paths must agree at sample granularity or `dataset_order="torch_randperm"` - # produces a different sample stream under `train_on_responses_only(...)`. - # - preserve_dataset_order=True / "sequential": dataset order. - # - "torch_randperm": deterministic torch.randperm permutation - # (matches `create_ordered_batches` -> `_torch_randperm_order` - # at utils.py:2845-2849), reseeded per epoch. - # - default / None: legacy mlx-lm length-sort + per-batch shuffle. + # 2. Sample order; must agree with unlabeled `create_ordered_batches` + # (utils.py:2845-2849) so `train_on_responses_only` sees the same stream. _order_requested = preserve_dataset_order or ( dataset_order not in (None, "default") ) @@ -2041,19 +1994,13 @@ def _order_samples_for_epoch(items, epoch_idx): return list(items) if dataset_order == "torch_randperm": from .utils import _torch_randperm_order - # Match unlabeled `create_ordered_batches`: reseed per epoch - # so the second epoch sees a different sample order rather - # than repeating the first. + # Reseed per epoch (matches `create_ordered_batches`). order = _torch_randperm_order(len(items), seed + epoch_idx) return [items[i] for i in order] # legacy default: length-sort once return sorted(items, key=lambda x: len(x[0])) - # 3. Materialize batches for one or many epochs. - # When `num_epochs > 1` and the caller requested a specific order, build - # `num_epochs * batches_per_epoch` batches up front so the trainer's - # `batches[batch_idx % len(batches)]` cycle reproduces the same per-epoch - # reseed semantics as the unlabeled `create_ordered_batches` path. + # 3. Build `num_epochs` blocks so `batches[i % len]` cycle reseeds correctly. _n_epochs_materialize = ( max(1, int(num_epochs)) if num_epochs is not None else 1 ) @@ -2067,8 +2014,7 @@ def _order_samples_for_epoch(items, epoch_idx): if not batch_items: continue max_len = max(len(ids) for ids, _ in batch_items) - # Match mlx-lm iterate_batches: +1 gives the autoregressive - # shift headroom so post-shift length is a clean _PAD_MULTIPLE. + # +1 for autoregressive shift (mlx-lm iterate_batches parity). padded_len = 1 + ((max_len + _PAD_MULTIPLE - 1) // _PAD_MULTIPLE) * _PAD_MULTIPLE padded_len = min(padded_len, max_seq_length) @@ -2080,8 +2026,7 @@ def _order_samples_for_epoch(items, epoch_idx): pad_len = padded_len - L batch_ids.append(ids[:L] + [0] * pad_len) batch_labels.append(lbls[:L] + [-100] * pad_len) - # Right-half-open [start, end) to match the loss masks in - # utils.py:360/:393/:429/:439 (`steps < lengths[:, 1:]`). + # [start, end) matches loss masks in utils.py:360/:393/:429/:439. batch_lengths.append([1, L]) epoch_batches.append(( @@ -2090,9 +2035,7 @@ def _order_samples_for_epoch(items, epoch_idx): mx.array(batch_labels), )) - # 4. Order the batch sequence within the epoch. - # legacy default (length-sort) gets a per-batch shuffle so adjacent - # steps are not similar; explicit-order paths keep the sample order. + # 4. Legacy length-sort: shuffle batches so adjacent steps differ. if not _order_requested: rng.shuffle(epoch_batches) batches.extend(epoch_batches) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 6cb475f60..03b89b09e 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -421,13 +421,8 @@ def make_baseline_loss_fn(): """ def loss_fn(model, batch, lengths, labels=None): if labels is None: - # Match the CCE (`utils.py:360`, `:393`) and labels-aware - # baseline (`utils.py:439`) masks: end is exclusive. The - # pre-PR `<=` was inclusive and the comment said "byte-identical - # to mlx_lm.tuner.trainer.default_loss", but mlx_lm's lengths - # convention is right-half-open (`[start, end)`), so the - # consistent CCE / labels-aware paths are also the correct - # ones for the unlabeled baseline. + # Half-open [start, end) end-exclusive mask; matches CCE/labels paths + # (:360, :393, :439) and mlx_lm's lengths convention. inputs = batch[:, :-1] targets = batch[:, 1:] logits = model(inputs) @@ -689,18 +684,12 @@ def loss_fn(model, batch_dict): attention_mask = batch_dict.get("attention_mask") labels = batch_dict.get("labels") - # Match the CCE path semantics: forward the full multimodal - # sequence and shift the resulting logits afterwards. Qwen3-VL - # image / mRoPE / deepstack state depends on the complete - # sequence; trimming `input_ids[:, :-1]` before the multimodal - # forward gives a different loss from the full-logits CUDA - # path. Mirrors `_vlm_cce_forward` so use_cce=False stays in - # parity with use_cce=True. + # Forward full sequence then shift (Qwen3-VL mRoPE/deepstack need it); + # mirrors `_vlm_cce_forward` so use_cce={True,False} stay in parity. inputs = input_ids fwd_mask = attention_mask - # Forward pass — let the model create its own causal mask. - # Pass extra keys (e.g. image_grid_thw for Qwen) that some models need. + # Pass through extras (e.g. image_grid_thw for Qwen); model owns causal mask. fwd_kwargs = { k: v for k, v in batch_dict.items() if k not in ("input_ids", "pixel_values", "attention_mask", "labels") @@ -714,14 +703,10 @@ def loss_fn(model, batch_dict): logits = logits[:, :-1, :] if labels is not None: - # Labels encode instruction/padding/special-token masking when - # produced by MLX VLM collation. The extra mask keeps legacy - # externally supplied labels compatible. + # Extra masks keep externally supplied labels compatible. targets = labels[:, 1:] targets = _mask_image_tokens(targets, _image_token_ids) - # Apply train_on_completions prompt masking even when labels - # were preset by _apply_vlm_label_masks (which only handles - # image/pad/ignore tokens, not assistant-token boundaries). + # _apply_vlm_label_masks ignores assistant boundaries; mask here. targets = _mask_prompt_tokens(targets, _assistant_token_id) logits, targets = _align_logits_with_labels(logits, targets) if attention_mask is not None: @@ -858,11 +843,7 @@ def _vlm_cce_forward(model, batch_dict, image_token_ids=None, attention_mask = batch_dict.get("attention_mask") labels = batch_dict.get("labels") - # Match the standard VLM forward semantics: run the full multimodal - # sequence, then use hidden[:, :-1] to predict labels[:, 1:]. Qwen3-VL - # image/mRoPE/deepstack state depends on the complete sequence; trimming - # input_ids before the multimodal forward produces a different loss from - # the full-logits path and from CUDA. + # Forward full sequence then shift hidden[:, :-1] (Qwen3-VL mRoPE/deepstack). inputs = input_ids fwd_attn_mask = attention_mask @@ -881,10 +862,7 @@ def _vlm_cce_forward(model, batch_dict, image_token_ids=None, **extra_kwargs, ) merged_embeds, backbone_kwargs = _unpack_embed_result(embed_result, model) - # Prefer position_ids returned/stashed by get_input_embeddings (some - # VLM embedders, e.g. Qwen-VL family, adjust them for the merged - # multimodal sequence). Only fall back to the raw batch position_ids - # if the embedder did not produce its own. + # Prefer embedder-produced position_ids (Qwen-VL adjusts for merged seq). if "position_ids" in extra_kwargs and "position_ids" not in backbone_kwargs: backbone_kwargs["position_ids"] = extra_kwargs["position_ids"] @@ -897,9 +875,7 @@ def _vlm_cce_forward(model, batch_dict, image_token_ids=None, hidden = hidden[:, :-1, :] if labels is not None: - # Labels are the source of truth. Collation should already encode - # instruction/padding/special-token masking; the extra mask preserves - # compatibility for externally supplied labels. + # Extra mask keeps externally supplied labels compatible. targets = labels[:, 1:] masked_targets = _mask_image_tokens(targets, image_token_ids) if attention_mask is not None: @@ -911,11 +887,7 @@ def _vlm_cce_forward(model, batch_dict, image_token_ids=None, masked_targets, -100, ) - # Completion-only masking also has to run in the labels-aware - # branch because _apply_vlm_label_masks does not apply it. Pre-fix - # this was only applied in the labels=None branch, so VLM - # train_on_completions=True trained on prompt tokens whenever - # collation set batch["labels"]. + # Completion-only masking; _apply_vlm_label_masks doesn't do this. masked_targets = _mask_prompt_tokens(masked_targets, assistant_token_id) ntoks = (masked_targets != -100).sum() else: @@ -2281,9 +2253,7 @@ def _nest_vlm_images_by_sample(all_images): def _vlm_processor_prefers_nested_images(processor): cls = processor.__class__ marker = f"{getattr(cls, '__module__', '')}.{getattr(cls, '__name__', '')}".lower() - # Some mlx-vlm processors count images per prompt and require - # images=[[sample0_img0, ...], [sample1_img0, ...]]. Qwen/LLaVA/Gemma4-style - # processors consume a flat image stream and are handled by the default path. + # These processors need images grouped per-prompt; others take a flat list. return any( name in marker for name in ( @@ -2638,9 +2608,7 @@ def _emit(items): indices[i : i + batch_size] for i in range(0, len(indices), batch_size) ] - # Use a per-epoch local Generator so order is reproducible - # under `seed` and does not depend on global numpy RNG - # state. Mirrors the torch_randperm branch reseed above. + # Local RNG keeps order reproducible under `seed`; reseed per epoch. rng = np.random.default_rng(base_seed + epoch) order = rng.permutation(len(batch_indices)) for b in order: @@ -2655,10 +2623,7 @@ def _emit(items): yield _emit(items) epoch += 1 else: - # Iterable / unsized datasets cannot honor dataset_order because - # they expose no index space to permute. Refuse rather than - # silently stream source order (matches the text streaming path - # at trainer.py for the same asymmetry). + # Streaming has no index space; refuse rather than silently misorder. if dataset_order not in (None, "default"): raise ValueError( "Unsloth MLX VLM: preserve_dataset_order / " @@ -2885,16 +2850,10 @@ def make_order(epoch): if num_batches is None else None ) while num_batches is None or len(batch_pairs) < num_batches: - # Take up to batch_size contiguous indices from the current epoch. - # If the epoch tail is shorter, emit a partial batch and start - # the next batch fresh at epoch+1. Matches CUDA / SequentialSampler - # `drop_last=False` and the VLM ordered path at utils.py:2539, - # instead of mixing the last sample of one epoch with the first - # sample of the next inside the same micro-batch. + # Don't mix epochs in one batch; emit a partial then restart at epoch+1. + # Matches CUDA SequentialSampler `drop_last=False` and VLM path at :2539. if order_pos >= len(order): - # No more rows in this epoch. Stop only when we have hit - # either the requested number of batches or the requested - # total sample count (num_epochs * len(dataset)). + # Stop when num_batches or num_epochs*len(dataset) reached. if ( num_batches is None and (target_items is None or seen >= target_items) @@ -2915,10 +2874,7 @@ def make_order(epoch): batch_items = [tokenized[i] for i in chunk] max_length = max(len(ids) for ids in batch_items) - # Prefer the tokenizer's declared pad id; only fall back to 0 if - # the tokenizer has none. Matches mlx-lm's iterate_batches pad - # convention so the model receives a known special id (not raw 0) - # for padded positions in the forward input. + # mlx-lm iterate_batches pad convention; raw 0 only if no pad_token_id. _pad_id = getattr(tokenizer, "pad_token_id", None) if _pad_id is None: _pad_id = 0 From f545a00e990fe2323728df4178b44476e7d76a74 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 12:20:50 -0500 Subject: [PATCH 29/48] Add MLX max grad leaf norm clipping --- tests/test_mlx_max_grad_value_none.py | 152 +++++++++++++++++--------- tests/test_mlx_trainer_internals.py | 1 + unsloth_zoo/mlx/trainer.py | 134 ++++++++++++++++------- 3 files changed, 198 insertions(+), 89 deletions(-) diff --git a/tests/test_mlx_max_grad_value_none.py b/tests/test_mlx_max_grad_value_none.py index 21b62fe5c..9bf794cce 100644 --- a/tests/test_mlx_max_grad_value_none.py +++ b/tests/test_mlx_max_grad_value_none.py @@ -1,9 +1,10 @@ # Unsloth Zoo - Utilities for Unsloth -# Pin MLXTrainingConfig.max_grad_value resolution: -# * None (default) -> cheap MLX elementwise clip at 1.0, unless -# max_grad_norm > 0 is also passed (then global-norm wins). -# * 0.0 -> explicitly disabled. -# * positive -> explicit elementwise opt-in; overrides max_grad_norm. +# Pin MLXTrainingConfig cheap clipping resolution: +# * max_grad_leaf_norm is the proportional per-leaf norm cap. +# * max_grad_value keeps historical elementwise clamp semantics. +# * None defaults to cheap proportional leaf-norm clipping at 1.0 unless +# max_grad_norm > 0 is passed. +# * explicit 0.0 disables that specific cheap clipping knob. from __future__ import annotations @@ -16,94 +17,130 @@ def _install_mlx_shim(): simulate_mlx_on_torch() -def _resolve(raw_mgv, max_grad_norm): - """Mirror trainer.py's internal resolution. Returns the (max_grad_value, - max_grad_norm) pair the step function will actually use.""" - user_set = raw_mgv is not None - if user_set: - mgv = float(raw_mgv or 0.0) - if max_grad_norm > 0 and mgv > 0: - max_grad_norm = 0.0 - elif max_grad_norm > 0: - mgv = 0.0 - else: - mgv = 1.0 - return mgv, max_grad_norm +def _resolve(raw_mgv=None, raw_mgln=None, max_grad_norm=0.0): + from unsloth_zoo.mlx.trainer import MLXTrainingConfig, _resolve_mlx_grad_clipping + + cfg = MLXTrainingConfig( + max_grad_norm=max_grad_norm, + max_grad_value=raw_mgv, + max_grad_leaf_norm=raw_mgln, + output_dir="/tmp/x", + ) + return _resolve_mlx_grad_clipping(cfg) # -- field defaults --------------------------------------------------------- -def test_field_default_is_none_sentinel(): - """Default is None (a sentinel meaning 'use MLX cheap default').""" +def test_field_defaults_are_none_sentinels(): + """Defaults are sentinels meaning 'use MLX cheap default'.""" from unsloth_zoo.mlx.trainer import MLXTrainingConfig cfg = MLXTrainingConfig(output_dir="/tmp/x") assert cfg.max_grad_value is None + assert cfg.max_grad_leaf_norm is None -def test_field_accepts_none(): - """Field accepts None and round-trips through the dataclass.""" +def test_fields_accept_none(): + """Fields accept None and round-trip through the dataclass.""" from unsloth_zoo.mlx.trainer import MLXTrainingConfig - cfg = MLXTrainingConfig(max_grad_value=None, output_dir="/tmp/x") + cfg = MLXTrainingConfig( + max_grad_value=None, + max_grad_leaf_norm=None, + output_dir="/tmp/x", + ) assert cfg.max_grad_value is None + assert cfg.max_grad_leaf_norm is None -def test_field_accepts_explicit_positive(): - """Field accepts positive floats for power users opting in.""" +def test_fields_accept_explicit_positive(): + """Fields accept positive floats for power users opting in.""" from unsloth_zoo.mlx.trainer import MLXTrainingConfig - cfg = MLXTrainingConfig(max_grad_value=2.5, output_dir="/tmp/x") + cfg = MLXTrainingConfig( + max_grad_value=2.5, + max_grad_leaf_norm=1.5, + output_dir="/tmp/x", + ) assert cfg.max_grad_value == 2.5 + assert cfg.max_grad_leaf_norm == 1.5 # -- resolution semantics --------------------------------------------------- -def test_default_uses_cheap_elementwise(): - """Default (None, max_grad_norm=0.0) -> elementwise clip at 1.0.""" - mgv, mgn = _resolve(raw_mgv=None, max_grad_norm=0.0) - assert mgv == 1.0 +def test_default_uses_cheap_leaf_norm(): + """Default (all None, max_grad_norm=0.0) -> leaf norm clip at 1.0.""" + mgn, mgv, mgln, mode = _resolve(max_grad_norm=0.0) assert mgn == 0.0 + assert mgv == 0.0 + assert mgln == 1.0 + assert mode == "leaf_norm" def test_user_max_grad_norm_wins_over_default(): - """User passes max_grad_norm=1.0 with default max_grad_value=None -> - global-norm clipping wins, elementwise disabled. HF/TRL parity.""" - mgv, mgn = _resolve(raw_mgv=None, max_grad_norm=1.0) - assert mgv == 0.0 + """User passes max_grad_norm=1.0 with defaults -> global norm only.""" + mgn, mgv, mgln, mode = _resolve(max_grad_norm=1.0) assert mgn == 1.0 + assert mgv == 0.0 + assert mgln == 0.0 + assert mode == "global_norm" -def test_explicit_zero_disables_elementwise(): - """Explicit 0.0 disables elementwise. With no max_grad_norm, - nothing clips.""" - mgv, mgn = _resolve(raw_mgv=0.0, max_grad_norm=0.0) - assert mgv == 0.0 +def test_explicit_zero_disables_cheap_default(): + """Explicit 0.0 disables cheap clipping. With no max_grad_norm, no clip.""" + mgn, mgv, mgln, mode = _resolve(raw_mgv=0.0, max_grad_norm=0.0) assert mgn == 0.0 + assert mgv == 0.0 + assert mgln == 0.0 + assert mode == "none" def test_explicit_zero_lets_max_grad_norm_through(): - """Explicit max_grad_value=0.0 + max_grad_norm=1.0 -> only norm clipping.""" - mgv, mgn = _resolve(raw_mgv=0.0, max_grad_norm=1.0) - assert mgv == 0.0 + """Explicit cheap 0.0 + max_grad_norm=1.0 -> only norm clipping.""" + mgn, mgv, mgln, mode = _resolve(raw_mgv=0.0, max_grad_norm=1.0) assert mgn == 1.0 + assert mgv == 0.0 + assert mgln == 0.0 + assert mode == "global_norm" def test_explicit_positive_overrides_max_grad_norm(): """Explicit max_grad_value=2.0 with max_grad_norm=1.0 -> elementwise wins (existing rule), max_grad_norm zeroed.""" - mgv, mgn = _resolve(raw_mgv=2.0, max_grad_norm=1.0) - assert mgv == 2.0 + mgn, mgv, mgln, mode = _resolve(raw_mgv=2.0, max_grad_norm=1.0) assert mgn == 0.0 + assert mgv == 2.0 + assert mgln == 0.0 + assert mode == "value" def test_explicit_positive_alone(): """User passes max_grad_value=5.0 with no max_grad_norm -> elementwise at 5.""" - mgv, mgn = _resolve(raw_mgv=5.0, max_grad_norm=0.0) + mgn, mgv, mgln, mode = _resolve(raw_mgv=5.0, max_grad_norm=0.0) + assert mgn == 0.0 assert mgv == 5.0 + assert mgln == 0.0 + assert mode == "value" + + +def test_explicit_leaf_norm_overrides_max_grad_norm(): + """Explicit max_grad_leaf_norm uses proportional clipping and avoids global norm.""" + mgn, mgv, mgln, mode = _resolve(raw_mgln=1.3, max_grad_norm=1.0) + assert mgn == 0.0 + assert mgv == 0.0 + assert mgln == 1.3 + assert mode == "leaf_norm" + + +def test_max_grad_value_wins_over_leaf_norm_when_both_positive(): + """Keep max_grad_value's public elementwise meaning if both knobs are set.""" + mgn, mgv, mgln, mode = _resolve(raw_mgv=2.0, raw_mgln=1.3) assert mgn == 0.0 + assert mgv == 2.0 + assert mgln == 0.0 + assert mode == "value" # -- trainer source assertions (defense-in-depth) --------------------------- @@ -115,7 +152,24 @@ def test_trainer_source_pins_resolution_rule(): import inspect from unsloth_zoo.mlx import trainer as T - src = inspect.getsource(T.MLXTrainer.train) + inspect.getsource(T.MLXTrainer._train_inner) - assert "_user_set_mgv = _raw_mgv is not None" in src - assert "elif max_grad_norm > 0:" in src - assert "max_grad_value = 1.0" in src + src = ( + inspect.getsource(T._resolve_mlx_grad_clipping) + + inspect.getsource(T.MLXTrainer._train_inner) + ) + assert "max_grad_value" in src + assert "max_grad_leaf_norm" in src + assert 'return 0.0, 0.0, 1.0, "leaf_norm"' in src + + +def test_source_distinguishes_leaf_norm_from_elementwise_value_clip(): + """Pin the API split: value is elementwise, leaf_norm is proportional.""" + import inspect + from unsloth_zoo.mlx import trainer as T + + value_src = inspect.getsource(T._clip_grad_by_value) + leaf_src = inspect.getsource(T._clip_grad_by_leaf_norm) + + assert "mx.clip" in value_src + assert "mx.sqrt(mx.sum" in leaf_src + assert "return g * scale.astype(g.dtype)" in leaf_src + assert "mx.clip" not in leaf_src diff --git a/tests/test_mlx_trainer_internals.py b/tests/test_mlx_trainer_internals.py index e370229f4..dffc14db0 100644 --- a/tests/test_mlx_trainer_internals.py +++ b/tests/test_mlx_trainer_internals.py @@ -52,6 +52,7 @@ def test_mlx_training_config_is_dataclass_with_all_fields(): "optim", "weight_decay", "max_grad_norm", + "max_grad_leaf_norm", "seed", "logging_steps", "output_dir", diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index 3b47182e5..fb4281f75 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -283,6 +283,57 @@ def _normalize_mlx_scheduler_type(name): return sched_type +def _resolve_mlx_grad_clipping(args): + """Resolve mutually exclusive MLX clipping knobs. + + Returns ``(max_grad_norm, max_grad_value, max_grad_leaf_norm, mode)``. + ``max_grad_value`` keeps elementwise clamp semantics. ``max_grad_leaf_norm`` + is the cheap proportional alternative: cap each gradient leaf's L2 norm + without a cross-tree global reduction. + """ + max_grad_norm = float(getattr(args, "max_grad_norm", 0.0) or 0.0) + raw_value = getattr(args, "max_grad_value", None) + raw_leaf = getattr(args, "max_grad_leaf_norm", None) + user_set_value = raw_value is not None + user_set_leaf = raw_leaf is not None + + max_grad_value = float(raw_value or 0.0) if user_set_value else 0.0 + max_grad_leaf_norm = float(raw_leaf or 0.0) if user_set_leaf else 0.0 + + if max_grad_value > 0: + # Preserve the public meaning of max_grad_value as elementwise clamp. + return 0.0, max_grad_value, 0.0, "value" + + if max_grad_leaf_norm > 0: + return 0.0, 0.0, max_grad_leaf_norm, "leaf_norm" + + if max_grad_norm > 0: + return max_grad_norm, 0.0, 0.0, "global_norm" + + if user_set_value or user_set_leaf: + # Explicit 0.0 disables cheap clipping. + return 0.0, 0.0, 0.0, "none" + + # MLX default: cheap proportional clipping without global norm memory cost. + return 0.0, 0.0, 1.0, "leaf_norm" + + +def _clip_grad_by_value(grad, max_grad_value): + """Elementwise clamp; preserves the historical max_grad_value contract.""" + return tree_map(lambda g: mx.clip(g, -max_grad_value, max_grad_value), grad) + + +def _clip_grad_by_leaf_norm(grad, max_grad_leaf_norm): + """Scale each gradient leaf to a max L2 norm, preserving leaf direction.""" + def _clip_leaf_norm(g): + g_f = g.astype(mx.float32) + norm = mx.sqrt(mx.sum(g_f * g_f)) + scale = mx.minimum(max_grad_leaf_norm / (norm + 1e-6), 1.0) + return g * scale.astype(g.dtype) + + return tree_map(_clip_leaf_norm, grad) + + @dataclass class MLXTrainingConfig: """Training configuration mirroring SFTConfig / TrainingArguments field names.""" @@ -302,9 +353,13 @@ class MLXTrainingConfig: adam_beta1: float | None = None adam_beta2: float | None = None max_grad_norm: float = 0.0 # disabled by default on MLX to avoid clip-memory overhead - # Per-leaf elementwise clip to `[-v, +v]`. None defaults to 1.0 - # unless max_grad_norm > 0 (global-norm wins). 0.0 disables. + # Elementwise clip to `[-v, +v]`. None means "not requested"; + # positive values override other clipping modes to preserve API meaning. max_grad_value: float | None = None + # Proportional per-leaf L2 norm cap. This preserves each tensor's gradient + # direction and avoids max_grad_norm's cross-tree memory overhead. + # None uses MLX's cheap default of 1.0 unless another clip knob is explicit. + max_grad_leaf_norm: float | None = None seed: int = 3407 lora_plus_ratio: float = 0.0 # 0 = disabled, 16.0 = recommended embedding_learning_rate: float = 0.0 # 0 = disabled, 5e-5 = recommended @@ -1044,35 +1099,38 @@ def _train_inner(self): _needs_grad_scaling = use_lora_plus or use_embedding_lr _warned_skip_optimizer_state_grad_norm = False - # Build step functions following mlx-lm's pattern. - # Resolution rule: - # * max_grad_value=None (default) -> cheap MLX elementwise clip - # at 1.0, unless the user also passed max_grad_norm > 0 -- in - # that case the user opted into global-norm clipping (HF/TRL - # parity) and elementwise is disabled to avoid double-clip. - # * max_grad_value explicit (float or 0.0) -> honor exactly; - # if both modes are positive, elementwise wins (warn). - # max_grad_norm uses MLX's clip_grad_norm helper which materially - # increases peak memory on bf16 VLM runs, hence the elementwise - # default. - max_grad_norm = float(args.max_grad_norm or 0.0) - _raw_mgv = getattr(args, "max_grad_value", None) # TODO: expose MLX grad-clip in Studio UI for power users - _user_set_mgv = _raw_mgv is not None - if _user_set_mgv: - max_grad_value = float(_raw_mgv or 0.0) - if max_grad_norm > 0 and max_grad_value > 0: + # Build step functions following mlx-lm's pattern. `max_grad_value` + # remains an elementwise clamp. MLX's cheap default is now the clearer + # `max_grad_leaf_norm`, a proportional per-leaf norm cap that avoids + # global norm clipping's cross-tree memory overhead. + ( + max_grad_norm, + max_grad_value, + max_grad_leaf_norm, + _grad_clip_mode, + ) = _resolve_mlx_grad_clipping(args) + _raw_mgln = getattr(args, "max_grad_leaf_norm", None) + if max_grad_value > 0: + conflicts = [] + if float(getattr(args, "max_grad_norm", 0.0) or 0.0) > 0: + conflicts.append("max_grad_norm") + if _raw_mgln is not None and float(_raw_mgln or 0.0) > 0: + conflicts.append("max_grad_leaf_norm") + if conflicts: print( - "Unsloth: max_grad_norm and max_grad_value are both enabled; " - "ignoring max_grad_norm in favor of max_grad_value." + "Unsloth: max_grad_value is elementwise and overrides " + f"{', '.join(conflicts)}." ) - max_grad_norm = 0.0 - elif max_grad_norm > 0: - # User opted into global-norm clipping; suppress the default elementwise. - max_grad_value = 0.0 - else: - # Neither requested -> cheap MLX default. - max_grad_value = 1.0 + elif ( + max_grad_leaf_norm > 0 + and float(getattr(args, "max_grad_norm", 0.0) or 0.0) > 0 + ): + print( + "Unsloth: max_grad_leaf_norm is enabled; ignoring " + "max_grad_norm to avoid double clipping." + ) _clip_grad_value = max_grad_value > 0 + _clip_grad_leaf_norm = max_grad_leaf_norm > 0 state = [model.state, optimizer.state, mx.random.state] # The direct grad_accum==1 fast path delegates clipping to # mlx.optimizers.clip_grad_norm(). That helper is exact, but on current @@ -1082,7 +1140,8 @@ def _train_inner(self): grad_accum == 1 and not _needs_grad_scaling and max_grad_norm <= 0 and - not _clip_grad_value + not _clip_grad_value and + not _clip_grad_leaf_norm ) _restore_storage_after_norm_clip = max_grad_norm > 0 @@ -1177,15 +1236,6 @@ def _can_report_optimizer_state_norm(): # This avoids adding a second consumer to the lazy backward graph. return getattr(optimizer, "betas", None) - def _clip_grad_by_value(grad): - # Per-leaf elementwise clip; no cross-leaf reduction. - if not _clip_grad_value: - return grad - return tree_map( - lambda g: mx.clip(g, -max_grad_value, max_grad_value), - grad, - ) - def _apply_update(grad, toks_f): """Common gradient post-processing and optimizer update. @@ -1211,7 +1261,9 @@ def _apply_update(grad, toks_f): final_grad, max_norm=max_grad_norm ) if _clip_grad_value: - final_grad = _clip_grad_by_value(final_grad) + final_grad = _clip_grad_by_value(final_grad, max_grad_value) + if _clip_grad_leaf_norm: + final_grad = _clip_grad_by_leaf_norm(final_grad, max_grad_leaf_norm) self._apply_manual_weight_decay(model, optimizer, final_grad) optimizer.update(model, final_grad) _restore_trainable_storage_dtypes() @@ -1234,7 +1286,9 @@ def _apply_update_direct(grad): if max_grad_norm > 0: grad, grad_norm = optim.clip_grad_norm(grad, max_norm=max_grad_norm) if _clip_grad_value: - grad = _clip_grad_by_value(grad) + grad = _clip_grad_by_value(grad, max_grad_value) + if _clip_grad_leaf_norm: + grad = _clip_grad_by_leaf_norm(grad, max_grad_leaf_norm) self._apply_manual_weight_decay(model, optimizer, grad) optimizer.update(model, grad) _restore_trainable_storage_dtypes() From 6e6ab833e82b7480699784a11737c3360d8fba0a Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 12:29:12 -0500 Subject: [PATCH 30/48] Restore collated VLM position ids for parity --- tests/test_mlx_trainer_internals.py | 9 +++++++++ unsloth_zoo/mlx/utils.py | 7 +++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/test_mlx_trainer_internals.py b/tests/test_mlx_trainer_internals.py index dffc14db0..6b86e13bd 100644 --- a/tests/test_mlx_trainer_internals.py +++ b/tests/test_mlx_trainer_internals.py @@ -270,6 +270,15 @@ def test_mlx_text_loss_masks_exclude_position_at_sequence_length(): assert "steps < lengths[:, 1:]" in source +def test_vlm_cce_prefers_collated_position_ids_for_cuda_parity(): + import inspect + from unsloth_zoo.mlx import utils as mlx_utils + + source = inspect.getsource(mlx_utils._vlm_cce_forward) + assert 'if "position_ids" in extra_kwargs:' in source + assert 'and "position_ids" not in backbone_kwargs' not in source + + def test_mlx_train_result_reports_base_quantization(): import inspect from unsloth_zoo.mlx.trainer import MLXTrainer diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 03b89b09e..dae396696 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -862,8 +862,11 @@ def _vlm_cce_forward(model, batch_dict, image_token_ids=None, **extra_kwargs, ) merged_embeds, backbone_kwargs = _unpack_embed_result(embed_result, model) - # Prefer embedder-produced position_ids (Qwen-VL adjusts for merged seq). - if "position_ids" in extra_kwargs and "position_ids" not in backbone_kwargs: + # Collation builds CUDA-parity mRoPE position_ids for the full sequence. + # Use them over embedder fallbacks; preserving Qwen3-VL embedder-produced + # position_ids here shifts the first real-cat training loss from ~6.45 to + # ~6.90. + if "position_ids" in extra_kwargs: backbone_kwargs["position_ids"] = extra_kwargs["position_ids"] hidden = _forward_text_hidden_states( From 73006495acf5f8f65c04ffca37cecbf426cb731c Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 12:39:11 -0500 Subject: [PATCH 31/48] Scope VLM position id override to collated ids --- tests/test_mlx_trainer_internals.py | 8 +++++--- unsloth_zoo/mlx/utils.py | 15 ++++++++++----- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/test_mlx_trainer_internals.py b/tests/test_mlx_trainer_internals.py index 6b86e13bd..706ca5e23 100644 --- a/tests/test_mlx_trainer_internals.py +++ b/tests/test_mlx_trainer_internals.py @@ -274,9 +274,11 @@ def test_vlm_cce_prefers_collated_position_ids_for_cuda_parity(): import inspect from unsloth_zoo.mlx import utils as mlx_utils - source = inspect.getsource(mlx_utils._vlm_cce_forward) - assert 'if "position_ids" in extra_kwargs:' in source - assert 'and "position_ids" not in backbone_kwargs' not in source + forward_source = inspect.getsource(mlx_utils._vlm_cce_forward) + prepare_source = inspect.getsource(mlx_utils._prepare_vlm_batch_for_compile) + assert '"_unsloth_collated_position_ids"' in prepare_source + assert 'not k.startswith("_unsloth_")' in forward_source + assert 'use_collated_position_ids and "position_ids" in extra_kwargs' in forward_source def test_mlx_train_result_reports_base_quantization(): diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index dae396696..1b22a7766 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -848,9 +848,13 @@ def _vlm_cce_forward(model, batch_dict, image_token_ids=None, fwd_attn_mask = attention_mask # Collect extra keys (e.g. image_grid_thw for Qwen) that some models need. + use_collated_position_ids = bool( + batch_dict.get("_unsloth_collated_position_ids") + ) extra_kwargs = { k: v for k, v in batch_dict.items() if k not in ("input_ids", "pixel_values", "attention_mask", "labels") + and not k.startswith("_unsloth_") and v is not None } extra_kwargs = _trim_sequence_aligned_vlm_kwargs(extra_kwargs, inputs.shape[1]) @@ -862,11 +866,10 @@ def _vlm_cce_forward(model, batch_dict, image_token_ids=None, **extra_kwargs, ) merged_embeds, backbone_kwargs = _unpack_embed_result(embed_result, model) - # Collation builds CUDA-parity mRoPE position_ids for the full sequence. - # Use them over embedder fallbacks; preserving Qwen3-VL embedder-produced - # position_ids here shifts the first real-cat training loss from ~6.45 to - # ~6.90. - if "position_ids" in extra_kwargs: + # Prefer collator-built mRoPE IDs when present. Qwen/GLM collators build + # CUDA-parity full-sequence positions; recomputing inside the embedder moved + # Qwen3-VL first-step loss from ~6.45 to ~6.90 on the real-cat fixture. + if use_collated_position_ids and "position_ids" in extra_kwargs: backbone_kwargs["position_ids"] = extra_kwargs["position_ids"] hidden = _forward_text_hidden_states( @@ -1370,6 +1373,7 @@ def _prepare_vlm_batch_for_compile(batch_dict, config): video_token_id=int(_config_get(config, "video_token_id", _config_get(config, "video_token_index"))), spatial_merge_size=int(vision_config.get("spatial_merge_size", 2)), ) + batch_dict["_unsloth_collated_position_ids"] = True if model_type == "glm_ocr": input_ids = batch_dict.get("input_ids") @@ -1391,6 +1395,7 @@ def _prepare_vlm_batch_for_compile(batch_dict, config): video_token_id=int(_config_get(config, "video_token_id")), spatial_merge_size=int(vision_config.get("spatial_merge_size", 2)), ) + batch_dict["_unsloth_collated_position_ids"] = True if model_type == "phi3_v": input_ids = batch_dict.get("input_ids") From 830cd37035a2fe945814b71020e6ffd8d6caec95 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 12:57:14 -0500 Subject: [PATCH 32/48] Preserve returned VLM position ids --- tests/test_mlx_trainer_internals.py | 2 ++ unsloth_zoo/mlx/utils.py | 7 +++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_mlx_trainer_internals.py b/tests/test_mlx_trainer_internals.py index 706ca5e23..61b197b91 100644 --- a/tests/test_mlx_trainer_internals.py +++ b/tests/test_mlx_trainer_internals.py @@ -275,10 +275,12 @@ def test_vlm_cce_prefers_collated_position_ids_for_cuda_parity(): from unsloth_zoo.mlx import utils as mlx_utils forward_source = inspect.getsource(mlx_utils._vlm_cce_forward) + unpack_source = inspect.getsource(mlx_utils._unpack_embed_result) prepare_source = inspect.getsource(mlx_utils._prepare_vlm_batch_for_compile) assert '"_unsloth_collated_position_ids"' in prepare_source assert 'not k.startswith("_unsloth_")' in forward_source assert 'use_collated_position_ids and "position_ids" in extra_kwargs' in forward_source + assert 'lm is not None and "position_ids" not in backbone_kwargs' in unpack_source def test_mlx_train_result_reports_base_quantization(): diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 1b22a7766..d29c55fb3 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -781,13 +781,16 @@ def _unpack_embed_result(embed_result, model): else: merged_embeds = embed_result - # Qwen-VL family: get_input_embeddings stashes position_ids on the + # Qwen-VL family: some get_input_embeddings paths stash position_ids on the # language model wrapper; the inner backbone needs them explicitly. + # Do not override position_ids explicitly returned by InputEmbeddingsFeatures + # (for example when the collator passed CUDA-parity mRoPE IDs through the + # embedder). # When no position_ids were stashed (e.g. text-only samples or simple # images without grid_thw), generate sequential ones so the backbone # doesn't crash accessing cache.offset with cache=None. lm = getattr(model, "language_model", None) - if lm is not None: + if lm is not None and "position_ids" not in backbone_kwargs: _MISSING = object() pos_ids = getattr(lm, "_position_ids", _MISSING) if pos_ids is not _MISSING and pos_ids is not None: From 67b35048d8d5cc819e9b368cd42c29498d401b25 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 15:56:18 -0500 Subject: [PATCH 33/48] Fix MLX VLM parity masking edge cases --- tests/test_mlx_trainer_internals.py | 6 +- tests/test_mlx_vlm_label_masks.py | 117 +++++++++++++++++++++ unsloth_zoo/mlx/loader.py | 89 +++++++++++++++- unsloth_zoo/mlx/utils.py | 155 ++++++++++++++++++++++++++-- 4 files changed, 354 insertions(+), 13 deletions(-) diff --git a/tests/test_mlx_trainer_internals.py b/tests/test_mlx_trainer_internals.py index 61b197b91..678cb52c0 100644 --- a/tests/test_mlx_trainer_internals.py +++ b/tests/test_mlx_trainer_internals.py @@ -205,10 +205,8 @@ def test_scheduler_lr_matches_expected_optimizer_update_steps(scheduler, warmup) if scheduler == "linear" and warmup == 0: # Match `transformers.get_scheduler("linear", num_warmup_steps=0, - # num_training_steps=total_steps)`: step 0 = learning_rate, then - # decays linearly to 0 over total_steps. The earlier expectation - # of `[0, lr, lr*6/7, ...]` would have the first optimizer - # update fire at zero LR and is inconsistent with HF behavior. + # num_training_steps=total_steps)` as seen by optimizer steps across + # Transformers 4.56.1 through 5.5.0: step 1 uses base LR, then decays. lr = trainer.args.learning_rate expected = [lr * (total_steps - step) / total_steps for step in range(total_steps)] assert values == pytest.approx(expected) diff --git a/tests/test_mlx_vlm_label_masks.py b/tests/test_mlx_vlm_label_masks.py index 353f8bc8d..a5b264c4c 100644 --- a/tests/test_mlx_vlm_label_masks.py +++ b/tests/test_mlx_vlm_label_masks.py @@ -173,6 +173,71 @@ def call(self, text, images=None, **_kwargs): assert processor.seen_images == expected +def test_vlm_processor_inputs_retries_duplicate_add_special_tokens(): + from unsloth_zoo.mlx.utils import _processor_vlm_inputs + + class PaddleLikeProcessor: + __module__ = "mlx_vlm.models.paddleocr_vl.processing_paddleocr_vl" + + def __init__(self): + self.calls = [] + + def __call__(self, text, images=None, **kwargs): + self.calls.append(dict(kwargs)) + if "add_special_tokens" in kwargs: + raise TypeError( + "got multiple values for keyword argument 'add_special_tokens'" + ) + return { + "input_ids": np.ones((len(text), 2), dtype=np.int32), + "attention_mask": np.ones((len(text), 2), dtype=np.int32), + } + + processor = PaddleLikeProcessor() + _processor_vlm_inputs(processor, ["a"], [["img0"]], 8) + + assert "add_special_tokens" in processor.calls[0] + assert "add_special_tokens" not in processor.calls[1] + + +def test_deepseek_ocr_loader_patches_removed_llama_flash_attention(monkeypatch): + import sys + import types + + from unsloth_zoo.mlx.loader import _patch_deepseek_ocr_transformers_import_compat + + llama_module = types.SimpleNamespace(LlamaAttention=object) + package = types.ModuleType("transformers.models.llama") + package.modeling_llama = llama_module + monkeypatch.setitem(sys.modules, "transformers.models.llama", package) + import transformers.utils.import_utils as import_utils + monkeypatch.delattr(import_utils, "is_torch_fx_available", raising=False) + + _patch_deepseek_ocr_transformers_import_compat("deepseekocr") + + assert llama_module.LlamaFlashAttention2 is llama_module.LlamaAttention + assert import_utils.is_torch_fx_available() is False + + +def test_deepseek_rendering_repairs_missing_image_token(): + from unsloth_zoo.mlx.utils import _render_vlm_messages + + class DeepseekProcessor: + __module__ = "mlx_vlm.models.deepseekocr.processing_deepseekocr" + image_token = "" + chat_template = "deepseek" + + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False): + return "question" + + text = _render_vlm_messages( + DeepseekProcessor(), + [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "question"}]}], + ) + + assert text == "question" + + def test_token_expansion_masks_inserted_label_positions(): from unsloth_zoo.mlx.utils import _expand_token_runs @@ -202,3 +267,55 @@ def test_mlx_trainer_does_not_attach_processor_for_loss_masking(): assert "self.model._processor =" not in trainer_source assert "_get_vlm_ignore_token_ids(" in trainer_source + + +def test_gemma3_vlm_cce_does_not_forward_outer_product_attention_mask(): + from types import SimpleNamespace + + from unsloth_zoo.mlx.utils import _unpack_embed_result + + embeds = mx.ones((1, 4, 8)) + outer_mask = mx.ones((1, 1, 4, 4), dtype=mx.int32) + embed_result = SimpleNamespace( + inputs_embeds=embeds, + attention_mask_4d=outer_mask, + ) + + _merged, kwargs = _unpack_embed_result( + embed_result, + SimpleNamespace(config=SimpleNamespace(model_type="gemma3")), + ) + + assert "attention_mask_4d" not in kwargs + + +def test_non_gemma3_vlm_cce_keeps_embedder_attention_mask(): + from types import SimpleNamespace + + from unsloth_zoo.mlx.utils import _unpack_embed_result + + embeds = mx.ones((1, 4, 8)) + outer_mask = mx.ones((1, 1, 4, 4), dtype=mx.int32) + embed_result = SimpleNamespace( + inputs_embeds=embeds, + attention_mask_4d=outer_mask, + ) + + _merged, kwargs = _unpack_embed_result( + embed_result, + SimpleNamespace(config=SimpleNamespace(model_type="gemma3n")), + ) + + assert kwargs["attention_mask_4d"] is outer_mask + + +def test_gemma_image_attention_mask_allows_bidirectional_image_block(): + from unsloth_zoo.mlx.utils import _build_gemma_image_attention_mask + + token_type_ids = mx.array([[0, 1, 1, 0]], dtype=mx.int32) + mask = _build_gemma_image_attention_mask(token_type_ids)[0, 0].tolist() + + assert mask[0] == [True, False, False, False] + assert mask[1] == [True, True, True, False] + assert mask[2] == [True, True, True, False] + assert mask[3] == [True, True, True, True] diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index f4c151dd7..a1dab2295 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -27,6 +27,7 @@ import inspect import math import os +import tempfile import types import warnings from contextlib import contextmanager @@ -288,6 +289,81 @@ def _message_matches_known_fallback(message, rule): return all(token in message for token in rule.get("message_tokens", ())) +def _patch_deepseek_ocr_transformers_import_compat(model_type): + """Let DeepSeek-OCR remote config imports survive newer Transformers. + + The MLX path does not instantiate the Torch Llama flash-attention class, + but DeepSeek-OCR's tokenizer/config import still imports that symbol from + Transformers. Recent Transformers releases removed it, so provide the + nearest eager-attention alias only for this import-time compatibility case. + """ + if model_type not in {"deepseekocr", "deepseekocr_2", "deepseek_vl_v2"}: + return + try: + from transformers.models.llama import modeling_llama + except Exception: + return + if ( + not hasattr(modeling_llama, "LlamaFlashAttention2") + and hasattr(modeling_llama, "LlamaAttention") + ): + modeling_llama.LlamaFlashAttention2 = modeling_llama.LlamaAttention + try: + from transformers.utils import import_utils + except Exception: + return + if not hasattr(import_utils, "is_torch_fx_available"): + import_utils.is_torch_fx_available = lambda: False + + +def _deepseek_ocr_config_model_type(config_data): + architectures = config_data.get("architectures") or () + if isinstance(architectures, str): + architectures = (architectures,) + normalized = {str(arch).lower() for arch in architectures} + if "deepseekocrforcausallm" in normalized: + return "deepseekocr" + if "deepseekocr2forcausallm" in normalized: + return "deepseekocr_2" + return None + + +def _materialize_mlx_vlm_config_override(local_path, config_data): + """Return a load path whose config routes known repos to the right mlx-vlm class.""" + if not local_path: + return local_path, config_data + corrected_model_type = _deepseek_ocr_config_model_type(config_data) + if ( + corrected_model_type is None + or config_data.get("model_type") == corrected_model_type + ): + return local_path, config_data + + patched_config = dict(config_data) + patched_config["model_type"] = corrected_model_type + # mlx-vlm supplies the model/processor implementation locally. Keeping the + # Torch remote-code auto_map here makes AutoProcessor import incompatible + # DeepSeek OCR Torch modules during MLX loads. + patched_config.pop("auto_map", None) + override_dir = tempfile.mkdtemp(prefix="unsloth_mlx_vlm_config_") + for name in os.listdir(local_path): + src = os.path.join(local_path, name) + dst = os.path.join(override_dir, name) + if name == "config.json": + continue + try: + os.symlink(src, dst) + except FileExistsError: + pass + with open(os.path.join(override_dir, "config.json"), "w") as f: + json.dump(patched_config, f, indent=2) + print( + "Unsloth: Routing DeepSeek OCR checkpoint through " + f"mlx-vlm model_type={corrected_model_type!r}." + ) + return override_dir, patched_config + + def _load_mlx_lm_with_strict_fallback( model_name, model_type, @@ -358,6 +434,7 @@ def _load_mlx_vlm_with_extra_weight_filter( through load(), so retry with a temporary load_weights shim only for registered mismatch signatures and exact allow-listed keys. """ + _patch_deepseek_ocr_transformers_import_compat(model_type) try: with _temporary_hf_token_env(hf_token): return vlm_load(model_name, **vlm_kwargs) @@ -2388,6 +2465,10 @@ def from_pretrained( config_data = json.load(f) except (json.JSONDecodeError, KeyError): config_data = {} + local_path, config_data = _materialize_mlx_vlm_config_override( + local_path, + config_data, + ) # Reject full_finetuning against a pre-quantized repo. The weights on # disk are int4/int8 packed; full FT would need them in a trainable @@ -2653,14 +2734,16 @@ def from_pretrained( from mlx_vlm.utils import load_config as _vlm_load_config print(f"Unsloth: Loading {model_name} via mlx-vlm (VLM, " f"runtime {quantization_spec.bits}-bit {quantization_spec.mode} quantization)...") + _patch_deepseek_ocr_transformers_import_compat(model_type) + vlm_load_target = local_path or model_name with _temporary_hf_token_env(token): model, processor = vlm_load( - model_name, + vlm_load_target, lazy=True, revision=revision, **extra_kwargs, ) - vlm_cfg = _vlm_load_config(local_path or model_name) + vlm_cfg = _vlm_load_config(vlm_load_target) model, vlm_cfg = _apply_mlx_quantization( model, vlm_cfg, quantization_spec, is_vlm=True, user_predicate=quant_predicate, @@ -2676,7 +2759,7 @@ def from_pretrained( if target_dtype is not None: vlm_kwargs["lazy"] = True model, processor = _load_mlx_vlm_with_extra_weight_filter( - model_name, + local_path or model_name, model_type, vlm_load, vlm_kwargs, diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index d29c55fb3..6012a2ded 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -164,6 +164,36 @@ def _has_hidden_stack(obj): ) +def _build_gemma_image_attention_mask(token_type_ids, attention_mask=None, + window_size=None): + seq_len = token_type_ids.shape[1] + q_idx = mx.arange(seq_len)[:, None] + kv_idx = mx.arange(seq_len)[None, :] + causal = q_idx >= kv_idx + if window_size is not None: + causal = mx.logical_and(causal, q_idx < kv_idx + int(window_size)) + + is_image = token_type_ids == 1 + previous_image = mx.concatenate( + [mx.zeros_like(is_image[:, :1]), is_image[:, :-1]], + axis=1, + ) + new_image_start = mx.logical_and(is_image, mx.logical_not(previous_image)) + group_ids = mx.cumsum(new_image_start.astype(mx.int32), axis=1) - 1 + group_ids = mx.where(is_image, group_ids, -1) + same_image_group = mx.logical_and( + group_ids[:, :, None] == group_ids[:, None, :], + group_ids[:, :, None] >= 0, + ) + + mask = mx.logical_or(causal[None, :, :], same_image_group) + if attention_mask is not None: + valid = attention_mask.astype(mx.bool_) + mask = mx.logical_and(mask, valid[:, :, None]) + mask = mx.logical_and(mask, valid[:, None, :]) + return mx.expand_dims(mask, axis=1) + + def _run_hidden_stack(stack, inputs, inputs_embeds=None, **kwargs): """Execute a language stack up to pre-lm_head hidden states.""" from mlx_vlm.models.base import create_attention_mask @@ -184,11 +214,35 @@ def _run_hidden_stack(stack, inputs, inputs_embeds=None, **kwargs): mask = kwargs.get("attention_mask_4d") if mask is None: mask = kwargs.get("attention_mask") - if mask is None: + token_type_ids = kwargs.get("token_type_ids") + token_type_mask = None + if token_type_ids is not None: + attention_mask = kwargs.get("attention_mask") + token_type_mask = _build_gemma_image_attention_mask( + token_type_ids, + attention_mask=attention_mask, + ) + if mask is None and token_type_mask is None: mask = create_attention_mask(h, cache) - for layer, c in zip(stack.layers, cache): - h = layer(h, mask, c) + sliding_window_pattern = getattr(stack, "sliding_window_pattern", None) + window_size = getattr(stack, "window_size", None) + sliding_token_type_mask = None + if token_type_ids is not None and sliding_window_pattern and window_size: + sliding_token_type_mask = _build_gemma_image_attention_mask( + token_type_ids, + attention_mask=kwargs.get("attention_mask"), + window_size=window_size, + ) + + for i, (layer, c) in enumerate(zip(stack.layers, cache)): + local_mask = mask + if token_type_mask is not None: + is_global = not sliding_window_pattern or ( + i % sliding_window_pattern == sliding_window_pattern - 1 + ) + local_mask = token_type_mask if is_global else sliding_token_type_mask + h = layer(h, local_mask, c) return stack.norm(h) @@ -755,7 +809,9 @@ def _unpack_embed_result(embed_result, model): if hasattr(embed_result, "inputs_embeds"): merged_embeds = embed_result.inputs_embeds if getattr(embed_result, "attention_mask_4d", None) is not None: - backbone_kwargs["attention_mask_4d"] = embed_result.attention_mask_4d + model_type = _config_get(getattr(model, "config", None), "model_type") + if model_type != "gemma3": + backbone_kwargs["attention_mask_4d"] = embed_result.attention_mask_4d if getattr(embed_result, "position_ids", None) is not None: backbone_kwargs["position_ids"] = embed_result.position_ids # Gemma4: per-layer inputs for vision token injection @@ -874,6 +930,10 @@ def _vlm_cce_forward(model, batch_dict, image_token_ids=None, # Qwen3-VL first-step loss from ~6.45 to ~6.90 on the real-cat fixture. if use_collated_position_ids and "position_ids" in extra_kwargs: backbone_kwargs["position_ids"] = extra_kwargs["position_ids"] + if "token_type_ids" in extra_kwargs: + backbone_kwargs["token_type_ids"] = extra_kwargs["token_type_ids"] + if attention_mask is not None: + backbone_kwargs["attention_mask"] = attention_mask hidden = _forward_text_hidden_states( model, @@ -963,6 +1023,12 @@ def _normalize_grid_thw(grid_thw): return tuple(normalized) +def _grid_thw_to_mx_array(grid_thw): + if grid_thw is None: + return None + return mx.array(grid_thw, dtype=mx.int32) + + def _normalize_size_tuples(values): if values is None: return None @@ -1336,10 +1402,17 @@ def _prepare_vlm_batch_for_compile(batch_dict, config): spatial_shapes = _normalize_size_tuples(batch_dict.get("spatial_shapes")) images_spatial_crop = _normalize_size_tuples(batch_dict.get("images_spatial_crop")) audio_embed_sizes = _normalize_int_tuple(batch_dict.get("audio_embed_sizes")) + grid_as_array = model_type in {"glm4v", "glm_ocr"} if image_grid_thw is not None: - batch_dict["image_grid_thw"] = image_grid_thw + # GLM native mlx-vlm paths call .tolist(), .prod(), and slicing on + # grids; Qwen/Paddle compile patches expect Python tuples. + batch_dict["image_grid_thw"] = ( + _grid_thw_to_mx_array(image_grid_thw) if grid_as_array else image_grid_thw + ) if video_grid_thw is not None: - batch_dict["video_grid_thw"] = video_grid_thw + batch_dict["video_grid_thw"] = ( + _grid_thw_to_mx_array(video_grid_thw) if grid_as_array else video_grid_thw + ) if image_sizes is not None: batch_dict["image_sizes"] = image_sizes if spatial_shapes is not None: @@ -2002,6 +2075,42 @@ def _flatten_vlm_messages_to_content_parts(messages): return parts +def _count_vlm_image_parts(messages): + if isinstance(messages, str): + return 0 + count = 0 + for message in messages or []: + if not isinstance(message, dict): + continue + content = message.get("content", "") + if not isinstance(content, list): + continue + for part in content: + if isinstance(part, dict) and part.get("type") == "image": + count += 1 + return count + + +def _repair_deepseek_rendered_image_tokens(processor, text, messages): + if not isinstance(text, str) or not text.strip(): + return text + marker = ( + f"{processor.__class__.__module__}.{processor.__class__.__name__}" + ).lower() + if "deepseek" not in marker: + return text + image_count = _count_vlm_image_parts(messages) + if image_count <= 0: + return text + image_token = getattr(processor, "image_token", None) + if not image_token: + return text + missing = image_count - text.count(image_token) + if missing <= 0: + return text + return (image_token * missing) + text + + def _processor_accepts_assistant_list_content(processor): cached = getattr(processor, "_unsloth_assistant_single_content", None) if cached is not None: @@ -2044,6 +2153,7 @@ def _render_vlm_messages(processor, messages): tokenize=False, add_generation_prompt=False, ) + text = _repair_deepseek_rendered_image_tokens(processor, text, messages) if isinstance(text, str) and text.strip(): return text except Exception as first_exc: @@ -2057,6 +2167,7 @@ def _render_vlm_messages(processor, messages): tokenize=False, add_generation_prompt=False, ) + text = _repair_deepseek_rendered_image_tokens(processor, text, messages) if isinstance(text, str) and text.strip(): return text except Exception as second_exc: @@ -2070,6 +2181,7 @@ def _render_vlm_messages(processor, messages): tokenize=False, add_generation_prompt=False, ) + text = _repair_deepseek_rendered_image_tokens(processor, text, messages) if isinstance(text, str) and text.strip(): return text except Exception as third_exc: @@ -2320,6 +2432,10 @@ def _to_mx_vlm_batch(inputs): batch["attention_mask"] = batch["attention_mask"].astype(mx.int32) if "labels" in batch: batch["labels"] = batch["labels"].astype(mx.int32) + if "token_type_ids" in batch: + batch["token_type_ids"] = batch["token_type_ids"].astype(mx.int32) + if "mm_token_type_ids" in batch: + batch["mm_token_type_ids"] = batch["mm_token_type_ids"].astype(mx.int32) return batch @@ -2344,6 +2460,14 @@ def _processor_vlm_inputs(processor, texts, all_images, max_seq_length, suffixes image_layouts = (None,) if suffixes is not None and any(suffix is not None for suffix in suffixes): base_kwargs["suffix"] = [suffix or "" for suffix in suffixes] + marker = f"{processor.__class__.__module__}.{processor.__class__.__name__}".lower() + if ( + "gemma3" in marker + or "gemma4" in marker + or "qwen3_vl" in marker + or "qwen3_5" in marker + ): + base_kwargs["return_mm_token_type_ids"] = True first_error = None for image_layout in image_layouts: @@ -2356,6 +2480,25 @@ def _processor_vlm_inputs(processor, texts, all_images, max_seq_length, suffixes ) try: return processor(**proc_kwargs) + except TypeError as exc: + if ( + "add_special_tokens" in str(exc) + and "multiple values" in str(exc) + and "add_special_tokens" in proc_kwargs + ): + proc_kwargs.pop("add_special_tokens", None) + try: + return processor(**proc_kwargs) + except Exception as retry_exc: + if first_error is None: + first_error = retry_exc + if len(image_layouts) == 1: + raise + continue + if first_error is None: + first_error = exc + if len(image_layouts) == 1: + raise except Exception as exc: if first_error is None: first_error = exc From 35c0710f5d1cbef241d14c5094fc5054f7167ca5 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 16:06:56 -0500 Subject: [PATCH 34/48] Route text-only VLM loads through text trainer --- tests/test_mlx_vlm_label_masks.py | 12 ++++++++++++ unsloth_zoo/mlx/utils.py | 2 ++ 2 files changed, 14 insertions(+) diff --git a/tests/test_mlx_vlm_label_masks.py b/tests/test_mlx_vlm_label_masks.py index a5b264c4c..c23d871aa 100644 --- a/tests/test_mlx_vlm_label_masks.py +++ b/tests/test_mlx_vlm_label_masks.py @@ -269,6 +269,18 @@ def test_mlx_trainer_does_not_attach_processor_for_loss_masking(): assert "_get_vlm_ignore_token_ids(" in trainer_source +def test_text_only_vlm_wrapper_uses_text_training_path(): + from unsloth_zoo.mlx.utils import _is_vlm_model + + class TextOnlyVLMWrapper: + _is_vlm_model = True + _unsloth_text_only_vlm = True + language_model = object() + vision_tower = object() + + assert _is_vlm_model(TextOnlyVLMWrapper()) is False + + def test_gemma3_vlm_cce_does_not_forward_outer_product_attention_mask(): from types import SimpleNamespace diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 6012a2ded..9ae00a71a 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -676,6 +676,8 @@ def _mask_prompt_tokens(targets, assistant_token_id): def _is_vlm_model(model) -> bool: """Check if model is a VLM (has language_model + vision component).""" + if getattr(model, "_unsloth_text_only_vlm", False): + return False explicit_flag = getattr(model, "_is_vlm_model", None) if explicit_flag is not None: return bool(explicit_flag) From 5f497dce2146f7dbc1f9eb8c23dc279cd4fa86b1 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 16:28:15 -0500 Subject: [PATCH 35/48] Match BNB nested NF4 scale quantization --- unsloth_zoo/mlx/loader.py | 46 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index a1dab2295..38be864a9 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -1580,6 +1580,51 @@ def _nf4_dense_dequantize_weight(weight, group_size=64): ], dtype=mx.float32, ) + + def _bnb_dynamic_codebook(): + data = [] + max_exponent_bits = 7 + total_bits = 8 + non_sign_bits = total_bits - 1 + additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 + for i in range(max_exponent_bits): + fraction_items = int(2 ** (i + non_sign_bits - max_exponent_bits) + 1) + boundaries = mx.linspace(0.1, 1.0, fraction_items, dtype=mx.float32) + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + scale = 10 ** (-(max_exponent_bits - 1) + i) + data.extend((scale * means).tolist()) + data.extend((-scale * means).tolist()) + if additional_items > 0: + boundaries = mx.linspace(0.1, 1.0, additional_items + 1, dtype=mx.float32) + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + scale = 10 ** (-(max_exponent_bits - 1) + i) + data.extend((scale * means).tolist()) + data.extend((-scale * means).tolist()) + data.append(0.0) + data.append(1.0) + data.sort() + return mx.array(data, dtype=mx.float32) + + def _bnb_nested_absmax(absmax): + dynamic_codebook = _bnb_dynamic_codebook() + original_size = ( + absmax.numel() + if callable(getattr(absmax, "numel", None)) + else (absmax.size() if callable(getattr(absmax, "size", None)) else absmax.size) + ) + offset = mx.mean(absmax) + shifted = (absmax - offset).reshape((-1,)) + pad = (-original_size) % 256 + if pad: + shifted = mx.concatenate([shifted, mx.zeros((pad,), dtype=mx.float32)]) + scale_groups = shifted.reshape((-1, 256)) + scale_absmax = mx.max(mx.abs(scale_groups), axis=1, keepdims=True) + scale_denom = mx.where(scale_absmax > 0, scale_absmax, mx.ones_like(scale_absmax)) + scaled = scale_groups / scale_denom + scale_indices = mx.argmin(mx.abs(scaled[..., None] - dynamic_codebook), axis=-1) + nested = (dynamic_codebook[scale_indices] * scale_absmax).reshape((-1,))[:original_size] + return nested + offset + original_shape = weight.shape original_dtype = weight.dtype flat = weight.astype(mx.float32).reshape((-1,)) @@ -1596,6 +1641,7 @@ def _nf4_dense_dequantize_weight(weight, group_size=64): denom = mx.where(absmax > 0, absmax, mx.ones_like(absmax)) scaled = groups / denom indices = mx.argmin(mx.abs(scaled[..., None] - codebook), axis=-1) + absmax = _bnb_nested_absmax(absmax.reshape((-1,))).reshape((-1, 1)) dequantized = (codebook[indices] * absmax).reshape((-1,))[:original_size] return dequantized.reshape(original_shape).astype(original_dtype) From af3d68dd3974d26c269fba0f0f9d8fe5a5779dec Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 16:59:14 -0500 Subject: [PATCH 36/48] Match CUDA VLM resize-min behavior in MLX --- tests/test_mlx_vlm_label_masks.py | 22 ++++++++++++++++++++++ unsloth_zoo/mlx/utils.py | 13 ++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/test_mlx_vlm_label_masks.py b/tests/test_mlx_vlm_label_masks.py index c23d871aa..914421b8c 100644 --- a/tests/test_mlx_vlm_label_masks.py +++ b/tests/test_mlx_vlm_label_masks.py @@ -173,6 +173,28 @@ def call(self, text, images=None, **_kwargs): assert processor.seen_images == expected +def test_vlm_resize_int_does_not_upscale_small_images(): + from PIL import Image + + from unsloth_zoo.mlx.utils import _resize_vlm_images + + image = Image.new("RGB", (512, 512)) + resized = _resize_vlm_images([image], 896) + + assert resized[0].size == (512, 512) + + +def test_vlm_resize_int_downscales_large_images_like_cuda_collator(): + from PIL import Image + + from unsloth_zoo.mlx.utils import _resize_vlm_images + + image = Image.new("RGB", (1024, 512)) + resized = _resize_vlm_images([image], 512) + + assert resized[0].size == (512, 256) + + def test_vlm_processor_inputs_retries_duplicate_add_special_tokens(): from unsloth_zoo.mlx.utils import _processor_vlm_inputs diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 9ae00a71a..4e3440859 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -2314,7 +2314,18 @@ def _resize_vlm_images(images, image_size): resized = [] for image in images: if isinstance(image, Image.Image): - resized.append(image.convert("RGB").resize(target, Image.Resampling.LANCZOS)) + image = image.convert("RGB") + if isinstance(image_size, int): + # Match UnslothVisionDataCollator resize="min": shrink large + # images to the model limit, but let processors handle upscaling. + if image.size[0] > image_size: + width, height = image.size + new_width = (width * image_size + width // 2) // width + new_height = (height * image_size + width // 2) // width + image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) + else: + image = image.resize(target, Image.Resampling.LANCZOS) + resized.append(image) else: resized.append(image) return resized From bf40c719bf33499d3bd978a012e277923f0077b7 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 17:02:17 -0500 Subject: [PATCH 37/48] Match Gemma3 vision post norm epsilon --- tests/test_mlx_trainer_internals.py | 20 +++++++++++++++ unsloth_zoo/mlx/loader.py | 38 +++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/tests/test_mlx_trainer_internals.py b/tests/test_mlx_trainer_internals.py index 678cb52c0..fa672ea1b 100644 --- a/tests/test_mlx_trainer_internals.py +++ b/tests/test_mlx_trainer_internals.py @@ -365,6 +365,26 @@ def update(self, parameters): ) +def test_mlx_loader_fixes_gemma3_vision_post_layernorm_eps(): + from types import SimpleNamespace + + from unsloth_zoo.mlx.loader import _fix_gemma3_vision_post_layernorm_eps + + post_layernorm = SimpleNamespace(eps=1e-5) + model = SimpleNamespace( + config=SimpleNamespace( + vision_config=SimpleNamespace(layer_norm_eps=1e-6), + ), + vision_tower=SimpleNamespace( + vision_model=SimpleNamespace(post_layernorm=post_layernorm), + ), + ) + + assert _fix_gemma3_vision_post_layernorm_eps(model) is True + assert post_layernorm.eps == 1e-6 + assert model._unsloth_gemma3_vision_post_layernorm_eps == 1e-6 + + def test_qwen3_vl_vision_rotary_uses_transformers_fp32_math(): import inspect import unsloth_zoo.mlx.compile as mc diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index 38be864a9..a392b7d8a 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -637,6 +637,43 @@ def patched_attn_call(self, x, mask=None, cache=None, position_ids=None): print("Unsloth: Fixed Qwen3.5 attention for training (cache=None).") +def _fix_gemma3_vision_post_layernorm_eps(model): + """Match HF Gemma3/SigLIP final vision LayerNorm epsilon. + + mlx-vlm constructs ``post_layernorm`` with MLX's default eps=1e-5, while + the checkpoint config and Transformers path use vision_config.layer_norm_eps + (1e-6 for Gemma3). The mismatch only appears after the full vision tower, + so it is easy to misdiagnose as attention drift. + """ + + vision_tower = getattr(model, "vision_tower", None) + vision_model = getattr(vision_tower, "vision_model", None) + post_layernorm = getattr(vision_model, "post_layernorm", None) + if post_layernorm is None or not hasattr(post_layernorm, "eps"): + return False + + config = getattr(model, "config", None) + vision_config = getattr(config, "vision_config", None) + if vision_config is None and isinstance(config, dict): + vision_config = config.get("vision_config") + + eps = None + if isinstance(vision_config, dict): + eps = vision_config.get("layer_norm_eps") + elif vision_config is not None: + eps = getattr(vision_config, "layer_norm_eps", None) + if eps is None: + return False + + eps = float(eps) + if float(getattr(post_layernorm, "eps")) == eps: + return False + + post_layernorm.eps = eps + model._unsloth_gemma3_vision_post_layernorm_eps = eps + return True + + def _safe_getsource(obj) -> str: try: return inspect.getsource(obj) @@ -2840,6 +2877,7 @@ def from_pretrained( model._is_vlm_model = True model._processor = processor _fix_gemma4_kv_sharing(model) + _fix_gemma3_vision_post_layernorm_eps(model) model._config = getattr(model, "_config", config_data) model._hf_repo = model_name From 104e41dff46bbc51a6d597be7539de026099b7ca Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 17:10:15 -0500 Subject: [PATCH 38/48] Run Gemma3 vision SDPA in fp32 on MLX --- tests/test_mlx_trainer_internals.py | 15 ++++++++ unsloth_zoo/mlx/loader.py | 54 +++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/tests/test_mlx_trainer_internals.py b/tests/test_mlx_trainer_internals.py index fa672ea1b..056cc28ea 100644 --- a/tests/test_mlx_trainer_internals.py +++ b/tests/test_mlx_trainer_internals.py @@ -385,6 +385,21 @@ def test_mlx_loader_fixes_gemma3_vision_post_layernorm_eps(): assert model._unsloth_gemma3_vision_post_layernorm_eps == 1e-6 +def test_mlx_loader_patches_gemma3_vision_attention_fp32_sdpa(): + import inspect + + import unsloth_zoo.mlx.loader as loader + from unsloth_zoo.mlx.loader import _fix_gemma3_vision_attention_fp32_sdpa + + patched = _fix_gemma3_vision_attention_fp32_sdpa() + assert patched in {True, False} + + source = inspect.getsource(loader._fix_gemma3_vision_attention_fp32_sdpa) + assert "scaled_dot_product_attention" in source + assert "astype(mx.float32)" in source + assert "output.astype(orig_dtype)" in source + + def test_qwen3_vl_vision_rotary_uses_transformers_fp32_math(): import inspect import unsloth_zoo.mlx.compile as mc diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index a392b7d8a..111b1bb65 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -674,6 +674,59 @@ def _fix_gemma3_vision_post_layernorm_eps(model): return True +def _fix_gemma3_vision_attention_fp32_sdpa(model=None): + """Run Gemma3 vision SDPA in fp32, then cast back before the output proj.""" + + try: + import mlx.core as mx + vision_module = importlib.import_module("mlx_vlm.models.gemma3.vision") + except Exception: + return False + + attention_cls = getattr(vision_module, "Attention", None) + if attention_cls is None: + return False + if getattr(attention_cls, "_unsloth_fp32_sdpa_patched", False): + return False + + def patched_attention_call(self, x, mask=None): + queries = self.q_proj(x) + keys = self.k_proj(x) + values = self.v_proj(x) + orig_dtype = queries.dtype + + num_heads = self.num_heads + batch_size, query_length, hidden_size = queries.shape + _, key_length, _ = keys.shape + queries = queries.reshape( + batch_size, query_length, num_heads, -1, + ).transpose(0, 2, 1, 3).astype(mx.float32) + keys = keys.reshape( + batch_size, key_length, num_heads, -1, + ).transpose(0, 2, 1, 3).astype(mx.float32) + values = values.reshape( + batch_size, key_length, num_heads, -1, + ).transpose(0, 2, 1, 3).astype(mx.float32) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask, + ) + output = output.astype(orig_dtype) + output = output.transpose(0, 2, 1, 3).reshape( + batch_size, query_length, hidden_size, + ) + return self.out_proj(output) + + try: + attention_cls.__call__ = patched_attention_call + attention_cls._unsloth_fp32_sdpa_patched = True + except Exception: + return False + if model is not None: + model._unsloth_gemma3_vision_attention_fp32_sdpa = True + return True + + def _safe_getsource(obj) -> str: try: return inspect.getsource(obj) @@ -2878,6 +2931,7 @@ def from_pretrained( model._processor = processor _fix_gemma4_kv_sharing(model) _fix_gemma3_vision_post_layernorm_eps(model) + _fix_gemma3_vision_attention_fp32_sdpa(model) model._config = getattr(model, "_config", config_data) model._hf_repo = model_name From f6d5f6228394e5e9ae3372f8f2310a82723812f0 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 17:27:19 -0500 Subject: [PATCH 39/48] Match Gemma3 image feature scaling on MLX --- tests/test_mlx_trainer_internals.py | 37 +++++++++++++++ unsloth_zoo/mlx/loader.py | 74 +++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/tests/test_mlx_trainer_internals.py b/tests/test_mlx_trainer_internals.py index 056cc28ea..126ed4a79 100644 --- a/tests/test_mlx_trainer_internals.py +++ b/tests/test_mlx_trainer_internals.py @@ -400,6 +400,43 @@ def test_mlx_loader_patches_gemma3_vision_attention_fp32_sdpa(): assert "output.astype(orig_dtype)" in source +def test_mlx_loader_patches_gemma3_image_feature_scale(): + import inspect + + import mlx.core as mx + import unsloth_zoo.mlx.loader as loader + from unsloth_zoo.mlx.loader import _fix_gemma3_multimodal_image_feature_scale + + patched = _fix_gemma3_multimodal_image_feature_scale() + assert patched in {True, False} + + source = inspect.getsource(loader._fix_gemma3_multimodal_image_feature_scale) + assert "embed_dim = image_features.shape[-1]" in source + assert "image_features / (embed_dim**0.5)" in source + assert "del hidden_size" in source + + if patched: + from mlx_vlm.models.gemma3.gemma3 import Model + + image_token_id = 99 + input_ids = mx.array([[1, image_token_id, image_token_id]]) + inputs_embeds = mx.ones((1, 3, 4)) + image_features = mx.ones((1, 2, 4)) + attention_mask = mx.ones((1, 3)) + + embeds, _ = Model.prepare_inputs_for_multimodal( + 9, + 0, + image_token_id, + image_features, + inputs_embeds, + input_ids, + attention_mask, + ) + + assert mx.allclose(embeds[0, 1:], mx.full((2, 4), 0.5)) + + def test_qwen3_vl_vision_rotary_uses_transformers_fp32_math(): import inspect import unsloth_zoo.mlx.compile as mc diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index 111b1bb65..653c7a548 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -727,6 +727,79 @@ def patched_attention_call(self, x, mask=None): return True +def _fix_gemma3_multimodal_image_feature_scale(model=None): + """Use text embedding width when compensating Gemma3 image feature scaling.""" + + try: + import mlx.core as mx + gemma3_module = importlib.import_module("mlx_vlm.models.gemma3.gemma3") + except Exception: + return False + + model_cls = getattr(gemma3_module, "Model", None) + masked_scatter = getattr(gemma3_module, "masked_scatter", None) + if model_cls is None or masked_scatter is None: + return False + if getattr(model_cls, "_unsloth_image_feature_scale_patched", False): + if model is not None: + model._unsloth_gemma3_image_feature_scale = "text_embed_dim" + return True + + def prepare_inputs_for_multimodal( + hidden_size, + pad_token_id, + image_token_index, + image_features, + inputs_embeds, + input_ids, + attention_mask, + ): + del hidden_size + embed_dim = image_features.shape[-1] + batch_size, sequence_length = input_ids.shape + # Gemma3's language model scales all inputs_embeds by sqrt(text hidden + # size). Compensate image features with the actual embedding width, not + # the top-level multimodal config hidden_size. + scaled_image_features = image_features / (embed_dim**0.5) + final_embedding = mx.zeros((batch_size, sequence_length, embed_dim)) + + pad_token_id = pad_token_id if pad_token_id is not None else 0 + text_mask = (input_ids != image_token_index) & (input_ids != pad_token_id) + image_mask = input_ids == image_token_index + pad_mask = input_ids == pad_token_id + + text_mask_expanded = mx.repeat(mx.expand_dims(text_mask, -1), embed_dim, axis=-1) + pad_mask_expanded = mx.repeat(mx.expand_dims(pad_mask, -1), embed_dim, axis=-1) + image_mask_expanded = mx.repeat(mx.expand_dims(image_mask, -1), embed_dim, axis=-1) + + final_embedding = mx.where(text_mask_expanded, inputs_embeds, final_embedding) + final_embedding = mx.where( + pad_mask_expanded, mx.zeros_like(final_embedding), final_embedding, + ) + final_embedding = masked_scatter( + final_embedding, image_mask_expanded, scaled_image_features, + ) + + attention_mask_expanded_1 = mx.expand_dims(attention_mask, 1) + attention_mask_expanded_2 = mx.expand_dims(attention_mask, 2) + final_attention_mask_4d = mx.expand_dims( + attention_mask_expanded_1 * attention_mask_expanded_2, + 1, + ) + return final_embedding.astype(inputs_embeds.dtype), final_attention_mask_4d + + try: + model_cls.prepare_inputs_for_multimodal = staticmethod( + prepare_inputs_for_multimodal, + ) + model_cls._unsloth_image_feature_scale_patched = True + except Exception: + return False + if model is not None: + model._unsloth_gemma3_image_feature_scale = "text_embed_dim" + return True + + def _safe_getsource(obj) -> str: try: return inspect.getsource(obj) @@ -2932,6 +3005,7 @@ def from_pretrained( _fix_gemma4_kv_sharing(model) _fix_gemma3_vision_post_layernorm_eps(model) _fix_gemma3_vision_attention_fp32_sdpa(model) + _fix_gemma3_multimodal_image_feature_scale(model) model._config = getattr(model, "_config", config_data) model._hf_repo = model_name From e8e6c9be369aedd6a1e9cc0fb8e8ff5461b80279 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 17:34:29 -0500 Subject: [PATCH 40/48] Use Gemma3 image attention mask in MLX VLM CCE --- tests/test_mlx_vlm_label_masks.py | 46 +++++++++++++++++++++++++++++++ unsloth_zoo/mlx/utils.py | 10 +++++++ 2 files changed, 56 insertions(+) diff --git a/tests/test_mlx_vlm_label_masks.py b/tests/test_mlx_vlm_label_masks.py index 914421b8c..e90bc201f 100644 --- a/tests/test_mlx_vlm_label_masks.py +++ b/tests/test_mlx_vlm_label_masks.py @@ -353,3 +353,49 @@ def test_gemma_image_attention_mask_allows_bidirectional_image_block(): assert mask[1] == [True, True, True, False] assert mask[2] == [True, True, True, False] assert mask[3] == [True, True, True, True] + + +def test_gemma3_vlm_hidden_stack_uses_image_mask_and_embed_scale(): + from types import SimpleNamespace + + from unsloth_zoo.mlx.utils import _forward_text_hidden_states + + class RecordingLayer: + def __init__(self): + self.seen_h = None + self.seen_mask = None + + def __call__(self, h, mask, _cache): + self.seen_h = h + self.seen_mask = mask + return h + + class IdentityNorm: + weight = mx.ones((4,), dtype=mx.float32) + + def __call__(self, h): + return h + + layer = RecordingLayer() + stack = SimpleNamespace( + config=SimpleNamespace(model_type="gemma3_text", hidden_size=4), + embed_tokens=object(), + layers=[layer], + norm=IdentityNorm(), + sliding_window_pattern=1, + window_size=2, + ) + model = SimpleNamespace(language_model=SimpleNamespace(model=stack)) + embeds = mx.ones((1, 4, 4), dtype=mx.float32) + token_type_ids = mx.array([[0, 1, 1, 0]], dtype=mx.int32) + + out = _forward_text_hidden_states( + model, + mx.array([[1, 2, 3, 4]], dtype=mx.int32), + inputs_embeds=embeds, + token_type_ids=token_type_ids, + ) + + assert mx.allclose(out, mx.full((1, 4, 4), 2.0)) + assert mx.allclose(layer.seen_h, mx.full((1, 4, 4), 2.0)) + assert layer.seen_mask[0, 0].tolist()[1] == [True, True, True, False] diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index 4e3440859..e6bd60307 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -198,6 +198,7 @@ def _run_hidden_stack(stack, inputs, inputs_embeds=None, **kwargs): """Execute a language stack up to pre-lm_head hidden states.""" from mlx_vlm.models.base import create_attention_mask + config = getattr(stack, "config", None) norm_weight = getattr(getattr(stack, "norm", None), "weight", None) if inputs_embeds is None: h = stack.embed_tokens(inputs) @@ -205,6 +206,8 @@ def _run_hidden_stack(stack, inputs, inputs_embeds=None, **kwargs): h = inputs_embeds.astype(norm_weight.dtype) else: h = inputs_embeds + if inputs_embeds is not None and _config_get(config, "model_type") == "gemma3_text": + h *= mx.array(_config_get(config, "hidden_size")**0.5, mx.bfloat16).astype(h.dtype) cache = kwargs.get("cache") if cache is None: @@ -257,6 +260,13 @@ def _forward_text_hidden_states(model, inputs, inputs_embeds=None, **kwargs): tm = _get_text_model(model) backbone = getattr(tm, "model", None) if backbone is not None: + if ( + inputs_embeds is not None + and "token_type_ids" in kwargs + and _config_get(getattr(backbone, "config", None), "model_type") == "gemma3_text" + and _has_hidden_stack(backbone) + ): + return _run_hidden_stack(backbone, inputs, inputs_embeds=inputs_embeds, **kwargs) if getattr(backbone, "lm_head", None) is not None and _has_hidden_stack(backbone): return _run_hidden_stack(backbone, inputs, inputs_embeds=inputs_embeds, **kwargs) embed_kwarg = _get_backbone_embed_kwarg(backbone) From 3e2ee986c34c290188a0fef1bd9cdfb2f7de0391 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 17:53:22 -0500 Subject: [PATCH 41/48] Clip MLX global grad norms in fp32 --- tests/test_mlx_trainer_internals.py | 12 ++++++++++++ unsloth_zoo/mlx/trainer.py | 29 ++++++++++++++++++++++++++--- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/tests/test_mlx_trainer_internals.py b/tests/test_mlx_trainer_internals.py index 126ed4a79..6a8040052 100644 --- a/tests/test_mlx_trainer_internals.py +++ b/tests/test_mlx_trainer_internals.py @@ -171,6 +171,18 @@ def should_restore_original_dtype(name): assert not should_restore_original_dtype("vision.blocks.0.norm1.weight") +def test_global_norm_clip_reduces_in_float32(): + import inspect + + from unsloth_zoo.mlx.trainer import _clip_grad_norm_fp32 + + source = inspect.getsource(_clip_grad_norm_fp32) + + assert "g.astype(mx.float32)" in source + assert "scale.astype(g.dtype)" in source + assert "tree_reduce" in source + + @pytest.mark.parametrize( ("scheduler", "warmup"), [ diff --git a/unsloth_zoo/mlx/trainer.py b/unsloth_zoo/mlx/trainer.py index fb4281f75..6693c11ce 100644 --- a/unsloth_zoo/mlx/trainer.py +++ b/unsloth_zoo/mlx/trainer.py @@ -46,7 +46,7 @@ import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim -from mlx.utils import tree_flatten, tree_map, tree_unflatten +from mlx.utils import tree_flatten, tree_map, tree_reduce, tree_unflatten _PAD_MULTIPLE = 32 SUPPORTED_MLX_OPTIMIZERS = ("adafactor", "adamw", "adam", "sgd", "muon", "lion") @@ -334,6 +334,29 @@ def _clip_leaf_norm(g): return tree_map(_clip_leaf_norm, grad) +def _clip_grad_norm_fp32(grad, max_norm): + """Global norm clipping with a float32 norm reduction. + + ``mlx.optimizers.clip_grad_norm`` reduces each leaf in its storage dtype. + For bf16/fp16 VLMs, that can move the global scale away from PyTorch/HF, + which computes the clipping norm in fp32. Keep clipped leaves in their + original dtype, but compute the single global scale in fp32. + """ + norm_squared = tree_reduce( + lambda acc, g: acc + mx.sum(mx.square(g.astype(mx.float32))), + grad, + mx.array(0.0, dtype=mx.float32), + ) + total_norm = mx.sqrt(norm_squared) + scale = mx.minimum( + mx.array(max_norm, dtype=mx.float32) / ( + total_norm + mx.array(1e-6, dtype=mx.float32) + ), + mx.array(1.0, dtype=mx.float32), + ) + return tree_map(lambda g: g * scale.astype(g.dtype), grad), total_norm + + @dataclass class MLXTrainingConfig: """Training configuration mirroring SFTConfig / TrainingArguments field names.""" @@ -1257,7 +1280,7 @@ def _apply_update(grad, toks_f): final_items.append((name, scaled)) final_grad = tree_unflatten(final_items) if max_grad_norm > 0: - final_grad, grad_norm = optim.clip_grad_norm( + final_grad, grad_norm = _clip_grad_norm_fp32( final_grad, max_norm=max_grad_norm ) if _clip_grad_value: @@ -1284,7 +1307,7 @@ def _apply_update_direct(grad): """ grad_norm = None if max_grad_norm > 0: - grad, grad_norm = optim.clip_grad_norm(grad, max_norm=max_grad_norm) + grad, grad_norm = _clip_grad_norm_fp32(grad, max_norm=max_grad_norm) if _clip_grad_value: grad = _clip_grad_by_value(grad, max_grad_value) if _clip_grad_leaf_norm: From 80b62cfd4ea78c1460e46957fb3d11409f9836fb Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 18:32:58 -0500 Subject: [PATCH 42/48] Match Gemma3 vision fp32 norm and activation math --- tests/test_mlx_trainer_internals.py | 44 ++++++++++ unsloth_zoo/mlx/loader.py | 131 ++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+) diff --git a/tests/test_mlx_trainer_internals.py b/tests/test_mlx_trainer_internals.py index 6a8040052..df029f1ef 100644 --- a/tests/test_mlx_trainer_internals.py +++ b/tests/test_mlx_trainer_internals.py @@ -412,6 +412,50 @@ def test_mlx_loader_patches_gemma3_vision_attention_fp32_sdpa(): assert "output.astype(orig_dtype)" in source +def test_mlx_loader_patches_gemma3_vision_mlp_fp32_activation(): + import inspect + + import unsloth_zoo.mlx.loader as loader + from unsloth_zoo.mlx.loader import _fix_gemma3_vision_mlp_fp32_activation + + patched = _fix_gemma3_vision_mlp_fp32_activation() + assert patched in {True, False} + + source = inspect.getsource(loader._fix_gemma3_vision_mlp_fp32_activation) + assert "activation_fn(x.astype(mx.float32)).astype(orig_dtype)" in source + assert "_unsloth_fp32_activation_patched" in source + + +def test_mlx_loader_patches_gemma3_vision_encoder_fp32_layernorm(): + import inspect + + import unsloth_zoo.mlx.loader as loader + from unsloth_zoo.mlx.loader import _fix_gemma3_vision_encoder_fp32_layernorm + + patched = _fix_gemma3_vision_encoder_fp32_layernorm() + assert patched in {True, False} + + source = inspect.getsource(loader._fix_gemma3_vision_encoder_fp32_layernorm) + assert "x.astype(mx.float32)" in source + assert "return y.astype(orig_dtype)" in source + assert "_unsloth_fp32_layernorm_patched" in source + + +def test_mlx_loader_patches_gemma3_vision_post_layernorm_fp32(): + import inspect + + import unsloth_zoo.mlx.loader as loader + from unsloth_zoo.mlx.loader import _fix_gemma3_vision_post_layernorm_fp32 + + patched = _fix_gemma3_vision_post_layernorm_fp32() + assert patched in {True, False} + + source = inspect.getsource(loader._fix_gemma3_vision_post_layernorm_fp32) + assert "pooler_output = torch_like_layer_norm" in source + assert "return y.astype(orig_dtype)" in source + assert "_unsloth_fp32_post_layernorm_patched" in source + + def test_mlx_loader_patches_gemma3_image_feature_scale(): import inspect diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index 653c7a548..11de435c7 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -727,6 +727,134 @@ def patched_attention_call(self, x, mask=None): return True +def _fix_gemma3_vision_mlp_fp32_activation(model=None): + """Match CUDA SigLIP GELU: compute activation in fp32, then cast back.""" + + try: + import mlx.core as mx + vision_module = importlib.import_module("mlx_vlm.models.gemma3.vision") + except Exception: + return False + + mlp_cls = getattr(vision_module, "MLP", None) + if mlp_cls is None: + return False + if getattr(mlp_cls, "_unsloth_fp32_activation_patched", False): + if model is not None: + model._unsloth_gemma3_vision_mlp_fp32_activation = True + return True + + def patched_mlp_call(self, x): + x = self.fc1(x) + orig_dtype = x.dtype + x = self.activation_fn(x.astype(mx.float32)).astype(orig_dtype) + x = self.fc2(x) + return x + + try: + mlp_cls.__call__ = patched_mlp_call + mlp_cls._unsloth_fp32_activation_patched = True + except Exception: + return False + if model is not None: + model._unsloth_gemma3_vision_mlp_fp32_activation = True + return True + + +def _fix_gemma3_vision_encoder_fp32_layernorm(model=None): + """Match CUDA SigLIP LayerNorm math in fp32 while preserving bf16 activations.""" + + try: + import mlx.core as mx + vision_module = importlib.import_module("mlx_vlm.models.gemma3.vision") + except Exception: + return False + + encoder_layer_cls = getattr(vision_module, "EncoderLayer", None) + if encoder_layer_cls is None: + return False + if getattr(encoder_layer_cls, "_unsloth_fp32_layernorm_patched", False): + if model is not None: + model._unsloth_gemma3_vision_encoder_fp32_layernorm = True + return True + + def torch_like_layer_norm(norm, x): + orig_dtype = x.dtype + x_f = x.astype(mx.float32) + mean = mx.mean(x_f, axis=-1, keepdims=True) + centered = x_f - mean + var = mx.mean(centered * centered, axis=-1, keepdims=True) + y = centered * mx.rsqrt(var + norm.eps) + if "weight" in norm: + y = y * norm.weight.astype(mx.float32) + if "bias" in norm: + y = y + norm.bias.astype(mx.float32) + return y.astype(orig_dtype) + + def patched_encoder_layer_call(self, x, mask=None): + r = self.self_attn(torch_like_layer_norm(self.layer_norm1, x), mask) + h = x + r + r = self.mlp(torch_like_layer_norm(self.layer_norm2, h)) + return h + r + + try: + encoder_layer_cls.__call__ = patched_encoder_layer_call + encoder_layer_cls._unsloth_fp32_layernorm_patched = True + except Exception: + return False + if model is not None: + model._unsloth_gemma3_vision_encoder_fp32_layernorm = True + return True + + +def _fix_gemma3_vision_post_layernorm_fp32(model=None): + """Run Gemma3 final SigLIP vision LayerNorm in fp32, then cast back.""" + + try: + import mlx.core as mx + vision_module = importlib.import_module("mlx_vlm.models.gemma3.vision") + except Exception: + return False + + siglip_cls = getattr(vision_module, "SigLipVisionModel", None) + if siglip_cls is None: + return False + if getattr(siglip_cls, "_unsloth_fp32_post_layernorm_patched", False): + if model is not None: + model._unsloth_gemma3_vision_post_layernorm_fp32 = True + return True + + def torch_like_layer_norm(norm, x): + orig_dtype = x.dtype + x_f = x.astype(mx.float32) + mean = mx.mean(x_f, axis=-1, keepdims=True) + centered = x_f - mean + var = mx.mean(centered * centered, axis=-1, keepdims=True) + y = centered * mx.rsqrt(var + norm.eps) + if "weight" in norm: + y = y * norm.weight.astype(mx.float32) + if "bias" in norm: + y = y + norm.bias.astype(mx.float32) + return y.astype(orig_dtype) + + def patched_siglip_call(self, x, output_hidden_states=None): + x = self.embeddings(x) + encoder_outputs = self.encoder( + x=x, output_hidden_states=output_hidden_states, mask=None, + ) + pooler_output = torch_like_layer_norm(self.post_layernorm, encoder_outputs[0]) + return pooler_output, x, encoder_outputs[-1] + + try: + siglip_cls.__call__ = patched_siglip_call + siglip_cls._unsloth_fp32_post_layernorm_patched = True + except Exception: + return False + if model is not None: + model._unsloth_gemma3_vision_post_layernorm_fp32 = True + return True + + def _fix_gemma3_multimodal_image_feature_scale(model=None): """Use text embedding width when compensating Gemma3 image feature scaling.""" @@ -3005,6 +3133,9 @@ def from_pretrained( _fix_gemma4_kv_sharing(model) _fix_gemma3_vision_post_layernorm_eps(model) _fix_gemma3_vision_attention_fp32_sdpa(model) + _fix_gemma3_vision_encoder_fp32_layernorm(model) + _fix_gemma3_vision_post_layernorm_fp32(model) + _fix_gemma3_vision_mlp_fp32_activation(model) _fix_gemma3_multimodal_image_feature_scale(model) model._config = getattr(model, "_config", config_data) From 6ba0a7f30af3fb3f1640178162a2e5409819d82d Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 18:41:41 -0500 Subject: [PATCH 43/48] Disable Gemma3 MLX training compile pending parity --- tests/test_mlx_trainer_internals.py | 6 ++++++ unsloth_zoo/mlx/compile.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_mlx_trainer_internals.py b/tests/test_mlx_trainer_internals.py index df029f1ef..e97f03558 100644 --- a/tests/test_mlx_trainer_internals.py +++ b/tests/test_mlx_trainer_internals.py @@ -514,6 +514,12 @@ def test_qwen3_vl_training_compile_verified(): assert "qwen3_vl_moe" in mc._VERIFIED_TRAINING_ARCHES +def test_gemma3_training_compile_not_verified_until_loss_parity(): + import unsloth_zoo.mlx.compile as mc + + assert "gemma3" not in mc._VERIFIED_TRAINING_ARCHES + + # --------------------------------------------------------------------------- # 2. compile module-level discovery functions return sensible defaults # on a host with no real MLX architectures. diff --git a/unsloth_zoo/mlx/compile.py b/unsloth_zoo/mlx/compile.py index 7986c8292..d4ae1aa0b 100644 --- a/unsloth_zoo/mlx/compile.py +++ b/unsloth_zoo/mlx/compile.py @@ -73,6 +73,9 @@ def set_qwen3_vision_norm_cast_output(enabled: bool) -> None: # - qwen2_5_vl: real end-to-end compiled training via train.py # - qwen3_vl / qwen3_5 / qwen3_5_moe / gemma4 / paligemma / moondream3: # compiled synthetic forward+backward +# Gemma3 stays patched but unqualified for training compile: a fixed-fixture +# Gemma3 VLM loss probe showed compiled loss differs from eager before any +# optimizer update. Re-promote only after real loss parity is verified. # SmolVLM has processor/template support, but real mlx-vlm training still hits # MLX primitive-less-array failures after a compiled call. Keep it patched but # unqualified until a real dataset compile run passes. @@ -80,7 +83,6 @@ def set_qwen3_vision_norm_cast_output(enabled: bool) -> None: "aya_vision", "deepseekocr", "deepseekocr_2", - "gemma3", "gemma3n", "gemma4", "glm_ocr", From e70e5754b05e66ea49ce91d401ce8c64a95fdbdf Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 18:49:59 -0500 Subject: [PATCH 44/48] Match Gemma3 text RMSNorm fp32 math --- tests/test_mlx_trainer_internals.py | 16 +++++++++++++ unsloth_zoo/mlx/loader.py | 36 +++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/tests/test_mlx_trainer_internals.py b/tests/test_mlx_trainer_internals.py index e97f03558..d3b228788 100644 --- a/tests/test_mlx_trainer_internals.py +++ b/tests/test_mlx_trainer_internals.py @@ -412,6 +412,22 @@ def test_mlx_loader_patches_gemma3_vision_attention_fp32_sdpa(): assert "output.astype(orig_dtype)" in source +def test_mlx_loader_patches_gemma3_text_rmsnorm_fp32(): + import inspect + + import unsloth_zoo.mlx.loader as loader + from unsloth_zoo.mlx.loader import _fix_gemma3_text_rmsnorm_fp32 + + patched = _fix_gemma3_text_rmsnorm_fp32() + assert patched in {True, False} + + source = inspect.getsource(loader._fix_gemma3_text_rmsnorm_fp32) + assert "x.astype(mx.float32)" in source + assert "mx.rsqrt(mx.mean(x_f * x_f" in source + assert "return y.astype(orig_dtype)" in source + assert "_unsloth_fp32_rmsnorm_patched" in source + + def test_mlx_loader_patches_gemma3_vision_mlp_fp32_activation(): import inspect diff --git a/unsloth_zoo/mlx/loader.py b/unsloth_zoo/mlx/loader.py index 11de435c7..baf3ef7f6 100644 --- a/unsloth_zoo/mlx/loader.py +++ b/unsloth_zoo/mlx/loader.py @@ -674,6 +674,41 @@ def _fix_gemma3_vision_post_layernorm_eps(model): return True +def _fix_gemma3_text_rmsnorm_fp32(model=None): + """Match HF Gemma3 text RMSNorm: fp32 math, then cast back to activation dtype.""" + + try: + import mlx.core as mx + language_module = importlib.import_module("mlx_vlm.models.gemma3.language") + except Exception: + return False + + rmsnorm_cls = getattr(language_module, "RMSNorm", None) + if rmsnorm_cls is None: + return False + if getattr(rmsnorm_cls, "_unsloth_fp32_rmsnorm_patched", False): + if model is not None: + model._unsloth_gemma3_text_rmsnorm_fp32 = True + return True + + def patched_rmsnorm_call(self, x): + orig_dtype = x.dtype + x_f = x.astype(mx.float32) + y = x_f * mx.rsqrt(mx.mean(x_f * x_f, axis=-1, keepdims=True) + self.eps) + if "weight" in self: + y = y * (1.0 + self.weight.astype(mx.float32)) + return y.astype(orig_dtype) + + try: + rmsnorm_cls.__call__ = patched_rmsnorm_call + rmsnorm_cls._unsloth_fp32_rmsnorm_patched = True + except Exception: + return False + if model is not None: + model._unsloth_gemma3_text_rmsnorm_fp32 = True + return True + + def _fix_gemma3_vision_attention_fp32_sdpa(model=None): """Run Gemma3 vision SDPA in fp32, then cast back before the output proj.""" @@ -3131,6 +3166,7 @@ def from_pretrained( model._is_vlm_model = True model._processor = processor _fix_gemma4_kv_sharing(model) + _fix_gemma3_text_rmsnorm_fp32(model) _fix_gemma3_vision_post_layernorm_eps(model) _fix_gemma3_vision_attention_fp32_sdpa(model) _fix_gemma3_vision_encoder_fp32_layernorm(model) From 31457a23f98a8cd2cd09cc05d9e9c51ffee7c163 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 19:04:27 -0500 Subject: [PATCH 45/48] Preserve VLM hidden stack activation dtype Rationale / guardrails for the local Gemma3 parity stack: This is the last local-only zoo commit before push, so this body documents the changes that should not be accidentally flipped back during review. Do not restore the broader Daniel position-id override. VLM CCE should prefer collator-built position_ids only when _unsloth_collated_position_ids is set, preserve position_ids explicitly returned by InputEmbeddingsFeatures, and otherwise fall back to model-stashed or sequential ids. The broad override moved Qwen/Gemma-style VLM runs away from CUDA collation semantics. Do not re-add global pad_token_id masking to the VLM loss. Padding is masked by labels/attention masks; globally ignoring pad ids also suppresses legitimate target ids for custom datasets. Image/video placeholder token ids are the only global ignore ids needed for VLM CCE. Do not mark Gemma3 training compile verified yet. Fixed-fixture Gemma3 showed compiled loss differing from eager before optimizer update, so best-effort must fall back to eager until real training parity is proven. Do not remove the Gemma3 MLX-vLM patches as cosmetic. The current patches fix concrete CUDA parity mismatches: SigLIP post-layernorm eps, vision SDPA fp32 math with cast-back, vision LayerNorm/GELU fp32 math with cast-back, text RMSNorm fp32 math with cast-back, image feature scaling by text embedding width, image-token attention masking in CCE, and preserving merged VLM inputs_embeds dtype instead of promoting activations to fp32 because norm weights are fp32. Do not switch MLX grad clipping back to bf16 reductions. Global grad norm clipping should reduce in fp32; bf16 reductions changed clipping behavior. Validation summary: focused MLX/Gemma3/VLM tests pass, and the remaining Gemma3 VLM delta was isolated to cumulative bf16/backend drift through the 27-layer SigLIP tower rather than labels, preprocessing, position ids, projector, final post-LN, block-0 attention backward, or weight mapping. --- tests/test_mlx_trainer_internals.py | 10 ++++++++++ unsloth_zoo/mlx/utils.py | 3 --- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/test_mlx_trainer_internals.py b/tests/test_mlx_trainer_internals.py index d3b228788..9502b89b3 100644 --- a/tests/test_mlx_trainer_internals.py +++ b/tests/test_mlx_trainer_internals.py @@ -428,6 +428,16 @@ def test_mlx_loader_patches_gemma3_text_rmsnorm_fp32(): assert "_unsloth_fp32_rmsnorm_patched" in source +def test_vlm_hidden_stack_preserves_inputs_embed_dtype(): + import inspect + + import unsloth_zoo.mlx.utils as utils + + source = inspect.getsource(utils._run_hidden_stack) + assert "h = inputs_embeds" in source + assert "inputs_embeds.astype(norm_weight.dtype)" not in source + + def test_mlx_loader_patches_gemma3_vision_mlp_fp32_activation(): import inspect diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index e6bd60307..ded8b9ecf 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -199,11 +199,8 @@ def _run_hidden_stack(stack, inputs, inputs_embeds=None, **kwargs): from mlx_vlm.models.base import create_attention_mask config = getattr(stack, "config", None) - norm_weight = getattr(getattr(stack, "norm", None), "weight", None) if inputs_embeds is None: h = stack.embed_tokens(inputs) - elif norm_weight is not None: - h = inputs_embeds.astype(norm_weight.dtype) else: h = inputs_embeds if inputs_embeds is not None and _config_get(config, "model_type") == "gemma3_text": From 9cc6885ddd1d1de31088ab3ee99b28228ce52a7d Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 19:57:21 -0500 Subject: [PATCH 46/48] Restore Gemma3 MLX training compile qualification Keep Gemma3 in the verified MLX training compile set. The observed eager-vs-compiled loss deltas are small enough that Gemma3 should continue using compile rather than falling back to eager by policy. Update the regression test to assert the intended compile qualification so this does not get accidentally demoted again. --- tests/test_mlx_trainer_internals.py | 4 ++-- unsloth_zoo/mlx/compile.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/test_mlx_trainer_internals.py b/tests/test_mlx_trainer_internals.py index 9502b89b3..d95479f01 100644 --- a/tests/test_mlx_trainer_internals.py +++ b/tests/test_mlx_trainer_internals.py @@ -540,10 +540,10 @@ def test_qwen3_vl_training_compile_verified(): assert "qwen3_vl_moe" in mc._VERIFIED_TRAINING_ARCHES -def test_gemma3_training_compile_not_verified_until_loss_parity(): +def test_gemma3_training_compile_verified(): import unsloth_zoo.mlx.compile as mc - assert "gemma3" not in mc._VERIFIED_TRAINING_ARCHES + assert "gemma3" in mc._VERIFIED_TRAINING_ARCHES # --------------------------------------------------------------------------- diff --git a/unsloth_zoo/mlx/compile.py b/unsloth_zoo/mlx/compile.py index d4ae1aa0b..592019822 100644 --- a/unsloth_zoo/mlx/compile.py +++ b/unsloth_zoo/mlx/compile.py @@ -71,11 +71,8 @@ def set_qwen3_vision_norm_cast_output(enabled: bool) -> None: # Architectures explicitly verified for mlx compile support. # Training verification currently covers: # - qwen2_5_vl: real end-to-end compiled training via train.py -# - qwen3_vl / qwen3_5 / qwen3_5_moe / gemma4 / paligemma / moondream3: +# - gemma3 / qwen3_vl / qwen3_5 / qwen3_5_moe / gemma4 / paligemma / moondream3: # compiled synthetic forward+backward -# Gemma3 stays patched but unqualified for training compile: a fixed-fixture -# Gemma3 VLM loss probe showed compiled loss differs from eager before any -# optimizer update. Re-promote only after real loss parity is verified. # SmolVLM has processor/template support, but real mlx-vlm training still hits # MLX primitive-less-array failures after a compiled call. Keep it patched but # unqualified until a real dataset compile run passes. @@ -83,6 +80,7 @@ def set_qwen3_vision_norm_cast_output(enabled: bool) -> None: "aya_vision", "deepseekocr", "deepseekocr_2", + "gemma3", "gemma3n", "gemma4", "glm_ocr", From 0c0e5674fca34c8bce468b94b259b63e332064e9 Mon Sep 17 00:00:00 2001 From: DoubleMathew Date: Tue, 26 May 2026 20:16:00 -0500 Subject: [PATCH 47/48] Handle quantized CCE layer modes --- tests/test_mlx_runtime_cce_compile.py | 33 +++++++++++++++++++++++++++ tests/test_mlx_trainer_internals.py | 11 +++++++++ unsloth_zoo/mlx/cce/runtime_cce.py | 4 ++-- unsloth_zoo/mlx/utils.py | 15 +++++++++--- 4 files changed, 58 insertions(+), 5 deletions(-) diff --git a/tests/test_mlx_runtime_cce_compile.py b/tests/test_mlx_runtime_cce_compile.py index 9168cfe0f..b5e7c8b5c 100644 --- a/tests/test_mlx_runtime_cce_compile.py +++ b/tests/test_mlx_runtime_cce_compile.py @@ -117,3 +117,36 @@ def loss_fn(hh): assert compiled_loss.item() == pytest.approx(eager_loss.item(), rel=1e-5) assert compiled_norm.item() == pytest.approx(eager_norm.item(), rel=1e-4) + + +def test_quantized_runtime_cce_rejects_missing_affine_biases(): + _skip_torch_shim() + import mlx.nn as nn + + from unsloth_zoo.mlx.cce import make_chunked_cross_entropy_loss + + linear = nn.Linear(32, 128, bias=False) + linear.weight = ( + mx.arange(128 * 32, dtype=mx.float32).reshape(128, 32) / 113.0 + ) - 1.0 + qlinear = nn.QuantizedLinear.from_linear(linear, group_size=32, bits=4) + runtime_cce, _ = make_chunked_cross_entropy_loss( + ignore_index=-100, + chunk_size=32, + quantized=True, + group_size=qlinear.group_size, + bits=qlinear.bits, + ) + hidden = (mx.arange(64 * 32, dtype=mx.float32).reshape(64, 32) / 97.0) - 1.0 + targets = (mx.arange(64, dtype=mx.int32) * 7) % 128 + ntoks = mx.maximum( + mx.sum((targets != -100).astype(mx.float32)), + mx.array(1.0, dtype=mx.float32), + ) + + def loss_fn(hh): + losses = runtime_cce(hh, qlinear.weight, qlinear.scales, None, targets) + return losses.astype(mx.float32).sum() / ntoks + + with pytest.raises(ValueError, match="Biases must be provided for affine"): + mx.eval(loss_fn(hidden)) diff --git a/tests/test_mlx_trainer_internals.py b/tests/test_mlx_trainer_internals.py index d95479f01..8926c4b7f 100644 --- a/tests/test_mlx_trainer_internals.py +++ b/tests/test_mlx_trainer_internals.py @@ -540,6 +540,17 @@ def test_qwen3_vl_training_compile_verified(): assert "qwen3_vl_moe" in mc._VERIFIED_TRAINING_ARCHES +def test_quantized_cce_uses_layer_mode_and_affine_bias_guard(): + import inspect + import unsloth_zoo.mlx.utils as mlx_utils + + source = inspect.getsource(mlx_utils.make_vlm_cce_loss_fn) + assert 'quant_mode = getattr(lm_layer, "mode", "affine")' in source + assert "mode=quant_mode" in source + assert 'if bi is None and quant_mode == "affine":' in source + assert "bi = mx.zeros_like(sc)" in source + + def test_gemma3_training_compile_verified(): import unsloth_zoo.mlx.compile as mc diff --git a/unsloth_zoo/mlx/cce/runtime_cce.py b/unsloth_zoo/mlx/cce/runtime_cce.py index 464051c97..d7d2f4783 100644 --- a/unsloth_zoo/mlx/cce/runtime_cce.py +++ b/unsloth_zoo/mlx/cce/runtime_cce.py @@ -695,7 +695,7 @@ def runtime_cce_loss_vjp(primals, cotangents, outputs): v_end = min(v_start + resolved_chunk_size, vocab_size) weight_chunk = weight_compute[v_start:v_end] scales_chunk = scales[v_start:v_end] - biases_chunk = biases[v_start:v_end] + biases_chunk = None if biases is None else biases[v_start:v_end] logits = _chunk_matmul( hidden_compute, @@ -757,7 +757,7 @@ def runtime_cce_loss_vjp(primals, cotangents, outputs): grad_hidden.astype(hidden.dtype), mx.zeros_like(weight), mx.zeros_like(scales), - mx.zeros_like(biases), + None if biases is None else mx.zeros_like(biases), mx.zeros_like(targets), ) diff --git a/unsloth_zoo/mlx/utils.py b/unsloth_zoo/mlx/utils.py index ded8b9ecf..cfaf8b70a 100644 --- a/unsloth_zoo/mlx/utils.py +++ b/unsloth_zoo/mlx/utils.py @@ -395,7 +395,11 @@ def _get_lm_weight_layer(model): ) group_size = getattr(lm_layer, "group_size", 64) bits = getattr(lm_layer, "bits", 4) - print(f"Unsloth: CCE using quantized matmul (group_size={group_size}, bits={bits})") + quant_mode = getattr(lm_layer, "mode", "affine") + print( + "Unsloth: CCE using quantized matmul " + f"(group_size={group_size}, bits={bits}, mode={quant_mode})" + ) _has_biases = hasattr(lm_layer, "biases") rt_cce = _get_runtime_cce( @@ -404,6 +408,7 @@ def _get_lm_weight_layer(model): quantized=True, group_size=group_size, bits=bits, + mode=quant_mode, ) def loss_fn(model, batch, lengths, labels=None): @@ -416,7 +421,9 @@ def loss_fn(model, batch, lengths, labels=None): layer = _get_lm_weight_layer(model) w = layer.weight sc = layer.scales - bi = layer.biases if _has_biases else mx.zeros_like(layer.scales) + bi = layer.biases if _has_biases else None + if bi is None and quant_mode == "affine": + bi = mx.zeros_like(sc) steps = mx.arange(1, targets.shape[1] + 1) length_mask = mx.logical_and(steps >= lengths[:, 0:1], steps < lengths[:, 1:]) if labels is None: @@ -1700,6 +1707,7 @@ def make_vlm_cce_loss_fn(model, assistant_token_id=0, ignore_token_ids=None): ) group_size = getattr(lm_layer, "group_size", 64) bits = getattr(lm_layer, "bits", 4) + quant_mode = getattr(lm_layer, "mode", "affine") rt_cce = _get_runtime_cce( ignore_index=-100, @@ -1707,6 +1715,7 @@ def make_vlm_cce_loss_fn(model, assistant_token_id=0, ignore_token_ids=None): quantized=True, group_size=group_size, bits=bits, + mode=quant_mode, ) def loss_fn(model, batch_dict): @@ -1717,7 +1726,7 @@ def loss_fn(model, batch_dict): w = lm_head.weight sc = lm_head.scales bi = getattr(lm_head, "biases", None) - if bi is None: + if bi is None and quant_mode == "affine": bi = mx.zeros_like(sc) # Quantized backward already returns zero weight/scales/biases # gradients (see runtime_cce.py VJP), so stop_gradient is From c25c86b8ae90aaff325db416ef23a32438c736a8 Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Wed, 27 May 2026 11:15:22 +0000 Subject: [PATCH 48/48] Rename grad-clip test to reflect three-mode scope `test_mlx_max_grad_value_none.py` now covers max_grad_leaf_norm and max_grad_norm too. Rename to test_mlx_grad_clip_resolution.py and update the docstring to list all three knobs. --- ...value_none.py => test_mlx_grad_clip_resolution.py} | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) rename tests/{test_mlx_max_grad_value_none.py => test_mlx_grad_clip_resolution.py} (93%) diff --git a/tests/test_mlx_max_grad_value_none.py b/tests/test_mlx_grad_clip_resolution.py similarity index 93% rename from tests/test_mlx_max_grad_value_none.py rename to tests/test_mlx_grad_clip_resolution.py index 9bf794cce..31bb6f9fa 100644 --- a/tests/test_mlx_max_grad_value_none.py +++ b/tests/test_mlx_grad_clip_resolution.py @@ -1,10 +1,9 @@ # Unsloth Zoo - Utilities for Unsloth -# Pin MLXTrainingConfig cheap clipping resolution: -# * max_grad_leaf_norm is the proportional per-leaf norm cap. -# * max_grad_value keeps historical elementwise clamp semantics. -# * None defaults to cheap proportional leaf-norm clipping at 1.0 unless -# max_grad_norm > 0 is passed. -# * explicit 0.0 disables that specific cheap clipping knob. +# Pin MLXTrainingConfig grad-clip resolution across all three knobs: +# max_grad_leaf_norm proportional per-leaf L2 cap (cheap, direction-preserving) +# max_grad_value elementwise clamp (historical contract; explicit positives win) +# max_grad_norm global L2 (HF parity; cross-tree reduction, pays memory) +# Default (all None) -> ("leaf_norm", 1.0); explicit 0.0 disables that knob. from __future__ import annotations