Skip to content

Fix: RL base feature parity#2133

Merged
NanoCode012 merged 83 commits into
mainfrom
fix/orpo_feature_parity
May 30, 2025
Merged

Fix: RL base feature parity#2133
NanoCode012 merged 83 commits into
mainfrom
fix/orpo_feature_parity

Conversation

@NanoCode012
Copy link
Copy Markdown
Collaborator

@NanoCode012 NanoCode012 commented Dec 6, 2024

Description

RL trainer was not loading some basic configs like logging_steps etc. This PR consolidates the setting of these params and cleans them up.

This PR also fixes a case where we did not call .map with num_proc.

To discuss:

  • Handling of bf16/bfloat16
  • Handling of fp16 in RL
  • Handling of tf32 in RL
  • Default of use_reentrant
  • Default of remove_unused_columns

Addresses a lot of points in #2105

Motivation and Context

How has this been tested?

Untested!

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features

    • Added new trainer builder modules for both causal language modeling and reinforcement learning workflows, offering more modular and extensible configuration options.
    • Introduced support for a new configuration parameter torch_compile_mode and an optional field dpo_label_smoothing for advanced training customization.
  • Refactor

    • Restructured trainer builder logic into dedicated modules, improving maintainability and clarity.
    • Updated import paths throughout the codebase to reflect the new modular trainer builder structure.
    • Enhanced optimizer integration for trainers using mixins for more flexible initialization and configuration.
  • Bug Fixes

    • Corrected reward function handling and argument blocklisting in GRPO strategy to improve reliability.
  • Documentation

    • Updated documentation configuration to reflect the new builder module structure and clarified certain configuration options.
  • Tests

    • Added comprehensive unit and integration tests for the new trainer builder modules, covering various training configurations and optimizer scenarios.

@NanoCode012 NanoCode012 force-pushed the fix/orpo_feature_parity branch 2 times, most recently from 0719188 to 4b8f65b Compare February 3, 2025 10:44
@NanoCode012 NanoCode012 force-pushed the fix/orpo_feature_parity branch from 1b15a11 to 62d04e4 Compare February 14, 2025 13:10
@NanoCode012 NanoCode012 marked this pull request as ready for review February 14, 2025 13:21
@bursteratom bursteratom force-pushed the fix/orpo_feature_parity branch from 65a83b7 to 93a2ecc Compare February 18, 2025 04:13
Copy link
Copy Markdown
Contributor

@bursteratom bursteratom left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@winglian
Copy link
Copy Markdown
Collaborator

Thanks @NanoCode012 ! This should be good to go once the multi gpu tests pass too https://github.com/axolotl-ai-cloud/axolotl/actions/runs/13457763772

@winglian winglian force-pushed the fix/orpo_feature_parity branch from fc04dcf to 4321607 Compare February 22, 2025 03:08
@winglian
Copy link
Copy Markdown
Collaborator

@winglian
Copy link
Copy Markdown
Collaborator

Comment thread src/axolotl/core/trainer_builder.py Outdated
Comment thread src/axolotl/core/trainer_builder.py Outdated
@winglian winglian force-pushed the fix/orpo_feature_parity branch from 4321607 to 8331312 Compare February 23, 2025 17:42
@winglian
Copy link
Copy Markdown
Collaborator

@NanoCode012 NanoCode012 force-pushed the fix/orpo_feature_parity branch from 8143e16 to 6a37ec7 Compare March 26, 2025 07:21
@winglian winglian force-pushed the fix/orpo_feature_parity branch from 9078cfe to b5ca6a7 Compare April 28, 2025 19:10
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 28, 2025

@github-actions github-actions Bot temporarily deployed to preview April 28, 2025 19:13 Inactive
@github-actions github-actions Bot temporarily deployed to preview May 6, 2025 08:48 Inactive
@codecov
Copy link
Copy Markdown

codecov Bot commented May 6, 2025

@NanoCode012 NanoCode012 force-pushed the fix/orpo_feature_parity branch from 998de17 to 79917ac Compare May 6, 2025 08:55
@github-actions github-actions Bot temporarily deployed to preview May 6, 2025 08:57 Inactive
@djsaunde
Copy link
Copy Markdown
Collaborator

One suggestion: can we break out the trainer_builder.py file into smaller modules? E.g., something like:

builders
├── __init__.py
├── base.py
├── causal.py
└── rl.py

@NanoCode012
Copy link
Copy Markdown
Collaborator Author

One suggestion: can we break out the trainer_builder.py file into smaller modules? E.g., something like:

builders
├── __init__.py
├── base.py
├── causal.py
└── rl.py

Yep, I can do that tomorrow!

@github-actions github-actions Bot temporarily deployed to preview May 27, 2025 07:36 Inactive
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
tests/core/test_trainer_builder.py (2)

12-12: Import path needs updating after refactor.

The import statement is using the old path structure. Based on the refactor that moved trainer builders to modular submodules, the imports should be updated.

-from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
+from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder
+from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder

579-579: Avoid directly modifying fixture in test.

Directly modifying the fixture can affect other tests that use the same fixture. Create a copy of the configuration instead.

-        kto_cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
+        cfg = kto_cfg.copy()
+        cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]

And update the builder instantiation:

-            builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)
+            builder = HFRLTrainerBuilder(cfg, model, tokenizer)
🧹 Nitpick comments (1)
tests/core/test_trainer_builder.py (1)

485-487: Simplify nested if statements.

The nested if statements can be combined for better readability.

-            if cfg_string == "grpo_cfg":
-                # remove imported module from path
-                if str(rewards_dir) in sys.path:
-                    sys.path.remove(str(rewards_dir))
+            if cfg_string == "grpo_cfg" and str(rewards_dir) in sys.path:
+                # remove imported module from path
+                sys.path.remove(str(rewards_dir))
🧰 Tools
🪛 Ruff (0.11.9)

485-487: Use a single if statement instead of nested if statements

(SIM102)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0e52b4d and 1f71e3a.

📒 Files selected for processing (6)
  • _quarto.yml (1 hunks)
  • src/axolotl/core/trainer_builder/rl.py (1 hunks)
  • src/axolotl/core/trainers/grpo/trainer.py (3 hunks)
  • src/axolotl/core/trainers/trl.py (2 hunks)
  • src/axolotl/utils/data/rl.py (1 hunks)
  • tests/core/test_trainer_builder.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
  • src/axolotl/utils/data/rl.py
  • src/axolotl/core/trainers/grpo/trainer.py
  • _quarto.yml
  • src/axolotl/core/trainers/trl.py
  • src/axolotl/core/trainer_builder/rl.py
🧰 Additional context used
🪛 Ruff (0.11.9)
tests/core/test_trainer_builder.py

485-487: Use a single if statement instead of nested if statements

(SIM102)

🔇 Additional comments (5)
tests/core/test_trainer_builder.py (5)

22-235: Well-structured fixture organization.

The comprehensive fixture setup with a base configuration and specialized variants for different training types (DPO, ORPO, KTO, GRPO, IPO, SIMPO, SFT, RM, PRM) provides excellent test coverage and maintainability.


253-282: Excellent helper method for common test assertions.

The _test_common_training_arguments helper method effectively reduces code duplication and ensures consistent testing of shared training argument properties across different RL variants.


312-357: Comprehensive GRPO test with proper cleanup.

The test properly handles dynamic reward function creation and cleanup, which is essential for testing GRPO functionality. The try/finally block ensures proper cleanup of the Python path.


378-489: Thorough parameterized testing for custom optimizer.

The parameterized test comprehensively covers optimizer configuration across different RL types, ensuring the Muon optimizer is correctly instantiated and configured.

🧰 Tools
🪛 Ruff (0.11.9)

485-487: Use a single if statement instead of nested if statements

(SIM102)


491-565: Comprehensive SFT trainer testing.

The TestHFCausalTrainerBuilder class provides thorough coverage of SFT trainer functionality, including parameterized testing for different model types (SFT, RM, PRM) with custom optimizers.

@github-actions github-actions Bot temporarily deployed to preview May 27, 2025 08:35 Inactive
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
tests/core/test_trainer_builder.py (2)

12-12: Fix outdated import path after refactor

The import path needs to be updated to reflect the new modular trainer builder structure.

-from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
+from axolotl.core.trainers.trainer_builder.sft import HFCausalTrainerBuilder
+from axolotl.core.trainers.trainer_builder.rl import HFRLTrainerBuilder

576-595: Avoid directly modifying fixture in test

Directly modifying the fixture can affect other tests that use the same fixture. Create a copy of the configuration instead.

 def test_trainer_cls_is_not_none_with_plugin(self, kto_cfg, model, tokenizer):
     """
     Test that the trainer cls is not none with plugin

     Fixes #2693
     """
-    kto_cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
+    cfg = kto_cfg.copy()
+    cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]

     # Expected AttributeError as we don't pass regular model configs to RL trainer builder
     # If it throws `TypeError: None is not a callable object`, trainer_cls could be None
     try:
-        builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)
+        builder = HFRLTrainerBuilder(cfg, model, tokenizer)

         builder.build(100)
🧹 Nitpick comments (1)
tests/core/test_trainer_builder.py (1)

488-491: Simplify nested if statements

The nested if statements can be combined into a single condition for better readability.

-        if cfg_string == "grpo_cfg":
-            # remove imported module from path
-            if str(rewards_dir) in sys.path:
-                sys.path.remove(str(rewards_dir))
+        if cfg_string == "grpo_cfg" and str(rewards_dir) in sys.path:
+            # remove imported module from path
+            sys.path.remove(str(rewards_dir))
🧰 Tools
🪛 Ruff (0.11.9)

488-490: Use a single if statement instead of nested if statements

(SIM102)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1f71e3a and 132a1dd.

📒 Files selected for processing (3)
  • src/axolotl/core/trainer_builder/rl.py (1 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • tests/core/test_trainer_builder.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • src/axolotl/utils/schemas/config.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/core/trainer_builder/rl.py
🧰 Additional context used
🪛 Ruff (0.11.9)
tests/core/test_trainer_builder.py

488-490: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms (9)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: preview
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit
🔇 Additional comments (4)
tests/core/test_trainer_builder.py (4)

22-88: Well-structured base configuration fixture

The base configuration fixture provides comprehensive coverage of common training parameters. The organization and parameter choices look appropriate for testing various trainer scenarios.


91-181: Comprehensive RL algorithm fixtures

Excellent coverage of different RL algorithms (DPO, ORPO, KTO, GRPO, IPO, SIMPO) with algorithm-specific parameters. The fixture design promotes good test maintainability.


314-327: Good test utility for reward function setup

The _write_rewards_file helper method is well-designed for creating temporary reward functions needed for GRPO testing. The approach is clean and self-contained.


381-442: Excellent use of parametrized testing

The parametrized test approach for different RL configurations is well-structured and provides comprehensive coverage while avoiding code duplication.

@NanoCode012
Copy link
Copy Markdown
Collaborator Author

NanoCode012 commented May 27, 2025

In e07c33c and 0e52b4d , I have added parity for setting custom optimizer in RL and RM/PRM.

I think I addressed most of the comments left as well with one under discussion for rename from trainer_builder->builders and refactoring to private functions within the builders.

@github-actions github-actions Bot temporarily deployed to preview May 27, 2025 08:40 Inactive
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
tests/core/test_trainer_builder.py (1)

12-12: Import path still needs to be updated after refactor

Based on the past review comments and the AI summary indicating that trainer builder was restructured into modular submodules, this import path is still outdated and needs to be fixed.

-from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
+from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder
+from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder
🧹 Nitpick comments (2)
tests/core/test_trainer_builder.py (2)

254-283: Good helper method with minor improvement opportunity

The helper method for testing common training arguments is well-designed. Consider extracting the magic numbers into named constants for better maintainability.

+# Test constants
+EXPECTED_GRPO_BATCH_SIZE = 4
+EXPECTED_DEFAULT_BATCH_SIZE = 2
+
 def _test_common_training_arguments(self, training_arguments, rl: str):
     """Helper to test common arguments across all variants"""
     # Basic training settings
     if rl == "grpo":
         # grpo_cfg's micro_batch_size is diff from others
-        assert training_arguments.per_device_train_batch_size == 4
+        assert training_arguments.per_device_train_batch_size == EXPECTED_GRPO_BATCH_SIZE
     else:
-        assert training_arguments.per_device_train_batch_size == 2
+        assert training_arguments.per_device_train_batch_size == EXPECTED_DEFAULT_BATCH_SIZE

328-360: Proper resource cleanup but consider using context manager

The GRPO test properly handles sys.path manipulation with try/finally, but consider using a context manager for cleaner resource management.

from contextlib import contextmanager

@contextmanager
def temp_sys_path(path):
    """Context manager for temporarily adding a path to sys.path"""
    sys.path.insert(0, str(path))
    try:
        yield
    finally:
        if str(path) in sys.path:
            sys.path.remove(str(path))

def test_grpo_training_arguments(self, grpo_cfg, model, tokenizer, tmp_path):
    rewards_dir = tmp_path / "rewards_test"
    self._write_rewards_file(rewards_dir)
    
    with temp_sys_path(rewards_dir):
        builder = HFRLTrainerBuilder(grpo_cfg, model, tokenizer)
        # ... rest of test
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 132a1dd and 704e7d0.

📒 Files selected for processing (1)
  • tests/core/test_trainer_builder.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (7)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: preview
🔇 Additional comments (5)
tests/core/test_trainer_builder.py (5)

22-89: Excellent comprehensive base configuration fixture

The base configuration fixture is well-structured with logical grouping of settings and good coverage of common parameters. The normalization at the end ensures proper configuration handling.


314-327: Good reward function setup with proper file handling

The reward function file creation is well-implemented for testing GRPO functionality. The temporary file approach is appropriate for isolated testing.


381-491: Comprehensive parameterized test with proper dataset handling

The parameterized test for custom optimizers excellently covers all RL configurations with appropriate dataset setups. The conditional logic for different configurations is well-handled, and the mock/real dataset approach is appropriate.


527-568: Well-structured SFT tests with good coverage

The SFT trainer tests provide good coverage of training arguments and custom optimizer functionality. The parameterization across different configuration types (sft, rm, prm) is comprehensive.


575-596: Good plugin test with proper fixture handling

The plugin test correctly addresses the previous review comment by creating a copy of the fixture instead of modifying it directly. The error handling logic properly distinguishes between different types of failures.

This reverts commit 99c8859.
@github-actions github-actions Bot temporarily deployed to preview May 27, 2025 09:42 Inactive
@djsaunde djsaunde self-requested a review May 27, 2025 15:30
Copy link
Copy Markdown
Collaborator

@djsaunde djsaunde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO good to go once the aforementioned naming change is in. Great stuff!

@github-actions github-actions Bot temporarily deployed to preview May 30, 2025 02:51 Inactive
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (7)
src/axolotl/core/builders/causal.py (5)

77-77: Track or address the TODO comment

The TODO comment should be tracked in an issue or addressed directly. If the loss watchdog callback can be moved to the base class, it would improve code organization.

Would you like me to help evaluate whether the loss watchdog callback can be moved to the base class?


89-106: Refactor duplicated logging callback creation

The code for creating logging prediction callbacks is duplicated for wandb, mlflow, and comet_ml. Consider refactoring this into a helper method.

-        if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
-            LogPredictionCallback = log_prediction_callback_factory(
-                trainer, self.tokenizer, "wandb"
-            )
-            callbacks.append(LogPredictionCallback(self.cfg))
-        if (
-            self.cfg.use_mlflow
-            and is_mlflow_available()
-            and self.cfg.eval_table_size > 0
-        ):
-            LogPredictionCallback = log_prediction_callback_factory(
-                trainer, self.tokenizer, "mlflow"
-            )
-            callbacks.append(LogPredictionCallback(self.cfg))
-        if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
-            LogPredictionCallback = log_prediction_callback_factory(
-                trainer, self.tokenizer, "comet_ml"
-            )
-            callbacks.append(LogPredictionCallback(self.cfg))
+        if self.cfg.eval_table_size > 0:
+            log_providers = []
+            if self.cfg.use_wandb:
+                log_providers.append("wandb")
+            if self.cfg.use_mlflow and is_mlflow_available():
+                log_providers.append("mlflow")
+            if self.cfg.use_comet and is_comet_available():
+                log_providers.append("comet_ml")
+            
+            for provider in log_providers:
+                LogPredictionCallback = log_prediction_callback_factory(
+                    trainer, self.tokenizer, provider
+                )
+                callbacks.append(LogPredictionCallback(self.cfg))

365-368: Simplify nested if statements

Combine the nested if statements for better readability.

-        if eval_data_collator := self.build_collator(
-            training_args, is_eval=True, **data_collator_kwargs
-        ):
-            if not (self.cfg.reward_model or self.cfg.process_reward_model):
-                trainer_kwargs["eval_data_collator"] = eval_data_collator
+        if (eval_data_collator := self.build_collator(
+            training_args, is_eval=True, **data_collator_kwargs
+        )) and not (self.cfg.reward_model or self.cfg.process_reward_model):
+            trainer_kwargs["eval_data_collator"] = eval_data_collator
🧰 Tools
🪛 Ruff (0.11.9)

365-368: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


382-383: Use 'not in' for membership test

For better readability, use not in instead of negating the in operator.

-        if (
-            not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
-            and self.cfg.datasets is not None
-        ):
+        if (
+            trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
+            and self.cfg.datasets is not None
+        ):
🧰 Tools
🪛 Ruff (0.11.9)

382-382: Test for membership should be not in

Convert to not in

(E713)


401-405: Consider safer DeepSpeed configuration update

Directly modifying the accelerator state's DeepSpeed configuration might be fragile if the internal structure changes. Consider using a more robust approach or adding error handling.

         if self.cfg.deepspeed and self.cfg.sample_packing:
-            trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
-                "train_micro_batch_size_per_gpu"
-            ] = self.cfg.micro_batch_size
+            try:
+                if hasattr(trainer.accelerator.state, 'deepspeed_plugin') and \
+                   hasattr(trainer.accelerator.state.deepspeed_plugin, 'deepspeed_config'):
+                    trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
+                        "train_micro_batch_size_per_gpu"
+                    ] = self.cfg.micro_batch_size
+                else:
+                    LOG.warning("Unable to set DeepSpeed micro batch size: plugin structure not found")
+            except Exception as e:
+                LOG.warning(f"Failed to set DeepSpeed micro batch size: {e}")
src/axolotl/core/builders/base.py (2)

168-183: Consider implementing or documenting the hook methods.

All four hook methods have TODO comments. Consider either implementing them if needed, documenting their intended use, or removing them if they're not required.

Would you like help implementing these hook methods or documenting their intended usage?


413-416: Consider more selective error handling for torch._dynamo.

Suppressing all errors might hide real issues during compilation. Consider handling specific expected errors or logging suppressed errors for debugging.

Consider logging suppressed errors:

         if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
+            # Log suppressed errors for debugging if needed
+            if LOG.isEnabledFor(logging.DEBUG):
+                torch._dynamo.config.verbose = True
             torch._dynamo.config.suppress_errors = (  # pylint: disable=protected-access
                 True
             )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 704e7d0 and f26d63a.

📒 Files selected for processing (18)
  • _quarto.yml (1 hunks)
  • docs/config.qmd (2 hunks)
  • src/axolotl/core/builders/__init__.py (1 hunks)
  • src/axolotl/core/builders/base.py (1 hunks)
  • src/axolotl/core/builders/causal.py (1 hunks)
  • src/axolotl/core/builders/rl.py (1 hunks)
  • src/axolotl/core/trainers/grpo/__init__.py (2 hunks)
  • src/axolotl/core/trainers/mixins/optimizer.py (1 hunks)
  • src/axolotl/train.py (1 hunks)
  • src/axolotl/utils/callbacks/__init__.py (1 hunks)
  • src/axolotl/utils/callbacks/comet_.py (1 hunks)
  • src/axolotl/utils/callbacks/lisa.py (1 hunks)
  • src/axolotl/utils/callbacks/mlflow_.py (1 hunks)
  • src/axolotl/utils/data/rl.py (1 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • src/axolotl/utils/trainer.py (1 hunks)
  • tests/core/test_builders.py (1 hunks)
  • tests/e2e/test_imports.py (1 hunks)
✅ Files skipped from review due to trivial changes (4)
  • src/axolotl/utils/callbacks/init.py
  • tests/e2e/test_imports.py
  • src/axolotl/train.py
  • src/axolotl/core/builders/init.py
🚧 Files skipped from review as they are similar to previous changes (9)
  • src/axolotl/utils/callbacks/comet_.py
  • src/axolotl/utils/callbacks/mlflow_.py
  • _quarto.yml
  • src/axolotl/utils/data/rl.py
  • src/axolotl/utils/callbacks/lisa.py
  • docs/config.qmd
  • src/axolotl/core/trainers/mixins/optimizer.py
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/core/trainers/grpo/init.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
src/axolotl/utils/trainer.py (2)
src/axolotl/core/builders/causal.py (1)
  • HFCausalTrainerBuilder (58-487)
src/axolotl/core/builders/rl.py (1)
  • HFRLTrainerBuilder (28-228)
src/axolotl/core/builders/causal.py (14)
src/axolotl/core/builders/base.py (13)
  • TrainerBuilderBase (49-503)
  • get_callbacks (107-148)
  • get_post_trainer_create_callbacks (150-166)
  • build (104-105)
  • _set_base_training_args (439-503)
  • hook_pre_create_training_args (168-170)
  • hook_post_create_training_args (172-174)
  • hook_pre_create_trainer (176-178)
  • train_dataset (80-81)
  • train_dataset (84-85)
  • eval_dataset (88-89)
  • eval_dataset (92-93)
  • hook_post_create_trainer (180-182)
src/axolotl/core/trainers/mamba.py (1)
  • AxolotlMambaTrainer (8-32)
src/axolotl/core/trainers/trl.py (2)
  • AxolotlPRMTrainer (117-124)
  • AxolotlRewardTrainer (107-114)
src/axolotl/core/training_args.py (3)
  • AxolotlPRMConfig (283-286)
  • AxolotlRewardConfig (276-279)
  • AxolotlTrainingArguments (240-246)
src/axolotl/integrations/base.py (1)
  • PluginManager (281-550)
src/axolotl/monkeypatch/relora.py (1)
  • ReLoRACallback (83-255)
src/axolotl/utils/__init__.py (2)
  • is_comet_available (16-17)
  • is_mlflow_available (12-13)
src/axolotl/utils/callbacks/__init__.py (7)
  • EvalFirstStepCallback (56-72)
  • LossWatchDogCallback (133-160)
  • SaveBetterTransformerModelCallback (75-108)
  • bench_eval_callback_factory (163-364)
  • causal_lm_bench_eval_callback_factory (367-575)
  • colab_inference_post_train_callback (880-902)
  • log_prediction_callback_factory (578-790)
src/axolotl/utils/callbacks/lisa.py (1)
  • lisa_callback_factory (23-92)
src/axolotl/utils/callbacks/qat.py (1)
  • QATCallback (33-50)
src/axolotl/utils/chat_templates.py (1)
  • get_chat_template_from_config (120-132)
src/axolotl/utils/collators/batching.py (3)
  • BatchSamplerDataCollatorForSeq2Seq (127-153)
  • DataCollatorForSeq2Seq (12-123)
  • V2BatchSamplerDataCollatorForSeq2Seq (157-183)
src/axolotl/utils/collators/mamba.py (1)
  • MambaDataCollator (15-39)
src/axolotl/utils/collators/mm_chat.py (1)
  • MultiModalChatDataCollator (18-92)
🪛 Ruff (0.11.9)
src/axolotl/core/builders/causal.py

365-368: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


382-382: Test for membership should be not in

Convert to not in

(E713)

⏰ Context from checks skipped due to timeout of 90000ms (6)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.5.1)
🔇 Additional comments (16)
src/axolotl/core/builders/causal.py (3)

132-147: LGTM!

The trainer class selection logic is clear and follows a well-defined precedence order.


217-218: Verify eval_accumulation_steps configuration

Setting eval_accumulation_steps to gradient_accumulation_steps might not be appropriate for all training scenarios. According to the PR objectives, there's a remaining TODO to decide whether this should be set for RL training similar to SFT.

Could you confirm if this setting is intentional for all trainer types, including reward models and RL trainers?


408-487: LGTM!

The collator selection logic is comprehensive and well-structured, handling all necessary scenarios appropriately.

tests/core/test_builders.py (7)

1-21: Test file structure and imports look good.

The imports are well-organized and appropriate for testing the builder classes.


22-89: Base configuration fixture is comprehensive and well-structured.

Good use of a base fixture that other configurations can extend. The configuration covers all essential training parameters.


91-236: RL configuration fixtures are properly defined.

Each RL variant fixture correctly extends the base configuration with appropriate algorithm-specific parameters. Good coverage of different training paradigms.


249-380: RLHF trainer builder tests are comprehensive.

Good test coverage for all RL variants with proper validation of both common and algorithm-specific training arguments.


493-568: Causal trainer builder tests are well-implemented.

Good coverage of SFT, reward model, and process reward model configurations with proper dataset handling.


570-596: Plugin integration test effectively guards against regression.

Good test design that specifically checks for the trainer_cls None bug (#2693) while allowing other expected exceptions.


488-491: ⚠️ Potential issue

Fix potential NameError in finally block.

The rewards_dir variable is only defined when cfg_string == "grpo_cfg", but the finally block unconditionally references it, which could cause a NameError for other test cases.

Apply this fix to prevent the NameError:

         finally:
             # remove imported module from path
-            if cfg_string == "grpo_cfg" and str(rewards_dir) in sys.path:
-                sys.path.remove(str(rewards_dir))
+            if cfg_string == "grpo_cfg":
+                if 'rewards_dir' in locals() and str(rewards_dir) in sys.path:
+                    sys.path.remove(str(rewards_dir))

Likely an incorrect or invalid review comment.

src/axolotl/core/builders/base.py (3)

49-102: Base class initialization and properties are well-designed.

Good design with the axolotl tagging for model tracking and proper property encapsulation.


107-167: Callback management is comprehensive and well-organized.

Good separation of concerns between callbacks that can be created before and after trainer instantiation.


439-504: Main configuration method is well-structured.

The _set_base_training_args method effectively orchestrates all configuration methods and provides a clean interface for building training arguments.

src/axolotl/utils/trainer.py (1)

19-19: LGTM! Import path correctly updated.

The import path has been properly updated to reflect the new modular structure where trainer_builder.py has been split into the builders package.

src/axolotl/core/builders/rl.py (2)

40-77: Well-structured trainer class selection logic.

The method properly prioritizes the plugin system and handles all supported RL types with appropriate error handling for unsupported types.


195-200: Excellent API compatibility handling.

Great use of inspect.signature to dynamically handle the different parameter names (tokenizer vs processing_class) across trainer implementations. This makes the code more robust to API variations.

if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))

if any("COLAB_" in key for key in os.environ):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix Colab environment detection

The current check would match any environment variable containing "COLAB_" anywhere in the key name. Use str.startswith() for more precise detection.

-        if any("COLAB_" in key for key in os.environ):
+        if any(key.startswith("COLAB_") for key in os.environ):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if any("COLAB_" in key for key in os.environ):
if any(key.startswith("COLAB_") for key in os.environ):
🤖 Prompt for AI Agents
In src/axolotl/core/builders/causal.py at line 125, replace the condition that
checks if any environment variable key contains "COLAB_" with a condition that
checks if any key starts with "COLAB_" using str.startswith(). This change
ensures more precise detection of the Colab environment.

Comment on lines +271 to +281
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit

optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)

LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove deprecated ao_adamw_4bit optimizer.

The deprecation comment indicates this code should have been removed by April 1, 2025, but it's still present in May 2025.

Remove the deprecated optimizer case:

-            elif self.cfg.optimizer == "ao_adamw_4bit":
-                # TODO remove 20250401
-                from torchao.prototype.low_bit_optim import AdamW4bit
-
-                optimizer_cls = AdamW4bit
-                optimizer_kwargs.update(adam_kwargs)
-
-                LOG.warning(
-                    f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
-                )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
elif self.cfg.optimizer == "ao_adamw_8bit":
🤖 Prompt for AI Agents
In src/axolotl/core/builders/base.py between lines 271 and 281, remove the
entire conditional block handling the deprecated "ao_adamw_4bit" optimizer,
including the import, optimizer class assignment, kwargs update, and warning
log, as it is no longer needed past the deprecation date.

Comment on lines +101 to +104
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify the alpha to beta mapping is intentional.

The comment indicates TRL does "odd mapping of alpha to beta". Please confirm this is the intended behavior and consider adding a more detailed explanation or a link to TRL documentation explaining this mapping.


🌐 Web query:

How does TRL handle the alpha parameter in ORPO training and why is it mapped to beta?

💡 Result:

In the context of TRL (Transformer Reinforcement Learning) using the ORPO (Off-Policy Regularized Policy Optimization) algorithm, the "alpha" parameter typically refers to a coefficient that controls some form of regularization or weighting in the loss function.

Here's why and how TRL handles the alpha parameter and its mapping to beta:

  1. Alpha as a general regularization coefficient:
    In the original ORPO formulation, there is an alpha parameter used to weigh the influence of the regularization term that helps balance exploration and exploitation or constrains policy updates.

  2. Mapping alpha to beta in TRL:
    In the TRL implementation, what is conceptually the "alpha" parameter in ORPO is renamed or mapped to "beta." This is often done for clarity or to conform to the naming conventions used within the TRL codebase or framework. Essentially, beta in TRL corresponds to the same hyperparameter role as alpha in the original ORPO algorithm. It's just a matter of naming alignment.

  3. Practical implications:
    When tuning or configuring ORPO training via TRL, you set the beta value knowing that it acts as the alpha in the ORPO paper's equations — controlling how strongly the regularization impacts the policy optimization.

So, the parameter isn't functionally changed — it's just that TRL uses the name beta for what ORPO calls alpha, to keep consistent with their code structure or to avoid naming conflicts.

If you want, I can dive deeper into the specific equations or parts of the code where this mapping happens. Would you like me to?


Clarify ORPO “alpha→beta” mapping

TRL deliberately renames the ORPO algorithm’s alpha hyperparameter to “beta” for naming consistency. Please update the inline comment to reflect this and, if possible, link to the TRL documentation for ORPO:

• File: src/axolotl/core/builders/rl.py
• Lines: 101–104

Suggested diff:

-        elif self.cfg.orpo_alpha is not None:
-            # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
-            training_args_kwargs["beta"] = self.cfg.orpo_alpha
+        elif self.cfg.orpo_alpha is not None:
+            # TRL uses “beta” for what the ORPO paper calls “alpha” to align naming.
+            # See TRL ORPO docs for details: https://github.com/CarperAI/trl#off-policy-regularized-policy-optimization
+            training_args_kwargs["beta"] = self.cfg.orpo_alpha
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
elif self.cfg.orpo_alpha is not None:
# TRL uses “beta” for what the ORPO paper calls “alpha” to align naming.
# See TRL ORPO docs for details: https://github.com/CarperAI/trl#off-policy-regularized-policy-optimization
training_args_kwargs["beta"] = self.cfg.orpo_alpha
🤖 Prompt for AI Agents
In src/axolotl/core/builders/rl.py around lines 101 to 104, update the inline
comment to clarify that TRL intentionally renames the ORPO algorithm's alpha
hyperparameter to "beta" for naming consistency. Add a brief explanation that
beta in TRL corresponds to alpha in the original ORPO formulation and, if
possible, include a link to the TRL documentation for ORPO to provide context
for this mapping.

@NanoCode012 NanoCode012 merged commit 6778856 into main May 30, 2025
18 checks passed
@NanoCode012 NanoCode012 deleted the fix/orpo_feature_parity branch May 30, 2025 04:21
@NanoCode012
Copy link
Copy Markdown
Collaborator Author

Thanks for all the reviews. It was big PR. If there are any missing gaps, I'll address them in a follow up PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants