Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/mixed_precision.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,26 @@ bf16: true
bf16: full # Equivalent to bf16_full_eval in the HF trainer
```

### Keeping norms in fp32 (FSDP2) {#sec-fp32-norms}

Some models declare RMSNorm/LayerNorm layers as fp32 for training
stability — the variance computation in RMSNorm is numerically poor in
bf16, and the learned gain γ quantizes harshly. With FSDP1 this fights
the flat-param dtype uniformity constraint; with FSDP2 each norm can have
its own `MixedPrecisionPolicy`. Enable with:

```{.yaml}
fsdp_version: 2
fp32_norms: true
# fp32_norm_classes: # optional override
# - RMSNorm
# - LayerNorm
```

Defaults match any class whose name ends in `RMSNorm` or `LayerNorm`. Use
fully qualified names (`module.path.ClassName`) to pin a specific
implementation.

## FP8 Mixed Precision {#sec-fp8}

::: {.callout-note}
Expand Down
13 changes: 12 additions & 1 deletion src/axolotl/loaders/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@
get_device_count,
get_device_type,
)
from axolotl.utils.fp32_norms import (
_matches_norm_class,
get_fp32_norm_patterns,
tag_model_fp32_norms,
)
from axolotl.utils.logging import get_logger
from axolotl.utils.model_shard_quant import load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType
Expand Down Expand Up @@ -191,6 +196,9 @@ def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | Non
self.patch_manager.apply_post_model_load_patches(self.model)
PLUGIN_MANAGER.post_model_load(self.cfg, self.model)

if self.cfg.fp32_norms:
tag_model_fp32_norms(self.model, self.cfg)

return self.model, lora_config

def _apply_pre_model_load_setup(self):
Expand Down Expand Up @@ -911,8 +919,11 @@ def _convert_embedding_modules_dtype(
dest = {"dtype": dist_dtype}
if self.cfg.lora_on_cpu:
dest["device"] = "cpu"
fp32_norm_patterns = get_fp32_norm_patterns(self.cfg)
for name, module in self.model.named_modules():
if "norm" in name:
if fp32_norm_patterns and _matches_norm_class(module, fp32_norm_patterns):
module.to(torch.float32)
elif "norm" in name:
module.to(dist_dtype)
if before_kbit_train_or_finetune:
if name.endswith(".gate"):
Expand Down
9 changes: 9 additions & 0 deletions src/axolotl/monkeypatch/accelerate/fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import nn

from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.fp32_norms import get_fp32_norm_patterns, shard_norms_fp32
from axolotl.utils.logging import get_logger

LOG = get_logger(__name__)
Expand Down Expand Up @@ -426,6 +427,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:

auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
log_bias_dtype_mismatch = False
fp32_norm_patterns = get_fp32_norm_patterns(model)
if fp32_norm_patterns:
shard_norms_fp32(
model,
patterns=fp32_norm_patterns,
fully_shard_kwargs=fsdp2_kwargs,
)

if auto_wrap_policy is not None:
for module in get_module_children_bottom_up(model)[:-1]:
if is_peft_model and isinstance(module, LoraLayer):
Expand Down
135 changes: 135 additions & 0 deletions src/axolotl/utils/fp32_norms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Helpers for keeping selected norm modules in fp32 under FSDP2."""

from __future__ import annotations

from typing import Any, Sequence

import torch

from axolotl.utils.logging import get_logger

LOG = get_logger(__name__)

DEFAULT_FP32_NORM_SUFFIXES: tuple[str, ...] = ("RMSNorm", "LayerNorm")


def _matches_norm_class(module: "torch.nn.Module", patterns: Sequence[str]) -> bool:
"""Match a module against class-name patterns.

Two matching modes, chosen per-pattern by presence of a dot:
- Fully qualified (contains "."): matches f"{module.__module__}.{cls}" exactly.
- Suffix (no dot): matches type(module).__name__.endswith(pattern).
Empty / whitespace-only patterns are skipped (``cls_name.endswith("")``
is True for every class, which would silently match everything).
"""
cls = type(module)
cls_name = cls.__name__
qualified = f"{cls.__module__}.{cls_name}"
for pattern in patterns:
if not pattern or not pattern.strip():
continue
if "." in pattern:
if qualified == pattern:
return True
elif cls_name.endswith(pattern):
return True
return False


def get_fp32_norm_patterns(source) -> list[str] | None:
"""Resolve configured fp32 norm patterns from a config or tagged model."""
tagged_patterns = getattr(source, "_axolotl_fp32_norm_patterns", None)
if tagged_patterns is not None:
return list(tagged_patterns)

if not getattr(source, "fp32_norms", False):
return None

configured_patterns = getattr(source, "fp32_norm_classes", None)
if configured_patterns:
return list(configured_patterns)

return list(DEFAULT_FP32_NORM_SUFFIXES)


def tag_model_fp32_norms(model: "torch.nn.Module", cfg) -> list[str] | None:
"""Attach the resolved fp32 norm patterns to the model for FSDP2 prepare."""
patterns = get_fp32_norm_patterns(cfg)
if patterns is None:
if hasattr(model, "_axolotl_fp32_norm_patterns"):
delattr(model, "_axolotl_fp32_norm_patterns")
return None

model._axolotl_fp32_norm_patterns = list(patterns)
return patterns


def shard_norms_fp32(
model: "torch.nn.Module",
source=None,
*,
patterns: Sequence[str] | None = None,
fully_shard_kwargs: dict[str, Any] | None = None,
) -> int:
"""Wrap matching norm modules with FSDP2 + fp32 MixedPrecisionPolicy."""
if source is not None and not getattr(source, "fp32_norms", False):
return 0

if source is not None and getattr(source, "fsdp_version", None) != 2:
raise ValueError(
"fp32_norms requires fsdp_version: 2. FSDP1 enforces flat-param "
"dtype uniformity within each wrap group, which is incompatible "
"with keeping norms in fp32 while the rest of the layer is bf16."
)

patterns = (
list(patterns)
if patterns is not None
else get_fp32_norm_patterns(source or model)
)
if not patterns:
return 0

from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard

outer_policy = (fully_shard_kwargs or {}).get("mp_policy")
output_dtype = getattr(outer_policy, "param_dtype", None)
fp32_policy = MixedPrecisionPolicy(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
output_dtype=output_dtype,
)

matches = [
(name, module)
for name, module in model.named_modules()
if _matches_norm_class(module, patterns)
]

if not matches:
LOG.warning(
"fp32_norms enabled but no modules matched patterns %s. Check "
"fp32_norm_classes against the model's actual norm class names.",
patterns,
)
return 0

shard_kwargs = dict(fully_shard_kwargs or {})
shard_kwargs["mp_policy"] = fp32_policy

for _name, module in matches:
for param in module.parameters(recurse=False):
param.data = param.data.to(torch.float32)
for buffer in module.buffers(recurse=False):
if buffer.dtype.is_floating_point:
buffer.data = buffer.data.to(torch.float32)
fully_shard(module, **shard_kwargs)

LOG.info(
"Sharded %d norm modules with fp32 MixedPrecisionPolicy "
"(patterns=%s, output_dtype=%s)",
len(matches),
patterns,
output_dtype,
)
return len(matches)
45 changes: 45 additions & 0 deletions src/axolotl/utils/schemas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,27 @@ class AxolotlInputConfig(
default=None,
json_schema_extra={"description": "FSDP version"},
)
fp32_norms: bool | None = Field(
default=None,
json_schema_extra={
"description": (
"Keep norm modules (RMSNorm/LayerNorm) in fp32 by sharding them "
"under their own FSDP2 MixedPrecisionPolicy. Requires fsdp_version: 2."
)
},
)
fp32_norm_classes: list[str] | None = Field(
default=None,
json_schema_extra={
"description": (
"Class-name patterns to match for fp32 norm sharding. Patterns "
"without a '.' match against type(module).__name__ as a suffix. "
"Patterns containing a '.' match the fully qualified class path "
"exactly. Defaults to ['RMSNorm', 'LayerNorm'] when fp32_norms is "
"true and this is unset."
)
},
)
fsdp_final_state_dict_type: (
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
) = Field(
Expand Down Expand Up @@ -1477,6 +1498,30 @@ def validate_attn_implementation(cls, value):
f"path containing '/'."
)

@model_validator(mode="after")
def check_fp32_norms(self):
if self.fp32_norms:
# FSDP must actually be configured — fsdp_version alone is not
# sufficient since the rest of axolotl treats fsdp_config as the
# canonical "is_fsdp" signal.
if self.fsdp_config is None:
raise ValueError(
"fp32_norms requires FSDP to be enabled "
"(fsdp_config block must be set)."
)
if str(self.fsdp_version) != "2":
raise ValueError(
"fp32_norms requires fsdp_version: 2. FSDP1's flat-param "
"dtype uniformity constraint is incompatible with keeping "
"norms in fp32 while decoder layers run in bf16."
)
if self.fp32_norm_classes and not self.fp32_norms:
LOG.warning(
"fp32_norm_classes is set but fp32_norms is not enabled; "
"it will be ignored."
)
return self
Comment thread
winglian marked this conversation as resolved.

@model_validator(mode="after")
def check_sageattn_wo_sample_packing(self):
if (
Expand Down
58 changes: 58 additions & 0 deletions tests/e2e/multigpu/_fp32_norms_dtype_capture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Test-only plugin that captures param dtypes after the first optimizer step
and dumps them as JSON to ``$FP32_NORMS_DTYPE_DUMP_PATH``.

Loaded via ``plugins: [tests.e2e.multigpu._fp32_norms_dtype_capture.DtypeCapturePlugin]``
in the test yaml config; the dump path is the contract between the subprocess
and the outer pytest function. Rank 0 only — dtype is identical across ranks.
"""

from __future__ import annotations

import json
import os

import torch
from transformers.trainer_callback import TrainerCallback

from axolotl.integrations.base import BasePlugin


def _dtype_name(dtype: torch.dtype) -> str:
return str(dtype).removeprefix("torch.")


class _DtypeCaptureCallback(TrainerCallback):
"""Capture norm vs non-norm param dtypes after step 1, dump to JSON, exit."""

def on_step_end(self, args, state, control, model=None, **kwargs): # type: ignore[override]
if state.global_step != 1 or model is None:
return
# Rank 0 only — every rank sees the same dtype info under FSDP2.
if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
return
dump_path = os.environ.get("FP32_NORMS_DTYPE_DUMP_PATH")
if not dump_path:
return

norm_dtypes: dict[str, str] = {}
non_norm_dtypes: dict[str, str] = {}
for name, param in model.named_parameters():
entry = (name, _dtype_name(param.dtype))
if "norm" in name.lower():
norm_dtypes[entry[0]] = entry[1]
else:
non_norm_dtypes[entry[0]] = entry[1]

with open(dump_path, "w", encoding="utf-8") as fout:
json.dump(
{"norms": norm_dtypes, "non_norms": non_norm_dtypes},
fout,
indent=2,
)


class DtypeCapturePlugin(BasePlugin):
"""Plugin that registers :class:`_DtypeCaptureCallback` with the trainer."""

def add_callbacks_pre_trainer(self, cfg, model): # type: ignore[override]
return [_DtypeCaptureCallback()]
Loading
Loading