From 09e85f22b702bca763b4a32734f77c5cb4ebf8cf Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 3 Aug 2025 21:21:51 -0400 Subject: [PATCH 1/9] Add support for Dion optimizer --- src/axolotl/core/builders/base.py | 14 ++++++++++++++ src/axolotl/integrations/base.py | 23 +++++++++++++++++++++++ src/axolotl/utils/schemas/enums.py | 2 ++ src/axolotl/utils/schemas/training.py | 20 ++++++++++++++++++++ 4 files changed, 59 insertions(+) diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index dbdda7a7cd..cb6b065c0a 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -267,6 +267,20 @@ def _configure_custom_optimizer( optimizer_cls = MuonOptimizerFactory optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer in ["dion", "dion_8bit"]: + from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module + Dion8bitOptimizerFactory, + DionOptimizerFactory, + ) + + optimizer_cls = ( + Dion8bitOptimizerFactory + if self.cfg.optimizer == "dion_8bit" + else DionOptimizerFactory + ) + optimizer_kwargs.update(adam_kwargs) + partial_state = PartialState() + optimizer_kwargs["device_mesh"] = partial_state.device_mesh elif self.cfg.optimizer == "optimi_adamw": from optimi import AdamW diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 7d9b6a6f96..f43031287f 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -26,9 +26,11 @@ from typing import TYPE_CHECKING, Callable, OrderedDict, Union from peft import PeftModel +from torch import nn from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from transformers import PreTrainedModel, Trainer +from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -641,3 +643,24 @@ def __call__( self, opt_model, training_args, **optimizer_kwargs ) -> Optimizer | None: pass + + # duplicated from transformers + def get_decay_parameter_names(self, model) -> list[str]: + """ + Get all parameter names that weight decay will be applied to. + + This function filters out parameters in two ways: + 1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS) + 2. By parameter name patterns (containing 'bias', or variation of 'norm') + """ + forbidden_name_patterns = [ + r"bias", + r"layernorm", + r"rmsnorm", + r"(?:^|\.)norm(?:$|\.)", + r"_norm(?:$|\.)", + ] + decay_parameters = get_parameter_names( + model, [nn.LayerNorm], forbidden_name_patterns + ) + return decay_parameters diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 3c88283962..6d9c6c6573 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -79,6 +79,8 @@ class CustomSupportedOptimizers(str, Enum): adopt_adamw = "adopt_adamw" came_pytorch = "came_pytorch" muon = "muon" + dion = "dion" + dion_8bit = "dion_8bit" class RingAttnFunc(str, Enum): diff --git a/src/axolotl/utils/schemas/training.py b/src/axolotl/utils/schemas/training.py index 6ee8633975..b1788dcaa5 100644 --- a/src/axolotl/utils/schemas/training.py +++ b/src/axolotl/utils/schemas/training.py @@ -138,6 +138,26 @@ class HyperparametersConfig(BaseModel): adam_beta3: float | None = Field( default=None, json_schema_extra={"description": "only used for CAME Optimizer"} ) + + dion_lr: float | None = Field( + default=None, json_schema_extra={"description": "Dion Optimizer learning rate"} + ) + dion_momentum: float | None = Field( + default=None, json_schema_extra={"description": "Dion Optimizer momentum"} + ) + dion_rank_fraction: float | None = Field( + default=1.0, + json_schema_extra={ + "description": "Dion Optimizer: r/d fraction for low-rank approximation. Used to compute the low-rank dimension." + }, + ) + dion_rank_multiple_of: int | None = Field( + default=1, + json_schema_extra={ + "description": "Dion Optimizer: Round up the low-rank dimension to a multiple of this number. This may be useful to ensure even sharding." + }, + ) + max_grad_norm: float | None = Field( default=None, json_schema_extra={"description": "Gradient clipping max norm"} ) From 60c1ed4e81ab876eda9d389e11e46194e086353b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 3 Aug 2025 21:31:49 -0400 Subject: [PATCH 2/9] dion training kwargs --- src/axolotl/core/builders/base.py | 12 ++++++++++++ src/axolotl/core/training_args_base.py | 15 +++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index cb6b065c0a..7ae8be5f99 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -278,6 +278,8 @@ def _configure_custom_optimizer( if self.cfg.optimizer == "dion_8bit" else DionOptimizerFactory ) + optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"] + optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"] optimizer_kwargs.update(adam_kwargs) partial_state = PartialState() optimizer_kwargs["device_mesh"] = partial_state.device_mesh @@ -534,6 +536,16 @@ def _set_base_training_args( if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: training_args_kwargs[arg] = getattr(self.cfg, arg) + arg_map = { + "dion_learning_rate": "dion_lr", + "dion_momentum": "dion_mu", + "dion_rank_fraction": "dion_rank_fraction", + "dion_rank_multiple_of": "dion_rank_multiple_of", + } + for kwarg, cfg_arg in arg_map.items(): + if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None: + training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg) + training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size training_args_kwargs["average_tokens_across_devices"] = False diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py index 66649deefd..fd0859ae91 100644 --- a/src/axolotl/core/training_args_base.py +++ b/src/axolotl/core/training_args_base.py @@ -243,3 +243,18 @@ class AxolotlTrainingMixins: ) # end of multi-modal section + + dion_learning_rate: float | None = field( + default=None, + metadata={"help": "The learning rate for Dion"}, + ) + dion_momentum: float | None = field( + default=None, + metadata={"help": "The momentum for Dion"}, + ) + dion_rank_fraction: float | None = field( + default=None, + ) + dion_rank_multiple_of: int | None = field( + default=None, + ) From 4c437af8bc7fba8406ca07eb930c43a6ab03edf9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 3 Aug 2025 21:42:27 -0400 Subject: [PATCH 3/9] fix var names --- src/axolotl/core/builders/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 7ae8be5f99..819ff27941 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -532,15 +532,15 @@ def _set_base_training_args( "include_tokens_per_second", "weight_decay", "seed", + "dion_momentum", + "dion_rank_fraction", + "dion_rank_multiple_of", ]: if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None: training_args_kwargs[arg] = getattr(self.cfg, arg) arg_map = { "dion_learning_rate": "dion_lr", - "dion_momentum": "dion_mu", - "dion_rank_fraction": "dion_rank_fraction", - "dion_rank_multiple_of": "dion_rank_multiple_of", } for kwarg, cfg_arg in arg_map.items(): if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None: From 972aed8e5af7cad8794045c1f0dc83a8f02248c8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 4 Aug 2025 00:24:58 -0400 Subject: [PATCH 4/9] no dion 8bit for now --- src/axolotl/core/builders/base.py | 9 ++------- src/axolotl/utils/schemas/enums.py | 1 - 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 819ff27941..5a25c1834d 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -267,17 +267,12 @@ def _configure_custom_optimizer( optimizer_cls = MuonOptimizerFactory optimizer_kwargs.update(adam_kwargs) - elif self.cfg.optimizer in ["dion", "dion_8bit"]: + elif self.cfg.optimizer == "dion": from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module - Dion8bitOptimizerFactory, DionOptimizerFactory, ) - optimizer_cls = ( - Dion8bitOptimizerFactory - if self.cfg.optimizer == "dion_8bit" - else DionOptimizerFactory - ) + optimizer_cls = DionOptimizerFactory optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"] optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"] optimizer_kwargs.update(adam_kwargs) diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 6d9c6c6573..cf2a8b484f 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -80,7 +80,6 @@ class CustomSupportedOptimizers(str, Enum): came_pytorch = "came_pytorch" muon = "muon" dion = "dion" - dion_8bit = "dion_8bit" class RingAttnFunc(str, Enum): From 6bb8d424b68edcfc4ec13e033cb55e6269f43d6b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 4 Aug 2025 09:52:58 -0400 Subject: [PATCH 5/9] use updated axolotl-contribs-mit for dion optimizer --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 4e82dfd893..cd9b2cf621 100644 --- a/requirements.txt +++ b/requirements.txt @@ -66,6 +66,6 @@ torchao==0.12.0 schedulefree==1.4.1 axolotl-contribs-lgpl==0.0.6 -axolotl-contribs-mit==0.0.3 +axolotl-contribs-mit==0.0.4 mistral-common==1.8.3 From 57d6895cd652f8b84858ef324c2f6785246bce57 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 4 Aug 2025 09:56:32 -0400 Subject: [PATCH 6/9] add smoke test for dion optimizer --- tests/e2e/test_optimizers.py | 46 +++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index 1d233a2013..e5bc33ae29 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -13,6 +13,7 @@ check_model_output_exists, require_torch_2_5_1, require_torch_2_6_0, + require_torch_2_7_0, with_temp_dir, ) @@ -158,7 +159,50 @@ def test_muon(self, temp_dir): _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) - assert "Muon" in trainer.optimizer.optimizer.__class__.__name__ + assert "Muon" in trainer.optimizer.optimizer.__class__.__name__ @ with_temp_dir + + @require_torch_2_7_0 + def test_dion(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "model_type": "AutoModelForCausalLM", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "val_set_size": 0.0, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "dion", + "dion_lr": 0.01, + "dion_momentum": 0.95, + "lr_scheduler": "cosine", + "weight_decay": 0.01, + "save_first_step": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + assert "Dion" in trainer.optimizer.optimizer.__class__.__name__ @with_temp_dir def test_fft_schedule_free_adamw(self, temp_dir): From 6b764ba7ff8f3efb9dd9091c99fbb12eb389e1ab Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 4 Aug 2025 10:22:39 -0400 Subject: [PATCH 7/9] add docs --- _quarto.yml | 1 + docs/nd_parallelism.qmd | 2 +- docs/optimizers.qmd | 18 ++++++++++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 docs/optimizers.qmd diff --git a/_quarto.yml b/_quarto.yml index 738fe5e2fa..5bb771c01d 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -284,6 +284,7 @@ website: - docs/sequence_parallelism.qmd - docs/gradient_checkpointing.qmd - docs/nd_parallelism.qmd + - docs/optimizers.qmd - section: "Troubleshooting" contents: diff --git a/docs/nd_parallelism.qmd b/docs/nd_parallelism.qmd index d27a156634..8aebab1409 100644 --- a/docs/nd_parallelism.qmd +++ b/docs/nd_parallelism.qmd @@ -1,5 +1,5 @@ --- -title: "N-D Parallelism" +title: "N-D Parallelism (Beta)" --- Axolotl enables training models at scale by composing different parallelism techniques. This is essential when: diff --git a/docs/optimizers.qmd b/docs/optimizers.qmd new file mode 100644 index 0000000000..563e9695bb --- /dev/null +++ b/docs/optimizers.qmd @@ -0,0 +1,18 @@ +--- +title: Optimizers +description: Configuring optimizers +--- + +### Dion Optimizer + +Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient +orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication. + +Usage: + +```yaml +optimizer: dion +dion_lr: 0.01 +dion_momentum: 0.95 +lr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW +``` From dcb5de80a3696a563a746791a0df89af574a71ee Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 4 Aug 2025 11:30:56 -0400 Subject: [PATCH 8/9] fix typo during edits --- tests/e2e/test_optimizers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index e5bc33ae29..db6f0de112 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -159,8 +159,9 @@ def test_muon(self, temp_dir): _, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) - assert "Muon" in trainer.optimizer.optimizer.__class__.__name__ @ with_temp_dir + assert "Muon" in trainer.optimizer.optimizer.__class__.__name__ + @with_temp_dir @require_torch_2_7_0 def test_dion(self, temp_dir): # pylint: disable=duplicate-code From a3214544981666f38f002181c17dff9dad5d1d23 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 4 Aug 2025 13:23:41 -0400 Subject: [PATCH 9/9] fix test to not remove load in 8bit --- tests/e2e/test_optimizers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index db6f0de112..987d860418 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -171,7 +171,6 @@ def test_dion(self, temp_dir): "model_type": "AutoModelForCausalLM", "tokenizer_type": "AutoTokenizer", "sequence_len": 1024, - "load_in_8bit": True, "val_set_size": 0.0, "special_tokens": { "pad_token": "<|endoftext|>",