MLX Update Training#684
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces several enhancements and fixes for MLX training, focusing on VLM support and parity with HuggingFace's trainer behavior. Key updates include a manual AdamW weight decay implementation that filters out bias and normalization parameters, a diagnostic 'nf4_dense' quantization mode, and logic to maintain normalization parameters in float32. Additionally, it refines VLM collation, fixes a loss masking off-by-one error, and prevents automatic EOS appending in datasets. Feedback from the review identified a bug in the Qwen3-VL LayerNorm parameter check, precision and memory issues in the manual weight decay logic, and a performance regression in the compiler's logit handling.
| if "weight" in norm: | ||
| y = y * norm.weight.astype(mx.float32) | ||
| if "bias" in norm: | ||
| y = y + norm.bias.astype(mx.float32) |
There was a problem hiding this comment.
The check if "weight" in norm: is not a valid way to verify parameter existence on an mlx.nn.Module. This will likely evaluate to False or raise a TypeError, causing the LayerNorm calculation to skip applying the weight and bias tensors. This will result in incorrect numerical output for the Qwen3-VL vision blocks.
| if "weight" in norm: | |
| y = y * norm.weight.astype(mx.float32) | |
| if "bias" in norm: | |
| y = y + norm.bias.astype(mx.float32) | |
| if hasattr(norm, "weight"): | |
| y = y * norm.weight.astype(mx.float32) | |
| if hasattr(norm, "bias"): | |
| y = y + norm.bias.astype(mx.float32) |
| lr = optimizer.learning_rate.astype(flat_grad[name].dtype) | ||
| scale = mx.array(1.0, dtype=lr.dtype) - lr * mx.array(wd, dtype=lr.dtype) | ||
| decayed.append((name, parameter * scale)) |
There was a problem hiding this comment.
This manual weight decay implementation has two significant issues:
- Precision Underflow: Calculating the
scalein the parameter's native dtype (e.g.,float16orbfloat16) will cause the weight decay to be ignored. For typical values likelr=2e-4andwd=0.01, the termlr * wd(2e-6) is smaller than the machine epsilon forfloat16/bfloat16relative to 1.0, so1.0 - 2e-6rounds back to1.0. - Unintended Parameter Promotion: If
scaleis calculated infloat32(to fix the precision issue), the operationparameter * scalewill promote the model parameters tofloat32. Since these parameters (LoRA weights and norms) are explicitly excluded from the restoration logic in_restore_trainable_storage_dtypes, they will remain infloat32, doubling their memory footprint for the rest of the training session.
The calculation should be done in float32 and explicitly cast back to the original dtype.
| lr = optimizer.learning_rate.astype(flat_grad[name].dtype) | |
| scale = mx.array(1.0, dtype=lr.dtype) - lr * mx.array(wd, dtype=lr.dtype) | |
| decayed.append((name, parameter * scale)) | |
| lr = optimizer.learning_rate.astype(mx.float32) | |
| scale = mx.array(1.0, dtype=mx.float32) - lr * mx.array(wd, dtype=mx.float32) | |
| decayed.append((name, (parameter * scale).astype(parameter.dtype))) |
| logit_softcapping = None if (\\4) == () else (\\4), | ||
| ) | ||
| elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and NOT_RETURN_LOGITS: | ||
| elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: |
There was a problem hiding this comment.
Removing the NOT_RETURN_LOGITS check here, combined with the deletion of the specialized branch for returned logits, introduces a performance and memory regression. When UNSLOTH_RETURN_LOGITS=1 is active, logits are materialized early; forcing the code into the unsloth_fused_ce_loss branch causes a redundant matmul. Furthermore, ensure that logit scaling (multiply/divide) and softcapping transformations are applied to these materialized logits before they are passed to the loss function to maintain consistency with standard execution paths.
References
- When implementing code paths that return materialized logits (e.g., using UNSLOTH_RETURN_LOGITS=1), ensure that logit scaling and softcapping transformations are applied before the loss function.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a93449f518
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| "pad_token_id", | ||
| "image_token_id", | ||
| "video_token_id", | ||
| "audio_token_id", | ||
| ): | ||
| _append_unique_int(ids, getattr(tokenizer, attr, None)) |
There was a problem hiding this comment.
Exclude pad token ID from global VLM ignore list
Adding pad_token_id to _get_vlm_ignore_token_ids and then masking via _mask_label_token_ids masks that token everywhere, not only in padded positions. For tokenizers where pad_token_id == eos_token_id (common in Llama-derived setups), valid in-sequence EOS targets are turned into -100, so EOS supervision is silently removed across training/eval batches. Padding should be masked only through attention_mask == 0, not by globally ignoring the pad ID.
Useful? React with 👍 / 👎.
| logit_softcapping = None if (\\4) == () else (\\4), | ||
| ) | ||
| elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and NOT_RETURN_LOGITS: | ||
| elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None: |
There was a problem hiding this comment.
Keep return-logits mode out of fused CE loss branch
This condition now routes all labeled ForCausalLMLoss calls through unsloth_fused_ce_loss, including UNSLOTH_RETURN_LOGITS=1. In that mode the compiler replacement already materializes logits up front, so this branch computes a second LM-head projection inside fused CE, increasing memory/compute and risking OOM on large models. The previous NOT_RETURN_LOGITS guard avoided that duplicate matmul.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: dcd0a9001c
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if "weight" in norm: | ||
| y = y * norm.weight.astype(mx.float32) | ||
| if "bias" in norm: | ||
| y = y + norm.bias.astype(mx.float32) |
There was a problem hiding this comment.
Apply LayerNorm affine params in Qwen3 vision block patch
The new _qwen3_torch_like_layer_norm gate checks use membership ("weight" in norm, "bias" in norm) instead of attribute existence. For module-style LayerNorm objects, this test does not reliably indicate parameter presence, so the affine scale/bias can be skipped (or error), making every patched Qwen3-VL vision block run a non-affine normalization that diverges from pretrained/Transformers behavior and can degrade training and loss parity.
Useful? React with 👍 / 👎.
| if dataset_order == "torch_randperm": | ||
| indices = _torch_randperm_order(len(dataset), int(seed) + epoch) | ||
| elif dataset_order in (None, "default"): | ||
| np.random.seed(int(seed) + epoch) |
There was a problem hiding this comment.
Normalize nullable seeds before VLM epoch reshuffle
When create_vlm_batches rolls to a new epoch, it unconditionally computes int(seed) + epoch for both torch_randperm and default reshuffling. If callers pass seed=None (which is already supported in _torch_randperm_order and by NumPy seeding), multi-epoch materialization will raise TypeError on the first rollover and stop training/eval batch creation.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e16efc0f9c
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if dataset_order == "torch_randperm": | ||
| indices = _torch_randperm_order(len(dataset), int(seed) + epoch) |
There was a problem hiding this comment.
Normalize nullable seed in streaming torch_randperm path
When dataset_order="torch_randperm", the streaming iterator now does int(seed) + epoch before calling _torch_randperm_order. If callers pass seed=None (which _torch_randperm_order explicitly supports by defaulting to 3407), this raises TypeError on the first epoch and stops batch generation. Please normalize seed the same way as _torch_randperm_order (or pass it through directly) before adding epoch.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b0a83b52cf
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| continue | ||
| if not self._should_apply_weight_decay(name, parameter): | ||
| continue | ||
| lr = optimizer.learning_rate.astype(flat_grad[name].dtype) |
There was a problem hiding this comment.
Handle scalar learning rates in manual AdamW decay
_apply_manual_adamw_weight_decay assumes optimizer.learning_rate has an .astype(...) method, but in valid configurations it can be a Python float (notably the constant scheduler path without warmup, and the MLX simulation optimizer stubs). In that case this line raises AttributeError before any optimizer step, so AdamW training fails as soon as manual decay runs. Please normalize LR to an mx.array (or branch on scalar) before dtype-casting.
Useful? React with 👍 / 👎.
| if not tokenized: | ||
| return [] |
There was a problem hiding this comment.
Raise on empty ordered token stream instead of returning []
When create_ordered_batches drops all rows via if len(ids) >= 2 (for example very small max_seq_length or single-token rows), it returns an empty batch list. The new ordered-data path in MLXTrainer then indexes batches with batch_idx % len(batches), which crashes at runtime with division-by-zero instead of surfacing a data error. This should raise a clear ValueError here, matching the other dataset-prep guards.
Useful? React with 👍 / 👎.
| centered = x_f - mean | ||
| var = mx.mean(centered * centered, axis=-1, keepdims=True) | ||
| y = centered * mx.rsqrt(var + norm.eps) | ||
| if "weight" in norm: |
There was a problem hiding this comment.
If this is specific to qwen3 do we need the if check?
or even in qwen3 family some norms have weight/bias and some dont?
| flat = mx.concatenate([flat, mx.zeros((pad,), dtype=mx.float32)]) | ||
| groups = flat.reshape((-1, group_size)) | ||
| absmax = mx.max(mx.abs(groups), axis=1, keepdims=True) | ||
| denom = mx.maximum(absmax, mx.array(1e-12, dtype=mx.float32)) |
There was a problem hiding this comment.
I'm thinking we do this to avoid 0 division. But div by 1e-12 might cause the numbers to blow up?
For a similar thing what I did was set the scale to 1
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 964be34ec2
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if num_batches is None and seen >= len(tokenized): | ||
| break |
There was a problem hiding this comment.
Keep torch_randperm active across epoch-based text training
When dataset_order="torch_randperm" is used with epoch-based training (max_steps == 0), this early break materializes only one epoch of randomized order and returns. MLXTrainer then reuses that fixed batch list with batch_idx % len(batches) for subsequent epochs, so later epochs never get a fresh randperm order. This diverges from the expected per-epoch reshuffle behavior of a torch-style random sampler and can bias training/eval results for multi-epoch runs.
Useful? React with 👍 / 👎.
| if num_batches is None: | ||
| break |
There was a problem hiding this comment.
Rebuild VLM torch_randperm order for each epoch run
In pre-materialized VLM batching, this branch exits as soon as one dataset pass completes whenever num_batches is None (the normal epoch-based path). With dataset_order="torch_randperm", that means only the first epoch gets randomized; later epochs in MLXTrainer reuse the same stored batches instead of generating a new permutation. This breaks sampler parity for multi-epoch VLM training and can skew convergence behavior.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ca08652226
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if order_pos >= len(order): | ||
| epoch += 1 | ||
| order = make_order(epoch) | ||
| order_pos = 0 | ||
| batch_items.append(tokenized[order[order_pos]]) |
There was a problem hiding this comment.
Stop batches from spanning torch_randperm epoch boundaries
create_ordered_batches wraps to a new permutation inside the same batch (order_pos >= len(order)), so when num_epochs is used and len(dataset) % batch_size != 0, one batch can contain samples from two different epochs. This changes both sample order and batch count (e.g., 5 samples, batch size 4, 2 epochs yields 3 batches instead of the expected 4), so MLXTrainer under-trains in epoch mode and no longer mirrors CUDA/DataLoader epoch semantics for dataset_order="torch_randperm".
Useful? React with 👍 / 👎.
|
Ok pretty sure resolved review feedback |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7e0bee546f
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if any("norm" in part for part in parts): | ||
| return False |
There was a problem hiding this comment.
Exclude ln_ normalization weights from AdamW decay*
The new HF-style decay filter only checks for path segments containing "norm", so normalization parameters named like ln_1.weight or ln_f.weight still get decayed. This repo already treats ln_* as norm parameters in _ensure_lora_frozen (see _NORM_FRAGMENTS), so models using those names will now receive unintended weight decay on norm scales, diverging from the stated “skip bias and norms” behavior and changing optimization dynamics.
Useful? React with 👍 / 👎.
| parts = str(path).lower().split(".") | ||
| return any("norm" in part for part in parts[:-1]) |
There was a problem hiding this comment.
Include ln_ params in fp32 norm-parameter preservation*
_keep_norm_parameters_float32 claims to keep normalization parameters in fp32, but _is_norm_parameter_path only matches components containing "norm". Any normalization layer named with ln_* (which this codebase already recognizes as norm-like elsewhere) is skipped and left in lower precision, undermining the stabilization this pass is meant to provide for FT/LoRA/QLoRA training.
Useful? React with 👍 / 👎.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
Rationale / guardrails for the local Gemma3 parity stack: This is the last local-only zoo commit before push, so this body documents the changes that should not be accidentally flipped back during review. Do not restore the broader Daniel position-id override. VLM CCE should prefer collator-built position_ids only when _unsloth_collated_position_ids is set, preserve position_ids explicitly returned by InputEmbeddingsFeatures, and otherwise fall back to model-stashed or sequential ids. The broad override moved Qwen/Gemma-style VLM runs away from CUDA collation semantics. Do not re-add global pad_token_id masking to the VLM loss. Padding is masked by labels/attention masks; globally ignoring pad ids also suppresses legitimate target ids for custom datasets. Image/video placeholder token ids are the only global ignore ids needed for VLM CCE. Do not mark Gemma3 training compile verified yet. Fixed-fixture Gemma3 showed compiled loss differing from eager before optimizer update, so best-effort must fall back to eager until real training parity is proven. Do not remove the Gemma3 MLX-vLM patches as cosmetic. The current patches fix concrete CUDA parity mismatches: SigLIP post-layernorm eps, vision SDPA fp32 math with cast-back, vision LayerNorm/GELU fp32 math with cast-back, text RMSNorm fp32 math with cast-back, image feature scaling by text embedding width, image-token attention masking in CCE, and preserving merged VLM inputs_embeds dtype instead of promoting activations to fp32 because norm weights are fp32. Do not switch MLX grad clipping back to bf16 reductions. Global grad norm clipping should reduce in fp32; bf16 reductions changed clipping behavior. Validation summary: focused MLX/Gemma3/VLM tests pass, and the remaining Gemma3 VLM delta was isolated to cumulative bf16/backend drift through the 27-layer SigLIP tower rather than labels, preprocessing, position ids, projector, final post-LN, block-0 attention backward, or weight mapping.
|
Reviewer / maintainer guardrail for the next MLX parity push: A few of the local commits intentionally narrow or revert behavior from the recent review commits. Please do not flip these back without re-running the parity probes.
Validation: focused MLX/Gemma3/VLM tests pass. The remaining Gemma3 VLM delta was isolated to cumulative bf16/backend drift through the 27-layer SigLIP tower, not labels, preprocessing, position ids, projector, final post-LN, block-0 attention backward, or weight mapping. |
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
Keep Gemma3 in the verified MLX training compile set. The observed eager-vs-compiled loss deltas are small enough that Gemma3 should continue using compile rather than falling back to eager by policy. Update the regression test to assert the intended compile qualification so this does not get accidentally demoted again.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
Confirmed the new Verifier setupIdentical gemma-3-270m + r=8 LoRA memorisation fixture across CUDA and MLX (same as the existing macos-14 smoke). Each side ran the three clip modes back-to-back with the same seed and the same data ordering. Staging fork: Loss curves (30-step memorisation, first step where loss < 0.001)Within CUDA (precision-controlled, fp32 path):
On real MLX (macos-14, fp16), runs
Mean abs delta Memory + wall-clock probe (macos-14, run
|
| mode | peak GPU | step time |
|---|---|---|
max_grad_value=1.0 |
0.6643 GB | 3.29 s |
max_grad_leaf_norm=1.0 |
0.6638 GB | 3.35 s |
max_grad_norm=1.0 |
0.6664 GB | 3.64 s |
max_grad_normis +2.7 MB peak and +9 to 10% wall-clock vs the per-leaf modes.leaf_normandelementwiseare within 0.5 MB of each other (noise floor).- This is on a 270M model with ~1.7M trainable params. The overhead scales linearly with
sum(num_trainable_params), so on Llama-3-8B LoRA r=16 (~40M trainable) it's ~60 MB peak, and on 70B LoRA or full FT the gap matters.
Reading
- The new default lands where it should: proportional per-leaf rescale preserves direction (unlike elementwise) without the cross-tree reduction cost (unlike global norm).
max_grad_valuesemantics are preserved, so existing users who explicitly opt in keep the elementwise behavior.- The
_resolve_mlx_grad_clippingprecedence is verified on real MLX (separate assertion step in the same staging workflow): default ->("leaf_norm", 1.0), explicitmax_grad_value=1.5->("value", 1.5), explicitmax_grad_norm=1.0->("global_norm", 1.0).
LGTM from my side on the clip rework; nothing else to address here unless the gemma3 / mlx-vlm arch-by-arch work needs a separate PR.
Raw data: per-step loss + grad-norm JSONs and the memory probe artifact are uploaded as artifacts on staging-1 run 26497373034.
`test_mlx_max_grad_value_none.py` now covers max_grad_leaf_norm and max_grad_norm too. Rename to test_mlx_grad_clip_resolution.py and update the docstring to list all three knobs.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
Merges 5 main-side mlx fixes (unslothai#673 zero-token CCE, unslothai#679 + unslothai#692 LoRA save metadata, unslothai#682 invalid label NaN-poisoning, unslothai#688 tool mask). All 13 conflict regions in unsloth_zoo/mlx/utils.py resolved to keep PR unslothai#684's behavior where it conflicts on semantics: - half-open `<` length mask (PR unslothai#684 fix) wins over main's inclusive `<=` - `if labels is None` branch preserved (PR unslothai#684 generality) alongside main's `_normalize_cce_label_dtype` dtype widening - `_get_image_token_ids` legacy wrapper kept alongside main's new `_normalize_cce_label_dtype` / `_normalize_numpy_cce_labels` - `_mask_label_token_ids` calls `_normalize_cce_label_dtype` first so image masking honors main's uint-widening contract - HEAD's `_expand_token_replacements` dropped; main's three-function split (`_normalize_numpy_cce_labels` + `_expand_image_token_sequences` + `_expand_token_runs`) is canonical; duplicate HEAD wrappers removed - `_collate_vlm_prompt_completion_batch` reads back the masked labels in int64 so image + attention masking survives without narrowing - prompt-completion VLM collator routes through `_apply_vlm_label_masks` after dtype normalisation so ignore_token_ids and wide invalid ids both reach runtime CCE intact - `_to_mx_vlm_batch` uses main's `_normalize_cce_label_dtype` for labels while keeping PR unslothai#684's token_type_ids / mm_token_type_ids handling - `_unsloth_*` prefix filter preserved so the new collated_position_ids flag and main's raw-input-ids carrier both get stripped 152 MLX tests pass post-merge.
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
Many small fixes to align MLX training with unsloth transformers style mostly related to VLMs.