From 8cb87ea1fdb19484d3251582e351476ab92cef00 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 7 Aug 2025 17:36:20 +0000 Subject: [PATCH 1/7] use nanmena for loss aggregation (CP fix) --- src/axolotl/loaders/model.py | 4 +- src/axolotl/loaders/patch_manager.py | 6 + .../transformers/trainer_loss_calc.py | 143 ++++++++++++++++++ tests/monkeypatch/test_trainer_loss_calc.py | 38 +++++ 4 files changed, 189 insertions(+), 2 deletions(-) create mode 100644 src/axolotl/monkeypatch/transformers/trainer_loss_calc.py create mode 100644 tests/monkeypatch/test_trainer_loss_calc.py diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 7061e1ff38..c738fd2309 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -1,5 +1,5 @@ -"""Model loader class implementation for loading, configuring, and patching various -models. +""" +Model loader class implementation for loading, configuring, and patching various models. """ import gc diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 4273f3ccea..e6b5743264 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -76,8 +76,14 @@ def _apply_transformers_patches(self): from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import ( patch_prepare_from_posids, ) + from axolotl.monkeypatch.transformers.trainer_loss_calc import ( + patch_evaluation_loop, + patch_maybe_log_save_evaluate, + ) patch_prepare_from_posids() + patch_evaluation_loop() + patch_maybe_log_save_evaluate() def apply_post_model_load_patches(self, model: PreTrainedModel): """Apply patches that require the model instance.""" diff --git a/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py new file mode 100644 index 0000000000..f0d5e76928 --- /dev/null +++ b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py @@ -0,0 +1,143 @@ +""" +Module for patching transformers Trainer loss calculation to use nanmean. + +This is needed for context parallelism since chunks of the input sequences may be fully +masked and return NaNs in the loss calculation. +""" + +import importlib +import inspect + +from transformers import Trainer + +from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +# pylint: disable=protected-access +def patch_evaluation_loop(): + """Patch the evaluation_loop method.""" + # Check if already patched + if hasattr(Trainer, "_original_evaluation_loop"): + LOG.info("Trainer.evaluation_loop already patched") + return + + # Get the original method source + evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop) + Trainer._original_evaluation_loop = evaluation_loop_source + evaluation_loop_source, _ = detab_code(evaluation_loop_source) + + # Define the patterns to replace + original_list_pattern = 'metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()' + patched_list_pattern = 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(np.concatenate(all_losses)).item()' + + original_array_pattern = ( + 'metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()' + ) + patched_array_pattern = ( + 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()' + ) + + # Check if the patterns exist + if ( + original_list_pattern not in evaluation_loop_source + or original_array_pattern not in evaluation_loop_source + ): + LOG.warning( + "Original loss calculation patterns not found in Trainer.evaluation_loop" + ) + return + + # Apply the patches + evaluation_loop_source = evaluation_loop_source.replace( + original_list_pattern, patched_list_pattern + ) + evaluation_loop_source = evaluation_loop_source.replace( + original_array_pattern, patched_array_pattern + ) + + # Rename the function to avoid conflicts + evaluation_loop_source = evaluation_loop_source.replace( + "def evaluation_loop(", + "def axolotl_evaluation_loop(", + 1, + ) + + # Get the module for necessary imports + module_name = Trainer.__module__ + module = importlib.import_module(module_name) + + # Import necessary items from the module + items_to_import = [] + for item in dir(module): + if item in evaluation_loop_source: + items_to_import.append(item) + + # Execute the imports and patched method + exec( # pylint: disable=exec-used # nosec B102 + f"from {module_name} import ({', '.join(items_to_import)})", + globals(), + ) + exec(evaluation_loop_source, globals()) # pylint: disable=exec-used # nosec B102 + + LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation") + Trainer.evaluation_loop = ( + axolotl_evaluation_loop # pylint: disable=undefined-variable # noqa: F821 + ) + + +# pylint: disable=protected-access +def patch_maybe_log_save_evaluate(): + """Patch the _maybe_log_save_evaluate method.""" + # Check if already patched + if hasattr(Trainer, "_original_maybe_log_save_evaluate"): + LOG.info("Trainer._maybe_log_save_evaluate already patched") + return + + # Get the original method source + maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate) + Trainer._original_maybe_log_save_evaluate = maybe_log_source + maybe_log_source, _ = detab_code(maybe_log_source) + + # Define the pattern to replace + original_pattern = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()" + patched_pattern = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()" + + # Check if the pattern exists + if original_pattern not in maybe_log_source: + LOG.warning( + "Original tr_loss_scalar pattern not found in Trainer._maybe_log_save_evaluate" + ) + return + + # Apply the patch + maybe_log_source = maybe_log_source.replace(original_pattern, patched_pattern) + + # Rename the function to avoid conflicts + maybe_log_source = maybe_log_source.replace( + "def _maybe_log_save_evaluate(", + "def axolotl_maybe_log_save_evaluate(", + 1, + ) + + # Get the module for necessary imports + module_name = Trainer.__module__ + module = importlib.import_module(module_name) + + # Import necessary items from the module + items_to_import = [] + for item in dir(module): + if item in maybe_log_source: + items_to_import.append(item) + + # Execute the imports and patched method + exec( # pylint: disable=exec-used # nosec B102 + f"from {module_name} import ({', '.join(items_to_import)})", + globals(), + ) + exec(maybe_log_source, globals()) # pylint: disable=exec-used # nosec B102 + + LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation") + Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate # pylint: disable=undefined-variable # noqa: F821 diff --git a/tests/monkeypatch/test_trainer_loss_calc.py b/tests/monkeypatch/test_trainer_loss_calc.py new file mode 100644 index 0000000000..d56fb5517d --- /dev/null +++ b/tests/monkeypatch/test_trainer_loss_calc.py @@ -0,0 +1,38 @@ +"""Unit tests for trainer loss calc monkeypatch.""" + +import unittest + +from transformers import Trainer + +from axolotl.monkeypatch.transformers.trainer_loss_calc import ( + patch_evaluation_loop, + patch_maybe_log_save_evaluate, +) + + +class TestTrainerLossCalc(unittest.TestCase): + """ + Unit test class for trainer loss calc monkeypatch + """ + + def test_patch_evaluation_loop_applies(self): + """ + Test that patch_evaluation_loop applies successfully + """ + # Ensure we start with a clean state + if hasattr(Trainer, "_original_evaluation_loop"): + delattr(Trainer, "_original_evaluation_loop") + + patch_evaluation_loop() + self.assertTrue(hasattr(Trainer, "_original_evaluation_loop")) + + def test_patch_maybe_log_save_evaluate_applies(self): + """ + Test that patch_maybe_log_save_evaluate applies successfully + """ + # Ensure we start with a clean state + if hasattr(Trainer, "_original_maybe_log_save_evaluate"): + delattr(Trainer, "_original_maybe_log_save_evaluate") + + patch_maybe_log_save_evaluate() + self.assertTrue(hasattr(Trainer, "_original_maybe_log_save_evaluate")) From e2fea627456ecf94a108b4b50c11b383e2c8bae4 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 7 Aug 2025 21:10:37 +0000 Subject: [PATCH 2/7] use regular asserts --- tests/monkeypatch/test_trainer_loss_calc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/monkeypatch/test_trainer_loss_calc.py b/tests/monkeypatch/test_trainer_loss_calc.py index d56fb5517d..ef71b9089f 100644 --- a/tests/monkeypatch/test_trainer_loss_calc.py +++ b/tests/monkeypatch/test_trainer_loss_calc.py @@ -24,7 +24,7 @@ def test_patch_evaluation_loop_applies(self): delattr(Trainer, "_original_evaluation_loop") patch_evaluation_loop() - self.assertTrue(hasattr(Trainer, "_original_evaluation_loop")) + assert hasattr(Trainer, "_original_evaluation_loop") def test_patch_maybe_log_save_evaluate_applies(self): """ @@ -35,4 +35,4 @@ def test_patch_maybe_log_save_evaluate_applies(self): delattr(Trainer, "_original_maybe_log_save_evaluate") patch_maybe_log_save_evaluate() - self.assertTrue(hasattr(Trainer, "_original_maybe_log_save_evaluate")) + assert hasattr(Trainer, "_original_maybe_log_save_evaluate") From 7222c590daa340b35295f9fdac97ccf5c47cfb99 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 7 Aug 2025 21:54:53 +0000 Subject: [PATCH 3/7] small changes to make tests isolate --- .../transformers/trainer_loss_calc.py | 74 +++++++++---------- tests/monkeypatch/test_trainer_loss_calc.py | 30 +++----- 2 files changed, 44 insertions(+), 60 deletions(-) diff --git a/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py index f0d5e76928..f353f8c978 100644 --- a/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py +++ b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py @@ -15,6 +15,23 @@ LOG = get_logger(__name__) +ORIGINAL_EVAL_CODE = { + "list": 'metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()', + "array": 'metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()', +} +PATCHED_EVAL_CODE = { + "list": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(np.concatenate(all_losses)).item()', + "array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()', +} + +ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()" +PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()" + + +def check_evaluation_loop_is_patchable() -> bool: + evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop) + return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values()) + # pylint: disable=protected-access def patch_evaluation_loop(): @@ -24,38 +41,20 @@ def patch_evaluation_loop(): LOG.info("Trainer.evaluation_loop already patched") return - # Get the original method source - evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop) - Trainer._original_evaluation_loop = evaluation_loop_source - evaluation_loop_source, _ = detab_code(evaluation_loop_source) - - # Define the patterns to replace - original_list_pattern = 'metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()' - patched_list_pattern = 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(np.concatenate(all_losses)).item()' - - original_array_pattern = ( - 'metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()' - ) - patched_array_pattern = ( - 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()' - ) - # Check if the patterns exist - if ( - original_list_pattern not in evaluation_loop_source - or original_array_pattern not in evaluation_loop_source - ): - LOG.warning( - "Original loss calculation patterns not found in Trainer.evaluation_loop" - ) + try: + evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop) + except OSError: return + Trainer.evaluation = evaluation_loop_source + evaluation_loop_source, _ = detab_code(evaluation_loop_source) # Apply the patches evaluation_loop_source = evaluation_loop_source.replace( - original_list_pattern, patched_list_pattern + ORIGINAL_EVAL_CODE["list"], PATCHED_EVAL_CODE["list"] ) evaluation_loop_source = evaluation_loop_source.replace( - original_array_pattern, patched_array_pattern + ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"] ) # Rename the function to avoid conflicts @@ -88,6 +87,11 @@ def patch_evaluation_loop(): ) +def check_maybe_log_save_evaluate_is_patchable() -> bool: + maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate) + return ORIGINAL_MAYBE_CODE in maybe_log_source + + # pylint: disable=protected-access def patch_maybe_log_save_evaluate(): """Patch the _maybe_log_save_evaluate method.""" @@ -96,24 +100,16 @@ def patch_maybe_log_save_evaluate(): LOG.info("Trainer._maybe_log_save_evaluate already patched") return - # Get the original method source - maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate) + # Check if the patterns exist + try: + maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate) + except OSError: + return Trainer._original_maybe_log_save_evaluate = maybe_log_source maybe_log_source, _ = detab_code(maybe_log_source) - # Define the pattern to replace - original_pattern = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()" - patched_pattern = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()" - - # Check if the pattern exists - if original_pattern not in maybe_log_source: - LOG.warning( - "Original tr_loss_scalar pattern not found in Trainer._maybe_log_save_evaluate" - ) - return - # Apply the patch - maybe_log_source = maybe_log_source.replace(original_pattern, patched_pattern) + maybe_log_source = maybe_log_source.replace(ORIGINAL_MAYBE_CODE, PATCHED_MAYBE_CODE) # Rename the function to avoid conflicts maybe_log_source = maybe_log_source.replace( diff --git a/tests/monkeypatch/test_trainer_loss_calc.py b/tests/monkeypatch/test_trainer_loss_calc.py index ef71b9089f..c72cb621b3 100644 --- a/tests/monkeypatch/test_trainer_loss_calc.py +++ b/tests/monkeypatch/test_trainer_loss_calc.py @@ -2,11 +2,9 @@ import unittest -from transformers import Trainer - from axolotl.monkeypatch.transformers.trainer_loss_calc import ( - patch_evaluation_loop, - patch_maybe_log_save_evaluate, + check_evaluation_loop_is_patchable, + check_maybe_log_save_evaluate_is_patchable, ) @@ -15,24 +13,14 @@ class TestTrainerLossCalc(unittest.TestCase): Unit test class for trainer loss calc monkeypatch """ - def test_patch_evaluation_loop_applies(self): + def test_trainer_loss_calc_is_patchable(self): """ - Test that patch_evaluation_loop applies successfully + Test that the upstream transformers code is still patchable. This will fail if + the patched code changes upstream. """ - # Ensure we start with a clean state - if hasattr(Trainer, "_original_evaluation_loop"): - delattr(Trainer, "_original_evaluation_loop") - - patch_evaluation_loop() - assert hasattr(Trainer, "_original_evaluation_loop") + assert check_evaluation_loop_is_patchable() + assert check_maybe_log_save_evaluate_is_patchable() - def test_patch_maybe_log_save_evaluate_applies(self): - """ - Test that patch_maybe_log_save_evaluate applies successfully - """ - # Ensure we start with a clean state - if hasattr(Trainer, "_original_maybe_log_save_evaluate"): - delattr(Trainer, "_original_maybe_log_save_evaluate") - patch_maybe_log_save_evaluate() - assert hasattr(Trainer, "_original_maybe_log_save_evaluate") +if __name__ == "__main__": + unittest.main() From c3dac65dd1876f4ba2f92b18a09ebc910c2a400a Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 7 Aug 2025 23:44:42 +0000 Subject: [PATCH 4/7] combining evaluation_loop patches --- src/axolotl/loaders/patch_manager.py | 8 ++++- .../transformers/trainer_loss_calc.py | 29 +++++++++++++++++-- src/axolotl/utils/trainer.py | 3 -- 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index e6b5743264..7c56b32e15 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -81,8 +81,14 @@ def _apply_transformers_patches(self): patch_maybe_log_save_evaluate, ) + patch_fsdp2 = ( + self.cfg.torch_compile + and self.cfg.fsdp_config + and self.cfg.fsdp_version == 2 + ) + patch_prepare_from_posids() - patch_evaluation_loop() + patch_evaluation_loop(patch_fsdp2) patch_maybe_log_save_evaluate() def apply_post_model_load_patches(self, model: PreTrainedModel): diff --git a/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py index f353f8c978..bdc5ccdc44 100644 --- a/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py +++ b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py @@ -3,6 +3,10 @@ This is needed for context parallelism since chunks of the input sequences may be fully masked and return NaNs in the loss calculation. + +Also includes a patch for FSDP2 + torch.compile. We need to bundle this together with +the other evaluation_loop patch because we can't patch the same code twice without +raising an OSError. """ import importlib @@ -24,6 +28,15 @@ "array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()', } +ORIGINAL_FSDP2_CODE = """ + model.eval() +""" + +PATCHED_FSDP2_CODE = """ + if hasattr(model, "eval") and callable(model.eval): + self.model.eval() +""" + ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()" PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()" @@ -33,8 +46,13 @@ def check_evaluation_loop_is_patchable() -> bool: return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values()) +def check_evaluation_loop_is_fsdp2_patchable() -> bool: + evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop) + return ORIGINAL_FSDP2_CODE in evaluation_loop_source + + # pylint: disable=protected-access -def patch_evaluation_loop(): +def patch_evaluation_loop(patch_fsdp2: bool): """Patch the evaluation_loop method.""" # Check if already patched if hasattr(Trainer, "_original_evaluation_loop"): @@ -49,7 +67,7 @@ def patch_evaluation_loop(): Trainer.evaluation = evaluation_loop_source evaluation_loop_source, _ = detab_code(evaluation_loop_source) - # Apply the patches + # Apply the nanmean patches evaluation_loop_source = evaluation_loop_source.replace( ORIGINAL_EVAL_CODE["list"], PATCHED_EVAL_CODE["list"] ) @@ -57,6 +75,13 @@ def patch_evaluation_loop(): ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"] ) + # Apply FSDP2 eval guard patch if needed + if patch_fsdp2 and ORIGINAL_FSDP2_CODE in evaluation_loop_source: + evaluation_loop_source = evaluation_loop_source.replace( + ORIGINAL_FSDP2_CODE, PATCHED_FSDP2_CODE + ) + LOG.info("Applied FSDP2 eval guard patch to evaluation_loop") + # Rename the function to avoid conflicts evaluation_loop_source = evaluation_loop_source.replace( "def evaluation_loop(", diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 90ae1a8892..48fa6064c4 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -15,7 +15,6 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.utils import is_torch_bf16_gpu_available -from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2 from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.logging import get_logger @@ -667,8 +666,6 @@ def setup_trainer( """ from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder - if cfg.torch_compile and cfg.fsdp_config and cfg.fsdp_version == 2: - patch_evaluation_loop_for_fsdp2() if cfg.rl: trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor) trainer_builder.model_ref = model_ref From 2fac201660aa8f74f08e97c2da5a8d63e81c5ccb Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 7 Aug 2025 23:47:15 +0000 Subject: [PATCH 5/7] fix --- tests/monkeypatch/test_trainer_loss_calc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/monkeypatch/test_trainer_loss_calc.py b/tests/monkeypatch/test_trainer_loss_calc.py index c72cb621b3..de3e926210 100644 --- a/tests/monkeypatch/test_trainer_loss_calc.py +++ b/tests/monkeypatch/test_trainer_loss_calc.py @@ -3,6 +3,7 @@ import unittest from axolotl.monkeypatch.transformers.trainer_loss_calc import ( + check_evaluation_loop_is_fsdp2_patchable, check_evaluation_loop_is_patchable, check_maybe_log_save_evaluate_is_patchable, ) @@ -19,6 +20,7 @@ def test_trainer_loss_calc_is_patchable(self): the patched code changes upstream. """ assert check_evaluation_loop_is_patchable() + assert check_evaluation_loop_is_fsdp2_patchable() assert check_maybe_log_save_evaluate_is_patchable() From deee775f127ded1c1deb19116e2f98102a809f0d Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 8 Aug 2025 00:08:19 +0000 Subject: [PATCH 6/7] delete unused --- src/axolotl/monkeypatch/trainer_eval_guard.py | 78 ------------------- 1 file changed, 78 deletions(-) delete mode 100644 src/axolotl/monkeypatch/trainer_eval_guard.py diff --git a/src/axolotl/monkeypatch/trainer_eval_guard.py b/src/axolotl/monkeypatch/trainer_eval_guard.py deleted file mode 100644 index 8488a16df9..0000000000 --- a/src/axolotl/monkeypatch/trainer_eval_guard.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -fix for FSDP2 evals when using torch.compile -""" - -import inspect - -from transformers import Trainer - -from axolotl.monkeypatch.utils import detab_code -from axolotl.utils.logging import get_logger - -LOG = get_logger(__name__) - -ORIGINAL_TRAINER_CODE = """ - model.eval() -""" - -PATCHED_TRAINER_CODE = """ - if hasattr(model, "eval") and callable(model.eval): - self.model.eval() -""" - - -def get_evaluation_loop_code() -> str: - training_loop = inspect.getsource(Trainer.evaluation_loop) - return training_loop - - -def check_evaluation_loop_is_patchable() -> bool: - eval_loop = get_evaluation_loop_code() - eval_loop, _ = detab_code(eval_loop) - return ORIGINAL_TRAINER_CODE in eval_loop - - -def patch_evaluation_loop_for_fsdp2(): - """ - monkeypatch for fixing the eval loop for fsdp2 with torch.compile - """ - - try: - evaluation_loop = get_evaluation_loop_code() - except OSError: - return - Trainer._original_evaluation_loop = ( # pylint: disable=protected-access - evaluation_loop - ) - evaluation_loop, _ = detab_code(evaluation_loop) - if ORIGINAL_TRAINER_CODE not in evaluation_loop: - return - - evaluation_loop = evaluation_loop.replace( - ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE - ) - evaluation_loop = evaluation_loop.replace( - "def evaluation_loop(", - "def _fixed_evaluation_loop(", - 1, - ) - - # load imports necessary - import transformers.trainer - - items_to_import = [] - for item in dir(transformers.trainer): - if item in evaluation_loop: - items_to_import.append(item) - - exec( # pylint: disable=exec-used # nosec B102 - "from transformers.trainer import (" - + ", ".join(x for x in items_to_import) - + ")", - globals(), - ) - exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching _inner_training_loop for fsdp optimizer save") - Trainer.evaluation_loop = ( # pylint: disable=protected-access - _fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821 - ) From 4bef02d225b2a7634a075aa6ba728fbfb8948283 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 8 Aug 2025 00:09:30 +0000 Subject: [PATCH 7/7] fix check --- src/axolotl/monkeypatch/transformers/trainer_loss_calc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py index bdc5ccdc44..75f4158b3b 100644 --- a/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py +++ b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py @@ -48,6 +48,7 @@ def check_evaluation_loop_is_patchable() -> bool: def check_evaluation_loop_is_fsdp2_patchable() -> bool: evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop) + evaluation_loop_source, _ = detab_code(evaluation_loop_source) return ORIGINAL_FSDP2_CODE in evaluation_loop_source