FSDP2 fix validation and add tests#2910
Conversation
WalkthroughValidation logic for FSDP (Fully Sharded Data Parallel) configuration in Axolotl has been refactored. Validators were removed from the main config schema and migrated to a dedicated validation module, where they were restructured, expanded, and clarified. Comprehensive tests for FSDP validation scenarios were also introduced. Additionally, the deprecated FSDP config migration function was removed and replaced by validation-based normalization in tests and CLI. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant ConfigSchema
participant ValidationMixin
User->>ConfigSchema: Provide configuration dict
ConfigSchema->>ValidationMixin: Run before-validation checks (class methods)
ValidationMixin-->>ConfigSchema: Log warnings/errors, migrate config keys
ConfigSchema->>ValidationMixin: Run after-validation checks (instance methods)
ValidationMixin-->>ConfigSchema: Raise errors for invalid FSDP settings
ConfigSchema-->>User: Return validated config or raise ValueError
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (3)
💤 Files with no reviewable changes (2)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
🔇 Additional comments (12)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
|
📖 Documentation Preview: https://6874effbb39dc04f274f4d18--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 492c344 |
There was a problem hiding this comment.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
tests/utils/schemas/validation/test_fsdp.py (1)
112-112: Fix incomplete file structure.The file appears to end abruptly without properly closing the last test method. This will cause a syntax error.
Add the missing line to properly close the test method:
+ ): + validate_config(cfg)
🧹 Nitpick comments (3)
src/axolotl/utils/schemas/validation.py (3)
763-777: Simplify nested if statements for better readability.The validation logic is correct, but the nested if statements can be combined for clarity.
def check_fsdp2_base_model_quant_ram_efficient_loading(self): fsdp_config = self.fsdp_config if hasattr(self, "fsdp_config") else None fsdp_version = self.fsdp_version if hasattr(self, "fsdp_version") else None load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None - if fsdp_config and fsdp_version == 2: - if fsdp_config.get("cpu_ram_efficient_loading") and ( - load_in_8bit or load_in_4bit - ): - raise ValueError( - "FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, " - "set fsdp_version to 1, or disable cpu_ram_efficient_loading." - ) + if ( + fsdp_config + and fsdp_version == 2 + and fsdp_config.get("cpu_ram_efficient_loading") + and (load_in_8bit or load_in_4bit) + ): + raise ValueError( + "FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, " + "set fsdp_version to 1, or disable cpu_ram_efficient_loading." + ) return self
778-793: Simplify conditional logic for clarity.The validation correctly prevents incompatible configurations, but can be more readable.
def check_fsdp2_base_model_quant_dpo(cls, data): - if data.get("fsdp_version") == 2 and data.get("rl") in [ + if ( + data.get("fsdp_version") == 2 + and data.get("rl") in [ RLType.DPO, RLType.KTO, RLType.ORPO, RLType.IPO, - ]: - if data.get("load_in_8bit") or data.get("load_in_4bit"): - raise ValueError( - "FSDP2 does not support load_in_8bit or load_in_4bit with DPO. Please use DeepSpeed or set `fsdp_version` to 1." - ) + ] + and (data.get("load_in_8bit") or data.get("load_in_4bit")) + ): + raise ValueError( + "FSDP2 does not support load_in_8bit or load_in_4bit with DPO. Please use DeepSpeed or set `fsdp_version` to 1." + ) return data
794-805: Good deprecation handling with minor improvement opportunity.The method properly handles the deprecated configuration pattern.
def check_fsdp_version_in_fsdp_config(cls, data): - if data.get("fsdp_config"): - if data.get("fsdp_config", {}).get("fsdp_version"): - LOG.warning( - "Configuring `fsdp_version` in `fsdp_config` is deprecated. " - "Please configure `fsdp_version` as a top-level field." - ) - data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version") + if data.get("fsdp_config") and data.get("fsdp_config").get("fsdp_version"): + LOG.warning( + "Configuring `fsdp_version` in `fsdp_config` is deprecated. " + "Please configure `fsdp_version` as a top-level field." + ) + data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version") return data
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/axolotl/utils/schemas/config.py(0 hunks)src/axolotl/utils/schemas/validation.py(2 hunks)tests/utils/schemas/validation/test_fsdp.py(1 hunks)
💤 Files with no reviewable changes (1)
- src/axolotl/utils/schemas/config.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/utils/schemas/validation.py (2)
src/axolotl/utils/schemas/enums.py (1)
RLType(24-32)src/axolotl/utils/logging.py (1)
warning_once(31-39)
🪛 Ruff (0.11.9)
src/axolotl/utils/schemas/validation.py
768-771: Use a single if statement instead of nested if statements
(SIM102)
781-787: Use a single if statement instead of nested if statements
(SIM102)
797-798: Use a single if statement instead of nested if statements
(SIM102)
847-854: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: preview
- GitHub Check: pre-commit
🔇 Additional comments (15)
tests/utils/schemas/validation/test_fsdp.py (11)
12-26: Well-structured test fixture.The base configuration fixture provides a clean foundation for FSDP validation tests with all required fields.
33-44: Good test coverage for fsdp_version migration.The test properly validates that
fsdp_versionis migrated from the nestedfsdp_configto the top level and removed from its original location.
45-70: Comprehensive test for SHARDED_STATE_DICT validation.The test thoroughly covers both prefix variants (
fsdp_state_dict_typeandstate_dict_type), ensuring the validation works correctly regardless of the deprecated prefix usage.
71-112: Excellent test coverage for FSDP version-specific validations.The tests properly validate:
- FSDP1 offload incompatibility with 8-bit optimizers
- FSDP2 specific incompatibilities with certain 8-bit optimizers
- FSDP2 constraints with quantization and CPU RAM efficient loading
All error messages are clear and actionable.
1-10: LGTM! Clean module setup with appropriate imports.The module docstring is clear, imports are relevant to the testing functionality, and the pylint disable is reasonable for test files.
12-25: Well-structured base configuration fixture.The fixture provides a comprehensive baseline configuration with all essential fields for FSDP testing. Using
DictDefaultenables easy extension in individual test methods.
28-31: Standard test class structure - looks good.Clear naming and documentation following pytest conventions.
33-43: Excellent test for configuration migration logic.This test effectively validates that
fsdp_versionis properly extracted from the nestedfsdp_configand moved to the top level while being removed from the original location.
45-69: Comprehensive coverage of state dict compatibility.Good practice testing both prefixed (
fsdp_state_dict_type) and unprefixed (state_dict_type) versions of the configuration key. This ensures robustness across different configuration formats.
71-96: Clear validation of optimizer compatibility constraints.Both FSDP v1 and v2 optimizer compatibility tests are well-structured with specific error messages that guide users toward the correct optimizer choice (
adamw_torch_8bitfor FSDP2).
98-111: Good coverage of complex feature interactions.This test effectively validates the constraint between FSDP2, quantization (8-bit loading), and CPU RAM efficient loading - a realistic scenario that could cause issues in production.
src/axolotl/utils/schemas/validation.py (4)
751-761: Helpful FSDP version migration notice.The validation method provides clear guidance about FSDP1 deprecation with a link to migration documentation.
806-827: Excellent handling of deprecated prefix removal.The method properly warns users about deprecated prefixes using
warning_onceto avoid log spam, and cleanly migrates the configuration keys.
845-861: Clear FSDP2 optimizer compatibility checks.The validation properly prevents incompatible 8-bit optimizer usage with FSDP2 and suggests the correct alternative.
862-875: Proper safetensors compatibility validation.The method correctly validates the incompatibility between SHARDED_STATE_DICT and safetensors format.
Codecov ReportAll modified and coverable lines are covered by tests ✅ 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (5)
tests/utils/schemas/validation/test_fsdp.py (1)
113-129: Fix static analysis hint for better code style.The test logic is excellent for validating deprecated prefix removal and configuration normalization. However, there's a minor style improvement needed.
Apply this diff to address the static analysis hint:
- for keys in cfg.fsdp_config.keys(): + for keys in cfg.fsdp_config: assert not keys.startswith("fsdp_")src/axolotl/utils/schemas/validation.py (4)
762-776: Consider simplifying nested conditions per static analysis hint.The validation logic is correct and the error message provides helpful alternatives. However, the nested if statements could be simplified.
Apply this diff to simplify the nested conditions:
- if fsdp_config and fsdp_version == 2: - if fsdp_config.get("cpu_ram_efficient_loading") and ( - load_in_8bit or load_in_4bit - ): + if ( + fsdp_config + and fsdp_version == 2 + and fsdp_config.get("cpu_ram_efficient_loading") + and (load_in_8bit or load_in_4bit) + ): raise ValueError( "FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, " "set fsdp_version to 1, or disable cpu_ram_efficient_loading." )
778-792: Consider simplifying nested conditions for better readability.The validation logic correctly enforces FSDP2 constraints with RL modes and uses the RLType enum appropriately. The nested if statements could be simplified per static analysis.
Apply this diff to simplify the nested conditions:
- if data.get("fsdp_version") == 2 and data.get("rl") in [ - RLType.DPO, - RLType.KTO, - RLType.ORPO, - RLType.IPO, - ]: - if data.get("load_in_8bit") or data.get("load_in_4bit"): + if ( + data.get("fsdp_version") == 2 + and data.get("rl") in [RLType.DPO, RLType.KTO, RLType.ORPO, RLType.IPO] + and (data.get("load_in_8bit") or data.get("load_in_4bit")) + ): raise ValueError( f"FSDP2 does not support load_in_8bit or load_in_4bit with {data.get('rl')}. Please use DeepSpeed or set `fsdp_version` to 1." )
794-804: Good backwards compatibility handling with room for simplification.The deprecation migration logic is well-implemented with appropriate warnings. The nested if statement could be simplified per static analysis.
Apply this diff to simplify the nested condition:
- if data.get("fsdp_config"): - if data.get("fsdp_config", {}).get("fsdp_version"): + if data.get("fsdp_config") and data.get("fsdp_config", {}).get("fsdp_version"): LOG.warning( "Configuring `fsdp_version` in `fsdp_config` is deprecated. " "Please configure `fsdp_version` as a top-level field." ) data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version")
843-858: Consider simplifying nested conditions in FSDP2 optimizer validation.The validation logic correctly identifies incompatible 8-bit optimizers with FSDP2 and provides helpful alternative suggestions. The nested if statements could be simplified per static analysis.
Apply this diff to simplify the nested conditions:
- if ( - hasattr(self, "fsdp_config") - and self.fsdp_config - and self.optimizer - and "8bit" in self.optimizer.value - and str(self.fsdp_version) == "2" - ): - if self.optimizer in ["adamw_8bit", "adamw_bnb_8bit"]: + if ( + hasattr(self, "fsdp_config") + and self.fsdp_config + and self.optimizer + and "8bit" in self.optimizer.value + and str(self.fsdp_version) == "2" + and self.optimizer in ["adamw_8bit", "adamw_bnb_8bit"] + ): # CUDA ops errors with bnb 8bit optimizer + FSDP2 raise ValueError( f"FSDP2 not compatible with {self.optimizer.value}, use `adamw_torch_8bit` instead" )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/axolotl/utils/schemas/validation.py(2 hunks)tests/utils/schemas/validation/test_fsdp.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (2)
src/axolotl/utils/schemas/validation.py (2)
src/axolotl/utils/schemas/enums.py (1)
RLType(24-32)src/axolotl/utils/logging.py (1)
warning_once(31-39)
tests/utils/schemas/validation/test_fsdp.py (2)
src/axolotl/utils/config/__init__.py (1)
validate_config(260-304)src/axolotl/utils/dict.py (1)
DictDefault(6-38)
🪛 Ruff (0.11.9)
src/axolotl/utils/schemas/validation.py
768-771: Use a single if statement instead of nested if statements
(SIM102)
781-787: Use a single if statement instead of nested if statements
(SIM102)
797-798: Use a single if statement instead of nested if statements
(SIM102)
845-852: Use a single if statement instead of nested if statements
(SIM102)
tests/utils/schemas/validation/test_fsdp.py
125-125: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: preview
- GitHub Check: pre-commit
🔇 Additional comments (12)
tests/utils/schemas/validation/test_fsdp.py (7)
1-25: Excellent test setup and fixture design.The imports are clean and the
fsdp_base_cfgfixture provides a comprehensive baseline configuration that covers all essential fields needed for FSDP validation testing.
33-43: Good test coverage for deprecated fsdp_version migration.This test properly validates the extraction of
fsdp_versionfrom the nestedfsdp_configand ensures the deprecated field is cleaned up after migration.
45-69: Comprehensive validation testing for FSDP state dict compatibility.Excellent coverage testing both the prefixed (
fsdp_state_dict_type) and non-prefixed (state_dict_type) configurations to ensure the incompatibility withsave_safetensorsis properly caught.
71-82: Proper validation of FSDP v1 offload incompatibility.This test correctly ensures that FSDP version 1 with parameter offloading raises an error when used with 8-bit optimizers, which is a known incompatibility.
84-96: Good validation with helpful error messaging.This test properly validates FSDP version 2 incompatibility with
adamw_8bitand ensures the error message provides constructive guidance by recommendingadamw_torch_8bitas an alternative.
98-111: Thorough testing of FSDP v2 CPU RAM loading constraints.This test properly validates the incompatibility between FSDP version 2's CPU RAM efficient loading and quantization, which is an important technical constraint to enforce.
131-155: Excellent parameterized testing for RL mode constraints.This parameterized test efficiently covers multiple reinforcement learning modes (dpo, kto, orpo, ipo) to ensure FSDP version 2 properly rejects 8-bit/4-bit loading in RL contexts, providing comprehensive validation coverage.
src/axolotl/utils/schemas/validation.py (5)
3-3: Appropriate pylint disable for validation complexity.The addition of
too-many-boolean-expressionsto the pylint disable is justified given the inherent complexity of validation methods that require multiple boolean conditions.
750-760: Good deprecation guidance with helpful documentation link.This validation method appropriately informs users about FSDP1 deprecation while providing a helpful link to migration documentation. The non-blocking approach is user-friendly.
806-826: Excellent prefix cleanup with efficient deprecation warnings.This method provides robust backwards compatibility by automatically removing deprecated
fsdp_prefixes while usingwarning_oncefor efficient deprecation messaging. The logic correctly preserves non-prefixed keys and handles the special case offsdp_version.
828-841: Well-refactored validation with proper attribute handling.The method has been correctly updated to use instance attributes with appropriate
hasattrsafety checks. The logic properly validates FSDP offload incompatibility with 8-bit optimizers for non-FSDP2 versions.
860-872: Properly refactored safetensors validation with instance attributes.The method has been correctly updated to use instance attributes with appropriate safety checks. The validation properly prevents the incompatible combination of FSDP sharded state dict with safetensors format.
Follow on for #2760 for dropping prefixes properly and handling validation since Pydantic validation ordering is hard to trace
Summary by CodeRabbit
Refactor
Tests