Dion optimizer support#3014
Conversation
|
📖 Documentation Preview: https://6890ee1bfe0dc2a416e0086a--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit a321454 |
📝 WalkthroughWalkthroughThis change introduces support for a new custom optimizer named "dion" throughout the codebase. It updates configuration schemas, training argument mixins, and the optimizer builder to handle Dion-specific hyperparameters. The test suite is expanded with an end-to-end test for the Dion optimizer, and the requirements are updated to reference a newer version of Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~15–20 minutes Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. 📜 Recent review detailsConfiguration used: .coderabbit.yaml 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
src/axolotl/core/training_args_base.py (1)
247-260: Add help metadata for consistency and documentation.The
dion_rank_fractionanddion_rank_multiple_offields are missing help metadata, while the other Dion fields have descriptive help text. This inconsistency could impact auto-generated documentation and user experience.Consider adding help metadata for these fields:
dion_rank_fraction: float | None = field( default=None, + metadata={"help": "The rank fraction for Dion optimizer"}, ) dion_rank_multiple_of: int | None = field( default=None, + metadata={"help": "The rank multiple for Dion optimizer"}, )
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
requirements.txt(1 hunks)src/axolotl/core/builders/base.py(2 hunks)src/axolotl/core/training_args_base.py(1 hunks)src/axolotl/integrations/base.py(2 hunks)src/axolotl/utils/schemas/enums.py(1 hunks)src/axolotl/utils/schemas/training.py(1 hunks)tests/e2e/test_optimizers.py(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
🔇 Additional comments (13)
requirements.txt (1)
69-69: LGTM: Dependency version update supports new Dion optimizer.The version bump from 0.0.3 to 0.0.4 for
axolotl-contribs-mitcorrectly aligns with the introduction of Dion optimizer support, which imports from this package.src/axolotl/utils/schemas/enums.py (1)
82-82: LGTM: Enum addition follows established patterns.The addition of "dion" to the
CustomSupportedOptimizersenum is correctly implemented and consistent with existing optimizer entries.src/axolotl/utils/schemas/training.py (1)
142-159: LGTM: Well-structured schema extensions for Dion optimizer.The new Dion optimizer fields are properly implemented with:
- Consistent naming conventions following existing patterns
- Appropriate field types and default values
- Descriptive JSON schema metadata
- Logical parameter grouping
The schema extensions correctly support the new optimizer's configuration requirements.
tests/e2e/test_optimizers.py (1)
164-205: LGTM: Comprehensive test coverage for Dion optimizer.The new
test_dionmethod provides excellent test coverage with:
- Proper PyTorch version requirement enforcement
- Comprehensive configuration including Dion-specific parameters
- Standard validation flow following established patterns
- Appropriate optimizer class name verification
src/axolotl/integrations/base.py (3)
29-29: LGTM: Required import for new parameter filtering functionality.The
torch.nnimport is correctly added to support the newget_decay_parameter_namesmethod.
33-33: LGTM: Import follows transformers library patterns.The
get_parameter_namesimport from transformers provides established parameter filtering functionality.
647-666: LGTM: Well-implemented parameter filtering for weight decay.The
get_decay_parameter_namesmethod provides essential functionality for optimizers with:
- Proper filtering of normalization layers (LayerNorm)
- Comprehensive regex patterns for bias and norm parameter exclusion
- Use of established transformers utility functions
- Clear documentation explaining the filtering logic
This will enable proper weight decay application in optimizers like Dion.
src/axolotl/core/training_args_base.py (2)
247-250: LGTM!The Dion learning rate field is properly implemented with appropriate type annotation and descriptive help metadata.
251-254: LGTM!The Dion momentum field is properly implemented with appropriate type annotation and descriptive help metadata.
src/axolotl/core/builders/base.py (4)
530-532: LGTM!The Dion optimizer parameters are correctly added to the training arguments list following the existing pattern.
537-542: LGTM!The argument mapping for
dion_learning_ratetodion_lris implemented correctly and follows a clean pattern for parameter name translation.
279-280: Inefficient PartialState Instantiation – No Reusable Instance in ScopeEach call to
PartialState()creates a fresh object, and in this optimizer-building block there is no existingpartial_stateto reuse. The only other instantiation in this file lives inside a separate method (_configure_accelerator_config), so it isn’t accessible here. Also,PartialState().device_meshis guaranteed to exist, as downstream code uses it extensively (e.g. in sequence-parallel, FSDP monkey patches, model loaders).Locations to inspect or improve:
- src/axolotl/core/builders/base.py (around line 279):
partial_state = PartialState() optimizer_kwargs["device_mesh"] = partial_state.device_mesh_configure_accelerator_config(around line 447): separate scope, itspartial_statecannot be shared.Recommendation:
- If performance is a concern, consider hoisting a single
PartialStateinstance toself(e.g.,self._partial_state) in the builder’s constructor or earliest configuration method, then reuse it in both_configure_accelerator_configand optimizer setup.- Otherwise, leave as is—
device_meshis always available on a freshPartialState.
270-273: No changes needed for the dion import
The pathaxolotl.contribs.mit.dionis correct and the package is declared in requirements.txt (axolotl-contribs-mit==0.0.4). Other contrib imports (e.g. muon) use the same pattern and there’s no missing error handling requirement here.
| elif self.cfg.optimizer == "dion": | ||
| from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module | ||
| DionOptimizerFactory, | ||
| ) | ||
|
|
||
| optimizer_cls = DionOptimizerFactory | ||
| optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"] | ||
| optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"] | ||
| optimizer_kwargs.update(adam_kwargs) | ||
| partial_state = PartialState() | ||
| optimizer_kwargs["device_mesh"] = partial_state.device_mesh |
There was a problem hiding this comment.
Add error handling for required Dion parameters.
The code assumes dion_learning_rate and dion_momentum are always present in training_args_kwargs, but these could be None or missing, which would cause runtime errors.
Add validation for required parameters:
elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
DionOptimizerFactory,
)
optimizer_cls = DionOptimizerFactory
+
+ dion_lr = training_args_kwargs.get("dion_learning_rate")
+ dion_mu = training_args_kwargs.get("dion_momentum")
+
+ if dion_lr is None:
+ raise ValueError("dion_learning_rate is required when using dion optimizer")
+ if dion_mu is None:
+ raise ValueError("dion_momentum is required when using dion optimizer")
+
- optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
- optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
+ optimizer_kwargs["dion_lr"] = dion_lr
+ optimizer_kwargs["dion_mu"] = dion_mu
optimizer_kwargs.update(adam_kwargs)
partial_state = PartialState()
optimizer_kwargs["device_mesh"] = partial_state.device_mesh📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| elif self.cfg.optimizer == "dion": | |
| from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module | |
| DionOptimizerFactory, | |
| ) | |
| optimizer_cls = DionOptimizerFactory | |
| optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"] | |
| optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"] | |
| optimizer_kwargs.update(adam_kwargs) | |
| partial_state = PartialState() | |
| optimizer_kwargs["device_mesh"] = partial_state.device_mesh | |
| elif self.cfg.optimizer == "dion": | |
| from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module | |
| DionOptimizerFactory, | |
| ) | |
| optimizer_cls = DionOptimizerFactory | |
| dion_lr = training_args_kwargs.get("dion_learning_rate") | |
| dion_mu = training_args_kwargs.get("dion_momentum") | |
| if dion_lr is None: | |
| raise ValueError("dion_learning_rate is required when using dion optimizer") | |
| if dion_mu is None: | |
| raise ValueError("dion_momentum is required when using dion optimizer") | |
| optimizer_kwargs["dion_lr"] = dion_lr | |
| optimizer_kwargs["dion_mu"] = dion_mu | |
| optimizer_kwargs.update(adam_kwargs) | |
| partial_state = PartialState() | |
| optimizer_kwargs["device_mesh"] = partial_state.device_mesh |
🤖 Prompt for AI Agents
In src/axolotl/core/builders/base.py around lines 270 to 280, the code assumes
that 'dion_learning_rate' and 'dion_momentum' keys exist and are not None in
training_args_kwargs, which can cause runtime errors if missing or None. Add
validation checks before using these parameters to ensure they are present and
not None; if validation fails, raise a clear exception or handle the error
appropriately to prevent runtime failures.
Description
Adds support for https://github.com/microsoft/dion via the contribs integration
Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
axolotl-contribs-mitto 0.0.4.