Skip to content

MLX Update Training#684

Open
mmathew23 wants to merge 51 commits into
unslothai:mainfrom
mmathew23:explore/mlx
Open

MLX Update Training#684
mmathew23 wants to merge 51 commits into
unslothai:mainfrom
mmathew23:explore/mlx

Conversation

@mmathew23

Copy link
Copy Markdown
Collaborator

Many small fixes to align MLX training with unsloth transformers style mostly related to VLMs.

@mmathew23 mmathew23 requested a review from danielhanchen as a code owner May 20, 2026 22:08

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several enhancements and fixes for MLX training, focusing on VLM support and parity with HuggingFace's trainer behavior. Key updates include a manual AdamW weight decay implementation that filters out bias and normalization parameters, a diagnostic 'nf4_dense' quantization mode, and logic to maintain normalization parameters in float32. Additionally, it refines VLM collation, fixes a loss masking off-by-one error, and prevents automatic EOS appending in datasets. Feedback from the review identified a bug in the Qwen3-VL LayerNorm parameter check, precision and memory issues in the manual weight decay logic, and a performance regression in the compiler's logit handling.

Comment thread unsloth_zoo/mlx/compile.py Outdated
Comment on lines +2722 to +2725
if "weight" in norm:
y = y * norm.weight.astype(mx.float32)
if "bias" in norm:
y = y + norm.bias.astype(mx.float32)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check if "weight" in norm: is not a valid way to verify parameter existence on an mlx.nn.Module. This will likely evaluate to False or raise a TypeError, causing the LayerNorm calculation to skip applying the weight and bias tensors. This will result in incorrect numerical output for the Qwen3-VL vision blocks.

Suggested change
if "weight" in norm:
y = y * norm.weight.astype(mx.float32)
if "bias" in norm:
y = y + norm.bias.astype(mx.float32)
if hasattr(norm, "weight"):
y = y * norm.weight.astype(mx.float32)
if hasattr(norm, "bias"):
y = y + norm.bias.astype(mx.float32)

Comment thread unsloth_zoo/mlx/trainer.py Outdated
Comment on lines +495 to +497
lr = optimizer.learning_rate.astype(flat_grad[name].dtype)
scale = mx.array(1.0, dtype=lr.dtype) - lr * mx.array(wd, dtype=lr.dtype)
decayed.append((name, parameter * scale))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This manual weight decay implementation has two significant issues:

  1. Precision Underflow: Calculating the scale in the parameter's native dtype (e.g., float16 or bfloat16) will cause the weight decay to be ignored. For typical values like lr=2e-4 and wd=0.01, the term lr * wd (2e-6) is smaller than the machine epsilon for float16/bfloat16 relative to 1.0, so 1.0 - 2e-6 rounds back to 1.0.
  2. Unintended Parameter Promotion: If scale is calculated in float32 (to fix the precision issue), the operation parameter * scale will promote the model parameters to float32. Since these parameters (LoRA weights and norms) are explicitly excluded from the restoration logic in _restore_trainable_storage_dtypes, they will remain in float32, doubling their memory footprint for the rest of the training session.

The calculation should be done in float32 and explicitly cast back to the original dtype.

Suggested change
lr = optimizer.learning_rate.astype(flat_grad[name].dtype)
scale = mx.array(1.0, dtype=lr.dtype) - lr * mx.array(wd, dtype=lr.dtype)
decayed.append((name, parameter * scale))
lr = optimizer.learning_rate.astype(mx.float32)
scale = mx.array(1.0, dtype=mx.float32) - lr * mx.array(wd, dtype=mx.float32)
decayed.append((name, (parameter * scale).astype(parameter.dtype)))

Comment thread unsloth_zoo/compiler.py Outdated
logit_softcapping = None if (\\4) == () else (\\4),
)
elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and NOT_RETURN_LOGITS:
elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Removing the NOT_RETURN_LOGITS check here, combined with the deletion of the specialized branch for returned logits, introduces a performance and memory regression. When UNSLOTH_RETURN_LOGITS=1 is active, logits are materialized early; forcing the code into the unsloth_fused_ce_loss branch causes a redundant matmul. Furthermore, ensure that logit scaling (multiply/divide) and softcapping transformations are applied to these materialized logits before they are passed to the loss function to maintain consistency with standard execution paths.

References
  1. When implementing code paths that return materialized logits (e.g., using UNSLOTH_RETURN_LOGITS=1), ensure that logit scaling and softcapping transformations are applied before the loss function.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a93449f518

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +533 to +538
"pad_token_id",
"image_token_id",
"video_token_id",
"audio_token_id",
):
_append_unique_int(ids, getattr(tokenizer, attr, None))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Exclude pad token ID from global VLM ignore list

Adding pad_token_id to _get_vlm_ignore_token_ids and then masking via _mask_label_token_ids masks that token everywhere, not only in padded positions. For tokenizers where pad_token_id == eos_token_id (common in Llama-derived setups), valid in-sequence EOS targets are turned into -100, so EOS supervision is silently removed across training/eval batches. Padding should be masked only through attention_mask == 0, not by globally ignoring the pad ID.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/compiler.py Outdated
logit_softcapping = None if (\\4) == () else (\\4),
)
elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and NOT_RETURN_LOGITS:
elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep return-logits mode out of fused CE loss branch

This condition now routes all labeled ForCausalLMLoss calls through unsloth_fused_ce_loss, including UNSLOTH_RETURN_LOGITS=1. In that mode the compiler replacement already materializes logits up front, so this branch computes a second LM-head projection inside fused CE, increasing memory/compute and risking OOM on large models. The previous NOT_RETURN_LOGITS guard avoided that duplicate matmul.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: dcd0a9001c

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/compile.py Outdated
Comment on lines +2722 to +2725
if "weight" in norm:
y = y * norm.weight.astype(mx.float32)
if "bias" in norm:
y = y + norm.bias.astype(mx.float32)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Apply LayerNorm affine params in Qwen3 vision block patch

The new _qwen3_torch_like_layer_norm gate checks use membership ("weight" in norm, "bias" in norm) instead of attribute existence. For module-style LayerNorm objects, this test does not reliably indicate parameter presence, so the affine scale/bias can be skipped (or error), making every patched Qwen3-VL vision block run a non-affine normalization that diverges from pretrained/Transformers behavior and can degrade training and loss parity.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +2518 to +2521
if dataset_order == "torch_randperm":
indices = _torch_randperm_order(len(dataset), int(seed) + epoch)
elif dataset_order in (None, "default"):
np.random.seed(int(seed) + epoch)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Normalize nullable seeds before VLM epoch reshuffle

When create_vlm_batches rolls to a new epoch, it unconditionally computes int(seed) + epoch for both torch_randperm and default reshuffling. If callers pass seed=None (which is already supported in _torch_randperm_order and by NumPy seeding), multi-epoch materialization will raise TypeError on the first rollover and stop training/eval batch creation.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e16efc0f9c

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +2600 to +2601
if dataset_order == "torch_randperm":
indices = _torch_randperm_order(len(dataset), int(seed) + epoch)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Normalize nullable seed in streaming torch_randperm path

When dataset_order="torch_randperm", the streaming iterator now does int(seed) + epoch before calling _torch_randperm_order. If callers pass seed=None (which _torch_randperm_order explicitly supports by defaulting to 3407), this raises TypeError on the first epoch and stops batch generation. Please normalize seed the same way as _torch_randperm_order (or pass it through directly) before adding epoch.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b0a83b52cf

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/trainer.py Outdated
continue
if not self._should_apply_weight_decay(name, parameter):
continue
lr = optimizer.learning_rate.astype(flat_grad[name].dtype)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle scalar learning rates in manual AdamW decay

_apply_manual_adamw_weight_decay assumes optimizer.learning_rate has an .astype(...) method, but in valid configurations it can be a Python float (notably the constant scheduler path without warmup, and the MLX simulation optimizer stubs). In that case this line raises AttributeError before any optimizer step, so AdamW training fails as soon as manual decay runs. Please normalize LR to an mx.array (or branch on scalar) before dtype-casting.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +2792 to +2793
if not tokenized:
return []

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Raise on empty ordered token stream instead of returning []

When create_ordered_batches drops all rows via if len(ids) >= 2 (for example very small max_seq_length or single-token rows), it returns an empty batch list. The new ordered-data path in MLXTrainer then indexes batches with batch_idx % len(batches), which crashes at runtime with division-by-zero instead of surfacing a data error. This should raise a clear ValueError here, matching the other dataset-prep guards.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/compile.py Outdated
centered = x_f - mean
var = mx.mean(centered * centered, axis=-1, keepdims=True)
y = centered * mx.rsqrt(var + norm.eps)
if "weight" in norm:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is specific to qwen3 do we need the if check?
or even in qwen3 family some norms have weight/bias and some dont?

Comment thread unsloth_zoo/mlx/loader.py Outdated
flat = mx.concatenate([flat, mx.zeros((pad,), dtype=mx.float32)])
groups = flat.reshape((-1, group_size))
absmax = mx.max(mx.abs(groups), axis=1, keepdims=True)
denom = mx.maximum(absmax, mx.array(1e-12, dtype=mx.float32))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking we do this to avoid 0 division. But div by 1e-12 might cause the numbers to blow up?
For a similar thing what I did was set the scale to 1

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 964be34ec2

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +2830 to +2831
if num_batches is None and seen >= len(tokenized):
break

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep torch_randperm active across epoch-based text training

When dataset_order="torch_randperm" is used with epoch-based training (max_steps == 0), this early break materializes only one epoch of randomized order and returns. MLXTrainer then reuses that fixed batch list with batch_idx % len(batches) for subsequent epochs, so later epochs never get a fresh randperm order. This diverges from the expected per-epoch reshuffle behavior of a torch-style random sampler and can bias training/eval results for multi-epoch runs.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +2523 to +2524
if num_batches is None:
break

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Rebuild VLM torch_randperm order for each epoch run

In pre-materialized VLM batching, this branch exits as soon as one dataset pass completes whenever num_batches is None (the normal epoch-based path). With dataset_order="torch_randperm", that means only the first epoch gets randomized; later epochs in MLXTrainer reuse the same stored batches instead of generating a new permutation. This breaks sampler parity for multi-epoch VLM training and can skew convergence behavior.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ca08652226

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/utils.py Outdated
Comment on lines +2825 to +2829
if order_pos >= len(order):
epoch += 1
order = make_order(epoch)
order_pos = 0
batch_items.append(tokenized[order[order_pos]])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Stop batches from spanning torch_randperm epoch boundaries

create_ordered_batches wraps to a new permutation inside the same batch (order_pos >= len(order)), so when num_epochs is used and len(dataset) % batch_size != 0, one batch can contain samples from two different epochs. This changes both sample order and batch count (e.g., 5 samples, batch size 4, 2 epochs yields 3 batches instead of the expected 4), so MLXTrainer under-trains in epoch mode and no longer mirrors CUDA/DataLoader epoch semantics for dataset_order="torch_randperm".

Useful? React with 👍 / 👎.

@mmathew23

Copy link
Copy Markdown
Collaborator Author

Ok pretty sure resolved review feedback

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 7e0bee546f

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth_zoo/mlx/trainer.py Outdated
Comment on lines +627 to +628
if any("norm" in part for part in parts):
return False

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Exclude ln_ normalization weights from AdamW decay*

The new HF-style decay filter only checks for path segments containing "norm", so normalization parameters named like ln_1.weight or ln_f.weight still get decayed. This repo already treats ln_* as norm parameters in _ensure_lora_frozen (see _NORM_FRAGMENTS), so models using those names will now receive unintended weight decay on norm scales, diverging from the stated “skip bias and norms” behavior and changing optimization dynamics.

Useful? React with 👍 / 👎.

Comment thread unsloth_zoo/mlx/loader.py Outdated
Comment on lines +151 to +152
parts = str(path).lower().split(".")
return any("norm" in part for part in parts[:-1])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Include ln_ params in fp32 norm-parameter preservation*

_keep_norm_parameters_float32 claims to keep normalization parameters in fp32, but _is_norm_parameter_path only matches components containing "norm". Any normalization layer named with ln_* (which this codebase already recognizes as norm-like elsewhere) is skipped and left in lower precision, undermining the stabilization this pass is meant to provide for FT/LoRA/QLoRA training.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

mmathew23 added 15 commits May 26, 2026 12:39
Rationale / guardrails for the local Gemma3 parity stack:

This is the last local-only zoo commit before push, so this body documents the changes that should not be accidentally flipped back during review.

Do not restore the broader Daniel position-id override. VLM CCE should prefer collator-built position_ids only when _unsloth_collated_position_ids is set, preserve position_ids explicitly returned by InputEmbeddingsFeatures, and otherwise fall back to model-stashed or sequential ids. The broad override moved Qwen/Gemma-style VLM runs away from CUDA collation semantics.

Do not re-add global pad_token_id masking to the VLM loss. Padding is masked by labels/attention masks; globally ignoring pad ids also suppresses legitimate target ids for custom datasets. Image/video placeholder token ids are the only global ignore ids needed for VLM CCE.

Do not mark Gemma3 training compile verified yet. Fixed-fixture Gemma3 showed compiled loss differing from eager before optimizer update, so best-effort must fall back to eager until real training parity is proven.

Do not remove the Gemma3 MLX-vLM patches as cosmetic. The current patches fix concrete CUDA parity mismatches: SigLIP post-layernorm eps, vision SDPA fp32 math with cast-back, vision LayerNorm/GELU fp32 math with cast-back, text RMSNorm fp32 math with cast-back, image feature scaling by text embedding width, image-token attention masking in CCE, and preserving merged VLM inputs_embeds dtype instead of promoting activations to fp32 because norm weights are fp32.

Do not switch MLX grad clipping back to bf16 reductions. Global grad norm clipping should reduce in fp32; bf16 reductions changed clipping behavior.

Validation summary: focused MLX/Gemma3/VLM tests pass, and the remaining Gemma3 VLM delta was isolated to cumulative bf16/backend drift through the 27-layer SigLIP tower rather than labels, preprocessing, position ids, projector, final post-LN, block-0 attention backward, or weight mapping.
@mmathew23

Copy link
Copy Markdown
Collaborator Author

Reviewer / maintainer guardrail for the next MLX parity push:

A few of the local commits intentionally narrow or revert behavior from the recent review commits. Please do not flip these back without re-running the parity probes.

  • VLM position ids: CCE should use collator-built position_ids only when _unsloth_collated_position_ids is set, preserve position_ids explicitly returned by InputEmbeddingsFeatures, and otherwise fall back to model-stashed or sequential ids. The broader override changed CUDA collation semantics.
  • VLM loss masking: do not globally ignore pad_token_id in the VLM loss. Padding is already masked by labels/attention masks; global pad-id ignore can suppress legitimate custom-dataset targets. Image/video placeholder ids remain the global ignore ids.
  • Gemma3 compile: Gemma3 training compile should stay unverified. A fixed-fixture probe showed compiled loss differing from eager before any optimizer update.
  • Gemma3 mlx-vlm patches: these are not cosmetic. They fix measured CUDA parity mismatches: SigLIP post-LN eps, vision SDPA fp32 math with cast-back, vision LayerNorm/GELU fp32 math with cast-back, text RMSNorm fp32 math with cast-back, image feature scaling by text embedding width, image-token attention masking in CCE, and preserving merged VLM activation dtype.
  • Grad clipping: global grad norm clipping should reduce in fp32; bf16 reduction changed clipping behavior.

Validation: focused MLX/Gemma3/VLM tests pass. The remaining Gemma3 VLM delta was isolated to cumulative bf16/backend drift through the 27-layer SigLIP tower, not labels, preprocessing, position ids, projector, final post-LN, block-0 attention backward, or weight mapping.

@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

mmathew23 added 2 commits May 26, 2026 19:57
Keep Gemma3 in the verified MLX training compile set. The observed eager-vs-compiled loss deltas are small enough that Gemma3 should continue using compile rather than falling back to eager by policy.

Update the regression test to assert the intended compile qualification so this does not get accidentally demoted again.
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@danielhanchen

Copy link
Copy Markdown
Member

Confirmed the new max_grad_leaf_norm default end-to-end on real Apple Silicon and against a CUDA reference on B200. Posting the data here for the record.

Verifier setup

Identical gemma-3-270m + r=8 LoRA memorisation fixture across CUDA and MLX (same as the existing macos-14 smoke). Each side ran the three clip modes back-to-back with the same seed and the same data ordering. Staging fork: unslothai/unsloth-staging-1 branch mlx-matthew-validation.

Loss curves (30-step memorisation, first step where loss < 0.001)

Within CUDA (precision-controlled, fp32 path):

clip mode converged at step-8 loss
max_grad_norm=1.0 (HF reference) step 9 0.0023
max_grad_leaf_norm=1.0 (new default) step 10 0.0047
max_grad_value=1.0 (prior default) step 11 0.0245
none step 11 0.0242

max_grad_leaf_norm converges 5x faster than max_grad_value at the critical step and lands one step behind the HF reference. Elementwise tracks essentially the no-clip rate, confirming the OP intuition that elementwise distorts direction enough to slow convergence.

On real MLX (macos-14, fp16), runs 26493499055 (elementwise baseline) and 26495635429 (leaf_norm probe):

clip mode converged at
max_grad_value=1.0 step 8
max_grad_leaf_norm=1.0 step 8

Mean abs delta |MLX-leaf - MLX-elem| over 30 steps: 0.0105. The switch is a no-op on the memorisation outcome but converges slightly faster mid-training (step 4 loss 0.389 vs 0.525, step 5 0.0091 vs 0.0182). Both reach loss=0 by step 8 and stay there through step 30.

Memory + wall-clock probe (macos-14, run 26497373034)

Three sequential trainings, only clip mode varies, MLX peak memory tracker reset between runs:

mode peak GPU step time
max_grad_value=1.0 0.6643 GB 3.29 s
max_grad_leaf_norm=1.0 0.6638 GB 3.35 s
max_grad_norm=1.0 0.6664 GB 3.64 s
  • max_grad_norm is +2.7 MB peak and +9 to 10% wall-clock vs the per-leaf modes.
  • leaf_norm and elementwise are within 0.5 MB of each other (noise floor).
  • This is on a 270M model with ~1.7M trainable params. The overhead scales linearly with sum(num_trainable_params), so on Llama-3-8B LoRA r=16 (~40M trainable) it's ~60 MB peak, and on 70B LoRA or full FT the gap matters.

Reading

  • The new default lands where it should: proportional per-leaf rescale preserves direction (unlike elementwise) without the cross-tree reduction cost (unlike global norm).
  • max_grad_value semantics are preserved, so existing users who explicitly opt in keep the elementwise behavior.
  • The _resolve_mlx_grad_clipping precedence is verified on real MLX (separate assertion step in the same staging workflow): default -> ("leaf_norm", 1.0), explicit max_grad_value=1.5 -> ("value", 1.5), explicit max_grad_norm=1.0 -> ("global_norm", 1.0).

LGTM from my side on the clip rework; nothing else to address here unless the gemma3 / mlx-vlm arch-by-arch work needs a separate PR.

Raw data: per-step loss + grad-norm JSONs and the memory probe artifact are uploaded as artifacts on staging-1 run 26497373034.

`test_mlx_max_grad_value_none.py` now covers max_grad_leaf_norm and
max_grad_norm too. Rename to test_mlx_grad_clip_resolution.py and
update the docstring to list all three knobs.
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Merges 5 main-side mlx fixes (unslothai#673 zero-token CCE, unslothai#679 + unslothai#692 LoRA save
metadata, unslothai#682 invalid label NaN-poisoning, unslothai#688 tool mask). All 13
conflict regions in unsloth_zoo/mlx/utils.py resolved to keep PR unslothai#684's
behavior where it conflicts on semantics:

  - half-open `<` length mask (PR unslothai#684 fix) wins over main's inclusive `<=`
  - `if labels is None` branch preserved (PR unslothai#684 generality) alongside
    main's `_normalize_cce_label_dtype` dtype widening
  - `_get_image_token_ids` legacy wrapper kept alongside main's new
    `_normalize_cce_label_dtype` / `_normalize_numpy_cce_labels`
  - `_mask_label_token_ids` calls `_normalize_cce_label_dtype` first so
    image masking honors main's uint-widening contract
  - HEAD's `_expand_token_replacements` dropped; main's three-function
    split (`_normalize_numpy_cce_labels` + `_expand_image_token_sequences`
    + `_expand_token_runs`) is canonical; duplicate HEAD wrappers removed
  - `_collate_vlm_prompt_completion_batch` reads back the masked labels
    in int64 so image + attention masking survives without narrowing
  - prompt-completion VLM collator routes through `_apply_vlm_label_masks`
    after dtype normalisation so ignore_token_ids and wide invalid ids
    both reach runtime CCE intact
  - `_to_mx_vlm_batch` uses main's `_normalize_cce_label_dtype` for labels
    while keeping PR unslothai#684's token_type_ids / mm_token_type_ids handling
  - `_unsloth_*` prefix filter preserved so the new collated_position_ids
    flag and main's raw-input-ids carrier both get stripped

152 MLX tests pass post-merge.
@chatgpt-codex-connector

Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants