Skip to content

fix(mlx): cast custom norm outputs back to activation dtype#1

Merged
mmathew23 merged 4 commits into
mmathew23:explore/mlxfrom
Lyxot:pr/684
May 22, 2026
Merged

fix(mlx): cast custom norm outputs back to activation dtype#1
mmathew23 merged 4 commits into
mmathew23:explore/mlxfrom
Lyxot:pr/684

Conversation

@Lyxot
Copy link
Copy Markdown

@Lyxot Lyxot commented May 21, 2026

Summary

Adds MLX norm output cast-back support for custom norm modules discovered from the loaded model.

The helper now patches base nn.RMSNorm / nn.LayerNorm plus custom norm-like modules whose parameters are selected by the same norm-path rule used by _keep_norm_parameters_float32(). This keeps fp32 norm parameters/math while preventing fp32 norm outputs from promoting downstream activations.

The discovery is model-local, so hybrid mlx-lm / mlx-vlm models use whichever implementation is actually loaded.

Custom Norms Covered

Verified discovery covers these custom norms:

  • mlx_lm.models.bailing_moe_linear.GroupRMSNorm
  • mlx_lm.models.cohere.LayerNorm2D
  • mlx_lm.models.falcon_h1.FalconH1RMSNormGated
  • mlx_lm.models.gemma.RMSNorm
  • mlx_lm.models.gemma2.RMSNorm
  • mlx_lm.models.gemma3_text.RMSNorm
  • mlx_lm.models.granitemoehybrid.GraniteMoeHybridRMSNormGated
  • mlx_lm.models.mamba2.MambaRMSNormGated
  • mlx_lm.models.nemotron.NemotronLayerNorm1P
  • mlx_lm.models.nemotron_h.MambaRMSNormGated
  • mlx_lm.models.plamo2.RMSNorm
  • mlx_lm.models.qwen3_next.Qwen3NextRMSNormGated
  • mlx_lm.models.recurrent_gemma.RMSNorm
  • mlx_lm.models.rwkv7.LayerNormPerHead
  • mlx_lm.models.stablelm.LayerNormPerHead
  • mlx_lm.models.step3p5.ZeroCenteredRMSNorm
  • mlx_vlm.models.deepseekocr_2.vision.Qwen2RMSNorm
  • mlx_vlm.models.dots_ocr.vision.RMSNorm
  • mlx_vlm.models.fastvlm.vision.LayerNormChannel
  • mlx_vlm.models.gemma3.language.RMSNorm
  • mlx_vlm.models.gemma3n.audio.Gemma3nCumulativeGroupNorm
  • mlx_vlm.models.gemma3n.language.Gemma3nRMSNorm
  • mlx_vlm.models.gemma3n.vision.RMSNormAct2d
  • mlx_vlm.models.gemma4.audio.AudioRMSNorm
  • mlx_vlm.models.gemma4.language.RMSNormZeroShift
  • mlx_vlm.models.gemma4.vision.RMSNorm
  • mlx_vlm.models.gemma4.vision.VisionRMSNorm
  • mlx_vlm.models.jina_vlm.language.RMSNorm
  • mlx_vlm.models.paligemma.language.RMSNorm
  • mlx_vlm.models.qwen3_5.language.Qwen3_5RMSNormGated
  • mlx_vlm.models.sam3.sam_components.LayerNorm2d
  • mlx_vlm.models.sam3d_body.layers.LayerNorm32

Copilot AI review requested due to automatic review settings May 21, 2026 20:59
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds model-aware discovery of custom normalization modules in the MLX trainer so norm output casting can be applied beyond nn.RMSNorm / nn.LayerNorm, with a regression test covering real-world custom norm classes from optional MLX model packages.

Changes:

  • Discover additional “norm-like” module classes from a provided model and include them in the output-casting patch set.
  • Track which norm classes have been patched to support clean unpatching.
  • Add a test that verifies discovery and dtype-casting behavior for several custom norm implementations.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
unsloth_zoo/mlx/trainer.py Adds heuristics to find custom norm classes from the loaded model and patches/unpatches their __call__ for dtype casting.
tests/test_mlx_pr684_review_fixes.py Adds coverage ensuring custom norms are discovered and their outputs are cast back to activation dtype when enabled.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread unsloth_zoo/mlx/trainer.py
Comment thread unsloth_zoo/mlx/trainer.py
mmathew23 added a commit that referenced this pull request May 22, 2026
Multi-reviewer pass on the autocast wrapper / norm-upcast path:

- Instance-level forward (#2): an instance attribute `model.forward`
  (Unsloth runtime forward patching) shadows class-method overrides, so
  mutating __class__ silently bypassed the wrapper -> fp32 norm met a bf16
  linear with no autocast and crashed. Now wrap the instance attribute when
  present; otherwise subclass as before.
- Wrapper gating (unslothai#5, unslothai#7): install the wrapper iff fp32 norm params actually
  exist (from our upcast, the legacy env upcast, or an external
  _pre_set_compute_dtype policy) -- not on the upcast DECISION. Fixes the
  rollback path leaving external fp32 norms exposed, and stops wrapping models
  with no fp32 norm. Add _unwrap_forward_in_bf16_autocast for re-prepare (unslothai#10).
- config.architectures leak (unslothai#8/unslothai#9): keep the original __name__ on the
  generated subclass (unique __qualname__ for registration) so save_pretrained
  records the base architecture.
- Device detection (unslothai#11): recurse into mapping/list/tuple batches and fall
  back to the model's parameter device instead of defaulting to "cuda".
- Legacy UNSLOTH_UPCAST_LAYERNORM (#1/#3/unslothai#4): route through the shared
  _cast_named_module + union matcher and honour the external-policy deferral.
- Recursive external-ownership guard (unslothai#6): record descendants of tagged
  modules (the external policy casts recursively).
- Fresh-interpreter pickle test (unslothai#12): real subprocess load.

Shared helpers: _find_tensor_device_type, _call_forward_with_bf16_autocast,
_canonical_module_name, _cast_named_module. Unit suite: 25 passed.
@mmathew23 mmathew23 merged commit 5465959 into mmathew23:explore/mlx May 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants