feat(fsdp2): add fp32_norms for keeping RMSNorm/LayerNorm in fp32#3670
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR adds FP32 mixed-precision support for normalization layers under FSDP2. It introduces new configuration fields, implements norm-module detection and FSDP2 sharding logic, validates configuration constraints, provides comprehensive tests, and documents the feature for end users. ChangesFP32 Norms for FSDP2 Mixed Precision
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
src/axolotl/loaders/model.py (1)
67-77: ⚡ Quick winCondense this inline rationale to a single short line.
This block is useful context, but it exceeds the repository’s comment-length rule for source files.
As per coding guidelines, "Only add comments when explaining the WHY behind non-obvious logic, hidden constraints, or workarounds for specific bugs... Comments should be a maximum of one short line".
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/axolotl/loaders/model.py` around lines 67 - 77, Replace the multi-line rationale with a single short comment line that summarizes the purpose of this block: e.g. "Shard RMSNorm/LayerNorm as fp32 before decoder wrapping so FSDP2 can keep per-module MixedPrecisionPolicy (workaround for FSDP1 flat-param dtype constraints)." Locate the comment near the fp32 norm sharding section that references FSDP1, FSDP2, MixedPrecisionPolicy and RMSNorm/LayerNorm and replace the entire paragraph with that one-line summary.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/axolotl/loaders/model.py`:
- Around line 92-97: The loop over patterns (variable patterns, pattern) should
ignore empty or all-whitespace entries before attempting matches because
cls_name.endswith("") will match everything; update the code in the matching
block (the for loop handling pattern, qualified, and cls_name) to skip any
pattern where pattern.strip() is empty (e.g., continue) so only non-empty
patterns are compared using the existing qualified == pattern and
cls_name.endswith(pattern) checks.
In `@src/axolotl/utils/schemas/config.py`:
- Around line 1501-1514: The validator check_fp32_norms currently only compares
fsdp_version to "2" but doesn't enforce that FSDP is enabled at all; update
check_fp32_norms so that when self.fp32_norms is True it first verifies an FSDP
configuration exists (e.g., self.fsdp_version is set/non-empty) and raises a
ValueError if FSDP is not enabled, and retain the existing fsdp_version == "2"
check; touch the same method (check_fp32_norms) and the attributes fp32_norms,
fsdp_version, and fp32_norm_classes (and use LOG for warnings) so fp32_norms
cannot be set without an FSDP config.
---
Nitpick comments:
In `@src/axolotl/loaders/model.py`:
- Around line 67-77: Replace the multi-line rationale with a single short
comment line that summarizes the purpose of this block: e.g. "Shard
RMSNorm/LayerNorm as fp32 before decoder wrapping so FSDP2 can keep per-module
MixedPrecisionPolicy (workaround for FSDP1 flat-param dtype constraints)."
Locate the comment near the fp32 norm sharding section that references FSDP1,
FSDP2, MixedPrecisionPolicy and RMSNorm/LayerNorm and replace the entire
paragraph with that one-line summary.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: db6a4245-ef43-4d61-a3c2-62e2976c2886
📒 Files selected for processing (5)
docs/mixed_precision.qmdsrc/axolotl/loaders/model.pysrc/axolotl/utils/schemas/config.pytests/test_fp32_norms.pytests/utils/schemas/validation/test_fsdp.py
|
📖 Documentation Preview: https://6a1207fd24a34eb8730563b1--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 3eeb200 |
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
465871b to
8ab7c69
Compare
Add an opt-in config flag that shards norm modules under their own FSDP2 MixedPrecisionPolicy (fp32) before the standard decoder-layer wrap, so the norm and decoder shard groups stay independent. This lets models that declare fp32 norms for training stability train under FSDP2 while the rest of the model runs in bf16/fp16. FSDP1 enforces flat-param dtype uniformity within each wrap group, which is incompatible with keeping norms in fp32; the validator therefore requires fsdp_version: 2. Matching: patterns without a "." match type(module).__name__ as a suffix (catches LlamaRMSNorm, Qwen3RMSNorm, AfmoeRMSNorm, nn.LayerNorm, etc.); patterns containing a "." match the fully qualified class path exactly. Defaults to ["RMSNorm", "LayerNorm"]. Signed-off-by: Wing Lian <wing@axolotl.ai>
- matcher: skip empty/whitespace-only patterns (cls_name.endswith("") is
True for any class, which would silently match everything).
- validator: also require fsdp_config to be set, not just fsdp_version==2.
fsdp_config is the canonical "is_fsdp" signal elsewhere in the codebase
(used by check_fsdp_torch_version, sample_packing validators, etc.).
- tests: temporarily flip propagate=True on the `axolotl` logger so
pytest caplog can see the warnings. axolotl.cli.configure_logging()
sets propagate=False at import time, which is the documented reason
the assertions were failing in CI even though the warnings were
firing visibly in stdout.
- comment: replace multi-line rationale near the fp32_norms helpers with
a one-line summary (the longer version lives in the PR description).
Signed-off-by: Wing Lian <wing@axolotl.ai>
…ertion
The existing fp32_norms tests are pure-CPU and monkeypatch fully_shard —
they cover the matcher logic and validator guard rails but never exercise
the actual FSDP2 path that motivated this PR.
Adds tests/e2e/multigpu/test_fsdp2_fp32_norms.py: spawns a 2-GPU
`axolotl train` subprocess with `fp32_norms: true` + `fsdp_version: 2` +
`bf16: true` on tiny-qwen3-129m (full FT, 2 steps) and asserts:
1. Training completes — the original FSDP1 flat-param dtype crash
can't recur because we're on FSDP2 with the per-module
MixedPrecisionPolicy.
2. All RMSNorm params are float32 after step 1 — captured via a
test-only TrainerCallback in
tests/e2e/multigpu/_fp32_norms_dtype_capture.py, dumped to JSON
at $FP32_NORMS_DTYPE_DUMP_PATH on rank 0.
3. At least one non-norm param is bfloat16 — proves the two FSDP2
MixedPrecisionPolicy groups are independent (catches a silent
globally-fp32 fallback that would technically satisfy assertion 2
but defeat the point of the feature).
The dtype-capture plugin is plumbed in via the test's yaml `plugins:`
list, with PYTHONPATH=<repo_root> on the subprocess env so the
tests.e2e.multigpu._fp32_norms_dtype_capture module resolves.
Signed-off-by: Wing Lian <wing@axolotl.ai>
8ab7c69 to
3eeb200
Compare
Summary
Adds two config fields and a helper so users can keep norm modules in fp32 while training the rest of the model in bf16/fp16 under FSDP2. The mechanism is per-module
fully_shardwith a dedicatedMixedPrecisionPolicy(param_dtype=fp32, reduce_dtype=fp32); once a norm is sharded, the surrounding decoder-layer wrap treats it as a boundary and doesn't fold it into the bf16 flat-param group.Motivation
Surfaced during full-FT continued pretraining on
arcee-ai/Trinity-Mini-Base(AFMOE, 26B-A4B MoE) on 8×H200. Two paths, both broken:modeling_afmoe.py: clearing_keep_in_fp32_modules = []to get the model loading produces 5–60× worse held-out PPL (PPL 24–346 vs 1.3–6 on the same prompts). The plain_keep_in_fp32_modulesattribute only fires for fp16 in current transformers, so on the bf16 path norms silently drop to bf16 even with the attribute set.ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32— FSDP1's flat-param dtype uniformity constraint hitting fp32 norms inside an otherwise-bf16 decoder layer.Any model declaring fp32 norms (Llama, Mistral, Qwen3, etc.) hits the same FSDP1 wall. FSDP2's per-module
MixedPrecisionPolicyis the supported escape hatch; this PR wires it up via config so users don't have to hand-writefully_shardcalls.Changes
src/axolotl/loaders/model.py: addshard_norms_fp32,_matches_norm_class,DEFAULT_FP32_NORM_SUFFIXES, plus a call site inModelLoader.load()gated oncfg.fp32_norms. (The handoff namedutils/models.py; that file no longer exists — model loading now lives inloaders/model.py.)src/axolotl/utils/schemas/config.py: addfp32_norms: bool | Noneandfp32_norm_classes: list[str] | Nonefields with acheck_fp32_normsmodel-validator that requiresfsdp_version: 2.tests/test_fp32_norms.py: 10 pure-CPU unit tests (matcher + guard-rail).tests/utils/schemas/validation/test_fsdp.py: 3 new schema-validator tests.docs/mixed_precision.qmd: "Keeping norms in fp32 (FSDP2)" subsection.Matching semantics
Patterns without
.match as a suffix againsttype(module).__name__. Catches the whole RMSNorm family (LlamaRMSNorm,Qwen3RMSNorm,AfmoeRMSNorm,MistralRMSNorm) without enumeration and also catchesnn.LayerNorm.Patterns containing
.match the fully qualifiedf"{module.__module__}.{cls_name}". Use this when disambiguating same-named norms across model families or pinning atrust_remote_codemodeling file precisely.Scope (deliberately minimal)
fp32_normswith FSDP1 or DeepSpeed raises in the validator. An FSDP1 path is possible (custom wrap policy that segregates norms) but deferred.shard_norms_fp32runs, it raises with a clear message. The common cause iscpu_ram_efficient_loading: true; follow-up PR can add a post-materialization hook.Validator behavior
fp32_norms: true+fsdp_version: 2validates ✓fp32_norms: true+fsdp_version: 1→ValueError("fp32_norms requires fsdp_version: 2 …")fp32_norm_classesset +fp32_normsunset →LOG.warning("… will be ignored.")Testing
python3 -m pytest tests/test_fp32_norms.py tests/utils/schemas/validation/test_fsdp.py— 27 passedtest_config_validators.py/test_default_values.py) — 40 passedpre-commit run --from-ref origin/main --to-ref HEAD— all 8 hooks pass (ruff, ruff-format, mypy, bandit, etc.)modeling_afmoe.pyis updated for transformers 5.x.Out of scope / explicit non-goals
cpu_ram_efficient_loading: true_keep_in_fp32_modulesSummary by CodeRabbit
Release Notes
New Features
fp32_normsandfp32_norm_classessettings.Documentation
Tests