-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
feat(fsdp2): add fp32_norms for keeping RMSNorm/LayerNorm in fp32 #3670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
e55ee3c
feat(fsdp2): add fp32_norms for keeping RMSNorm/LayerNorm in fp32
winglian 0beee03
fixup! feat(fsdp2): address review findings + fix CI caplog assertions
winglian 9a2c37a
test(fsdp2): multi-GPU e2e for fp32_norms with dtype-preservation ass…
winglian 3eeb200
chore: lint
winglian File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| """Helpers for keeping selected norm modules in fp32 under FSDP2.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Any, Sequence | ||
|
|
||
| import torch | ||
|
|
||
| from axolotl.utils.logging import get_logger | ||
|
|
||
| LOG = get_logger(__name__) | ||
|
|
||
| DEFAULT_FP32_NORM_SUFFIXES: tuple[str, ...] = ("RMSNorm", "LayerNorm") | ||
|
|
||
|
|
||
| def _matches_norm_class(module: "torch.nn.Module", patterns: Sequence[str]) -> bool: | ||
| """Match a module against class-name patterns. | ||
|
|
||
| Two matching modes, chosen per-pattern by presence of a dot: | ||
| - Fully qualified (contains "."): matches f"{module.__module__}.{cls}" exactly. | ||
| - Suffix (no dot): matches type(module).__name__.endswith(pattern). | ||
| Empty / whitespace-only patterns are skipped (``cls_name.endswith("")`` | ||
| is True for every class, which would silently match everything). | ||
| """ | ||
| cls = type(module) | ||
| cls_name = cls.__name__ | ||
| qualified = f"{cls.__module__}.{cls_name}" | ||
| for pattern in patterns: | ||
| if not pattern or not pattern.strip(): | ||
| continue | ||
| if "." in pattern: | ||
| if qualified == pattern: | ||
| return True | ||
| elif cls_name.endswith(pattern): | ||
| return True | ||
| return False | ||
|
|
||
|
|
||
| def get_fp32_norm_patterns(source) -> list[str] | None: | ||
| """Resolve configured fp32 norm patterns from a config or tagged model.""" | ||
| tagged_patterns = getattr(source, "_axolotl_fp32_norm_patterns", None) | ||
| if tagged_patterns is not None: | ||
| return list(tagged_patterns) | ||
|
|
||
| if not getattr(source, "fp32_norms", False): | ||
| return None | ||
|
|
||
| configured_patterns = getattr(source, "fp32_norm_classes", None) | ||
| if configured_patterns: | ||
| return list(configured_patterns) | ||
|
|
||
| return list(DEFAULT_FP32_NORM_SUFFIXES) | ||
|
|
||
|
|
||
| def tag_model_fp32_norms(model: "torch.nn.Module", cfg) -> list[str] | None: | ||
| """Attach the resolved fp32 norm patterns to the model for FSDP2 prepare.""" | ||
| patterns = get_fp32_norm_patterns(cfg) | ||
| if patterns is None: | ||
| if hasattr(model, "_axolotl_fp32_norm_patterns"): | ||
| delattr(model, "_axolotl_fp32_norm_patterns") | ||
| return None | ||
|
|
||
| model._axolotl_fp32_norm_patterns = list(patterns) | ||
| return patterns | ||
|
|
||
|
|
||
| def shard_norms_fp32( | ||
| model: "torch.nn.Module", | ||
| source=None, | ||
| *, | ||
| patterns: Sequence[str] | None = None, | ||
| fully_shard_kwargs: dict[str, Any] | None = None, | ||
| ) -> int: | ||
| """Wrap matching norm modules with FSDP2 + fp32 MixedPrecisionPolicy.""" | ||
| if source is not None and not getattr(source, "fp32_norms", False): | ||
| return 0 | ||
|
|
||
| if source is not None and getattr(source, "fsdp_version", None) != 2: | ||
| raise ValueError( | ||
| "fp32_norms requires fsdp_version: 2. FSDP1 enforces flat-param " | ||
| "dtype uniformity within each wrap group, which is incompatible " | ||
| "with keeping norms in fp32 while the rest of the layer is bf16." | ||
| ) | ||
|
|
||
| patterns = ( | ||
| list(patterns) | ||
| if patterns is not None | ||
| else get_fp32_norm_patterns(source or model) | ||
| ) | ||
| if not patterns: | ||
| return 0 | ||
|
|
||
| from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard | ||
|
|
||
| outer_policy = (fully_shard_kwargs or {}).get("mp_policy") | ||
| output_dtype = getattr(outer_policy, "param_dtype", None) | ||
| fp32_policy = MixedPrecisionPolicy( | ||
| param_dtype=torch.float32, | ||
| reduce_dtype=torch.float32, | ||
| output_dtype=output_dtype, | ||
| ) | ||
|
|
||
| matches = [ | ||
| (name, module) | ||
| for name, module in model.named_modules() | ||
| if _matches_norm_class(module, patterns) | ||
| ] | ||
|
|
||
| if not matches: | ||
| LOG.warning( | ||
| "fp32_norms enabled but no modules matched patterns %s. Check " | ||
| "fp32_norm_classes against the model's actual norm class names.", | ||
| patterns, | ||
| ) | ||
| return 0 | ||
|
|
||
| shard_kwargs = dict(fully_shard_kwargs or {}) | ||
| shard_kwargs["mp_policy"] = fp32_policy | ||
|
|
||
| for _name, module in matches: | ||
| for param in module.parameters(recurse=False): | ||
| param.data = param.data.to(torch.float32) | ||
| for buffer in module.buffers(recurse=False): | ||
| if buffer.dtype.is_floating_point: | ||
| buffer.data = buffer.data.to(torch.float32) | ||
| fully_shard(module, **shard_kwargs) | ||
|
|
||
| LOG.info( | ||
| "Sharded %d norm modules with fp32 MixedPrecisionPolicy " | ||
| "(patterns=%s, output_dtype=%s)", | ||
| len(matches), | ||
| patterns, | ||
| output_dtype, | ||
| ) | ||
| return len(matches) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| """Test-only plugin that captures param dtypes after the first optimizer step | ||
| and dumps them as JSON to ``$FP32_NORMS_DTYPE_DUMP_PATH``. | ||
|
|
||
| Loaded via ``plugins: [tests.e2e.multigpu._fp32_norms_dtype_capture.DtypeCapturePlugin]`` | ||
| in the test yaml config; the dump path is the contract between the subprocess | ||
| and the outer pytest function. Rank 0 only — dtype is identical across ranks. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import json | ||
| import os | ||
|
|
||
| import torch | ||
| from transformers.trainer_callback import TrainerCallback | ||
|
|
||
| from axolotl.integrations.base import BasePlugin | ||
|
|
||
|
|
||
| def _dtype_name(dtype: torch.dtype) -> str: | ||
| return str(dtype).removeprefix("torch.") | ||
|
|
||
|
|
||
| class _DtypeCaptureCallback(TrainerCallback): | ||
| """Capture norm vs non-norm param dtypes after step 1, dump to JSON, exit.""" | ||
|
|
||
| def on_step_end(self, args, state, control, model=None, **kwargs): # type: ignore[override] | ||
| if state.global_step != 1 or model is None: | ||
| return | ||
| # Rank 0 only — every rank sees the same dtype info under FSDP2. | ||
| if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0: | ||
| return | ||
| dump_path = os.environ.get("FP32_NORMS_DTYPE_DUMP_PATH") | ||
| if not dump_path: | ||
| return | ||
|
|
||
| norm_dtypes: dict[str, str] = {} | ||
| non_norm_dtypes: dict[str, str] = {} | ||
| for name, param in model.named_parameters(): | ||
| entry = (name, _dtype_name(param.dtype)) | ||
| if "norm" in name.lower(): | ||
| norm_dtypes[entry[0]] = entry[1] | ||
| else: | ||
| non_norm_dtypes[entry[0]] = entry[1] | ||
|
|
||
| with open(dump_path, "w", encoding="utf-8") as fout: | ||
| json.dump( | ||
| {"norms": norm_dtypes, "non_norms": non_norm_dtypes}, | ||
| fout, | ||
| indent=2, | ||
| ) | ||
|
|
||
|
|
||
| class DtypeCapturePlugin(BasePlugin): | ||
| """Plugin that registers :class:`_DtypeCaptureCallback` with the trainer.""" | ||
|
|
||
| def add_callbacks_pre_trainer(self, cfg, model): # type: ignore[override] | ||
| return [_DtypeCaptureCallback()] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.