Fix: RL base feature parity#2133
Conversation
0719188 to
4b8f65b
Compare
1b15a11 to
62d04e4
Compare
65a83b7 to
93a2ecc
Compare
|
Thanks @NanoCode012 ! This should be good to go once the multi gpu tests pass too https://github.com/axolotl-ai-cloud/axolotl/actions/runs/13457763772 |
fc04dcf to
4321607
Compare
|
looks like the multi-gpu GRPO tests are failing https://github.com/axolotl-ai-cloud/axolotl/actions/runs/13469023440/job/37640091555 |
4321607 to
8331312
Compare
8143e16 to
6a37ec7
Compare
9078cfe to
b5ca6a7
Compare
Codecov ReportAttention: Patch coverage is 📢 Thoughts on this report? Let us know! |
998de17 to
79917ac
Compare
|
One suggestion: can we break out the builders
├── __init__.py
├── base.py
├── causal.py
└── rl.py |
Yep, I can do that tomorrow! |
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (2)
tests/core/test_trainer_builder.py (2)
12-12: Import path needs updating after refactor.The import statement is using the old path structure. Based on the refactor that moved trainer builders to modular submodules, the imports should be updated.
-from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder +from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder +from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder
579-579: Avoid directly modifying fixture in test.Directly modifying the fixture can affect other tests that use the same fixture. Create a copy of the configuration instead.
- kto_cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"] + cfg = kto_cfg.copy() + cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]And update the builder instantiation:
- builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer) + builder = HFRLTrainerBuilder(cfg, model, tokenizer)
🧹 Nitpick comments (1)
tests/core/test_trainer_builder.py (1)
485-487: Simplify nested if statements.The nested if statements can be combined for better readability.
- if cfg_string == "grpo_cfg": - # remove imported module from path - if str(rewards_dir) in sys.path: - sys.path.remove(str(rewards_dir)) + if cfg_string == "grpo_cfg" and str(rewards_dir) in sys.path: + # remove imported module from path + sys.path.remove(str(rewards_dir))🧰 Tools
🪛 Ruff (0.11.9)
485-487: Use a single
ifstatement instead of nestedifstatements(SIM102)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
_quarto.yml(1 hunks)src/axolotl/core/trainer_builder/rl.py(1 hunks)src/axolotl/core/trainers/grpo/trainer.py(3 hunks)src/axolotl/core/trainers/trl.py(2 hunks)src/axolotl/utils/data/rl.py(1 hunks)tests/core/test_trainer_builder.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
- src/axolotl/utils/data/rl.py
- src/axolotl/core/trainers/grpo/trainer.py
- _quarto.yml
- src/axolotl/core/trainers/trl.py
- src/axolotl/core/trainer_builder/rl.py
🧰 Additional context used
🪛 Ruff (0.11.9)
tests/core/test_trainer_builder.py
485-487: Use a single if statement instead of nested if statements
(SIM102)
🔇 Additional comments (5)
tests/core/test_trainer_builder.py (5)
22-235: Well-structured fixture organization.The comprehensive fixture setup with a base configuration and specialized variants for different training types (DPO, ORPO, KTO, GRPO, IPO, SIMPO, SFT, RM, PRM) provides excellent test coverage and maintainability.
253-282: Excellent helper method for common test assertions.The
_test_common_training_argumentshelper method effectively reduces code duplication and ensures consistent testing of shared training argument properties across different RL variants.
312-357: Comprehensive GRPO test with proper cleanup.The test properly handles dynamic reward function creation and cleanup, which is essential for testing GRPO functionality. The try/finally block ensures proper cleanup of the Python path.
378-489: Thorough parameterized testing for custom optimizer.The parameterized test comprehensively covers optimizer configuration across different RL types, ensuring the Muon optimizer is correctly instantiated and configured.
🧰 Tools
🪛 Ruff (0.11.9)
485-487: Use a single
ifstatement instead of nestedifstatements(SIM102)
491-565: Comprehensive SFT trainer testing.The
TestHFCausalTrainerBuilderclass provides thorough coverage of SFT trainer functionality, including parameterized testing for different model types (SFT, RM, PRM) with custom optimizers.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (2)
tests/core/test_trainer_builder.py (2)
12-12: Fix outdated import path after refactorThe import path needs to be updated to reflect the new modular trainer builder structure.
-from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder +from axolotl.core.trainers.trainer_builder.sft import HFCausalTrainerBuilder +from axolotl.core.trainers.trainer_builder.rl import HFRLTrainerBuilder
576-595: Avoid directly modifying fixture in testDirectly modifying the fixture can affect other tests that use the same fixture. Create a copy of the configuration instead.
def test_trainer_cls_is_not_none_with_plugin(self, kto_cfg, model, tokenizer): """ Test that the trainer cls is not none with plugin Fixes #2693 """ - kto_cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"] + cfg = kto_cfg.copy() + cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"] # Expected AttributeError as we don't pass regular model configs to RL trainer builder # If it throws `TypeError: None is not a callable object`, trainer_cls could be None try: - builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer) + builder = HFRLTrainerBuilder(cfg, model, tokenizer) builder.build(100)
🧹 Nitpick comments (1)
tests/core/test_trainer_builder.py (1)
488-491: Simplify nested if statementsThe nested if statements can be combined into a single condition for better readability.
- if cfg_string == "grpo_cfg": - # remove imported module from path - if str(rewards_dir) in sys.path: - sys.path.remove(str(rewards_dir)) + if cfg_string == "grpo_cfg" and str(rewards_dir) in sys.path: + # remove imported module from path + sys.path.remove(str(rewards_dir))🧰 Tools
🪛 Ruff (0.11.9)
488-490: Use a single
ifstatement instead of nestedifstatements(SIM102)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/axolotl/core/trainer_builder/rl.py(1 hunks)src/axolotl/utils/schemas/config.py(1 hunks)tests/core/test_trainer_builder.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- src/axolotl/utils/schemas/config.py
🚧 Files skipped from review as they are similar to previous changes (1)
- src/axolotl/core/trainer_builder/rl.py
🧰 Additional context used
🪛 Ruff (0.11.9)
tests/core/test_trainer_builder.py
488-490: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms (9)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: preview
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
🔇 Additional comments (4)
tests/core/test_trainer_builder.py (4)
22-88: Well-structured base configuration fixtureThe base configuration fixture provides comprehensive coverage of common training parameters. The organization and parameter choices look appropriate for testing various trainer scenarios.
91-181: Comprehensive RL algorithm fixturesExcellent coverage of different RL algorithms (DPO, ORPO, KTO, GRPO, IPO, SIMPO) with algorithm-specific parameters. The fixture design promotes good test maintainability.
314-327: Good test utility for reward function setupThe
_write_rewards_filehelper method is well-designed for creating temporary reward functions needed for GRPO testing. The approach is clean and self-contained.
381-442: Excellent use of parametrized testingThe parametrized test approach for different RL configurations is well-structured and provides comprehensive coverage while avoiding code duplication.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tests/core/test_trainer_builder.py (1)
12-12: Import path still needs to be updated after refactorBased on the past review comments and the AI summary indicating that trainer builder was restructured into modular submodules, this import path is still outdated and needs to be fixed.
-from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder +from axolotl.core.trainer_builder.sft import HFCausalTrainerBuilder +from axolotl.core.trainer_builder.rl import HFRLTrainerBuilder
🧹 Nitpick comments (2)
tests/core/test_trainer_builder.py (2)
254-283: Good helper method with minor improvement opportunityThe helper method for testing common training arguments is well-designed. Consider extracting the magic numbers into named constants for better maintainability.
+# Test constants +EXPECTED_GRPO_BATCH_SIZE = 4 +EXPECTED_DEFAULT_BATCH_SIZE = 2 + def _test_common_training_arguments(self, training_arguments, rl: str): """Helper to test common arguments across all variants""" # Basic training settings if rl == "grpo": # grpo_cfg's micro_batch_size is diff from others - assert training_arguments.per_device_train_batch_size == 4 + assert training_arguments.per_device_train_batch_size == EXPECTED_GRPO_BATCH_SIZE else: - assert training_arguments.per_device_train_batch_size == 2 + assert training_arguments.per_device_train_batch_size == EXPECTED_DEFAULT_BATCH_SIZE
328-360: Proper resource cleanup but consider using context managerThe GRPO test properly handles sys.path manipulation with try/finally, but consider using a context manager for cleaner resource management.
from contextlib import contextmanager @contextmanager def temp_sys_path(path): """Context manager for temporarily adding a path to sys.path""" sys.path.insert(0, str(path)) try: yield finally: if str(path) in sys.path: sys.path.remove(str(path)) def test_grpo_training_arguments(self, grpo_cfg, model, tokenizer, tmp_path): rewards_dir = tmp_path / "rewards_test" self._write_rewards_file(rewards_dir) with temp_sys_path(rewards_dir): builder = HFRLTrainerBuilder(grpo_cfg, model, tokenizer) # ... rest of test
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/core/test_trainer_builder.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (7)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: preview
🔇 Additional comments (5)
tests/core/test_trainer_builder.py (5)
22-89: Excellent comprehensive base configuration fixtureThe base configuration fixture is well-structured with logical grouping of settings and good coverage of common parameters. The normalization at the end ensures proper configuration handling.
314-327: Good reward function setup with proper file handlingThe reward function file creation is well-implemented for testing GRPO functionality. The temporary file approach is appropriate for isolated testing.
381-491: Comprehensive parameterized test with proper dataset handlingThe parameterized test for custom optimizers excellently covers all RL configurations with appropriate dataset setups. The conditional logic for different configurations is well-handled, and the mock/real dataset approach is appropriate.
527-568: Well-structured SFT tests with good coverageThe SFT trainer tests provide good coverage of training arguments and custom optimizer functionality. The parameterization across different configuration types (sft, rm, prm) is comprehensive.
575-596: Good plugin test with proper fixture handlingThe plugin test correctly addresses the previous review comment by creating a copy of the fixture instead of modifying it directly. The error handling logic properly distinguishes between different types of failures.
This reverts commit 99c8859.
djsaunde
left a comment
There was a problem hiding this comment.
IMO good to go once the aforementioned naming change is in. Great stuff!
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (7)
src/axolotl/core/builders/causal.py (5)
77-77: Track or address the TODO commentThe TODO comment should be tracked in an issue or addressed directly. If the loss watchdog callback can be moved to the base class, it would improve code organization.
Would you like me to help evaluate whether the loss watchdog callback can be moved to the base class?
89-106: Refactor duplicated logging callback creationThe code for creating logging prediction callbacks is duplicated for wandb, mlflow, and comet_ml. Consider refactoring this into a helper method.
- if self.cfg.use_wandb and self.cfg.eval_table_size > 0: - LogPredictionCallback = log_prediction_callback_factory( - trainer, self.tokenizer, "wandb" - ) - callbacks.append(LogPredictionCallback(self.cfg)) - if ( - self.cfg.use_mlflow - and is_mlflow_available() - and self.cfg.eval_table_size > 0 - ): - LogPredictionCallback = log_prediction_callback_factory( - trainer, self.tokenizer, "mlflow" - ) - callbacks.append(LogPredictionCallback(self.cfg)) - if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0: - LogPredictionCallback = log_prediction_callback_factory( - trainer, self.tokenizer, "comet_ml" - ) - callbacks.append(LogPredictionCallback(self.cfg)) + if self.cfg.eval_table_size > 0: + log_providers = [] + if self.cfg.use_wandb: + log_providers.append("wandb") + if self.cfg.use_mlflow and is_mlflow_available(): + log_providers.append("mlflow") + if self.cfg.use_comet and is_comet_available(): + log_providers.append("comet_ml") + + for provider in log_providers: + LogPredictionCallback = log_prediction_callback_factory( + trainer, self.tokenizer, provider + ) + callbacks.append(LogPredictionCallback(self.cfg))
365-368: Simplify nested if statementsCombine the nested if statements for better readability.
- if eval_data_collator := self.build_collator( - training_args, is_eval=True, **data_collator_kwargs - ): - if not (self.cfg.reward_model or self.cfg.process_reward_model): - trainer_kwargs["eval_data_collator"] = eval_data_collator + if (eval_data_collator := self.build_collator( + training_args, is_eval=True, **data_collator_kwargs + )) and not (self.cfg.reward_model or self.cfg.process_reward_model): + trainer_kwargs["eval_data_collator"] = eval_data_collator🧰 Tools
🪛 Ruff (0.11.9)
365-368: Use a single
ifstatement instead of nestedifstatementsCombine
ifstatements usingand(SIM102)
382-383: Use 'not in' for membership testFor better readability, use
not ininstead of negating theinoperator.- if ( - not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer]) - and self.cfg.datasets is not None - ): + if ( + trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer] + and self.cfg.datasets is not None + ):🧰 Tools
🪛 Ruff (0.11.9)
382-382: Test for membership should be
not inConvert to
not in(E713)
401-405: Consider safer DeepSpeed configuration updateDirectly modifying the accelerator state's DeepSpeed configuration might be fragile if the internal structure changes. Consider using a more robust approach or adding error handling.
if self.cfg.deepspeed and self.cfg.sample_packing: - trainer.accelerator.state.deepspeed_plugin.deepspeed_config[ - "train_micro_batch_size_per_gpu" - ] = self.cfg.micro_batch_size + try: + if hasattr(trainer.accelerator.state, 'deepspeed_plugin') and \ + hasattr(trainer.accelerator.state.deepspeed_plugin, 'deepspeed_config'): + trainer.accelerator.state.deepspeed_plugin.deepspeed_config[ + "train_micro_batch_size_per_gpu" + ] = self.cfg.micro_batch_size + else: + LOG.warning("Unable to set DeepSpeed micro batch size: plugin structure not found") + except Exception as e: + LOG.warning(f"Failed to set DeepSpeed micro batch size: {e}")src/axolotl/core/builders/base.py (2)
168-183: Consider implementing or documenting the hook methods.All four hook methods have TODO comments. Consider either implementing them if needed, documenting their intended use, or removing them if they're not required.
Would you like help implementing these hook methods or documenting their intended usage?
413-416: Consider more selective error handling for torch._dynamo.Suppressing all errors might hide real issues during compilation. Consider handling specific expected errors or logging suppressed errors for debugging.
Consider logging suppressed errors:
if self.cfg.torch_compile and getattr(torch, "_dynamo", None): + # Log suppressed errors for debugging if needed + if LOG.isEnabledFor(logging.DEBUG): + torch._dynamo.config.verbose = True torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access True )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (18)
_quarto.yml(1 hunks)docs/config.qmd(2 hunks)src/axolotl/core/builders/__init__.py(1 hunks)src/axolotl/core/builders/base.py(1 hunks)src/axolotl/core/builders/causal.py(1 hunks)src/axolotl/core/builders/rl.py(1 hunks)src/axolotl/core/trainers/grpo/__init__.py(2 hunks)src/axolotl/core/trainers/mixins/optimizer.py(1 hunks)src/axolotl/train.py(1 hunks)src/axolotl/utils/callbacks/__init__.py(1 hunks)src/axolotl/utils/callbacks/comet_.py(1 hunks)src/axolotl/utils/callbacks/lisa.py(1 hunks)src/axolotl/utils/callbacks/mlflow_.py(1 hunks)src/axolotl/utils/data/rl.py(1 hunks)src/axolotl/utils/schemas/config.py(1 hunks)src/axolotl/utils/trainer.py(1 hunks)tests/core/test_builders.py(1 hunks)tests/e2e/test_imports.py(1 hunks)
✅ Files skipped from review due to trivial changes (4)
- src/axolotl/utils/callbacks/init.py
- tests/e2e/test_imports.py
- src/axolotl/train.py
- src/axolotl/core/builders/init.py
🚧 Files skipped from review as they are similar to previous changes (9)
- src/axolotl/utils/callbacks/comet_.py
- src/axolotl/utils/callbacks/mlflow_.py
- _quarto.yml
- src/axolotl/utils/data/rl.py
- src/axolotl/utils/callbacks/lisa.py
- docs/config.qmd
- src/axolotl/core/trainers/mixins/optimizer.py
- src/axolotl/utils/schemas/config.py
- src/axolotl/core/trainers/grpo/init.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
src/axolotl/utils/trainer.py (2)
src/axolotl/core/builders/causal.py (1)
HFCausalTrainerBuilder(58-487)src/axolotl/core/builders/rl.py (1)
HFRLTrainerBuilder(28-228)
src/axolotl/core/builders/causal.py (14)
src/axolotl/core/builders/base.py (13)
TrainerBuilderBase(49-503)get_callbacks(107-148)get_post_trainer_create_callbacks(150-166)build(104-105)_set_base_training_args(439-503)hook_pre_create_training_args(168-170)hook_post_create_training_args(172-174)hook_pre_create_trainer(176-178)train_dataset(80-81)train_dataset(84-85)eval_dataset(88-89)eval_dataset(92-93)hook_post_create_trainer(180-182)src/axolotl/core/trainers/mamba.py (1)
AxolotlMambaTrainer(8-32)src/axolotl/core/trainers/trl.py (2)
AxolotlPRMTrainer(117-124)AxolotlRewardTrainer(107-114)src/axolotl/core/training_args.py (3)
AxolotlPRMConfig(283-286)AxolotlRewardConfig(276-279)AxolotlTrainingArguments(240-246)src/axolotl/integrations/base.py (1)
PluginManager(281-550)src/axolotl/monkeypatch/relora.py (1)
ReLoRACallback(83-255)src/axolotl/utils/__init__.py (2)
is_comet_available(16-17)is_mlflow_available(12-13)src/axolotl/utils/callbacks/__init__.py (7)
EvalFirstStepCallback(56-72)LossWatchDogCallback(133-160)SaveBetterTransformerModelCallback(75-108)bench_eval_callback_factory(163-364)causal_lm_bench_eval_callback_factory(367-575)colab_inference_post_train_callback(880-902)log_prediction_callback_factory(578-790)src/axolotl/utils/callbacks/lisa.py (1)
lisa_callback_factory(23-92)src/axolotl/utils/callbacks/qat.py (1)
QATCallback(33-50)src/axolotl/utils/chat_templates.py (1)
get_chat_template_from_config(120-132)src/axolotl/utils/collators/batching.py (3)
BatchSamplerDataCollatorForSeq2Seq(127-153)DataCollatorForSeq2Seq(12-123)V2BatchSamplerDataCollatorForSeq2Seq(157-183)src/axolotl/utils/collators/mamba.py (1)
MambaDataCollator(15-39)src/axolotl/utils/collators/mm_chat.py (1)
MultiModalChatDataCollator(18-92)
🪛 Ruff (0.11.9)
src/axolotl/core/builders/causal.py
365-368: Use a single if statement instead of nested if statements
Combine if statements using and
(SIM102)
382-382: Test for membership should be not in
Convert to not in
(E713)
⏰ Context from checks skipped due to timeout of 90000ms (6)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.5.1)
🔇 Additional comments (16)
src/axolotl/core/builders/causal.py (3)
132-147: LGTM!The trainer class selection logic is clear and follows a well-defined precedence order.
217-218: Verify eval_accumulation_steps configurationSetting
eval_accumulation_stepstogradient_accumulation_stepsmight not be appropriate for all training scenarios. According to the PR objectives, there's a remaining TODO to decide whether this should be set for RL training similar to SFT.Could you confirm if this setting is intentional for all trainer types, including reward models and RL trainers?
408-487: LGTM!The collator selection logic is comprehensive and well-structured, handling all necessary scenarios appropriately.
tests/core/test_builders.py (7)
1-21: Test file structure and imports look good.The imports are well-organized and appropriate for testing the builder classes.
22-89: Base configuration fixture is comprehensive and well-structured.Good use of a base fixture that other configurations can extend. The configuration covers all essential training parameters.
91-236: RL configuration fixtures are properly defined.Each RL variant fixture correctly extends the base configuration with appropriate algorithm-specific parameters. Good coverage of different training paradigms.
249-380: RLHF trainer builder tests are comprehensive.Good test coverage for all RL variants with proper validation of both common and algorithm-specific training arguments.
493-568: Causal trainer builder tests are well-implemented.Good coverage of SFT, reward model, and process reward model configurations with proper dataset handling.
570-596: Plugin integration test effectively guards against regression.Good test design that specifically checks for the trainer_cls None bug (#2693) while allowing other expected exceptions.
488-491:⚠️ Potential issueFix potential NameError in finally block.
The
rewards_dirvariable is only defined whencfg_string == "grpo_cfg", but the finally block unconditionally references it, which could cause a NameError for other test cases.Apply this fix to prevent the NameError:
finally: # remove imported module from path - if cfg_string == "grpo_cfg" and str(rewards_dir) in sys.path: - sys.path.remove(str(rewards_dir)) + if cfg_string == "grpo_cfg": + if 'rewards_dir' in locals() and str(rewards_dir) in sys.path: + sys.path.remove(str(rewards_dir))Likely an incorrect or invalid review comment.
src/axolotl/core/builders/base.py (3)
49-102: Base class initialization and properties are well-designed.Good design with the axolotl tagging for model tracking and proper property encapsulation.
107-167: Callback management is comprehensive and well-organized.Good separation of concerns between callbacks that can be created before and after trainer instantiation.
439-504: Main configuration method is well-structured.The
_set_base_training_argsmethod effectively orchestrates all configuration methods and provides a clean interface for building training arguments.src/axolotl/utils/trainer.py (1)
19-19: LGTM! Import path correctly updated.The import path has been properly updated to reflect the new modular structure where
trainer_builder.pyhas been split into thebuilderspackage.src/axolotl/core/builders/rl.py (2)
40-77: Well-structured trainer class selection logic.The method properly prioritizes the plugin system and handles all supported RL types with appropriate error handling for unsupported types.
195-200: Excellent API compatibility handling.Great use of
inspect.signatureto dynamically handle the different parameter names (tokenizervsprocessing_class) across trainer implementations. This makes the code more robust to API variations.
| if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: | ||
| callbacks.append(lisa_callback_factory(trainer)) | ||
|
|
||
| if any("COLAB_" in key for key in os.environ): |
There was a problem hiding this comment.
Fix Colab environment detection
The current check would match any environment variable containing "COLAB_" anywhere in the key name. Use str.startswith() for more precise detection.
- if any("COLAB_" in key for key in os.environ):
+ if any(key.startswith("COLAB_") for key in os.environ):📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if any("COLAB_" in key for key in os.environ): | |
| if any(key.startswith("COLAB_") for key in os.environ): |
🤖 Prompt for AI Agents
In src/axolotl/core/builders/causal.py at line 125, replace the condition that
checks if any environment variable key contains "COLAB_" with a condition that
checks if any key starts with "COLAB_" using str.startswith(). This change
ensures more precise detection of the Colab environment.
| elif self.cfg.optimizer == "ao_adamw_4bit": | ||
| # TODO remove 20250401 | ||
| from torchao.prototype.low_bit_optim import AdamW4bit | ||
|
|
||
| optimizer_cls = AdamW4bit | ||
| optimizer_kwargs.update(adam_kwargs) | ||
|
|
||
| LOG.warning( | ||
| f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead." | ||
| ) | ||
| elif self.cfg.optimizer == "ao_adamw_8bit": |
There was a problem hiding this comment.
Remove deprecated ao_adamw_4bit optimizer.
The deprecation comment indicates this code should have been removed by April 1, 2025, but it's still present in May 2025.
Remove the deprecated optimizer case:
- elif self.cfg.optimizer == "ao_adamw_4bit":
- # TODO remove 20250401
- from torchao.prototype.low_bit_optim import AdamW4bit
-
- optimizer_cls = AdamW4bit
- optimizer_kwargs.update(adam_kwargs)
-
- LOG.warning(
- f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
- )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| elif self.cfg.optimizer == "ao_adamw_4bit": | |
| # TODO remove 20250401 | |
| from torchao.prototype.low_bit_optim import AdamW4bit | |
| optimizer_cls = AdamW4bit | |
| optimizer_kwargs.update(adam_kwargs) | |
| LOG.warning( | |
| f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead." | |
| ) | |
| elif self.cfg.optimizer == "ao_adamw_8bit": | |
| elif self.cfg.optimizer == "ao_adamw_8bit": |
🤖 Prompt for AI Agents
In src/axolotl/core/builders/base.py between lines 271 and 281, remove the
entire conditional block handling the deprecated "ao_adamw_4bit" optimizer,
including the import, optimizer class assignment, kwargs update, and warning
log, as it is no longer needed past the deprecation date.
| elif self.cfg.orpo_alpha is not None: | ||
| # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? | ||
| training_args_kwargs["beta"] = self.cfg.orpo_alpha | ||
|
|
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Verify the alpha to beta mapping is intentional.
The comment indicates TRL does "odd mapping of alpha to beta". Please confirm this is the intended behavior and consider adding a more detailed explanation or a link to TRL documentation explaining this mapping.
🌐 Web query:
How does TRL handle the alpha parameter in ORPO training and why is it mapped to beta?
💡 Result:
In the context of TRL (Transformer Reinforcement Learning) using the ORPO (Off-Policy Regularized Policy Optimization) algorithm, the "alpha" parameter typically refers to a coefficient that controls some form of regularization or weighting in the loss function.
Here's why and how TRL handles the alpha parameter and its mapping to beta:
-
Alpha as a general regularization coefficient:
In the original ORPO formulation, there is an alpha parameter used to weigh the influence of the regularization term that helps balance exploration and exploitation or constrains policy updates. -
Mapping alpha to beta in TRL:
In the TRL implementation, what is conceptually the "alpha" parameter in ORPO is renamed or mapped to "beta." This is often done for clarity or to conform to the naming conventions used within the TRL codebase or framework. Essentially, beta in TRL corresponds to the same hyperparameter role as alpha in the original ORPO algorithm. It's just a matter of naming alignment. -
Practical implications:
When tuning or configuring ORPO training via TRL, you set the beta value knowing that it acts as the alpha in the ORPO paper's equations — controlling how strongly the regularization impacts the policy optimization.
So, the parameter isn't functionally changed — it's just that TRL uses the name beta for what ORPO calls alpha, to keep consistent with their code structure or to avoid naming conflicts.
If you want, I can dive deeper into the specific equations or parts of the code where this mapping happens. Would you like me to?
Clarify ORPO “alpha→beta” mapping
TRL deliberately renames the ORPO algorithm’s alpha hyperparameter to “beta” for naming consistency. Please update the inline comment to reflect this and, if possible, link to the TRL documentation for ORPO:
• File: src/axolotl/core/builders/rl.py
• Lines: 101–104
Suggested diff:
- elif self.cfg.orpo_alpha is not None:
- # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
- training_args_kwargs["beta"] = self.cfg.orpo_alpha
+ elif self.cfg.orpo_alpha is not None:
+ # TRL uses “beta” for what the ORPO paper calls “alpha” to align naming.
+ # See TRL ORPO docs for details: https://github.com/CarperAI/trl#off-policy-regularized-policy-optimization
+ training_args_kwargs["beta"] = self.cfg.orpo_alpha📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| elif self.cfg.orpo_alpha is not None: | |
| # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? | |
| training_args_kwargs["beta"] = self.cfg.orpo_alpha | |
| elif self.cfg.orpo_alpha is not None: | |
| # TRL uses “beta” for what the ORPO paper calls “alpha” to align naming. | |
| # See TRL ORPO docs for details: https://github.com/CarperAI/trl#off-policy-regularized-policy-optimization | |
| training_args_kwargs["beta"] = self.cfg.orpo_alpha |
🤖 Prompt for AI Agents
In src/axolotl/core/builders/rl.py around lines 101 to 104, update the inline
comment to clarify that TRL intentionally renames the ORPO algorithm's alpha
hyperparameter to "beta" for naming consistency. Add a brief explanation that
beta in TRL corresponds to alpha in the original ORPO formulation and, if
possible, include a link to the TRL documentation for ORPO to provide context
for this mapping.
|
Thanks for all the reviews. It was big PR. If there are any missing gaps, I'll address them in a follow up PR. |
Description
RL trainer was not loading some basic configs like
logging_stepsetc. This PR consolidates the setting of these params and cleans them up.This PR also fixes a case where we did not call
.mapwithnum_proc.To discuss:
use_reentrantremove_unused_columnsAddresses a lot of points in #2105
Motivation and Context
How has this been tested?
Untested!Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
torch_compile_modeand an optional fielddpo_label_smoothingfor advanced training customization.Refactor
Bug Fixes
Documentation
Tests