Create FSDPConfig schema#3157
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughRefactors AxolotlInputConfig to use a typed FSDPConfig instead of dict-based config, removing several deprecated FSDP fields and the prior torch-version validator. Adds a new FSDP schema module defining FSDPConfig with numerous FSDP-related fields and a central validator enforcing torch version compatibility for FSDP2. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks (3 passed)✅ Passed checks (3 passed)
Pre-merge checks (3 passed)✅ Passed checks (3 passed)
✨ Finishing touches🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
src/axolotl/utils/schemas/fsdp.py (2)
28-33: Two overlapping fields for state-dict type; unify source of truth.Both
final_state_dict_typeandstate_dict_typeexist and accept the same literals, which is confusing. Keep one canonical field and alias the other during validation to avoid contradictory inputs.Minimal aliasing without breaking configs:
@@ - final_state_dict_type: ( + final_state_dict_type: ( Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None ) = Field( default=None, deprecated="Configuring FSDP final state dict type using `fsdp_final_state_dict_type` is deprecated. Please use `fsdp_config.final_state_dict_type` instead.", ) @@ - state_dict_type: ( - Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None - ) = Field(default=None, description="Type of state dict to use for FSDP.") + state_dict_type: ( + Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None + ) = Field(default=None, description="Deprecated alias; use final_state_dict_type.") + + @model_validator(mode="before") + @classmethod + def _alias_state_dict_type(cls, data): + if isinstance(data, dict): + if data.get("state_dict_type") and not data.get("final_state_dict_type"): + data["final_state_dict_type"] = data["state_dict_type"] + return dataAlso applies to: 46-48
21-27: Avoid performance claims in schema descriptions.“Improve training speed by 10–15%” is context-sensitive and may not hold. Consider neutral wording to prevent misleading docs.
- "description": "Enable FSDP float8 all-gather optimization for FP8 training. Can " - "improve training speed by 10-15% when FSDP is enabled." + "description": "Enable FSDP float8 all-gather optimization for FP8 training."
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/axolotl/utils/schemas/config.py(2 hunks)src/axolotl/utils/schemas/fsdp.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/axolotl/utils/schemas/fsdp.py (1)
src/axolotl/utils/logging.py (1)
get_logger(42-49)
src/axolotl/utils/schemas/config.py (1)
src/axolotl/utils/schemas/fsdp.py (1)
FSDPConfig(11-71)
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/fsdp.py
69-69: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/utils/schemas/fsdp.py (1)
12-16: Deprecatedfsdpfield inside FSDPConfig is confusing.Deprecation text says “use fsdp_config instead,” but we are already inside
FSDPConfig. Either remove this field or clarify that it’s a legacy passthrough for old top-levelfsdpconfigs.Would you like me to add a validator that errors if this inner
fsdpis set, with guidance on the new layout?src/axolotl/utils/schemas/config.py (1)
27-27: Import wiring looks good.Using a typed schema here improves validation, docs, and IDE help.
| fsdp_config: FSDPConfig | None = Field( | ||
| default=None, json_schema_extra={"description": "FSDP configuration options"} | ||
| ) |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Typed FSDP config: good move. Also fix downstream validators that still read top-level fsdp_version.
This change makes fsdp_version nested. Validators below still do str(data.get("fsdp_version")) == "2", which will now miss FSDP2 and mis-gate LoRA kernels on multi-GPU.
Patch both validators to read the nested version:
@@ def check_multigpu_lora_kernels(cls, data):
- capabilities = data.get("capabilities")
- is_fsdp = data.get("fsdp_config") is not None
- is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
+ capabilities = data.get("capabilities")
+ fsdp_cfg = data.get("fsdp_config")
+ is_fsdp = fsdp_cfg is not None
+ fsdp_version = (
+ fsdp_cfg.get("fsdp_version")
+ if isinstance(fsdp_cfg, dict)
+ else getattr(fsdp_cfg, "fsdp_version", None)
+ )
+ is_fsdp2 = is_fsdp and str(fsdp_version) == "2"
@@ def check_auto_enable_lora_kernels(cls, data):
- capabilities = data.get("capabilities")
- is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
- is_fsdp = data.get("fsdp_config") is not None
- is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
+ capabilities = data.get("capabilities")
+ is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
+ fsdp_cfg = data.get("fsdp_config")
+ is_fsdp = fsdp_cfg is not None
+ fsdp_version = (
+ fsdp_cfg.get("fsdp_version")
+ if isinstance(fsdp_cfg, dict)
+ else getattr(fsdp_cfg, "fsdp_version", None)
+ )
+ is_fsdp2 = is_fsdp and str(fsdp_version) == "2"Run to find any remaining top-level uses:
🏁 Script executed:
#!/bin/bash
rg -n -C2 -g 'src/**' -e '\bfsdp_version\b' -e '\bfsdp_final_state_dict_type\b' -e '\bfp8_enable_fsdp_float8_all_gather\b' | grep -v 'schemas/fsdp.py'Length of output: 9737
Fix all top-level fsdp_version checks — read version from fsdp_config (critical)
fsdp_version was nested into fsdp_config; existing top-level checks (data.get("fsdp_version") / cfg.fsdp_version / str(...)) will miss FSDP2 and mis-gate runtime behavior. Patch the two validators below and every other occurrence to read the nested value (handle dict or typed object, optionally fall back to top-level for compatibility).
Patch for the two validators (unchanged from suggestion):
@@ def check_multigpu_lora_kernels(cls, data):
- capabilities = data.get("capabilities")
- is_fsdp = data.get("fsdp_config") is not None
- is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
+ capabilities = data.get("capabilities")
+ fsdp_cfg = data.get("fsdp_config")
+ is_fsdp = fsdp_cfg is not None
+ fsdp_version = (
+ fsdp_cfg.get("fsdp_version")
+ if isinstance(fsdp_cfg, dict)
+ else getattr(fsdp_cfg, "fsdp_version", None)
+ )
+ is_fsdp2 = is_fsdp and str(fsdp_version) == "2"
@@ def check_auto_enable_lora_kernels(cls, data):
- capabilities = data.get("capabilities")
- is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
- is_fsdp = data.get("fsdp_config") is not None
- is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
+ capabilities = data.get("capabilities")
+ is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
+ fsdp_cfg = data.get("fsdp_config")
+ is_fsdp = fsdp_cfg is not None
+ fsdp_version = (
+ fsdp_cfg.get("fsdp_version")
+ if isinstance(fsdp_cfg, dict)
+ else getattr(fsdp_cfg, "fsdp_version", None)
+ )
+ is_fsdp2 = is_fsdp and str(fsdp_version) == "2"Other places found by the grep that must be updated to use the same pattern (update logic or centralize a helper to extract fsdp_version from fsdp_config):
- src/axolotl/utils/trainer.py: line ~584 —
if str(cfg.fsdp_version) == "2": - src/axolotl/utils/schemas/config.py: lines ~1097–1101 and ~1133–1136 — (the two validator sites above; diff applied)
- src/axolotl/utils/schemas/validation.py: multiple locations (~417, ~810, ~822–827, ~838–846, ~855–861, ~891–907) — update data.get/self.fsdp_version checks and the migration/warning in check_fsdp_version_in_fsdp_config (it currently pops nested fsdp_version into top-level; reverse or remove this migration).
- src/axolotl/loaders/patch_manager.py: lines ~112, ~119, ~179, ~286 —
str(self.cfg.fsdp_version) == "2"checks - src/axolotl/loaders/model.py: lines ~203, ~730 —
self.cfg.fsdp_versioncomparisons - src/axolotl/monkeypatch/accelerate/parallelism_config.py: line ~69 —
self.fsdp_plugin.fsdp_version == 2
Actionable: replace direct top-level reads with a small centralized getter (or inline pattern above) that:
- reads fsdp_cfg = cfg.fsdp_config (or data.get("fsdp_config"))
- if dict -> fsdp_cfg.get("fsdp_version"); else -> getattr(fsdp_cfg, "fsdp_version", None)
- fallback to getattr(cfg, "fsdp_version", None) only if backward-compat needed
use str(...) == "2" or int compare consistently once value resolved.
Committable suggestion skipped: line range outside the PR's diff.
| fsdp_version: int | None = Field( | ||
| default=None, | ||
| json_schema_extra={"description": "FSDP version"}, | ||
| ) |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Constrain fsdp_version to known values.
Limit to Literal[1, 2] to catch typos early and avoid string/int ambiguity downstream.
- fsdp_version: int | None = Field(
+ fsdp_version: Literal[1, 2] | None = Field(
default=None,
json_schema_extra={"description": "FSDP version"},
)Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/axolotl/utils/schemas/fsdp.py around lines 17-20, the fsdp_version field
is currently typed as int | None; change its annotation to Literal[1, 2] | None
to restrict allowed values to 1 or 2 and prevent typos, and add the required
import for Literal (from typing import Literal) if not already present; keep the
Field(...) call and existing json_schema_extra, ensuring the default remains
None so Optional Literal values are allowed.
| @model_validator(mode="before") | ||
| @classmethod | ||
| def check_fsdp_torch_version(cls, data): | ||
| env_capabilities = data.get("env_capabilities", {}) | ||
| torch_version = env_capabilities.get("torch_version") | ||
|
|
||
| if torch_version is None: | ||
| import torch | ||
|
|
||
| torch_version = str(torch.__version__).split("+", maxsplit=1)[0] | ||
|
|
||
| if data.get("fsdp_config") and str(data.get("fsdp_version")) == "2": | ||
| if version.parse(torch_version) < version.parse("2.7.0"): | ||
| raise ValueError("FSDP2 is not supported on torch version < 2.7.0") | ||
|
|
||
| return data |
There was a problem hiding this comment.
FSDP2 torch-version gate never triggers (wrong keys in validator).
Inside FSDPConfig, data does not contain fsdp_config, so if data.get("fsdp_config") ... is always false and the FSDP2 check is skipped. Also comparing str(data.get("fsdp_version")) == "2" is brittle. Fix by checking fsdp_version directly and handling both int/str. Shorten the error to satisfy TRY003.
Apply:
@@
- def check_fsdp_torch_version(cls, data):
- env_capabilities = data.get("env_capabilities", {})
- torch_version = env_capabilities.get("torch_version")
-
- if torch_version is None:
- import torch
-
- torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
-
- if data.get("fsdp_config") and str(data.get("fsdp_version")) == "2":
- if version.parse(torch_version) < version.parse("2.7.0"):
- raise ValueError("FSDP2 is not supported on torch version < 2.7.0")
-
- return data
+ def check_fsdp_torch_version(cls, data):
+ # derive torch version from env_capabilities (if provided) or torch.__version__
+ torch_version = None
+ if isinstance(data, dict):
+ env_capabilities = data.get("env_capabilities") or {}
+ torch_version = env_capabilities.get("torch_version")
+ if torch_version is None:
+ try:
+ import torch
+ torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
+ except Exception:
+ return data # torch not importable; skip
+
+ v = data.get("fsdp_version") if isinstance(data, dict) else None
+ if v in (2, "2"):
+ if version.parse(torch_version) < version.parse("2.7.0"):
+ raise ValueError("FSDP2 requires torch>=2.7.0")
+ return data📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @model_validator(mode="before") | |
| @classmethod | |
| def check_fsdp_torch_version(cls, data): | |
| env_capabilities = data.get("env_capabilities", {}) | |
| torch_version = env_capabilities.get("torch_version") | |
| if torch_version is None: | |
| import torch | |
| torch_version = str(torch.__version__).split("+", maxsplit=1)[0] | |
| if data.get("fsdp_config") and str(data.get("fsdp_version")) == "2": | |
| if version.parse(torch_version) < version.parse("2.7.0"): | |
| raise ValueError("FSDP2 is not supported on torch version < 2.7.0") | |
| return data | |
| @model_validator(mode="before") | |
| @classmethod | |
| def check_fsdp_torch_version(cls, data): | |
| # derive torch version from env_capabilities (if provided) or torch.__version__ | |
| torch_version = None | |
| if isinstance(data, dict): | |
| env_capabilities = data.get("env_capabilities") or {} | |
| torch_version = env_capabilities.get("torch_version") | |
| if torch_version is None: | |
| try: | |
| import torch | |
| torch_version = str(torch.__version__).split("+", maxsplit=1)[0] | |
| except Exception: | |
| return data # torch not importable; skip | |
| v = data.get("fsdp_version") if isinstance(data, dict) else None | |
| if v in (2, "2"): | |
| if version.parse(torch_version) < version.parse("2.7.0"): | |
| raise ValueError("FSDP2 requires torch>=2.7.0") | |
| return data |
🧰 Tools
🪛 Ruff (0.12.2)
69-69: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In src/axolotl/utils/schemas/fsdp.py around lines 56–71, the before-model
validator uses the wrong key so the FSDP2 torch-version gate never runs; change
the check to read fsdp_version directly from data (coerce to string or int so
both 2 and "2" work) and remove the unused fsdp_config check, then keep the
existing torch_version lookup/fallback and raise a short error like "FSDP v2
requires torch>=2.7.0" when version.parse(torch_version) <
version.parse("2.7.0").
FSDPConfig schema
|
I do worry that this PR breaks backwards compatibility of being able to handle the legacy |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
| "used in combination with torch.compile." | ||
| }, | ||
| ) | ||
| fp8_enable_fsdp_float8_all_gather: bool | None = Field( |
There was a problem hiding this comment.
This shouldn't have been moved
| fsdp_config: FSDPConfig | None = Field( | ||
| default=None, json_schema_extra={"description": "FSDP configuration options"} | ||
| ) | ||
| fsdp_version: int | None = Field( |
There was a problem hiding this comment.
This also should not have been moved
|
|
||
|
|
||
| class FSDPConfig(BaseModel): | ||
| model_config = ConfigDict(extra="allow") |
There was a problem hiding this comment.
What is this used for?
|
|
||
| class FSDPConfig(BaseModel): | ||
| model_config = ConfigDict(extra="allow") | ||
| fsdp: list[str] | None = Field( |
There was a problem hiding this comment.
This shouldn't be here
| cpu_ram_efficient_loading = None | ||
| if hasattr(fsdp_config, "cpu_ram_efficient_loading"): | ||
| cpu_ram_efficient_loading = fsdp_config.cpu_ram_efficient_loading | ||
| elif isinstance(fsdp_config, dict): |
There was a problem hiding this comment.
When would this be the case?
Movign pramas to FSDP_config
Motivation and Context
How has this been tested?
WIP
Summary by CodeRabbit
New Features
Refactor
Chores