Skip to content

Create FSDPConfig schema#3157

Closed
ved1beta wants to merge 11 commits into
axolotl-ai-cloud:mainfrom
ved1beta:Fsdp_config
Closed

Create FSDPConfig schema#3157
ved1beta wants to merge 11 commits into
axolotl-ai-cloud:mainfrom
ved1beta:Fsdp_config

Conversation

@ved1beta

@ved1beta ved1beta commented Sep 12, 2025

Copy link
Copy Markdown
Member

Movign pramas to FSDP_config

Motivation and Context

How has this been tested?

WIP

Summary by CodeRabbit

  • New Features

    • Centralized, structured FSDP configuration with clearer options and defaults.
    • Built-in validation to ensure FSDP2 is only used with supported PyTorch versions.
  • Refactor

    • Migrated from ad-hoc dict-based FSDP settings to a typed configuration model for improved clarity and reliability.
    • Removed several legacy FSDP options and deprecated fields to streamline configuration.
  • Chores

    • Cleaned up outdated validation logic and deprecated configuration paths, aligning settings under the new FSDP configuration.

@coderabbitai

coderabbitai Bot commented Sep 12, 2025

Copy link
Copy Markdown
Contributor

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

Refactors AxolotlInputConfig to use a typed FSDPConfig instead of dict-based config, removing several deprecated FSDP fields and the prior torch-version validator. Adds a new FSDP schema module defining FSDPConfig with numerous FSDP-related fields and a central validator enforcing torch version compatibility for FSDP2.

Changes

Cohort / File(s) Summary
Config schema refactor
src/axolotl/utils/schemas/config.py
Replaces `fsdp_config: dict[str, Any]
New FSDP schema module
src/axolotl/utils/schemas/fsdp.py
Adds FSDPConfig Pydantic model with FSDP-related fields (including deprecated aliases) and a class-level validator to enforce torch>=2.7.0 when fsdp_version == 2; introduces module logger.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • SalmanMohammadi

Pre-merge checks (3 passed)

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly reflects the main change—moving FSDP-related parameters into an FSDP_config—so it accurately summarizes the PR's primary intent, but it contains spelling errors ("Movign pramas") that reduce clarity and polish.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.

Pre-merge checks (3 passed)

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.
Title Check ✅ Passed The title "Create FSDPConfig schema" accurately and concisely describes the primary change in this PR — adding a typed FSDPConfig model and updating AxolotlInputConfig.fsdp_config to use it while removing legacy FSDP fields — so it is clear to a reviewer what the main intent is.
✨ Finishing touches
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
src/axolotl/utils/schemas/fsdp.py (2)

28-33: Two overlapping fields for state-dict type; unify source of truth.

Both final_state_dict_type and state_dict_type exist and accept the same literals, which is confusing. Keep one canonical field and alias the other during validation to avoid contradictory inputs.

Minimal aliasing without breaking configs:

@@
-    final_state_dict_type: (
+    final_state_dict_type: (
         Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
     ) = Field(
         default=None,
         deprecated="Configuring FSDP final state dict type using `fsdp_final_state_dict_type` is deprecated. Please use `fsdp_config.final_state_dict_type` instead.",
     )
@@
-    state_dict_type: (
-        Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
-    ) = Field(default=None, description="Type of state dict to use for FSDP.")
+    state_dict_type: (
+        Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
+    ) = Field(default=None, description="Deprecated alias; use final_state_dict_type.")
+
+    @model_validator(mode="before")
+    @classmethod
+    def _alias_state_dict_type(cls, data):
+        if isinstance(data, dict):
+            if data.get("state_dict_type") and not data.get("final_state_dict_type"):
+                data["final_state_dict_type"] = data["state_dict_type"]
+        return data

Also applies to: 46-48


21-27: Avoid performance claims in schema descriptions.

“Improve training speed by 10–15%” is context-sensitive and may not hold. Consider neutral wording to prevent misleading docs.

-            "description": "Enable FSDP float8 all-gather optimization for FP8 training. Can "
-            "improve training speed by 10-15% when FSDP is enabled."
+            "description": "Enable FSDP float8 all-gather optimization for FP8 training."
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 58d67bf and b7245c4.

📒 Files selected for processing (2)
  • src/axolotl/utils/schemas/config.py (2 hunks)
  • src/axolotl/utils/schemas/fsdp.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/axolotl/utils/schemas/fsdp.py (1)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
src/axolotl/utils/schemas/config.py (1)
src/axolotl/utils/schemas/fsdp.py (1)
  • FSDPConfig (11-71)
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/fsdp.py

69-69: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/utils/schemas/fsdp.py (1)

12-16: Deprecated fsdp field inside FSDPConfig is confusing.

Deprecation text says “use fsdp_config instead,” but we are already inside FSDPConfig. Either remove this field or clarify that it’s a legacy passthrough for old top-level fsdp configs.

Would you like me to add a validator that errors if this inner fsdp is set, with guidance on the new layout?

src/axolotl/utils/schemas/config.py (1)

27-27: Import wiring looks good.

Using a typed schema here improves validation, docs, and IDE help.

Comment on lines +660 to 662
fsdp_config: FSDPConfig | None = Field(
default=None, json_schema_extra={"description": "FSDP configuration options"}
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Typed FSDP config: good move. Also fix downstream validators that still read top-level fsdp_version.

This change makes fsdp_version nested. Validators below still do str(data.get("fsdp_version")) == "2", which will now miss FSDP2 and mis-gate LoRA kernels on multi-GPU.

Patch both validators to read the nested version:

@@ def check_multigpu_lora_kernels(cls, data):
-            capabilities = data.get("capabilities")
-            is_fsdp = data.get("fsdp_config") is not None
-            is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
+            capabilities = data.get("capabilities")
+            fsdp_cfg = data.get("fsdp_config")
+            is_fsdp = fsdp_cfg is not None
+            fsdp_version = (
+                fsdp_cfg.get("fsdp_version")
+                if isinstance(fsdp_cfg, dict)
+                else getattr(fsdp_cfg, "fsdp_version", None)
+            )
+            is_fsdp2 = is_fsdp and str(fsdp_version) == "2"
@@ def check_auto_enable_lora_kernels(cls, data):
-            capabilities = data.get("capabilities")
-            is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
-            is_fsdp = data.get("fsdp_config") is not None
-            is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
+            capabilities = data.get("capabilities")
+            is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
+            fsdp_cfg = data.get("fsdp_config")
+            is_fsdp = fsdp_cfg is not None
+            fsdp_version = (
+                fsdp_cfg.get("fsdp_version")
+                if isinstance(fsdp_cfg, dict)
+                else getattr(fsdp_cfg, "fsdp_version", None)
+            )
+            is_fsdp2 = is_fsdp and str(fsdp_version) == "2"

Run to find any remaining top-level uses:


🏁 Script executed:

#!/bin/bash
rg -n -C2 -g 'src/**' -e '\bfsdp_version\b' -e '\bfsdp_final_state_dict_type\b' -e '\bfp8_enable_fsdp_float8_all_gather\b' | grep -v 'schemas/fsdp.py'

Length of output: 9737


Fix all top-level fsdp_version checks — read version from fsdp_config (critical)

fsdp_version was nested into fsdp_config; existing top-level checks (data.get("fsdp_version") / cfg.fsdp_version / str(...)) will miss FSDP2 and mis-gate runtime behavior. Patch the two validators below and every other occurrence to read the nested value (handle dict or typed object, optionally fall back to top-level for compatibility).

Patch for the two validators (unchanged from suggestion):

@@ def check_multigpu_lora_kernels(cls, data):
-            capabilities = data.get("capabilities")
-            is_fsdp = data.get("fsdp_config") is not None
-            is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
+            capabilities = data.get("capabilities")
+            fsdp_cfg = data.get("fsdp_config")
+            is_fsdp = fsdp_cfg is not None
+            fsdp_version = (
+                fsdp_cfg.get("fsdp_version")
+                if isinstance(fsdp_cfg, dict)
+                else getattr(fsdp_cfg, "fsdp_version", None)
+            )
+            is_fsdp2 = is_fsdp and str(fsdp_version) == "2"
@@ def check_auto_enable_lora_kernels(cls, data):
-            capabilities = data.get("capabilities")
-            is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
-            is_fsdp = data.get("fsdp_config") is not None
-            is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
+            capabilities = data.get("capabilities")
+            is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
+            fsdp_cfg = data.get("fsdp_config")
+            is_fsdp = fsdp_cfg is not None
+            fsdp_version = (
+                fsdp_cfg.get("fsdp_version")
+                if isinstance(fsdp_cfg, dict)
+                else getattr(fsdp_cfg, "fsdp_version", None)
+            )
+            is_fsdp2 = is_fsdp and str(fsdp_version) == "2"

Other places found by the grep that must be updated to use the same pattern (update logic or centralize a helper to extract fsdp_version from fsdp_config):

  • src/axolotl/utils/trainer.py: line ~584 — if str(cfg.fsdp_version) == "2":
  • src/axolotl/utils/schemas/config.py: lines ~1097–1101 and ~1133–1136 — (the two validator sites above; diff applied)
  • src/axolotl/utils/schemas/validation.py: multiple locations (~417, ~810, ~822–827, ~838–846, ~855–861, ~891–907) — update data.get/self.fsdp_version checks and the migration/warning in check_fsdp_version_in_fsdp_config (it currently pops nested fsdp_version into top-level; reverse or remove this migration).
  • src/axolotl/loaders/patch_manager.py: lines ~112, ~119, ~179, ~286 — str(self.cfg.fsdp_version) == "2" checks
  • src/axolotl/loaders/model.py: lines ~203, ~730 — self.cfg.fsdp_version comparisons
  • src/axolotl/monkeypatch/accelerate/parallelism_config.py: line ~69 — self.fsdp_plugin.fsdp_version == 2

Actionable: replace direct top-level reads with a small centralized getter (or inline pattern above) that:

  • reads fsdp_cfg = cfg.fsdp_config (or data.get("fsdp_config"))
  • if dict -> fsdp_cfg.get("fsdp_version"); else -> getattr(fsdp_cfg, "fsdp_version", None)
  • fallback to getattr(cfg, "fsdp_version", None) only if backward-compat needed
    use str(...) == "2" or int compare consistently once value resolved.

Committable suggestion skipped: line range outside the PR's diff.

Comment thread src/axolotl/utils/schemas/fsdp.py Outdated
Comment on lines +17 to +20
fsdp_version: int | None = Field(
default=None,
json_schema_extra={"description": "FSDP version"},
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Constrain fsdp_version to known values.

Limit to Literal[1, 2] to catch typos early and avoid string/int ambiguity downstream.

-    fsdp_version: int | None = Field(
+    fsdp_version: Literal[1, 2] | None = Field(
         default=None,
         json_schema_extra={"description": "FSDP version"},
     )

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/axolotl/utils/schemas/fsdp.py around lines 17-20, the fsdp_version field
is currently typed as int | None; change its annotation to Literal[1, 2] | None
to restrict allowed values to 1 or 2 and prevent typos, and add the required
import for Literal (from typing import Literal) if not already present; keep the
Field(...) call and existing json_schema_extra, ensuring the default remains
None so Optional Literal values are allowed.

Comment on lines +56 to +71
@model_validator(mode="before")
@classmethod
def check_fsdp_torch_version(cls, data):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")

if torch_version is None:
import torch

torch_version = str(torch.__version__).split("+", maxsplit=1)[0]

if data.get("fsdp_config") and str(data.get("fsdp_version")) == "2":
if version.parse(torch_version) < version.parse("2.7.0"):
raise ValueError("FSDP2 is not supported on torch version < 2.7.0")

return data

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

FSDP2 torch-version gate never triggers (wrong keys in validator).

Inside FSDPConfig, data does not contain fsdp_config, so if data.get("fsdp_config") ... is always false and the FSDP2 check is skipped. Also comparing str(data.get("fsdp_version")) == "2" is brittle. Fix by checking fsdp_version directly and handling both int/str. Shorten the error to satisfy TRY003.

Apply:

@@
-    def check_fsdp_torch_version(cls, data):
-        env_capabilities = data.get("env_capabilities", {})
-        torch_version = env_capabilities.get("torch_version")
-
-        if torch_version is None:
-            import torch
-
-            torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
-
-        if data.get("fsdp_config") and str(data.get("fsdp_version")) == "2":
-            if version.parse(torch_version) < version.parse("2.7.0"):
-                raise ValueError("FSDP2 is not supported on torch version < 2.7.0")
-
-        return data
+    def check_fsdp_torch_version(cls, data):
+        # derive torch version from env_capabilities (if provided) or torch.__version__
+        torch_version = None
+        if isinstance(data, dict):
+            env_capabilities = data.get("env_capabilities") or {}
+            torch_version = env_capabilities.get("torch_version")
+        if torch_version is None:
+            try:
+                import torch
+                torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
+            except Exception:
+                return data  # torch not importable; skip
+
+        v = data.get("fsdp_version") if isinstance(data, dict) else None
+        if v in (2, "2"):
+            if version.parse(torch_version) < version.parse("2.7.0"):
+                raise ValueError("FSDP2 requires torch>=2.7.0")
+        return data
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@model_validator(mode="before")
@classmethod
def check_fsdp_torch_version(cls, data):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if data.get("fsdp_config") and str(data.get("fsdp_version")) == "2":
if version.parse(torch_version) < version.parse("2.7.0"):
raise ValueError("FSDP2 is not supported on torch version < 2.7.0")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_torch_version(cls, data):
# derive torch version from env_capabilities (if provided) or torch.__version__
torch_version = None
if isinstance(data, dict):
env_capabilities = data.get("env_capabilities") or {}
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
try:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
except Exception:
return data # torch not importable; skip
v = data.get("fsdp_version") if isinstance(data, dict) else None
if v in (2, "2"):
if version.parse(torch_version) < version.parse("2.7.0"):
raise ValueError("FSDP2 requires torch>=2.7.0")
return data
🧰 Tools
🪛 Ruff (0.12.2)

69-69: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In src/axolotl/utils/schemas/fsdp.py around lines 56–71, the before-model
validator uses the wrong key so the FSDP2 torch-version gate never runs; change
the check to read fsdp_version directly from data (coerce to string or int so
both 2 and "2" work) and remove the unused fsdp_config check, then keep the
existing torch_version lookup/fallback and raise a short error like "FSDP v2
requires torch>=2.7.0" when version.parse(torch_version) <
version.parse("2.7.0").

@salmanmohammadi salmanmohammadi changed the title Movign pramas to FSDP_config Create FSDPConfig schema Sep 12, 2025
@winglian

Copy link
Copy Markdown
Collaborator

I do worry that this PR breaks backwards compatibility of being able to handle the legacy fsdp_ prefixed options.

@codecov

codecov Bot commented Sep 15, 2025

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 82.60870% with 12 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/utils/schemas/fsdp.py 84.90% 8 Missing ⚠️
src/axolotl/utils/schemas/validation.py 69.23% 4 Missing ⚠️

📢 Thoughts on this report? Let us know!

"used in combination with torch.compile."
},
)
fp8_enable_fsdp_float8_all_gather: bool | None = Field(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't have been moved

fsdp_config: FSDPConfig | None = Field(
default=None, json_schema_extra={"description": "FSDP configuration options"}
)
fsdp_version: int | None = Field(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also should not have been moved

Comment thread src/axolotl/utils/schemas/fsdp.py Outdated


class FSDPConfig(BaseModel):
model_config = ConfigDict(extra="allow")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this used for?


class FSDPConfig(BaseModel):
model_config = ConfigDict(extra="allow")
fsdp: list[str] | None = Field(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be here

cpu_ram_efficient_loading = None
if hasattr(fsdp_config, "cpu_ram_efficient_loading"):
cpu_ram_efficient_loading = fsdp_config.cpu_ram_efficient_loading
elif isinstance(fsdp_config, dict):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would this be the case?

@ved1beta ved1beta closed this Sep 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants