Skip to content

feat(fsdp2): add fp32_norms for keeping RMSNorm/LayerNorm in fp32#3670

Merged
winglian merged 4 commits into
mainfrom
feat/fp32-norms-fsdp2
May 26, 2026
Merged

feat(fsdp2): add fp32_norms for keeping RMSNorm/LayerNorm in fp32#3670
winglian merged 4 commits into
mainfrom
feat/fp32-norms-fsdp2

Conversation

@winglian

@winglian winglian commented May 20, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds two config fields and a helper so users can keep norm modules in fp32 while training the rest of the model in bf16/fp16 under FSDP2. The mechanism is per-module fully_shard with a dedicated MixedPrecisionPolicy(param_dtype=fp32, reduce_dtype=fp32); once a norm is sharded, the surrounding decoder-layer wrap treats it as a boundary and doesn't fold it into the bf16 flat-param group.

fp32_norms: true
fsdp_version: 2

# Optional override; defaults to ["RMSNorm", "LayerNorm"] (suffix match).
# fp32_norm_classes:
#   - AfmoeRMSNorm
#   - transformers.models.llama.modeling_llama.LlamaRMSNorm

Motivation

Surfaced during full-FT continued pretraining on arcee-ai/Trinity-Mini-Base (AFMOE, 26B-A4B MoE) on 8×H200. Two paths, both broken:

  • transformers 5.x + patched modeling_afmoe.py: clearing _keep_in_fp32_modules = [] to get the model loading produces 5–60× worse held-out PPL (PPL 24–346 vs 1.3–6 on the same prompts). The plain _keep_in_fp32_modules attribute only fires for fp16 in current transformers, so on the bf16 path norms silently drop to bf16 even with the attribute set.
  • transformers 4.57.5 (canonical, no patches): inference works perfectly. Training fails at FSDP1 init with ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32 — FSDP1's flat-param dtype uniformity constraint hitting fp32 norms inside an otherwise-bf16 decoder layer.

Any model declaring fp32 norms (Llama, Mistral, Qwen3, etc.) hits the same FSDP1 wall. FSDP2's per-module MixedPrecisionPolicy is the supported escape hatch; this PR wires it up via config so users don't have to hand-write fully_shard calls.

Changes

  • src/axolotl/loaders/model.py: add shard_norms_fp32, _matches_norm_class, DEFAULT_FP32_NORM_SUFFIXES, plus a call site in ModelLoader.load() gated on cfg.fp32_norms. (The handoff named utils/models.py; that file no longer exists — model loading now lives in loaders/model.py.)
  • src/axolotl/utils/schemas/config.py: add fp32_norms: bool | None and fp32_norm_classes: list[str] | None fields with a check_fp32_norms model-validator that requires fsdp_version: 2.
  • tests/test_fp32_norms.py: 10 pure-CPU unit tests (matcher + guard-rail).
  • tests/utils/schemas/validation/test_fsdp.py: 3 new schema-validator tests.
  • docs/mixed_precision.qmd: "Keeping norms in fp32 (FSDP2)" subsection.

Matching semantics

Patterns without . match as a suffix against type(module).__name__. Catches the whole RMSNorm family (LlamaRMSNorm, Qwen3RMSNorm, AfmoeRMSNorm, MistralRMSNorm) without enumeration and also catches nn.LayerNorm.

Patterns containing . match the fully qualified f"{module.__module__}.{cls_name}". Use this when disambiguating same-named norms across model families or pinning a trust_remote_code modeling file precisely.

Scope (deliberately minimal)

  • FSDP2 only. Enabling fp32_norms with FSDP1 or DeepSpeed raises in the validator. An FSDP1 path is possible (custom wrap policy that segregates norms) but deferred.
  • No meta-device support. If params are on meta when shard_norms_fp32 runs, it raises with a clear message. The common cause is cpu_ram_efficient_loading: true; follow-up PR can add a post-materialization hook.
  • No surprise behavior for existing configs. Users opt in explicitly.

Validator behavior

  • fp32_norms: true + fsdp_version: 2 validates ✓
  • fp32_norms: true + fsdp_version: 1ValueError("fp32_norms requires fsdp_version: 2 …")
  • fp32_norm_classes set + fp32_norms unset → LOG.warning("… will be ignored.")

Testing

  • python3 -m pytest tests/test_fp32_norms.py tests/utils/schemas/validation/test_fsdp.py27 passed
  • Pre-existing schema validator tests (test_config_validators.py / test_default_values.py) — 40 passed
  • pre-commit run --from-ref origin/main --to-ref HEAD — all 8 hooks pass (ruff, ruff-format, mypy, bandit, etc.)
  • Trinity-Mini-Base full FSDP2 validation deferred until upstream modeling_afmoe.py is updated for transformers 5.x.

Out of scope / explicit non-goals

  • FSDP1 wrap-policy path
  • DeepSpeed equivalent
  • Meta-device materialization handling for cpu_ram_efficient_loading: true
  • Auto-detection of fp32 norms from _keep_in_fp32_modules
  • Per-pattern policy overrides

Summary by CodeRabbit

Release Notes

  • New Features

    • Added support for keeping normalization layers in fp32 precision with FSDP2 through configurable fp32_norms and fp32_norm_classes settings.
  • Documentation

    • Added configuration guidance for fp32 normalization handling with FSDP2.
  • Tests

    • Added comprehensive validation and unit tests for fp32 norm sharding.

Review Change Stack

@coderabbitai

coderabbitai Bot commented May 20, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 1ced459d-554b-4770-83b4-385afbe76c75

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds FP32 mixed-precision support for normalization layers under FSDP2. It introduces new configuration fields, implements norm-module detection and FSDP2 sharding logic, validates configuration constraints, provides comprehensive tests, and documents the feature for end users.

Changes

FP32 Norms for FSDP2 Mixed Precision

Layer / File(s) Summary
FP32 norms configuration schema and validation
src/axolotl/utils/schemas/config.py
AxolotlInputConfig adds fp32_norms and fp32_norm_classes fields. A check_fp32_norms validator enforces fsdp_version == "2" and warns when fp32_norm_classes is set without fp32_norms.
FP32 norm module sharding for FSDP2
src/axolotl/loaders/model.py
Adds shard_norms_fp32() to identify norm modules by name patterns (default: RMSNorm, LayerNorm), validate FSDP2 version, reject meta-device models, wrap matched norms with fp32 MixedPrecisionPolicy, log warnings for no matches, and conditionally invoke during model setup.
Unit and integration test coverage
tests/test_fp32_norms.py, tests/utils/schemas/validation/test_fsdp.py
Tests pattern matching (suffix, explicit class, fully-qualified paths, mixed), shard_norms_fp32 behavior (disabled no-op, FSDP2 enforcement, meta rejection, no-match warning), class override behavior, and schema validation (version requirement, cross-field warnings).
Configuration guidance documentation
docs/mixed_precision.qmd
New subsection documenting how to enable FP32 normalization layers with YAML flags fsdp_version: 2 and fp32_norms: true, with guidance on optional fp32_norm_classes override.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes


Suggested labels

ready to merge


Suggested reviewers

  • SalmanMohammadi
  • NanoCode012
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 17.86% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: adding fp32_norms feature for keeping RMSNorm/LayerNorm in fp32 under FSDP2, which aligns with all file changes.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/fp32-norms-fsdp2

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
src/axolotl/loaders/model.py (1)

67-77: ⚡ Quick win

Condense this inline rationale to a single short line.

This block is useful context, but it exceeds the repository’s comment-length rule for source files.

As per coding guidelines, "Only add comments when explaining the WHY behind non-obvious logic, hidden constraints, or workarounds for specific bugs... Comments should be a maximum of one short line".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/loaders/model.py` around lines 67 - 77, Replace the multi-line
rationale with a single short comment line that summarizes the purpose of this
block: e.g. "Shard RMSNorm/LayerNorm as fp32 before decoder wrapping so FSDP2
can keep per-module MixedPrecisionPolicy (workaround for FSDP1 flat-param dtype
constraints)." Locate the comment near the fp32 norm sharding section that
references FSDP1, FSDP2, MixedPrecisionPolicy and RMSNorm/LayerNorm and replace
the entire paragraph with that one-line summary.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/axolotl/loaders/model.py`:
- Around line 92-97: The loop over patterns (variable patterns, pattern) should
ignore empty or all-whitespace entries before attempting matches because
cls_name.endswith("") will match everything; update the code in the matching
block (the for loop handling pattern, qualified, and cls_name) to skip any
pattern where pattern.strip() is empty (e.g., continue) so only non-empty
patterns are compared using the existing qualified == pattern and
cls_name.endswith(pattern) checks.

In `@src/axolotl/utils/schemas/config.py`:
- Around line 1501-1514: The validator check_fp32_norms currently only compares
fsdp_version to "2" but doesn't enforce that FSDP is enabled at all; update
check_fp32_norms so that when self.fp32_norms is True it first verifies an FSDP
configuration exists (e.g., self.fsdp_version is set/non-empty) and raises a
ValueError if FSDP is not enabled, and retain the existing fsdp_version == "2"
check; touch the same method (check_fp32_norms) and the attributes fp32_norms,
fsdp_version, and fp32_norm_classes (and use LOG for warnings) so fp32_norms
cannot be set without an FSDP config.

---

Nitpick comments:
In `@src/axolotl/loaders/model.py`:
- Around line 67-77: Replace the multi-line rationale with a single short
comment line that summarizes the purpose of this block: e.g. "Shard
RMSNorm/LayerNorm as fp32 before decoder wrapping so FSDP2 can keep per-module
MixedPrecisionPolicy (workaround for FSDP1 flat-param dtype constraints)."
Locate the comment near the fp32 norm sharding section that references FSDP1,
FSDP2, MixedPrecisionPolicy and RMSNorm/LayerNorm and replace the entire
paragraph with that one-line summary.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: db6a4245-ef43-4d61-a3c2-62e2976c2886

📥 Commits

Reviewing files that changed from the base of the PR and between afd74ae and a271fb5.

📒 Files selected for processing (5)
  • docs/mixed_precision.qmd
  • src/axolotl/loaders/model.py
  • src/axolotl/utils/schemas/config.py
  • tests/test_fp32_norms.py
  • tests/utils/schemas/validation/test_fsdp.py

Comment thread src/axolotl/loaders/model.py Outdated
Comment thread src/axolotl/utils/schemas/config.py
@github-actions

github-actions Bot commented May 20, 2026

Copy link
Copy Markdown
Contributor

📖 Documentation Preview: https://6a1207fd24a34eb8730563b1--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit 3eeb200

@codecov

codecov Bot commented May 20, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 82.75862% with 15 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/utils/fp32_norms.py 84.37% 10 Missing ⚠️
src/axolotl/monkeypatch/accelerate/fsdp2.py 0.00% 4 Missing ⚠️
src/axolotl/loaders/model.py 85.71% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@winglian winglian force-pushed the feat/fp32-norms-fsdp2 branch 4 times, most recently from 465871b to 8ab7c69 Compare May 23, 2026 19:54
winglian added 4 commits May 23, 2026 15:54
Add an opt-in config flag that shards norm modules under their own FSDP2
MixedPrecisionPolicy (fp32) before the standard decoder-layer wrap, so the
norm and decoder shard groups stay independent. This lets models that
declare fp32 norms for training stability train under FSDP2 while the rest
of the model runs in bf16/fp16.

FSDP1 enforces flat-param dtype uniformity within each wrap group, which is
incompatible with keeping norms in fp32; the validator therefore requires
fsdp_version: 2.

Matching: patterns without a "." match type(module).__name__ as a suffix
(catches LlamaRMSNorm, Qwen3RMSNorm, AfmoeRMSNorm, nn.LayerNorm, etc.);
patterns containing a "." match the fully qualified class path exactly.
Defaults to ["RMSNorm", "LayerNorm"].

Signed-off-by: Wing Lian <wing@axolotl.ai>
- matcher: skip empty/whitespace-only patterns (cls_name.endswith("") is
  True for any class, which would silently match everything).
- validator: also require fsdp_config to be set, not just fsdp_version==2.
  fsdp_config is the canonical "is_fsdp" signal elsewhere in the codebase
  (used by check_fsdp_torch_version, sample_packing validators, etc.).
- tests: temporarily flip propagate=True on the `axolotl` logger so
  pytest caplog can see the warnings. axolotl.cli.configure_logging()
  sets propagate=False at import time, which is the documented reason
  the assertions were failing in CI even though the warnings were
  firing visibly in stdout.
- comment: replace multi-line rationale near the fp32_norms helpers with
  a one-line summary (the longer version lives in the PR description).

Signed-off-by: Wing Lian <wing@axolotl.ai>
…ertion

The existing fp32_norms tests are pure-CPU and monkeypatch fully_shard —
they cover the matcher logic and validator guard rails but never exercise
the actual FSDP2 path that motivated this PR.

Adds tests/e2e/multigpu/test_fsdp2_fp32_norms.py: spawns a 2-GPU
`axolotl train` subprocess with `fp32_norms: true` + `fsdp_version: 2` +
`bf16: true` on tiny-qwen3-129m (full FT, 2 steps) and asserts:

  1. Training completes — the original FSDP1 flat-param dtype crash
     can't recur because we're on FSDP2 with the per-module
     MixedPrecisionPolicy.
  2. All RMSNorm params are float32 after step 1 — captured via a
     test-only TrainerCallback in
     tests/e2e/multigpu/_fp32_norms_dtype_capture.py, dumped to JSON
     at $FP32_NORMS_DTYPE_DUMP_PATH on rank 0.
  3. At least one non-norm param is bfloat16 — proves the two FSDP2
     MixedPrecisionPolicy groups are independent (catches a silent
     globally-fp32 fallback that would technically satisfy assertion 2
     but defeat the point of the feature).

The dtype-capture plugin is plumbed in via the test's yaml `plugins:`
list, with PYTHONPATH=<repo_root> on the subprocess env so the
tests.e2e.multigpu._fp32_norms_dtype_capture module resolves.

Signed-off-by: Wing Lian <wing@axolotl.ai>
@winglian winglian force-pushed the feat/fp32-norms-fsdp2 branch from 8ab7c69 to 3eeb200 Compare May 23, 2026 19:55
@winglian winglian merged commit b05ab9a into main May 26, 2026
17 of 18 checks passed
@winglian winglian deleted the feat/fp32-norms-fsdp2 branch May 26, 2026 12:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant