Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ website:
- docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- docs/nd_parallelism.qmd
- docs/optimizers.qmd

- section: "Troubleshooting"
contents:
Expand Down
2 changes: 1 addition & 1 deletion docs/nd_parallelism.qmd
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
title: "N-D Parallelism"
title: "N-D Parallelism (Beta)"
---

Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
Expand Down
18 changes: 18 additions & 0 deletions docs/optimizers.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
---
title: Optimizers
description: Configuring optimizers
---

### Dion Optimizer

Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient
orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication.

Usage:

```yaml
optimizer: dion
dion_lr: 0.01
dion_momentum: 0.95
lr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW
```
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@ torchao==0.12.0
schedulefree==1.4.1

axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
axolotl-contribs-mit==0.0.4

mistral-common==1.8.3
21 changes: 21 additions & 0 deletions src/axolotl/core/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,17 @@ def _configure_custom_optimizer(

optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
DionOptimizerFactory,
)

optimizer_cls = DionOptimizerFactory
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
optimizer_kwargs.update(adam_kwargs)
partial_state = PartialState()
optimizer_kwargs["device_mesh"] = partial_state.device_mesh
Comment on lines +270 to +280

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for required Dion parameters.

The code assumes dion_learning_rate and dion_momentum are always present in training_args_kwargs, but these could be None or missing, which would cause runtime errors.

Add validation for required parameters:

 elif self.cfg.optimizer == "dion":
     from axolotl.contribs.mit.dion import (  # pylint: disable=no-name-in-module
         DionOptimizerFactory,
     )

     optimizer_cls = DionOptimizerFactory
+    
+    dion_lr = training_args_kwargs.get("dion_learning_rate")
+    dion_mu = training_args_kwargs.get("dion_momentum")
+    
+    if dion_lr is None:
+        raise ValueError("dion_learning_rate is required when using dion optimizer")
+    if dion_mu is None:
+        raise ValueError("dion_momentum is required when using dion optimizer")
+        
-    optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
-    optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
+    optimizer_kwargs["dion_lr"] = dion_lr
+    optimizer_kwargs["dion_mu"] = dion_mu
     optimizer_kwargs.update(adam_kwargs)
     partial_state = PartialState()
     optimizer_kwargs["device_mesh"] = partial_state.device_mesh
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
DionOptimizerFactory,
)
optimizer_cls = DionOptimizerFactory
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
optimizer_kwargs.update(adam_kwargs)
partial_state = PartialState()
optimizer_kwargs["device_mesh"] = partial_state.device_mesh
elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
DionOptimizerFactory,
)
optimizer_cls = DionOptimizerFactory
dion_lr = training_args_kwargs.get("dion_learning_rate")
dion_mu = training_args_kwargs.get("dion_momentum")
if dion_lr is None:
raise ValueError("dion_learning_rate is required when using dion optimizer")
if dion_mu is None:
raise ValueError("dion_momentum is required when using dion optimizer")
optimizer_kwargs["dion_lr"] = dion_lr
optimizer_kwargs["dion_mu"] = dion_mu
optimizer_kwargs.update(adam_kwargs)
partial_state = PartialState()
optimizer_kwargs["device_mesh"] = partial_state.device_mesh
🤖 Prompt for AI Agents
In src/axolotl/core/builders/base.py around lines 270 to 280, the code assumes
that 'dion_learning_rate' and 'dion_momentum' keys exist and are not None in
training_args_kwargs, which can cause runtime errors if missing or None. Add
validation checks before using these parameters to ensure they are present and
not None; if validation fails, raise a clear exception or handle the error
appropriately to prevent runtime failures.

elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW

Expand Down Expand Up @@ -516,10 +527,20 @@ def _set_base_training_args(
"include_tokens_per_second",
"weight_decay",
"seed",
"dion_momentum",
"dion_rank_fraction",
"dion_rank_multiple_of",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)

arg_map = {
"dion_learning_rate": "dion_lr",
}
for kwarg, cfg_arg in arg_map.items():
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:
training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg)

training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
training_args_kwargs["average_tokens_across_devices"] = False

Expand Down
15 changes: 15 additions & 0 deletions src/axolotl/core/training_args_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,18 @@ class AxolotlTrainingMixins:
)

# end of multi-modal section

dion_learning_rate: float | None = field(
default=None,
metadata={"help": "The learning rate for Dion"},
)
dion_momentum: float | None = field(
default=None,
metadata={"help": "The momentum for Dion"},
)
dion_rank_fraction: float | None = field(
default=None,
)
dion_rank_multiple_of: int | None = field(
default=None,
)
23 changes: 23 additions & 0 deletions src/axolotl/integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
from typing import TYPE_CHECKING, Callable, OrderedDict, Union

from peft import PeftModel
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from transformers import PreTrainedModel, Trainer
from transformers.trainer_pt_utils import get_parameter_names

from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
Expand Down Expand Up @@ -641,3 +643,24 @@ def __call__(
self, opt_model, training_args, **optimizer_kwargs
) -> Optimizer | None:
pass

# duplicated from transformers
def get_decay_parameter_names(self, model) -> list[str]:
"""
Get all parameter names that weight decay will be applied to.

This function filters out parameters in two ways:
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
2. By parameter name patterns (containing 'bias', or variation of 'norm')
"""
forbidden_name_patterns = [
r"bias",
r"layernorm",
r"rmsnorm",
r"(?:^|\.)norm(?:$|\.)",
r"_norm(?:$|\.)",
]
decay_parameters = get_parameter_names(
model, [nn.LayerNorm], forbidden_name_patterns
)
return decay_parameters
1 change: 1 addition & 0 deletions src/axolotl/utils/schemas/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class CustomSupportedOptimizers(str, Enum):
adopt_adamw = "adopt_adamw"
came_pytorch = "came_pytorch"
muon = "muon"
dion = "dion"


class RingAttnFunc(str, Enum):
Expand Down
20 changes: 20 additions & 0 deletions src/axolotl/utils/schemas/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,26 @@ class HyperparametersConfig(BaseModel):
adam_beta3: float | None = Field(
default=None, json_schema_extra={"description": "only used for CAME Optimizer"}
)

dion_lr: float | None = Field(
default=None, json_schema_extra={"description": "Dion Optimizer learning rate"}
)
dion_momentum: float | None = Field(
default=None, json_schema_extra={"description": "Dion Optimizer momentum"}
)
dion_rank_fraction: float | None = Field(
default=1.0,
json_schema_extra={
"description": "Dion Optimizer: r/d fraction for low-rank approximation. Used to compute the low-rank dimension."
},
)
dion_rank_multiple_of: int | None = Field(
default=1,
json_schema_extra={
"description": "Dion Optimizer: Round up the low-rank dimension to a multiple of this number. This may be useful to ensure even sharding."
},
)

max_grad_norm: float | None = Field(
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
)
Expand Down
44 changes: 44 additions & 0 deletions tests/e2e/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
check_model_output_exists,
require_torch_2_5_1,
require_torch_2_6_0,
require_torch_2_7_0,
with_temp_dir,
)

Expand Down Expand Up @@ -160,6 +161,49 @@ def test_muon(self, temp_dir):
check_model_output_exists(temp_dir, cfg)
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__

@with_temp_dir
@require_torch_2_7_0
def test_dion(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "dion",
"dion_lr": 0.01,
"dion_momentum": 0.95,
"lr_scheduler": "cosine",
"weight_decay": 0.01,
"save_first_step": False,
}
)

cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)

_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "Dion" in trainer.optimizer.optimizer.__class__.__name__

@with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir):
# pylint: disable=duplicate-code
Expand Down